# 学習する

## Colabの環境整備

In [None]:
!apt -y -q install cuda-libraries-dev-9-2
!pip install -q cupy-cuda92 chainer
!pip install -U -q PyDrive
!pip install chainercv

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

## 学習する

[chainercvのSSDのtrain.py](https://github.com/chainer/chainercv/blob/master/examples/ssd/train.py)を適宜変更したもの

In [0]:
import argparse
import copy
import numpy as np

from ss2dbbox_dataset import SS2DBboxDataset
from ss2dbbox_dataset import ss2d_bbox_label_names

import chainer
from chainer.datasets import ConcatenatedDataset
from chainer.datasets import TransformDataset
from chainer.optimizer_hooks import WeightDecay
from chainer import serializers
from chainer import training
from chainer.training import extensions
from chainer.training import triggers

from detection_voc_evaluator_zero import DetectionVOCEvaluator
from chainercv.links.model.ssd import GradientScaling


from chainercv.links.model.ssd import multibox_loss
from chainercv.links import SSD300
from chainercv.links import SSD512
from chainercv import transforms
from chainercv import utils


from chainercv.links.model.ssd import random_crop_with_bbox_constraints
from chainercv.links.model.ssd import random_distort
from chainercv.links.model.ssd import resize_with_random_interpolation

In [None]:
class MultiboxTrainChain(chainer.Chain):

    def __init__(self, model, alpha=1, k=3):
        super(MultiboxTrainChain, self).__init__()
        with self.init_scope():
            self.model = model
        self.alpha = alpha
        self.k = k

    def __call__(self, imgs, gt_mb_locs, gt_mb_labels):
        mb_locs, mb_confs = self.model(imgs)
        loc_loss, conf_loss = multibox_loss(
            mb_locs, mb_confs, gt_mb_locs, gt_mb_labels, self.k)
        loss = loc_loss * self.alpha + conf_loss

        chainer.reporter.report(
            {'loss': loss, 'loss/loc': loc_loss, 'loss/conf': conf_loss},
            self)

        return loss


class Transform(object):

    def __init__(self, coder, size, mean):
        # to send cpu, make a copy
        self.coder = copy.copy(coder)
        self.coder.to_cpu()

        self.size = size
        self.mean = mean

    def __call__(self, in_data):
        # There are five data augmentation steps
        # 1. Color augmentation
        # 2. Random expansion
        # 3. Random cropping
        # 4. Resizing with random interpolation
        # 5. Random horizontal flipping

        img, bbox, label, nazo= in_data

        # 1. Color augmentation
        img = random_distort(img)

        # 2. Random expansion
        if np.random.randint(2):
            img, param = transforms.random_expand(
                img, fill=self.mean, return_param=True)
            if len(bbox)>0:
                bbox = transforms.translate_bbox(
                    bbox, y_offset=param['y_offset'], x_offset=param['x_offset'])
            else:
                bbox = bbox.copy()
        # 3. Random cropping
        img, param = random_crop_with_bbox_constraints(
            img, bbox, return_param=True)
        if len(bbox)>0:
          bbox, param = transforms.crop_bbox(
            bbox, y_slice=param['y_slice'], x_slice=param['x_slice'],
            allow_outside_center=False, return_param=True)
          label = label[param['index']]

        # 4. Resizing with random interpolatation
        _, H, W = img.shape
        img = resize_with_random_interpolation(img, (self.size, self.size))
        if len(bbox)>0:
          bbox = transforms.resize_bbox(bbox, (H, W), (self.size, self.size))

        # 5. Random horizontal flipping
        img, params = transforms.random_flip(
            img, x_random=True, return_param=True)
        if len(bbox)>0:
          bbox = transforms.flip_bbox(
            bbox, (self.size, self.size), x_flip=params['x_flip'])

        # Preparation for SSD network
        img -= self.mean
        mb_loc, mb_label = self.coder.encode(bbox, label)

        return img, mb_loc, mb_label

In [None]:
gpu=0
out='result'
resume=False
batchsize=16

# モデル
model = SSD300(
    n_fg_class=1,
    pretrained_model='imagenet')

model.use_preset('evaluate')
train_chain = MultiboxTrainChain(model)
chainer.cuda.get_device_from_id(gpu).use()
model.to_gpu()

# データセット
train = TransformDataset(
        SS2DBboxDataset(data_dir="dataset", split='train',use_difficult=True, return_difficult=True),
    Transform(model.coder, model.insize, model.mean))
train_iter = chainer.iterators.MultiprocessIterator(train, batchsize)

test = SS2DBboxDataset(data_dir="dataset", split='test',
    use_difficult=True, return_difficult=True)
test_iter = chainer.iterators.SerialIterator(
    test, batchsize, repeat=False, shuffle=False)

# Optimizer
optimizer = chainer.optimizers.MomentumSGD()
optimizer.setup(train_chain)
for param in train_chain.params():
    if param.name == 'b':
        param.update_rule.add_hook(GradientScaling(2))
    else:
        param.update_rule.add_hook(WeightDecay(0.0005))

# Updater
updater = training.updaters.StandardUpdater(
    train_iter, optimizer, device=gpu)

# Trainerの設定
trainer = training.Trainer(updater, (2000, 'iteration'), out)
trainer.extend(
    extensions.ExponentialShift('lr', 0.1, init=1e-3),
    trigger=triggers.ManualScheduleTrigger([3000, 5000], 'iteration'))

trainer.extend(
    DetectionVOCEvaluator(
        test_iter, model, use_07_metric=True,
        label_names=ss2d_bbox_label_names),
    trigger=(100, 'iteration'))

log_interval = 10, 'iteration'
trainer.extend(extensions.LogReport(trigger=log_interval))
trainer.extend(extensions.observe_lr(), trigger=log_interval)
trainer.extend(extensions.PrintReport(
    ['epoch', 'iteration', 'lr',
     'main/loss', 'main/loss/loc', 'main/loss/conf',
     'validation/main/map']),
    trigger=log_interval)
trainer.extend(extensions.ProgressBar(update_interval=10))

trainer.extend(extensions.snapshot(), trigger=(100, 'iteration'))
trainer.extend(
    extensions.snapshot_object(model, 'model_iter_{.updater.iteration}'),
    trigger=(8000, 'iteration'))

if resume:
    serializers.load_npz(resume, trainer)

# 実行
trainer.run()
serializers.save_npz("model.npz",model)

## 推論の実行

demo.ipynb