In [1]:
import os
import sys

sys.path.append(os.path.abspath('..'))

import numpy as np
import matplotlib.pyplot as plt
import idx2numpy
import pandas as pd
from lib.Net import *
from lib.Func import *

In [2]:
architecture_2 = [
    {"module": Flatten},
    {"module": Mlp, "params": {"input_dim": 784, "output_dim": 256, "activation": "relu"}},
    {"module": Mlp, "params": {"input_dim": 256, "output_dim": 64, "activation": "relu"}},
    {"module": Mlp, "params": {"input_dim": 64, "output_dim": 10, "activation": "softmax"}},
]

In [3]:
net = Net(architecture_2)

In [4]:
datapath = "../data/MNIST/train-images.idx3-ubyte"
labelpath = "../data/MNIST/train-labels.idx1-ubyte"
modelpath_2 = "../model/task1/MNIST.json"

In [5]:
data = idx2numpy.convert_from_file(datapath)    # (60000, 28, 28)
data = np.expand_dims(data, axis=1)             # 添加通道维度 -> (60000, 1, 28, 28)

label = idx2numpy.convert_from_file(labelpath)
one_hot_labels = one_hot(label, 10)

In [6]:
epochs = 20
pbar =tqdm(range(epochs))

for i in pbar:
    data, one_hot_labels = shuffle(data, one_hot_labels)
    net.train(data, one_hot_labels, batch_size=64, lr=0.1, lossfunc="cross_entropy")

    y_hat = net.predict(data)
    loss = cross_entropy_loss(one_hot_labels, y_hat)
    y_hat = np.argmax(y_hat, axis=1)    # (60000, 10) -> (60000,)
    accuracy = np.mean(y_hat == one_hot_labels.argmax(axis=1))
    pbar.set_postfix({"loss": loss, "accuracy": f"{accuracy*100:.2f}%"})

  0%|          | 0/20 [00:00<?, ?it/s]

Batch 1/937, Loss: 2.3255
Batch 2/937, Loss: 2.2961
Batch 3/937, Loss: 2.2959
Batch 4/937, Loss: 2.2906
Batch 5/937, Loss: 2.2912
Batch 6/937, Loss: 2.2702
Batch 7/937, Loss: 2.2574
Batch 8/937, Loss: 2.2081
Batch 9/937, Loss: 2.2360
Batch 10/937, Loss: 2.2175
Batch 11/937, Loss: 2.1932
Batch 12/937, Loss: 2.1638
Batch 13/937, Loss: 2.2118
Batch 14/937, Loss: 2.1738
Batch 15/937, Loss: 2.1590
Batch 16/937, Loss: 2.2016
Batch 17/937, Loss: 2.1284
Batch 18/937, Loss: 2.1851
Batch 19/937, Loss: 2.1243
Batch 20/937, Loss: 2.1060
Batch 21/937, Loss: 2.1323
Batch 22/937, Loss: 2.0995
Batch 23/937, Loss: 2.1106
Batch 24/937, Loss: 2.0226
Batch 25/937, Loss: 2.0548
Batch 26/937, Loss: 2.0217
Batch 27/937, Loss: 1.9414
Batch 28/937, Loss: 2.0035
Batch 29/937, Loss: 1.9915
Batch 30/937, Loss: 1.9485
Batch 31/937, Loss: 2.0046
Batch 32/937, Loss: 1.9189
Batch 33/937, Loss: 1.8878
Batch 34/937, Loss: 1.8859
Batch 35/937, Loss: 1.8636
Batch 36/937, Loss: 1.8296
Batch 37/937, Loss: 1.8016
Batch 38/9

  5%|▌         | 1/20 [00:07<02:29,  7.89s/it, loss=0.435, accuracy=85.38%]

Batch 1/937, Loss: 0.4593
Batch 2/937, Loss: 0.1792
Batch 3/937, Loss: 0.2971
Batch 4/937, Loss: 0.2022
Batch 5/937, Loss: 0.3013
Batch 6/937, Loss: 0.1063
Batch 7/937, Loss: 0.1482
Batch 8/937, Loss: 0.2674
Batch 9/937, Loss: 0.0802
Batch 10/937, Loss: 0.2339
Batch 11/937, Loss: 0.2024
Batch 12/937, Loss: 0.2265
Batch 13/937, Loss: 0.3560
Batch 14/937, Loss: 0.2465
Batch 15/937, Loss: 0.2403
Batch 16/937, Loss: 0.0868
Batch 17/937, Loss: 0.3063
Batch 18/937, Loss: 0.1414
Batch 19/937, Loss: 0.2446
Batch 20/937, Loss: 0.1494
Batch 21/937, Loss: 0.3188
Batch 22/937, Loss: 0.2992
Batch 23/937, Loss: 0.1015
Batch 24/937, Loss: 0.1724
Batch 25/937, Loss: 0.1505
Batch 26/937, Loss: 0.2181
Batch 27/937, Loss: 0.2743
Batch 28/937, Loss: 0.3514
Batch 29/937, Loss: 0.2019
Batch 30/937, Loss: 0.3589
Batch 31/937, Loss: 0.3152
Batch 32/937, Loss: 0.3490
Batch 33/937, Loss: 0.2253
Batch 34/937, Loss: 0.2422
Batch 35/937, Loss: 0.1157
Batch 36/937, Loss: 0.2195
Batch 37/937, Loss: 0.2850
Batch 38/9

  5%|▌         | 1/20 [00:12<03:59, 12.60s/it, loss=0.435, accuracy=85.38%]

Batch 797/937, Loss: 0.1265
Batch 798/937, Loss: 0.1909
Batch 799/937, Loss: 0.0595
Batch 800/937, Loss: 0.3965
Batch 801/937, Loss: 0.2824
Batch 802/937, Loss: 0.0986
Batch 803/937, Loss: 0.0757
Batch 804/937, Loss: 0.1242
Batch 805/937, Loss: 0.2410
Batch 806/937, Loss: 0.1760
Batch 807/937, Loss: 0.2600
Batch 808/937, Loss: 0.0786
Batch 809/937, Loss: 0.1607





KeyboardInterrupt: 

In [None]:
#net.save_params(modelpath_2)

Model parameters saved to ../model/task1/MNIST.json


In [None]:
net.load_params(modelpath_2)

Model parameters loaded from ../model/task1/MNIST.json


In [None]:
test_datapath = "../test/MNIST/t10k-images.idx3-ubyte"
test_labelpath = "../test/MNIST/t10k-labels.idx1-ubyte"

In [None]:
test_data = idx2numpy.convert_from_file(test_datapath)    # (60000, 28, 28)
test_data = np.expand_dims(test_data, axis=1)             # 添加通道维度 -> (60000, 1, 28, 28)

test_label = idx2numpy.convert_from_file(test_labelpath)

y_hat = net.predict(test_data)

accuracy = np.mean(np.argmax(y_hat, axis=1) == test_label)
print(f"Test accuracy: {accuracy * 100:.2f}%")

Test accuracy: 97.92%
