In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
from collections import Counter, defaultdict, deque
import os, sys, glob, copy, json, time, pickle

In [2]:
from chainer import Chain, ChainList, cuda, gradient_check, Function, Link, \
    optimizers, serializers, utils, Variable, dataset, datasets, using_config, training, iterators
from chainer.training import extensions
from chainer import functions as F
from chainer import links as L

In [180]:
# https://static.googleusercontent.com/media/research.google.com/ja//pubs/archive/42241.pdf に従ったネットワーク．
class Net(Chain):
    def __init__(self):
        super().__init__()
        with self.init_scope():
            # (*, 3, 30, 70)
            self.add_link('b1', L.BatchNormalization(3))
            self.add_link('c1', L.Convolution2D(3, 16, ksize=3, stride=2, pad=1))
            self.add_link('b2', L.BatchNormalization(16))
            self.add_link('c2', L.Convolution2D(16, 64, ksize=3, stride=2, pad=1))
            self.add_link('b3', L.BatchNormalization(64))
            self.add_link('c3', L.Convolution2D(64, 256, ksize=3, stride=2, pad=1))
            self.add_link('b4', L.BatchNormalization(256))
            self.add_link('c4', L.Convolution2D(256, 1024, ksize=3, stride=2, pad=1))
            self.add_link('b5', L.BatchNormalization(1024))
            self.add_link('c5', L.Convolution2D(1024, 4096, ksize=3, stride=2, pad=1))
            self.add_link('dl', L.Linear(4096, 7))
            self.add_link('d0', L.Linear(4096, 10))
            self.add_link('d1', L.Linear(4096, 10))
            self.add_link('d2', L.Linear(4096, 10))
            self.add_link('d3', L.Linear(4096, 10))
            self.add_link('d4', L.Linear(4096, 10))

    def __call__(self, x):
        '''
        x <Variable>
        returns <Variable>
        '''
        assert x.shape[1:] == (3, 30, 70)
        h = x
        h = self.b1(h)
        h = F.relu(self.c1(h))
        h = self.b2(h)
        h = F.relu(self.c2(h))
        h = self.b3(h)
        h = F.relu(self.c3(h))
        h = self.b4(h)
        h = F.relu(self.c4(h))
        h = self.b5(h)
        h = F.relu(self.c5(h))
        assert h.shape[1:] == (4096, 1, 3)
        h = F.mean(F.mean(h, axis=3), axis=2) 
        return [self.dl(h), self.d0(h), self.d1(h), self.d2(h), self.d3(h), self.d4(h)] 

xp = np

def lossfun(x, t):
    loss = 0
    # 桁数
    t_dl = Variable(xp.minimum(6, xp.array(list(map(len, t)))).astype(xp.int32))
    loss += F.softmax_cross_entropy(x[0], t_dl)
    # 各桁の値
    for i in range(5):
        t_di = Variable(xp.array([int(_[i]) if len(_) > i else -1 for _ in t]).astype(xp.int32))
        loss += F.softmax_cross_entropy(x[1+i], t_di)
    return loss
    
def accfun(x, t):
    # 桁数
    t_dl = xp.minimum(6, xp.array(list(map(len, t))))
    accflg = xp.argmax(x[0].data, axis=1) == t_dl

    # 各桁の値
    for i in range(5):
        t_di = xp.array([int(_[i]) if len(_) > i else -1 for _ in t])
        x_di = xp.argmax(x[1+i].data, axis=1)
        accflg = xp.logical_and(accflg, xp.logical_or(x_di == t_di, t_di == -1))
    return xp.count_nonzero(accflg) / len(accflg)

In [181]:
net = Net()

In [182]:
%%time
res = net(Variable(np.random.randn(2, 3, 30, 70).astype(np.float32)))

CPU times: user 128 ms, sys: 13.7 ms, total: 142 ms
Wall time: 81.3 ms


