In [1]:
import numpy as np
import chainer
import chainer.functions as F
import chainer.links as L
import chainer.initializers as I
from chainer import training
from chainer.training import extensions

In [6]:
class MyChain(chainer.Chain):
    def __init__(self):
        super(MyChain, self).__init__()
        with self.init_scope():
            self.conv1 = L.Convolution2D(1, 4, 3, 1, 1)
            self.conv2 = L.Convolution2D(4, 8, 3, 1, 1)
            self.l3 = L.Linear(None, 10)
    def __call__(self, x):
        h1 = F.max_pooling_2d(F.relu(self.conv1(x)), ksize=2, stride=2)
        h2 = F.max_pooling_2d(F.relu(self.conv2(h1)), ksize=2, stride=2)
        return self.l3(h2)

In [4]:
epoch = 20
batchsize = 100

In [5]:
train, test = chainer.datasets.get_mnist(ndim=3)

Downloading from http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz...
Downloading from http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz...


In [8]:
model = L.Classifier(MyChain(), lossfun=F.softmax_cross_entropy)
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

<chainer.optimizers.adam.Adam at 0x7f0995613400>

In [10]:
train_iter = chainer.iterators.SerialIterator(train, batchsize)
test_iter = chainer.iterators.SerialIterator(test, batchsize, repeat=False, shuffle=False)

In [11]:
updater = training.StandardUpdater(train_iter, optimizer)

In [13]:
trainer = training.Trainer(updater, (epoch,'epoch'))
trainer.extend(extensions.Evaluator(test_iter, model))
trainer.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch',file_name='loss.png'))
trainer.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'],'epoch', file_name='accuracy.png'))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss','main/accuracy', 'validation/main/accuracy', 'elapsed_time'] ))
trainer.extend(extensions.snapshot(), trigger=(10, 'epoch'))

In [14]:
trainer.run()

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
[J1           0.568484    0.192333              0.844083       0.9447                    19.6773       
[J2           0.160781    0.121514              0.9528         0.9628                    35.9509       
[J3           0.117761    0.0957774             0.965617       0.9712                    52.357        
[J4           0.0978286   0.0803289             0.970817       0.9751                    70.6344       
[J5           0.0860283   0.0757639             0.973617       0.9763                    87.9045       
[J6           0.0766062   0.0668726             0.976533       0.9793                    106.494       
[J7           0.0709228   0.0640852             0.978233       0.9805                    125.998       
[J8           0.0643853   0.0584628             0.980267       0.981                     142.68        
[J9           0.0606631   0.0553628             0.9814     

In [15]:
chainer.serializers.save_npz("result/mnist.model", model)