In [30]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

In [31]:
from chainer.datasets import LabeledImageDataset
from chainercv.transforms import resize
from chainer.datasets import TransformDataset

In [32]:
def transform(in_data):
    img, label = in_data
    img = resize(img, (224, 224))
    return img, label

In [33]:
train = LabeledImageDataset('data/train/train_label.txt', 'data/train/images/')
train = TransformDataset(train, transform)

valid = LabeledImageDataset('data/valid/valid_label.txt', 'data/valid/images/')
valid = TransformDataset(valid, transform)



In [34]:
import chainer
import chainer.links as L
import chainer.functions as F

In [35]:
class CNN(chainer.Chain):
    def __init__(self, n_mid_units1=224, n_out=2):
        super().__init__()
        with self.init_scope():
            self.conv = L.Convolution2D(None, 3, ksize=3, pad=1)
            self.fc1 = L.Linear(None, n_mid_units1)
            self.fc2 = L.Linear(None, n_out)
            
    def __call__(self, x):
        h = self.conv(x)
        h = self.fc1(h)
        h = F.relu(h)
        h = self.fc2(h)
        
        return h

In [36]:
nn = CNN()
model = L.Classifier(nn)

In [37]:
from chainer import training 
from chainer.training import extensions

chainerで選択できるモデル一覧
https://docs.chainer.org/en/stable/reference/optimizers.html

In [38]:
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

<chainer.optimizers.adam.Adam at 0x1289875d0>

In [39]:
batchsize = 5
train_iter = chainer.iterators.SerialIterator(train, batchsize)
valid_iter = chainer.iterators.SerialIterator(valid, batchsize, repeat=False,shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer)


In [40]:
epoch = 5

trainer = training.Trainer(updater, (epoch, 'epoch'), out='result')
trainer.extend(extensions.Evaluator(valid_iter, model, device=-1))

trainer.extend(extensions.LogReport(trigger=(1, 'epoch')))
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy','main/loss','validation/main/loss','elapsed_time']), trigger=(1,'epoch'))


In [41]:
trainer.run()

epoch       main/accuracy  validation/main/accuracy  main/loss   validation/main/loss  elapsed_time
[J1           0.44           0.47                      8710.35     383.876               21.0195       
[J2           0.54           0.5                       1026.77     611.556               40.6399       
[J3           0.59           0.51                      368.821     661.006               60.3046       
[J4           0.8            0.5                       172.703     292.885               79.5584       
[J5           0.81           0.68                      54.5336     113.14                97.6716       