In [118]:
# 参考： https://stackoverflow.com/questions/41176258/h5py-access-data-in-datasets-in-svhn
import h5py
def get_box_data(index, hdf5_data):
    """
    get `left, top, width, height` of each picture
    :param index:
    :param hdf5_data:
    :return:
    """
    meta_data = dict()
    meta_data['height'] = []
    meta_data['label'] = []
    meta_data['left'] = []
    meta_data['top'] = []
    meta_data['width'] = []

    def print_attrs(name, obj):
        vals = []
        if obj.shape[0] == 1:
            vals.append(obj[0][0])
        else:
            for k in range(obj.shape[0]):
                vals.append(int(hdf5_data[obj[k][0]][0][0]))
        meta_data[name] = vals

    box = hdf5_data['/digitStruct/bbox'][index]
    hdf5_data[box[0]].visititems(print_attrs)
    return meta_data

def get_name(index, hdf5_data):
    name = hdf5_data['/digitStruct/name']
    return ''.join([chr(v[0]) for v in hdf5_data[name[index][0]].value])

def load_digitStruct(mat_filename):
    ret = {'name': [], 'bbox':[]}
    with h5py.File(mat_filename, 'r') as f:
        size = f['/digitStruct/name'].size
        for _i in tqdm(range(size)):
            ret['name'].append(get_name(_i, f))
            ret['bbox'].append(get_box_data(_i, f))
    return ret

from pathlib import Path
def load_svhn(path, image_size=[70, 30]):
    print('Loading digitStruct.mat...', end='')
    dsm = load_digitStruct(str(Path(path) / 'digitStruct.mat'))
    print('Done.')
    print('Loading images...', end='')
    xs = []
    for name in tqdm(dsm['name']):
        xs.append(np.array(Image.open(str(Path(path) / name)).resize(image_size)))
    xs = np.asarray(xs, dtype=np.float32).transpose([0, 3, 1, 2]) / 256  # (batch, color, height, width)
    ys = [''.join(map(str, map(int, b['label']))) for b in dsm['bbox']]
    return {
        "xs": xs,
        "ys": ys,
        "digitStruct": dsm
    }

In [106]:
def concat_image(images, axis=0):
    if axis is None: # 2-dim tiling
        return concat_image([concat_image(row, axis=1) for row in images], axis=0)
    assert axis in [0, 1]
    if axis == 1:
        return concat_image([im.transpose(Image.TRANSPOSE) for im in images], axis=0).transpose(Image.TRANSPOSE)  
    # case with axis == 0
    ws = np.array([im.width for im in images])
    hs = np.array([im.height for im in images])
    ret = Image.new('RGB', (np.max(ws), np.sum(hs)))
    offs = 0
    for im in images:
        ret.paste(im, (0, offs))
        offs += im.height
    return ret

In [119]:
data_train = load_svhn('/Users/yoshidayuuki/Downloads/SVHN/train')

Loading digitStruct.mat...



Done.
Loading images...

In [173]:
data_test = load_svhn('/Users/yoshidayuuki/Downloads/SVHN/test/')

Loading digitStruct.mat...



Done.
Loading images...

In [183]:
data_train['xs'].shape, data_test['xs'].shape

((33402, 3, 30, 70), (13068, 3, 30, 70))

