In [1]:
import chainer
import chainer.links as L
import chainer.functions as F
import matplotlib.pyplot as plt
import numpy as np

In [2]:
train, test = chainer.datasets.get_mnist(ndim=3)

In [3]:
class NN(chainer.Chain):
    def __init__(self, class_labels=10):
        super().__init__()
        with self.init_scope():
            self.conv1 = L.Convolution2D(None, 16, ksize=5, pad=2, nobias=True)
            self.conv2 = L.Convolution2D(None, 16, ksize=5, pad=2, nobias=True)
            self.fc1 = L.Linear(None, 100)
            self.fc2 = L.Linear(None, class_labels)
            
    def __call__(self, x):
        h = self.conv1(x)
        h = F.relu(h)
        h = F.max_pooling_2d(h, ksize=2, stride=2)
        h = self.conv2(h)
        h = F.relu(h)
        h = F.max_pooling_2d(h, ksize=2, stride=2)
        h = self.fc1(h)
        h = F.relu(h)
        h = self.fc2(h)
        return h

In [4]:
nn = NN()
model = L.Classifier(nn)

In [5]:
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

<chainer.optimizers.adam.Adam at 0xb1e7031d0>

In [6]:
batch_size = 256
epoch = 30

In [7]:
class LeitnerIterator(chainer.iterators.SerialIterator):
    def __init__(self, dataset, batch_size, model, repeat=True, shuffle=None, order_sampler=None, n_queues=10):
        super().__init__(dataset, batch_size,repeat, shuffle, order_sampler)
        self.model = model
        self.n_queues = n_queues
        #iteration回数を記憶する
        self.n_epochs = 0
        #queueの初期化
        self.Q = [dataset] + [[]] * (self.n_queues - 1)
        #各epochで使われるデータとqueueの位置を格納するリスト
        self.using_dataset = []
        
    def __next__(self):
        #新たなqueueに基づいてbatchを作る。       
        self.update_queue()
        self._previous_epoch_detail = self.epoch_detail
        self._state, self.indices = _statemachine.iterator_statemachine(
            self._state, self.batch_size, self.repeat, self.order_sampler,
            len(self.using_dataset))
        if self.indices is None:
            raise StopIteration

        batch = [self.using_dataset[index][0] for index in self.indices]
        
        #epochが切り替わった際に使うデータセットを更新。        
        if _state.is_new_epoch == True:
            self.n_epochs += 1
            self.using_dataset = []
            for i in range(self.n_queues):
                if self.n_epochs % (2 ** i) == 0:
                    for j in Q[i]:
                        self.using_dataset += [j, i]
                    
        return batch
    
    def update_queue(self):
        #学習結果をもとにqueueを更新する。
        for i in self.indices:
            if self.model.predictor(self.using_dataset[i][0][1].reshape(1, 1, 28, 28)).data.argmax(axis=1)[0] == self.using_dataset[i][0][1]:
                if self.using_dataset[i][1] < self.n_queues - 1:
                    self.Q[self.using_dataset[i][1]].remove(self.using_dataset[i][1])
                    self.Q[self.using_dataset[i][1] + 1].append(self.using_dataset[i][1])    
            else:
                self.Q[self.using_dataset[i][1]].remove(self.using_dataset[i][1])
                self.Q[0].append(self.using_dataset[i][1])

In [8]:
train_iter = LeitnerIterator(train, batch_size, model)
test_iter = chainer.iterators.SerialIterator(test, batch_size, repeat=False, shuffle=False)

In [9]:
from chainer import training
from chainer.training import extensions

In [10]:
updater = training.StandardUpdater(train_iter, optimizer, device=-1)

In [11]:
trainer = training.Trainer(updater, (epoch, 'epoch'), out='result/mnist')

trainer.extend(extensions.Evaluator(test_iter, model, device=-1))

trainer.extend(extensions.LogReport(trigger=(1, 'epoch')))

trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy', 'main/loss', 'validation/main/loss','elapsed_time']), trigger=(1, 'epoch'))

In [None]:
trainer.run()

epoch       main/accuracy  validation/main/accuracy  main/loss   validation/main/loss  elapsed_time
[J1           0.907098       0.97168                   0.320697    0.0936902             140.162       
[J2           0.974292       0.976465                  0.0863331   0.0693407             350.015       
[J3           0.981582       0.985938                  0.0604799   0.0444126             555.898       
[J4           0.984559       0.988184                  0.0486356   0.0365737             744.051       
[J5           0.987246       0.988086                  0.0398765   0.0348631             914.769       
[J6           0.989262       0.988672                  0.0340496   0.0321054             1086.51       
[J7           0.990618       0.989746                  0.0297901   0.0308071             1221.55       
[J8           0.991253       0.987793                  0.0271333   0.0340598             1355.44       
[J9           0.99237        0.989941                  0.02

In [None]:
test[0]