In [1]:
import sys, os
sys.path.append(os.pardir)  # 부모 디렉터리의 파일을 가져올 수 있도록 설정


In [2]:
import numpy as np
import matplotlib.pyplot as plt
from dataset.mnist import load_mnist
from convnet import SimpleConvNet
from convnet_trainer import Trainer

In [3]:
# 데이터 읽기
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=False)

# 시간이 오래 걸릴 경우 데이터를 줄인다.
x_train, t_train = x_train[:5000], t_train[:5000]
x_test, t_test = x_test[:1000], t_test[:1000]

In [4]:
max_epochs = 10
network = SimpleConvNet(input_dim=(1,28,28), 
                        conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1},
                        hidden_size=100, output_size=10, weight_init_std=0.01)

In [5]:
trainer = Trainer(network, x_train, t_train, x_test, t_test,
                  epochs=max_epochs, mini_batch_size=100,
                  optimizer='Adam', optimizer_param={'lr': 0.001},
                  evaluate_sample_num_per_epoch=1000)

In [None]:
trainer.train()

train loss:2.2986278355055685
=== epoch:1, train acc:0.16, test acc:0.176 ===
train loss:2.2967684921350076
train loss:2.292636430251326
train loss:2.2857031691187957
train loss:2.273042602632095
train loss:2.2612208034560872
train loss:2.2428201825654908
train loss:2.2290089572554814
train loss:2.1952328830050947
train loss:2.181603028825661
train loss:2.1404179724006114
train loss:2.0850016739576622
train loss:2.0538051886430257
train loss:1.984189854188697
train loss:1.8951271733319224
train loss:1.8636593583836731
train loss:1.7139757451598086
train loss:1.6765744252357448
train loss:1.703875862115708
train loss:1.625084651274087
train loss:1.4733847711689276
train loss:1.4036663101378142
train loss:1.2820058629171385
train loss:1.2840504022164803
train loss:1.1299392642916368
train loss:1.2274172299387707
train loss:0.9949860279854374
train loss:1.0619335479022596
train loss:0.8012391662974678
train loss:0.8854279259609218
train loss:0.8086104425750462
train loss:0.749075129625293

train loss:0.3073865561122446
train loss:0.09536844815607998
train loss:0.13436202057529859
train loss:0.1347814810175069
train loss:0.17442581205240654
train loss:0.23900018411928264
train loss:0.30693851595412436
train loss:0.13634430107098017
train loss:0.07548599379590006
train loss:0.18494640545030383
train loss:0.12003398835350666
train loss:0.19363866010046238
train loss:0.14912288312423252
train loss:0.253388459407464
train loss:0.29614066693665364
train loss:0.20754683081889305
train loss:0.16372332311842044
train loss:0.24359620590611988
train loss:0.22156568741056526
train loss:0.1400610291103936
train loss:0.17367035915959478
train loss:0.2374231481817537
train loss:0.24462834947008397
train loss:0.20909259766001906
train loss:0.2894543770345303
train loss:0.28268466481291027
train loss:0.20693076933922513
train loss:0.15410254354475073
train loss:0.17539224121981636
train loss:0.14678727617561477
train loss:0.17922576046866304
train loss:0.10095368377509054
train loss:0.14

In [None]:
# 매개변수 보존
"""
network.save_params("params.pkl")
print("Saved Network Parameters!")
"""

In [None]:
# 그래프 그리기
markers = {'train': 'o', 'test': 's'}
x = np.arange(max_epochs)
plt.plot(x, trainer.train_acc_list, marker='o', label='train', markevery=2)
plt.plot(x, trainer.test_acc_list, marker='s', label='test', markevery=2)
plt.xlabel("epochs")
plt.ylabel("accuracy")
plt.ylim(0, 1.0)
plt.legend(loc='lower right')
plt.show()