In [207]:
def get_trainer(
        model_loss : Chain,
        data_train : datasets,
        data_test : datasets =None,
        num=(5, 'epoch'),
        batch_size=100,
        earlystop_patience=None,
        restore_best_model=False,
        device=-1,
        out_directory='result',
        converter=dataset.concat_examples,
        log_trigger=(1, 'epoch')):
    
    if device >= 0:
        cuda.get_device(device).use()  # Make a specified GPU current
        model_loss.to_gpu() 
    opt = optimizers.Adam()
    opt.setup(model_loss)
    itr_train = iterators.SerialIterator(data_train, shuffle=True, batch_size=batch_size)
    itr_test = iterators.SerialIterator(data_test, shuffle=False, repeat=False, batch_size=batch_size)
    upd = training.StandardUpdater(itr_train, opt, device=device, converter=converter)
    
    stop_trigger = num
    if earlystop_patience is not None:
        stop_trigger = training.triggers.EarlyStoppingTrigger(monitor='validation/main/loss', patients=earlystop_patience, max_trigger=stop_trigger)
    trn = training.Trainer(upd, stop_trigger, out=out_directory)

    trn.extend(extensions.LogReport(trigger=log_trigger))
    trn.extend(extensions.snapshot(filename='snapshot_epoch-{.updater.epoch}'))
    trn.extend(extensions.snapshot_object(model_loss, filename='model_epoch-{.updater.epoch}'))
    if data_test is not None:
        trn.extend(extensions.Evaluator(itr_test, model_loss, device=device, converter=converter))
    trn.extend(extensions.PrintReport(['epoch', 'iteration', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']), trigger=log_trigger)
    trn.extend(extensions.PlotReport(['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png'))
    trn.extend(extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], x_key='epoch', file_name='accuracy.png'))
    trn.extend(extensions.dump_graph('main/loss'))
    
    if restore_best_model:
        def _restore_best_model(trn):
            log = trn.get_extension('LogReport').log
            best_epoch = log[np.argmin(np.array([l['validation/main/loss'] for l in log]))]['epoch']
            print("Best epoch:", best_epoch)
            path = out_directory + "/" + 'model_epoch-' + str(best_epoch)
            try:
                serializers.load_npz(path, trn.updater.get_optimizer('main').target)
                print("Restoring best model done ({})".format(path))
            except FileNotFoundError:
                print("Restoring best model failed (File not found: {})".format(path))
        _stop_trigger = copy.deepcopy(stop_trigger) # stop_trigger は「ログが溜まらない間」に複数回呼び出すことを想定されていない（2回目以降はFalseを返す）のでコピーを作成．
        trn.extend(_restore_best_model, trigger=_stop_trigger)
    return trn


In [202]:
tds_train = datasets.TupleDataset(data_train['xs'], data_train['ys'])
tds_test = datasets.TupleDataset(data_test['xs'], data_test['ys'])

In [205]:
net = Net()
model_loss = L.Classifier(net, lossfun=lossfun, accfun=accfun)
def converter(list_of_tuple, device=-1):
    xs, ys = zip(*list_of_tuple)
    return np.asarray(xs), ys    
trn = get_trainer(model_loss, tds_train, tds_test, device=-1, out_directory='_tmp', converter=converter, log_trigger=(1, 'iteration'))

In [206]:
trn.run()

epoch       iteration   main/loss   main/accuracy  validation/main/loss  validation/main/accuracy  elapsed_time
[J0           1           14.1881     0                                                              3.03297       
[J0           2           28.273      0.02                                                           6.33769       
[J0           3           26.8443     0.02                                                           9.24161       
[J0           4           21.6807     0                                                              12.2964       
[J0           5           15.1948     0.02                                                           15.0016       
[J0           6           20.5335     0.01                                                           18.0522       
[J0           7           13.88       0.02                                                           20.6616       
[J0           8           22.0089     0                             

Exception in main training loop: string indices must be integers
Traceback (most recent call last):
  File "/Users/yoshidayuuki/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/chainer/training/trainer.py", line 318, in run
    entry.extension(self)
  File "/Users/yoshidayuuki/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/chainer/training/extensions/evaluator.py", line 157, in __call__
    result = self.evaluate()
  File "/Users/yoshidayuuki/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/chainer/training/extensions/evaluator.py", line 206, in evaluate
    in_arrays = self.converter(batch, self.device)
  File "/Users/yoshidayuuki/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/chainer/dataset/convert.py", line 134, in concat_examples
    [example[i] for example in batch], padding[i])))
  File "/Users/yoshidayuuki/.pyenv/versions/anaconda3-4.2.0/lib/python3.5/site-packages/chainer/dataset/convert.py", line 164, in _concat_arrays
    retur

TypeError: string indices must be integers