In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import print_function

import argparse
import itertools
import multiprocessing
import os.path
import random

import chainer
import cv2
import numpy as np
import six
from chainer import iterators
from chainer import training
from chainer.iterators import SerialIterator
from chainer.optimizer import WeightDecay
from chainer.training import extensions, updaters

from modules.args_utils import BracesAction, ByteType, ColorType, DimensionType, PercentType, UnifiedType
from modules.args_utils import DictParser, ObjectParser
from modules.common import AUGMENTERS, OPTIMIZERS, setup_model
from modules.image_dataset import CropExample, PatchAmplifier, RandomPatchAmplifier
from modules.image_dataset import LabeledImageDataset
from modules.image_dataset import MapColorOfExample, ResizeExample, SubtractFromExample, TransposeExample
from modules.io_utils import print, progress_message


def create_args():
    n_cpus = multiprocessing.cpu_count()

    parser = argparse.ArgumentParser()

    # train
    parser.add_argument('--gpu', '-g', type=int, nargs='+',
                        help='GPU IDs (negative value indicates CPU)')
    parser.add_argument('--multiprocess', '-p', type=int, choices=six.moves.range(1, 256 + 1), nargs='?', const=n_cpus,
                        help='Use multiprocess updater', metavar='1-256')
    parser.add_argument('--epoch', type=int, required=True)
    parser.add_argument('--batch_size', '-b', type=int, default=128,
                        help='Number of images in each mini-batch')
    parser.add_argument('--seed', type=int,
                        help='Seed to do deterministic training. Keep in mind that this option decreases performance')
    parser.add_argument('--cache', type=ByteType(), nargs='?', const=True)

    # model
    parser.add_argument('--model', required=True,
                        choices=['mnih_cnn_multi', 'fcrn_v1', 'fcrn_v2', 'fcrn_v3', 'fcrn_v4', 'fcrn_v5', 'fcrn_v6'])
    parser.add_argument('--nlabels', type=int, default=3)
    parser.add_argument('--class_weight', type=float, nargs='*', metavar='float',
                        help='Class weights')
    parser.add_argument('--mapping', type=DictParser(ColorType(tuple), int),
                        default={(0x00, 0x00, 0x80): 0, (0x00, 0x80, 0x00): 1, (0x80, 0x00, 0x00): 2},
                        metavar='COLOR=LABEL,COLOR=LABEL,...')
    parser.add_argument('--mapping_range', type=int, nargs=2, default=(0, 0))

    # optimizers
    parser.add_argument('--opt', type=ObjectParser(OPTIMIZERS), default=chainer.optimizers.MomentumSGD(),
                        metavar='TYPE:PARAMS',
                        help='type:value or type:name=value,name=value,...' +
                             'Available optimizers: ' + ', '.join(six.iterkeys(OPTIMIZERS)))
    parser.add_argument('--weight_decay', type=float, default=0.0005)

    # dataset
    parser.add_argument('--train', nargs='+', required=True,
                        help='Csv files written image paths to train')
    parser.add_argument('--valid', nargs='+', required=True,
                        help='Csv files written image paths to validate')
    parser.add_argument('--image_size', type=UnifiedType(DimensionType(), PercentType(min=1)),
                        help='Size or scale to resize images')
    parser.add_argument('--mean', '-m', nargs='?', const=True,
                        help='Mean file (computed by compute_mean.py)')

    # augmentation
    parser.add_argument('--augment', type=ObjectParser(AUGMENTERS), nargs='+', action='append',
                        metavar='TYPE:PARAMS',
                        help='type:value or type:name=value,name=value,...' +
                             'Available augmenters: ' + ', '.join(six.iterkeys(AUGMENTERS)))

    # patch
    parser.add_argument('--patch_size', type=int, nargs=2, default=None)
    parser.add_argument('--patch_density', type=PercentType(min=1), default=1.)
    parser.add_argument('--patch_ignore', type=ColorType(tuple))

    # output
    parser.add_argument('--out', '-o', default='out',
                        help='Directory to output the training result')
    parser.add_argument('--snapshot', type=int, default=1,
                        help='Interval of saving models and snapshots')
    parser.add_argument('--save_snapshot', action='store_true',
                        help='Save snapshots for resume')
    parser.add_argument('--resume', '-r',
                        help='Snapshot to resume from')

    # log
    parser.add_argument('--class_labels', nargs='+', default=['blue', 'green', 'red'])
    parser.add_argument('--display_entries', nargs='*', default=['accuracy'], action=BracesAction)

    args = parser.parse_args()

    if args.class_weight and args.nlabels != len(args.class_weight):
        raise Exception('nlabels and the number of class_weight is different')
    if args.nlabels != len(args.class_labels):
        raise Exception('nlabels and the number of class_labels is different')
    if any(l >= args.nlabels for l in six.itervalues(args.mapping)):
        raise Exception('Every mapped label must be less than nlabels')

    return args


def setup_optimizer(args, model):
    optimizer = args.opt
    optimizer.setup(model)
    optimizer.add_hook(WeightDecay(args.weight_decay))
    return optimizer


