In [None]:
X_train, Y_train, X_test, Y_test = load_fashion_mnist()
val_size = int(0.1 * X_train.shape[0])
X_val, Y_val = X_train[:val_size], Y_train[:val_size]
X_train_s, Y_train_s = X_train[val_size:], Y_train[val_size:]

In [None]:
sweep_config = {
    "method": "bayes",
    "metric": {"name": "val_accuracy", "goal": "maximize"},
    "parameters": {
        "epochs": {"values": [5, 10]},
        "num_layers": {"values": [3, 4, 5]},
        "hidden_size": {"values": [32, 64, 128]},
        "learning_rate": {"values": [1e-3, 1e-4]},
        "optimizer": {
            "values": ["sgd", "momentum", "nesterov", "rmsprop", "adam", "nadam"]
        },
        "batch_size": {"values": [16, 32, 64]},
    }
}

In [None]:
def sweep_train():
    run = wandb.init()
    config = wandb.config
    hidden_layers = [config.hidden_size] * config.num_layers
    model = FNN(
        input_size=X_train_s.shape[1],
        hidden_layers=hidden_layers,
        output_size=10
    )
    run.name = f"hl_{config.num_layers}_bs_{config.batch_size}_opt_{config.optimizer}"
    for epoch in range(config.epochs):
        perm = np.random.permutation(X_train_s.shape[0])
        X_shuf, Y_shuf = X_train_s[perm], Y_train_s[perm]
        for i in range(0, X_shuf.shape[0], config.batch_size):
            xb = X_shuf[i:i+config.batch_size]
            yb = Y_shuf[i:i+config.batch_size]
            preds = model.forward(xb)
            model.backward(yb)
            model.update_parameters(
                lr=config.learning_rate,
                optimizer=config.optimizer,
                t=epoch+1
            )
        train_preds = model.forward(X_train_s)
        train_loss = model.compute_loss(train_preds, Y_train_s)
        train_acc = np.mean(
        np.argmax(train_preds, axis=1) == np.argmax(Y_train_s, axis=1)
        )
        val_preds = model.forward(X_val)
        val_loss = model.compute_loss(val_preds, Y_val)
        val_acc = np.mean(
        np.argmax(val_preds, axis=1) == np.argmax(Y_val, axis=1)
        )
        wandb.log({
            "epoch": epoch,
            "loss": train_loss,
            "accuracy": train_acc,
            "val_loss": val_loss,
            "val_accuracy": val_acc
        })

In [None]:
sweep_id = wandb.sweep(sweep_config, project="fashion-mnist-q4")
wandb.agent(sweep_id, function=sweep_train, count=20)