In [None]:
%matplotlib inline
import chainer as C
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import time
import pandas as pd

In [None]:
import pickle, gzip
with gzip.open('/data/mnist.pkl.gz') as f:
    unpick = pickle._Unpickler(f)
    unpick.encoding = 'latin1'
    train, valid, test = unpick.load()

In [None]:
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(train[0][i].reshape((28, 28)))
    plt.axis('off')

In [None]:
class Network(C.Chain):
    def __init__(self):
        H = 256
        super().__init__(
            a=C.links.Linear(784, H),
            b=C.links.Linear(H, H),
            c=C.links.Linear(H, H),
            final=C.links.Linear(H, 10),
        )

    def __call__(self, x):
        w = x
        w = C.functions.tanh(self.a(w))
        g = C.functions.sigmoid(self.c(w))
        w = g * w + (1 - g) * C.functions.tanh(self.b(w))
        #w = C.functions.tanh(self.c(w))
        return self.final(w)

In [None]:
batch_size = 512

network = Network()
opt = C.optimizers.Adam()
opt.use_cleargrads()
opt.setup(network)

t0 = time.time()
accuracies = []
losses = []
valid_losses = []
valid_accuracies = []
for _ in range(100):
    for i in range(0, len(train[0]), batch_size):
        batch_x = C.Variable(train[0][i:(i + batch_size)])
        batch_y = C.Variable(train[1][i:(i + batch_size)].astype(np.int32))

        network.cleargrads()
        z = network(batch_x)
        c = C.functions.softmax_cross_entropy(z, batch_y)
        c.backward()
        opt.update()
        a = 100 * C.functions.accuracy(z, batch_y)

        #print("Loss = %.2f, Accuracy = %.2f%%" % (c.data, a.data))
        losses.append(float(c.data))
        accuracies.append(float(a.data))

    valid_data = C.Variable(valid[0])
    valid_labels = C.Variable(valid[1].astype(np.int32))
    valid_pred = network(valid_data)
    valid_loss = C.functions.softmax_cross_entropy(valid_pred, valid_labels)
    valid_accuracy = 100 * C.functions.accuracy(valid_pred, valid_labels)
    
    valid_losses.append(float(valid_loss.data))
    valid_accuracies.append(float(valid_accuracy.data))

t1 = time.time()

print("Validation Loss = %.2f" % (valid_loss.data))
print("Validation Accuracy = %.2f%%" % (valid_accuracy.data))
print("In %.1f s" % (t1 - t0))

In [None]:
pd.ewma(pd.Series(accuracies), com=100).plot()

In [None]:
pd.ewma(pd.Series(losses), com=100).plot()

In [None]:
pd.Series(valid_losses).plot()

In [None]:
pd.Series(valid_accuracies).plot()