In [24]:
import pandas as pd
train = pd.read_csv('data/train.csv')
test = pd.read_csv('data/test.csv')

In [25]:
def pre_processing(data):
    """
    Nanデータの補正と数値データに直す
    各データの中央値で埋める
    """
    data.Age = data.Age.fillna(data.Age.median())
    data.Fare = data.Fare.fillna(data.Fare.median())
    data.Sex = data.Sex.replace(['male', 'female'], [0, 1])
    return data
train = pre_processing(train)
test = pre_processing(test)

In [26]:
predictors = ['Pclass', 'Sex', 'Fare', 'Age', 'SibSp', 'Parch']
train_x = train[predictors]
train_t = train['Survived']
test_x = test[predictors]

In [27]:
import numpy as np
train_xp = np.array(train_x, dtype=np.float32)
train_tp = np.array(train_t, dtype=np.int32)
test_xp = np.array(train_x, dtype=np.float32)

In [28]:
from sklearn.model_selection import train_test_split
train_xp, valid_xp, train_tp, valid_tp = train_test_split(train_xp, train_tp, test_size=0.1, random_state=0)

In [128]:
import chainer
from chainer.datasets import tuple_dataset

train_set = tuple_dataset.TupleDataset(train_xp, train_tp)
valid_set = tuple_dataset.TupleDataset(valid_xp, valid_tp)

train_iter = chainer.iterators.SerialIterator(train_set, 32)
valid_iter = chainer.iterators.SerialIterator(valid_set, 32, repeat=False, shuffle=False)

In [129]:
import chainer.functions as F
import chainer.links as L
from chainer import initializers

class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        initializer = initializers.HeNormal()
        with self.init_scope():
            self.l1 = L.Linear(None, n_units, initialW=initializer)
            self.l2 = L.Linear(None, n_units, initialW=initializer)
            self.l3 = L.Linear(None, n_out, initialW=initializer)
            self.norm1 = L.BatchNormalization(n_units)
            self.norm2 = L.BatchNormalization(n_units)

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h1 = self.norm1(h1)
        h2 = F.relu(self.l2(h1))
        h2 = self.norm1(h2)
        return self.l3(h2)

In [130]:
from chainer import training
from chainer.training import extensions

model = L.Classifier(MLP(train_xp.shape[1], 2))

optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

In [131]:
updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (200, 'epoch'))

trainer.extend(extensions.Evaluator(valid_iter, model, device=-1))
trainer.extend(extensions.dump_graph('main/loss'))

trainer.extend(extensions.snapshot())

trainer.extend(extensions.LogReport())

trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss',
                                       'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

trainer.run()

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
[J1           1.0966      1.07088               0.533654       0.600962                  0.13679       
[J2           0.920387    1.03998               0.55625        0.554487                  0.66273       
[J3           0.848045    1.00489               0.565          0.544071                  1.20434       
[J4           0.808606    0.973255              0.575          0.580128                  1.69391       
[J5           0.76311     0.954197              0.60125        0.541667                  2.28615       
[J6           0.743895    0.946807              0.605          0.520833                  2.88887       
[J7           0.702788    0.961578              0.64125        0.508013                  3.45994       
[J8           0.694538    0.963623              0.63875        0.508013                  4.05633       
[J9           0.670813    0.946602              0.6575     

[J79          0.463831    0.650538              0.7925         0.678686                  45.1923       
[J80          0.450396    0.696916              0.7925         0.665865                  45.8831       
[J81          0.463868    0.704894              0.79           0.676282                  46.4448       
[J82          0.459147    0.716604              0.78125        0.663462                  47.0752       
[J83          0.463837    0.635584              0.7775         0.676282                  47.6456       
[J84          0.454295    0.679916              0.7925         0.686699                  48.2417       
[J85          0.459846    0.674623              0.7875         0.676282                  48.8207       
[J86          0.453611    0.666271              0.79125        0.665865                  49.4409       
[J87          0.461208    0.673966              0.79625        0.676282                  50.063        
[J88          0.453988    0.671198              0.79  

[J158         0.439816    0.556231              0.80625        0.655449                  92.8423       
[J159         0.443238    0.550246              0.80875        0.668269                  93.6467       
[J160         0.450216    0.594395              0.79625        0.653045                  94.2604       
[J161         0.429298    0.552535              0.802885       0.655449                  95.0663       
[J162         0.458461    0.566785              0.80125        0.665865                  95.6849       
[J163         0.431251    0.549495              0.8125         0.668269                  96.2595       
[J164         0.452591    0.533437              0.79625        0.678686                  96.8363       
[J165         0.445891    0.543957              0.80625        0.668269                  97.485        
[J166         0.440614    0.563096              0.80875        0.655449                  98.0368       
[J167         0.444036    0.527619              0.8125

In [125]:
import json
with open('result/log', 'r') as f:
    log = json.load(f)

In [126]:
log

[{'elapsed_time': 0.16696300001058262,
  'epoch': 1,
  'iteration': 26,
  'main/accuracy': 0.33774038461538464,
  'main/loss': 1.0697598090538611,
  'validation/main/accuracy': 0.435096154610316,
  'validation/main/loss': 2.3571544885635376},
 {'elapsed_time': 0.3288320000137901,
  'epoch': 2,
  'iteration': 51,
  'main/accuracy': 0.37125,
  'main/loss': 0.9427319359779358,
  'validation/main/accuracy': 0.435096154610316,
  'validation/main/loss': 2.300034443537394},
 {'elapsed_time': 0.48643400000582915,
  'epoch': 3,
  'iteration': 76,
  'main/accuracy': 0.41125,
  'main/loss': 0.8416697955131531,
  'validation/main/accuracy': 0.435096154610316,
  'validation/main/loss': 2.1095449129740396},
 {'elapsed_time': 0.6519810000027064,
  'epoch': 4,
  'iteration': 101,
  'main/accuracy': 0.4675,
  'main/loss': 0.7756927108764649,
  'validation/main/accuracy': 0.435096154610316,
  'validation/main/loss': 2.1115849018096924},
 {'elapsed_time': 0.7973090000014054,
  'epoch': 5,
  'iteration': 