def setup_dataset(args):
    mapping = {color[::-1]: label for color, label in six.iteritems(args.mapping)}
    mapper = MapColorOfExample(mapping, args.mapping_range)

    train = LabeledImageDataset(args.train)
    valid = LabeledImageDataset(args.valid)

    # Resize images
    if args.image_size:
        size = args.image_size if isinstance(args.image_size, list) else None
        scale = args.image_size if isinstance(args.image_size, float) else None
        resize = ResizeExample(size, scale)
        train = train.transform(resize)
        valid = valid.transform(resize)

    if args.cache:
        # Cache resized images
        if args.cache is True:
            train = train.cache()
            valid = valid.cache()
        else:
            total = len(train) + len(valid)
            train = train.cache(len(train) * args.cache // total)
            valid = valid.cache(len(valid) * args.cache // total)

    # Compute and save mean
    if args.mean:
        if isinstance(args.mean, str):
            mean = np.load(args.mean)
            if args.image_size:
                mean = resize.input(mean)
        else:
            n = len(train) + len(valid)
            mean = 0
            for image in itertools.chain(train.inputs, valid.inputs):
                mean += image.astype(np.float32)
            mean /= n

            if not os.path.exists(args.out):
                os.makedirs(args.out)
            np.save(os.path.join(args.out, 'mean.npy'), mean)
            cv2.imwrite(os.path.join(args.out, 'mean.png'), mean.astype(np.uint8))

        train = train.transform(SubtractFromExample(mean))
        valid = valid.transform(SubtractFromExample(mean))

    if args.augment:
        # Augment images
        orig = train
        for chain in args.augment:
            train = train.chain(six.moves.reduce(lambda target, augmenter: target.amplify(augmenter), chain, orig))

    # Make a patchwork
    if args.patch_size:
        input_size, label_size = args.patch_size

        cropper = CropExample(None, (input_size - label_size) // 2)
        kwargs = {
            'image_ksize': (input_size, input_size),
            'label_ksize': (label_size, label_size)
        }

        train = train.transform(cropper).amplify(RandomPatchAmplifier(density=args.patch_density, **kwargs))
        valid = valid.transform(cropper).amplify(PatchAmplifier(**kwargs))

        if args.patch_ignore:
            bg = args.patch_ignore[::-1]

            def is_fg(example):
                image, label = example
                return np.any(label != bg)

            train = train.filter(is_fg)
            valid = valid.filter(is_fg)

    # Transpose images
    train = train.transform(TransposeExample()).transform(mapper)
    valid = valid.transform(TransposeExample()).transform(mapper)

    return train, valid


def main(args):
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        if chainer.cuda.available:
            chainer.cuda.cupy.seed(args.seed)

        chainer.global_config.functions_deterministic = True
        chainer.global_config.cudnn_deterministic = True

    main_device = args.gpu[0] if args.gpu else -1

    with progress_message('Loading dataset...'):
        train, valid = setup_dataset(args)

    if not train:
        raise Exception('No training data')
    if not valid:
        raise Exception('No validation data')

    with progress_message('Setting up model...'):
        model = setup_model(args.model, main_device, args.nlabels, args.multiprocess, args.class_weight,
                            args.class_labels)
    optimizer = setup_optimizer(args, model)

    n_processes = args.multiprocess if args.multiprocess else 1

    train_iter = SerialIterator(train, args.batch_size)
    valid_iter = SerialIterator(valid, args.batch_size, False, False)

    if args.gpu and len(args.gpu) >= 2:
        if len(args.gpu) != len(set(args.gpu)):
            raise Exception('GPU specification must be unique')

        devices = {'main': main_device}
        devices.update(enumerate(args.gpu[1:], 1))

        if n_processes >= 2:
            train_iter = []
            for i in chainer.datasets.split_dataset_n_random(train, len(devices)):
                train_iter.append(iterators.MultiprocessIterator(i, args.batch_size, n_processes=n_processes))
            valid_iter = iterators.MultiprocessIterator(valid, args.batch_size, False, False, n_processes=n_processes)

            updater = updaters.MultiprocessParallelUpdater(train_iter, optimizer, devices=devices)
        else:
            updater = training.ParallelUpdater(train_iter, optimizer, devices=devices)
    else:
        updater = training.StandardUpdater(train_iter, optimizer, device=main_device)

    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    trainer.extend(extensions.dump_graph(root_name='main/loss'))

    trainer.extend(extensions.Evaluator(valid_iter, model, device=main_device))

    if args.save_snapshot:
        trainer.extend(extensions.snapshot(), trigger=(args.snapshot, 'epoch'))
    trainer.extend(extensions.snapshot_object(model, 'model_epoch_{.updater.epoch}'), trigger=(args.snapshot, 'epoch'))

    trainer.extend(extensions.observe_lr())

    trainer.extend(extensions.dump_graph('main/loss'))

    trainer.extend(extensions.LogReport())

    entries = [
        'epoch', 'main/loss', 'validation/main/loss',
    ]

    for entry in args.display_entries:
        entries += [
            'main/{}'.format(entry),
            'validation/main/{}'.format(entry)
        ]

    entries += ['lr', 'elapsed_time']

    trainer.extend(extensions.PrintReport(entries))

    trainer.extend(extensions.ProgressBar(update_interval=2))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    trainer.run()


if __name__ == '__main__':
    args = create_args()
    print(args)
    main(args)