In [1]:
import node
import node.cuda

In [2]:
import numpy as np
import cupy as cp

In [3]:
class Classifier(node.Network):
    
    def __init__(self, num_in_units, num_h_units, num_out_units):
        
        self.layers = [node.Linear(num_in_units, num_h_units),
                       node.Linear(num_h_units, num_out_units)]
        
    def __call__(self, x):
        
        h = self.layers[0](x).tanh()
        y = self.layers[1](h)
        
        return y
    
classifier = Classifier(28 * 28, 1024, 10).gpu()
optimizer = node.Adam(classifier.get_parameters(), 0.001)  

In [4]:
datasets = [node.MNIST(training=True, flatten=True), node.MNIST(training=False, flatten=True)]
data_loaders = [node.DataLoader(datasets[0], batch_size=100), node.DataLoader(datasets[1], batch_size=100)]

In [5]:
def train(input, target):
    optimizer.zero_grad()
    
    #　パラメーターを更新する
    output = classifier(input/255).softmax_with_cross_entropy(target)
    output.backward()
    optimizer()
    
    return output.cpu().value

In [6]:
def measure(prediction, target):
    # 出力とラベルを受け取り、何個正解したかの数を返す。
    prediction = cp.argmax(prediction.value, axis=1)
    target = cp.argmax(target.value, axis=1)
    
    return cp.asnumpy(cp.sum(cp.where(prediction == target, 1, 0)))

In [7]:
def evaluate(input, target):
    with node.zero_grad():
        prediction = classifier(input/255)
        output = prediction.softmax_with_cross_entropy(target)
        
    loss = output.cpu().value
    accuracy = measure(prediction, target)
    
    return loss, accuracy

In [8]:
for epoch in range(10):
    
    # Train Loss, Test Loss, Accuracy
    metrics = [0, 0, 0]

    for input, target in data_loaders[0]:
        metrics[0] += train(input.gpu(), target.gpu())
        
    for input, target in data_loaders[1]:
        loss, accuracy = evaluate(input.gpu(), target.gpu())
        metrics[1] += loss 
        metrics[2] += accuracy
        
    metrics[0] /= len(data_loaders[0])
    metrics[1] /= len(data_loaders[1])
    metrics[2] /= 100 * len(data_loaders[1])
    
    print("epoch {0:2}, training loss {1:.2f}, test loss {2:.2f}, accuracy {3:.2f}".format(epoch, *metrics))

epoch  0, training loss 0.70, test loss 0.51, accuracy 0.92
epoch  1, training loss 0.51, test loss 0.46, accuracy 0.92
epoch  2, training loss 0.47, test loss 0.44, accuracy 0.93


KeyboardInterrupt: 