In [1]:
%matplotlib inline

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

In [3]:
from chainer.datasets import LabeledImageDataset
from chainercv.transforms import resize
from chainer.datasets import TransformDataset

In [4]:
def transform(in_data):
    img, label = in_data

    img = resize(img, (224, 224))

    return img, label

In [5]:
train = LabeledImageDataset('data/train/train_labels.txt', 'data/train/images')
train = TransformDataset(train, transform)

valid = LabeledImageDataset('data/valid/valid_labels.txt', 'data/valid/images')
valid = TransformDataset(valid, transform)

In [6]:
import chainer
import chainer.links as L
import chainer.functions as F

In [7]:
class CNN(chainer.Chain):
    
    def __init__(self, n_mid_units1=224, n_out=2):
        super().__init__()
        with self.init_scope():

            self.conv = L.Convolution2D(None, 3, ksize=3, pad=1)

            self.fc1 = L.Linear(None,n_mid_units1)

            self.fc2 = L.Linear(None,n_out)
    
    def __call__(self, x):

        h = self.conv(x)

        h = self.fc1(h)

        h = F.relu(h)

        h = self.fc2(h)

        return h

In [8]:
nn = CNN()
model = L.Classifier(nn)

In [9]:
from chainer import training
from chainer.training import extensions

In [10]:
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

<chainer.optimizers.adam.Adam at 0x81af12450>

In [11]:
batchsize = 5
train_iter = chainer.iterators.SerialIterator(train, batchsize)
valid_iter = chainer.iterators.SerialIterator(valid, batchsize)

In [12]:
updater = training.StandardUpdater(train_iter, optimizer)

In [13]:
epoch = 5

trainer = training.Trainer(updater, (epoch, 'epoch'), out='result') 

trainer.extend(extensions.Evaluator(valid_iter, model, device=-1))

trainer.extend(extensions.LogReport(trigger=(1,'epoch')))
trainer.extend(extensions.PrintReport(['epoch', 'main/accuracy', 'validation/main/accuracy', 'main/loss','validation/main/loss','elapsed_time']), trigger=(1,'epoch'))



In [None]:
trainer.run()