In [23]:
import fire
import numpy as np
from tqdm import tqdm
from plain_cnn import PlainCNN
from mlp import MLP
import chainer
from chainer import computational_graph
from chainer import optimizers
import chainer.functions as F
import chainer.links as L

import matplotlib
matplotlib.use('Agg')

import matplotlib.pyplot as plt

from sobamchan.sobamchan_iterator import Iterator
from sobamchan.sobamchan_log import Log
from sobamchan.sobamchan_slack import Slack
from sobamchan import sobamchan_chainer
slack = Slack()

In [53]:
class ResBlock(sobamchan_chainer.Model):

    def __init__(self, in_channels, out_channels, stride=1):
        super(ResBlock, self).__init__(
            conv1=L.Convolution2D(in_channels, out_channels, ksize=(3,3), stride=1, pad=1),
            conv2=L.Convolution2D(out_channels, out_channels, ksize=(3,3), stride=1, pad=1),
            bn1=L.BatchNormalization(out_channels),
            bn2=L.BatchNormalization(out_channels),
        )

    def __call__(self, x, train=True):
        h = self.fwd(x, train)
        return h

    def fwd(self, x, train=True):
        h = F.relu(self.bn1((self.conv1(x))))
        h = F.relu(self.bn2((self.conv2(h))))
        _, x_channels, _, _ = x.shape
        h_batch_size, h_channels, h_h, h_w = h.shape
        if x_channels != h_channels:
            pad = chainer.Variable(np.zeros((h_batch_size, h_channels - x_channels, h_h, h_w)).astype(np.float32))
            if np.ndarray is not type(h.data):
                pad.to_gpu()
            return h + F.concat((x, pad))
        return h + x

In [5]:
train, test = chainer.datasets.cifar.get_cifar10()
train_x = np.array([x[0] for x in train])
train_t = np.array([x[1] for x in train])
test_x = np.array([x[0] for x in test])
test_t = np.array([x[1] for x in test])
train_n = len(train_x)
test_n = len(test_x)
train_x = train_x.reshape(train_n, 3, 32, 32)
test_x = test_x.reshape(test_n, 3, 32, 32)

In [8]:
train_x = np.subtract(train_x, np.mean(train_x, axis=0))
test_x = np.subtract(test_x, np.mean(test_x, axis=0))

In [61]:
x_batch = train_x[:256]
print(x_batch.shape)
t_batch = train_t[:256]
print(t_batch.shape)

(256, 3, 32, 32)
(256,)


In [54]:
resblock = ResBlock(3, 64)

In [55]:
y = resblock(x_batch)

In [56]:
y.shape

(256, 64, 32, 32)

In [57]:
class ResNet(sobamchan_chainer.Model):

    def __init__(self):
        super(ResNet, self).__init__()
        layer_i = 1
        n = 2
        modules = []
        input_channel = 3
        # 16 layer, 32 * 32 output map size
        for i in range(layer_i, layer_i+n*2+1):
            modules += [('resblock_{}'.format(layer_i), ResBlock(input_channel, 16, stride=1))]
            input_channel = 16
            layer_i += 1
        # 32 layer
        for i in range(layer_i, layer_i+n*2):
            if i+layer_i != layer_i+n*2:
                modules += [('resblock_{}'.format(layer_i), ResBlock(input_channel, 32, stride=1))]
            else:
                modules += [('resblock_{}'.format(layer_i), ResBlock(input_channel, 32, stride=2))]
            input_channel = 16
            layer_i += 1
        # 64 layer
        for i in range(layer_i, layer_i+n*2):
            if i+layer_i != layer_i+n*2:
                modules += [('resblock_{}'.format(layer_i), ResBlock(input_channel, 64, stride=1))]
            else:
                modules += [('resblock_{}'.format(layer_i), ResBlock(input_channel, 64, stride=2))]
            input_channel = 8
            layer_i += 1

        modules += [('fc', L.Linear(None, 10))]
        
        # register
        [ self.add_link(*link) for link in modules ]
        self.modules = modules
        self.layer_n = layer_i

    def __call__(self, x, t, train=True):
        y = self.fwd(x, train)
        return F.softmax_cross_entropy(y, t), F.accuracy(y, t)

    def fwd(self, x, train=True):
        # convs and bns
        for i in range(1, self.layer_n-1):
            if i == 1:
                x = F.max_pooling_2d(x, (2,2), stride=1)

        x = F.average_pooling_2d(x, (2,2), stride=1)
        # fc
        x = self['fc'](x)

        return x

In [59]:
resnet = ResNet()

In [65]:
y, acc = resnet(x_batch, t_batch)

In [69]:
acc.data

array(0.078125, dtype=float32)