In [None]:
import os, sys

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)

sys.path.append(parent_dir)

import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from ch06.simple_convnet import SimpleConvNet
from common.trainer import Trainer

# Load MNIST dataset
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=False)

x_train, t_train = x_train[:5000], t_train[:5000]  # Use a subset for faster training
x_test, t_test = x_test[:1000], t_test[:1000]  # Use a subset for testing

max_epoch = 20

network = SimpleConvNet(
    input_dim=(1, 28, 28),
    conv_param={
        "filter_num": 50,
        "filter_size": 5,
        "pad": 0,
        "stride": 1,
    },
    hidden_size=200,
    output_size=10,
    weight_init_std=0.05,
)

trainer = Trainer(
    network,
    x_train,
    t_train,
    x_test,
    t_test,
    epochs=max_epoch,
    mini_batch_size=100,
    optimizer="Adam",
    optimizer_param={"lr": 0.002},
    evaluate_sample_num_per_epoch=1000,
)

# Train the network
trainer.train()

network.save_params("params.pkl")

# Plot training and validation accuracy for the best model
markers = {"train": "o", "test": "s"}

x = np.arange(max_epoch)
plt.plot(x, trainer.train_acc_list, label="train acc", marker="o", markevery=2)
plt.plot(
    x,
    trainer.test_acc_list,
    label="test acc",
    marker="s",
    markevery=2,
)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc="lower right")
plt.show()