In [None]:
import os
import numpy as np
import chainer
import chainer.links as L
import chainer.functions as F
from chainer import serializers
from chainer import training
from chainer.training import extensions


class MyChain(chainer.Chain):
    def __init__(self):
        super(MyChain, self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(1, 256) 
            self.l2 = L.Linear(256, 256)
            self.l3 = L.Linear(256, 1) 

    def forward(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)


batchsize = 4
epoch = 500
out_dir = 'regression_model'

# データの準備
trainx = np.array(([0.00], [0.25], [0.50], [0.75], [1.00], [1.25], [1.50], [1.75], [2.00]), dtype=np.float32)
trainy = np.array([[1], [1.8125], [2.75], [3.8125], [5], [6.3125], [7.75], [9.3125], [11]], dtype=np.float32)
train = chainer.datasets.TupleDataset(trainx, trainy)
valx = np.array(([0.05], [0.3], [0.4], [0.6], [0.8], [0.9], [1.1], [1.4], [1.9]), dtype=np.float32)
valy = np.array(([1.1525], [1.99], [2.36], [3.16], [4.04], [4.51], [5.51], [7.16], [10.31]), dtype=np.float32)
val = chainer.datasets.TupleDataset(valx, valy)

# イテレータの設定
train_iter = chainer.iterators.SerialIterator(train, batchsize)
val_iter = chainer.iterators.SerialIterator(val, batchsize,
                                            repeat=False, shuffle=False)

# 訓練対象のネットワークモデルの作成と目的関数の設定
net = L.Classifier(MyChain(),
                   lossfun=F.mean_squared_error,
                   accfun=F.mean_absolute_error)

# オプティマイザの設定
optimizer = chainer.optimizers.Adam()
optimizer.setup(net)

# アップデータの設定
updater = training.updaters.StandardUpdater(train_iter, optimizer)

# トレーナの設定
trainer = training.Trainer(updater, (epoch, 'epoch'), out='result')
trainer.extend(extensions.LogReport())
trainer.extend(extensions.Evaluator(val_iter, net))
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'main/accuracy', 'validation/main/loss', 'validation/main/accuracy', 'elapsed_time']))

# トレーナの実行
trainer.run()

# 訓練済みのネットワークを保存
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
print('save the model')
serializers.save_npz('{}/mychain.model'.format(out_dir), net.predictor)
