In [1]:
import node
import numpy as np

In [2]:
class Classifier(node.Network):
    
    def __init__(self, num_in_ch):
        self.layers = [node.Conv2D(num_in_ch, 16, 3),
                       node.MaxPool2D(3, 2),
                       node.Conv2D(16, 32, 3),
                       node.MaxPool2D(3, 2),
                       node.Linear(512, 256),
                       node.Linear(256, 10)]
        
    def __call__(self, input):
        hidden = input
        hidden = self.layers[1](self.layers[0](hidden)).relu()
        hidden = self.layers[3](self.layers[2](hidden)).relu()
        hidden = hidden.reshape(input.value.shape[0], -1)
        hidden = self.layers[4](hidden)
        hidden = self.layers[5](hidden)
        
        return hidden
    
classifier = Classifier(1)
optimizer = node.SGD(classifier.get_parameters(), 0.001, 0.0001)

In [3]:
train_dataset = node.MNIST(training=True)
train_dataloader = node.DataLoader(train_dataset, batch_size=100)

In [4]:
test_dataset = node.MNIST(training=False)
test_dataloader = node.DataLoader(test_dataset, batch_size=100)

In [5]:
def train(input, target):
    optimizer.zero_grad()
    
    #　パラメーターを更新する
    output = classifier(input/255).softmax_with_cross_entropy(target)
    output.backward()
    optimizer()
    
    return output.value

In [6]:
def measure(prediction, target):
    # 出力とラベルを受け取り、何個正解したかの数を返す。
    prediction = np.argmax(prediction.value, axis=1)
    target = np.argmax(target.value, axis=1)
    
    return np.sum(np.where(prediction == target, 1, 0))

In [7]:
def evaluate(input, target):
    with node.zero_grad():
        prediction = classifier(input/255)
        output = prediction.softmax_with_cross_entropy(target)
        
    loss = output.value
    accuracy = measure(prediction, target)
    
    return loss, accuracy

In [8]:
import tqdm

print("")
for epoch in range(16):
    
    # Train Loss, Test Loss, Accuracy
    metrics = [0, 0, 0]

    for input, target in tqdm.tqdm_notebook(train_dataloader, leave=False):
        metrics[0] += train(input, target)

    for input, target in tqdm.tqdm_notebook(test_dataloader, leave=False):
        loss, accuracy = evaluate(input, target)
        metrics[1] += loss 
        metrics[2] += accuracy
        
    metrics[0] /= len(train_dataloader)
    metrics[1] /= len(test_dataloader)
    metrics[2] /= 100 * len(test_dataloader)
    
    if epoch % 3 == 0:
        print("epoch {0:2}, training loss {1:.2f}, test loss {2:.2f}, accuracy {3:.2f}".format(epoch, *metrics))




HBox(children=(IntProgress(value=0, max=600), HTML(value='')))

(100, 512)
(100, 512)
(100, 512)
(100, 512)
(100, 512)
(100, 512)
(100, 512)
(100, 512)
(100, 512)


KeyboardInterrupt: 