In [None]:
import os, sys

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)

sys.path.append(parent_dir)

import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from ch06.simple_convnet import SimpleConvNet
from common.trainer import Trainer

# Load MNIST dataset
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=False)

x_train, t_train = x_train[:5000], t_train[:5000]  # Use a subset for faster training
x_test, t_test = x_test[:1000], t_test[:1000]  # Use a subset for testing

max_epoch = 20

# Define hyperparameter search space
filter_nums = [30, 40, 50, 60]
filter_sizes = [5, 10, 15, 20, 25]
hidden_sizes = [100, 200, 300]
learning_rates = [0.001, 0.002, 0.003, 0.004, 0.005]
weight_init_stds = [0.01, 0.015, 0.002, 0.025, 0.03]

best_test_acc = 0.0
best_hparams = None

for filter_num in filter_nums:
    for filter_size in filter_sizes:
        for hidden_size in hidden_sizes:
            for lr in learning_rates:
                for weight_init_std in weight_init_stds:
                    network = SimpleConvNet(
                        input_dim=(1, 28, 28),
                        conv_param={
                            "filter_num": filter_num,
                            "filter_size": filter_size,
                            "pad": 0,
                            "stride": 1,
                        },
                        hidden_size=hidden_size,
                        output_size=10,
                        weight_init_std=weight_init_std,
                    )

                    trainer = Trainer(
                        network,
                        x_train,
                        t_train,
                        x_test,
                        t_test,
                        epochs=max_epoch,
                        mini_batch_size=100,
                        optimizer="Adam",
                        optimizer_param={"lr": lr},
                        evaluate_sample_num_per_epoch=1000,
                    )

                    # Train the network
                    trainer.train()

                    test_acc = max(trainer.test_acc_list)
                    print(
                        f"Filter Num: {filter_num}, Filter Size: {filter_size}, Hidden Size: {hidden_size}, LR: {lr}, Weight Init Std: {weight_init_std:.4f}, Test Acc: {test_acc:.4f}"
                    )

                    if test_acc > best_test_acc:
                        best_test_acc = test_acc
                        best_hparams = {
                            "filter_num": filter_num,
                            "filter_size": filter_size,
                            "hidden_size": hidden_size,
                            "learning_rate": lr,
                            "weight_init_std": weight_init_std,
                        }

# Save the best model
best_network = SimpleConvNet(
    input_dim=(1, 28, 28),
    conv_param={
        "filter_num": best_hparams["filter_num"],
        "filter_size": best_hparams["filter_size"],
        "pad": 0,
        "stride": 1,
    },
    hidden_size=best_hparams["hidden_size"],
    output_size=10,
    weight_init_std=best_hparams["weight_init_std"],
)

best_trainer = Trainer(
    best_network,
    x_train,
    t_train,
    x_test,
    t_test,
    epochs=max_epoch,
    mini_batch_size=100,
    optimizer="Adam",
    optimizer_param={"lr": best_hparams["learning_rate"]},
    evaluate_sample_num_per_epoch=1000,
)

best_trainer.train()

best_network.save_params("params.pkl")

# Plot training and validation accuracy for the best model
markers = {"train": "o", "test": "s"}

x = np.arange(max_epoch)
plt.plot(x, best_trainer.train_acc_list, label="train acc", marker="o", markevery=2)
plt.plot(
    x,
    best_trainer.test_acc_list,
    label="test acc",
    marker="s",
    markevery=2,
)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc="lower right")
plt.show()

print(f"Best Test Accuracy: {best_test_acc:.4f}")
print("Best Hyperparameters:")
for key, value in best_hparams.items():
    print(f"{key.capitalize().replace('_', ' ')}: {value}")