In [19]:
import os
import sys

sys.path.append(os.path.abspath('..'))

import numpy as np
import matplotlib.pyplot as plt
import idx2numpy
import pandas as pd
from lib.Net import *
from lib.Func import *

In [20]:
architecture_2 = [
    {"module": Flatten},
    {"module": Mlp, "params": {"input_dim": 784, "output_dim": 256, "activation": "relu"}},
    {"module": Mlp, "params": {"input_dim": 256, "output_dim": 64, "activation": "relu"}},
    {"module": Mlp, "params": {"input_dim": 64, "output_dim": 10, "activation": "softmax"}},
]

In [21]:
net = Net(architecture_2)

In [22]:
datapath = "../data/MNIST/train-images.idx3-ubyte"
labelpath = "../data/MNIST/train-labels.idx1-ubyte"
modelpath_2 = "../model/task1/MNIST.json"

In [23]:
data = idx2numpy.convert_from_file(datapath)    # (60000, 28, 28)
data = np.expand_dims(data, axis=1)             # 添加通道维度 -> (60000, 1, 28, 28)

label = idx2numpy.convert_from_file(labelpath)
one_hot_labels = one_hot(label, 10)

In [24]:
epochs = 20
pbar =tqdm(range(epochs))

for i in pbar:
    data, one_hot_labels = shuffle(data, one_hot_labels)
    net.train(data, one_hot_labels, batch_size=64, lr=0.1, lossfunc="cross_entropy")

    y_hat = net.predict(data)
    loss = cross_entropy_loss(one_hot_labels, y_hat)
    y_hat = np.argmax(y_hat, axis=1)    # (60000, 10) -> (60000,)
    accuracy = np.mean(y_hat == one_hot_labels.argmax(axis=1))
    pbar.set_postfix({"loss": loss, "accuracy": f"{accuracy*100:.2f}%"})

100%|██████████| 20/20 [01:56<00:00,  5.81s/it, loss=0.0214, accuracy=99.82%]


In [None]:
#net.save_params(modelpath_2)

Model parameters saved to ../model/task1/MNIST.json


In [16]:
net.load_params(modelpath_2)

Model parameters loaded from ../model/task1/MNIST.json


In [25]:
test_datapath = "../test/MNIST/t10k-images.idx3-ubyte"
test_labelpath = "../test/MNIST/t10k-labels.idx1-ubyte"

In [27]:
test_data = idx2numpy.convert_from_file(test_datapath)    # (60000, 28, 28)
test_data = np.expand_dims(test_data, axis=1)             # 添加通道维度 -> (60000, 1, 28, 28)

test_label = idx2numpy.convert_from_file(test_labelpath)

y_hat = net.predict(test_data)

accuracy = np.mean(np.argmax(y_hat, axis=1) == test_label)
print(f"Test accuracy: {accuracy * 100:.2f}%")

Test accuracy: 97.92%
