In [None]:
import os

In [None]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

In [None]:
import wandb
import numpy as np
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.utils import to_categorical
from tensorflow import keras as k

In [None]:
wandb.login()

In [None]:
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint
import numpy as np


# --- custom callbacks ---------------------------------------------------------

class LogLRCallback(k.callbacks.Callback):
    """Log optimizer learning rate each epoch."""
    def on_epoch_end(self, epoch, logs=None):
        opt = self.model.optimizer
        lr = opt.learning_rate
        lr_val = float(lr.numpy() if hasattr(lr, "numpy") else lr)
        wandb.log({"lr": lr_val}, step=self.model.optimizer.iterations.numpy())

class LogSamplesCallback(k.callbacks.Callback):
    """Log a small table of predictions + images every epoch."""
    def __init__(self, x, y, labels, max_rows=32):
        super().__init__()
        self.x = x[:max_rows]
        self.y = y[:max_rows]
        self.labels = labels

    def on_epoch_end(self, epoch, logs=None):
        preds = self.model.predict(self.x, verbose=0)
        y_true = np.argmax(self.y, axis=1)
        y_pred = np.argmax(preds, axis=1)

        table = wandb.Table(columns=["image", "y_true", "y_pred", "correct", "p(y_pred)"])
        for i in range(len(self.x)):
            img = self.x[i].squeeze()
            table.add_data(
                wandb.Image(img),
                self.labels[y_true[i]],
                self.labels[y_pred[i]],
                bool(y_true[i] == y_pred[i]),
                float(np.max(preds[i])),
            )
        wandb.log({f"samples/epoch_{epoch+1}": table})

class ConfusionMatrixCallback(k.callbacks.Callback):
    """Log a confusion matrix from the full validation set each epoch."""
    def __init__(self, x_val, y_val, labels):
        super().__init__()
        self.x_val = x_val
        self.y_val = y_val
        self.labels = labels

    def on_epoch_end(self, epoch, logs=None):
        preds = self.model.predict(self.x_val, verbose=0)
        y_true = np.argmax(self.y_val, axis=1)
        y_pred = np.argmax(preds, axis=1)
        cm_plot = wandb.plot.confusion_matrix(
            probs=None,
            y_true=y_true,
            preds=y_pred,
            class_names=self.labels,
        )
        wandb.log({"confusion_matrix": cm_plot})

# --- trainer -----------------------------------------------------------------

class FashionMNISTTrainer:
    def __init__(self, project_name="Lab1-visualize-models", run_name="neural_network_plus"):
        self.cfg = dict(
            dropout=0.2,
            layer_1_size=32,
            learn_rate=0.01,   # no decay now
            momentum=0.9,
            epochs=5,
            batch_size=64,
            sample=10000,
        )
        self.run = wandb.init(
            project=project_name,
            name=run_name,
            config=self.cfg,
            settings=wandb.Settings(start_method="thread"),
        )
        self.config = wandb.config
        self.labels = ["T-shirt/top","Trouser","Pullover","Dress","Coat",
                       "Sandal","Shirt","Sneaker","Bag","Ankle boot"]
        self._prepare_data()

    def _prepare_data(self):
        (xtr, ytr), (xte, yte) = fashion_mnist.load_data()
        n = self.config.sample
        xtr = xtr[:n].astype("float32")/255.0
        ytr = ytr[:n]
        xte = xte[:n].astype("float32")/255.0
        yte = yte[:n]
        self.X_train = xtr[..., None]
        self.X_test  = xte[..., None]
        self.y_train = to_categorical(ytr)
        self.y_test  = to_categorical(yte)
        self.num_classes = self.y_test.shape[1]

    def _build_model(self):
        inputs = k.Input(shape=(28,28,1))
        x = k.layers.Conv2D(self.config.layer_1_size, (5,5), activation="relu")(inputs)
        x = k.layers.MaxPooling2D((2,2))(x)
        x = k.layers.Dropout(self.config.dropout)(x)
        x = k.layers.Flatten()(x)
        outputs = k.layers.Dense(self.num_classes, activation="softmax")(x)
        model = k.Model(inputs, outputs)

        opt = k.optimizers.SGD(
            learning_rate=self.config.learn_rate,  # no decay
            momentum=self.config.momentum,
            nesterov=True,
        )
        model.compile(loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
        return model

    def _log_model_artifact(self, model):
        # model summary as a text file + the saved model as an artifact
        summary_lines = []
        model.summary(print_fn=summary_lines.append)
        summary_txt = "\n".join(summary_lines)
        os.makedirs("artifacts", exist_ok=True)
        with open("artifacts/model_summary.txt", "w") as f:
            f.write(summary_txt)

        model_path = "artifacts/model.h5"
        model.save(model_path)

        art = wandb.Artifact("fashion_mnist_model", type="model")
        art.add_file("artifacts/model_summary.txt")
        art.add_file(model_path)
        self.run.log_artifact(art)

    def train(self):
        model = self._build_model()

        callbacks = [
            WandbMetricsLogger(log_freq=10),
            WandbModelCheckpoint("checkpoints/model-{epoch:02d}.h5", save_weights_only=False),
            LogLRCallback(),
            LogSamplesCallback(self.X_test, self.y_test, self.labels, max_rows=32),
            ConfusionMatrixCallback(self.X_test, self.y_test, self.labels),
        ]

        model.fit(
            self.X_train, self.y_train,
            validation_data=(self.X_test, self.y_test),
            epochs=self.config.epochs,
            batch_size=self.config.batch_size,
            callbacks=callbacks,
            verbose=1,
        )

        loss, acc = model.evaluate(self.X_test, self.y_test, verbose=0)
        wandb.log({"final/loss": loss, "final/accuracy": acc})

        self._log_model_artifact(model)

        self.run.finish()


FashionMNISTTrainer().train()
