In [1]:
import node
import numpy as np 

In [2]:
class MLP(node.Network):
    
    def __init__(self, num_in_units, num_h_units, num_out_units):
        
        self.layers = [node.Linear(num_in_units, num_h_units),
                       node.Linear(num_h_units, num_out_units)]
        
    def __call__(self, x):
        
        h = self.layers[0](x).tanh()
        y = self.layers[1](h)
        
        return y
    
classifier = MLP(28 * 28, 1024, 10)
optimizer = node.SGD(classifier.get_parameters(), 0.001)

In [3]:
train_dataset = node.MNIST(training=True)
train_dataloader = node.DataLoader(train_dataset, batch_size=128)

In [4]:
test_dataset = node.MNIST(training=False)
test_dataloader = node.DataLoader(test_dataset, batch_size=128)

In [5]:
def train(input, target):
    optimizer.zero_grad()
    
    #　パラメーターを更新する
    output = classifier(input/255).softmax_with_cross_entropy(target)
    output.backward()
    optimizer()
    
    return output.value

In [6]:
def measure(prediction, target):
    # 出力とラベルを受け取り、何個正解したかの数を返す。
    
    prediction = np.argmax(prediction.value, axis=1)
    target = np.argmax(target.value, axis=1)
    
    return np.sum(np.where(prediction == target, 1, 0))

In [None]:
def evaluate(input, target):
    with node.zero_grad():
        prediction = classifier(input/255)
        output = prediction.softmax_with_cross_entropy(target)
        
    loss = output.value
    accuracy = measure(prediction, target)
    
    return loss, accuracy

In [None]:
import tqdm

print("")
for epoch in range(1, 16):
    
    # Train Loss, Test Loss, Accuracy
    metrics = [0, 0, 0]

    for input, target in tqdm.tqdm_notebook(train_dataloader, leave=False):
        metrics[0] += train(input, target)

    for input, target in tqdm.tqdm_notebook(test_dataloader, leave=False):
        loss, accuracy = evaluate(input, target)
        metrics[1] += loss 
        metrics[2] += accuracy
        
    metrics[0] /= len(train_dataloader)
    metrics[1] /= len(test_dataloader)
    metrics[2] /= 128 * len(test_dataloader)
    
    if epoch % 3 == 0:
        print("epoch {0}, training loss {1:.2f}, test loss {2:.2f}, accuracy {3:.2f}".format(epoch, *metrics))


