## Import

In [None]:
import sys
import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions

## モデル定義

In [None]:
class LeNet5(Chain):
    def __init__(self, in_channels):
        super(LeNet5, self).__init__()
        with self.init_scope():
            self.conv1 = L.Convolution2D(
                in_channels=in_channels, out_channels=6, ksize=5, stride=1)
            self.conv2 = L.Convolution2D(
                in_channels=6, out_channels=16, ksize=5, stride=1)
            self.conv3 = L.Convolution2D(
                in_channels=16, out_channels=120, ksize=4, stride=1)
            self.fc4 = L.Linear(None, 84)
            self.fc5 = L.Linear(84, 10)

    def __call__(self, x):
        h = F.sigmoid(self.conv1(x))
        h = F.max_pooling_2d(h, 2, 2)
        h = F.sigmoid(self.conv2(h))
        h = F.max_pooling_2d(h, 2, 2)
        h = F.sigmoid(self.conv3(h))
        h = F.sigmoid(self.fc4(h))
        if chainer.config.train:
            return self.fc5(h)
        return F.softmax(self.fc5(h))

## 学習済みモデルをロード
ちなみに、20 epoch学習させて、validation accuracyは、

| dataset | val_accuracy |
| ------- | ------- |
| MNIST   | 0.9866 |
| CIFAR10 | 0.4835 |

In [None]:
mnist_model = L.Classifier(LeNet5(in_channels=1))
serializers.save_npz('lenet5-models/mnist.model', mnist_model)

cifar10_model = L.Classifier(LeNet5(in_channels=3))
serializers.save_npz('lenet5-models/cifar10.model', cifar10_model)

## フィルタ可視化メソッド

In [None]:
import matplotlib.pyplot as plt
% matplotlib inline

In [None]:
def visualize_conv_layer_weights(weights):
    out_n, in_n, h, w = weights.shape

    fig = plt.figure(figsize=(16, 9))
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)
    for i in range(out_n):
        for j in range(in_n):
            weight = weights[i, j].data
            # ax = fig.add_subplot(out_n, in_n, in_n * i + j + 1, xticks=[], yticks=[])
            ax = fig.add_subplot(in_n, out_n, out_n * j + i + 1, xticks=[], yticks=[])
            ax.imshow(weight, cmap=plt.cm.gray, interpolation='nearest')

## 可視化

In [None]:
visualize_conv_layer_weights(mnist_model.predictor.conv1.W)

In [None]:
visualize_conv_layer_weights(cifar10_model.predictor.conv1.W)

## (参考) 学習に使ったコード

In [None]:
class PreprocessedDataset(chainer.dataset.DatasetMixin):

    def __init__(self, base, in_channels, mean):
        self.base = base
        self.mean = mean.astype('f').reshape((in_channels, 1, 1))

    def __len__(self):
        return len(self.base)

    def get_example(self, i):
        image, label = self.base[i]
        image -= self.mean
        return image, label


def main(target):
    if target == 'mnist':
        get_dataset = datasets.get_mnist
        in_channels = 1
    elif target == 'cifar10':
        get_dataset = datasets.get_cifar10
        in_channels = 3
    elif target == 'cifar100':
        get_dataset = datasets.get_cifar100
        in_channels = 3
    else:
        print('Invaldi target')
        exit()

    print('target: {}'.format(target))

    train, test = get_dataset(ndim=3)

    # Mean subtract
    color_mean = train._datasets[0].mean(axis=(0, 2, 3))
    train = PreprocessedDataset(train, in_channels=in_channels, mean=color_mean)
    test = PreprocessedDataset(test, in_channels=in_channels, mean=color_mean)

    train_iter = iterators.SerialIterator(train, batch_size=100, shuffle=True)
    test_iter = iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False)

    model = L.Classifier(LeNet5(in_channels=in_channels))

    optimizer = optimizers.Adam()
    optimizer.setup(model)

    result_dir = 'result/' + target + '_lenet5'

    updater = training.StandardUpdater(train_iter, optimizer)
    trainer = training.Trainer(updater, (20, 'epoch'), out=result_dir)

    trainer.extend(extensions.Evaluator(test_iter, model))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy']))
    trainer.extend(extensions.ProgressBar())

    trainer.run()

    serializers.save_npz(target + '.model', model)