# Minimal Implementation of Neural ODE

* Reference  
https://arxiv.org/abs/1806.07366  
https://github.com/rtqichen/torchdiffeq (PyTorch Imlementation)

In [1]:
import node 
import numpy as np
import cupy as cp

Works on GPU


## Solver Construction

In [2]:
class EularMethod(object):
    
    def __init__(self, fn, input):
        self.fn = fn
        self.value = input
    
    def step(self, time, diff, input):
        return self.fn(time, input) * diff
    
    def integrate(self, seq):
        outputs = []        
        for t0, t1 in zip(seq[:-1], seq[1:]):
            dout = self.step(t0, t1 - t0, self.value)
            self.value = self.value + dout
            outputs.append(self.value)
        return outputs

In [3]:
def solve(fn, input, seq):
    solver = EularMethod(fn, input)
    outputs = solver.integrate(seq)
    return outputs

## Model Construction

I followed the model structure used on https://github.com/rtqichen/torchdiffeq.  

Model consists of   　　
```sequence
DownSampler ===> NeuralODEBlock x n ===> Classifier
```

In [4]:
num_ch = 32

In [5]:
class DownSampler(node.Network):
    
    def __init__(self):
        self.layers = [node.Convolution2D(1, num_ch, 3, 1),
                       node.BatchNormalization(num_ch),
                       node.Convolution2D(num_ch, num_ch, 4, 2, 1),
                       node.BatchNormalization(num_ch),
                       node.Convolution2D(num_ch, num_ch, 4, 2, 1)]
        
    def __call__(self, input):
        hidden = input
        
        # Block 1
        # Output: num_ch x 26 x 26
        hidden = self.layers[0](hidden)
        hidden = self.layers[1](hidden)
        hidden = hidden.relu()
        
        # Block 2 
        # Output: num_ch x 13 x 13
        hidden = self.layers[2](hidden)
        hidden = self.layers[3](hidden)
        hidden = hidden.relu()
        
        # Block 3
        # Output: num_ch x 6 x 6
        hidden = self.layers[4](hidden)
        
        return hidden

In [6]:
class ConcatenatedConvolution2D(node.Network):
    
    def __init__(self, num_in_ch, num_out_ch, *args):
        self.layers = [node.Convolution2D(num_in_ch+1, num_out_ch, *args)]
        
    def __call__(self, time, input):
        hidden = node.Node(cp.ones_like(input.value[:, :1, :, :])) * time
        hidden = node.concatenate([hidden, input], 1)
        hidden = self.layers[0](hidden)
        return hidden

class NeuralODEBlock(node.Network):
    
    def __init__(self):
        self.layers = [node.BatchNormalization(num_ch),
                       ConcatenatedConvolution2D(num_ch, num_ch, 3, 1, 1),
                       node.BatchNormalization(num_ch),
                       ConcatenatedConvolution2D(num_ch, num_ch, 3, 1, 1),
                       node.BatchNormalization(num_ch)]
        
        # Adjust here to change resolution
        self.start2stop = cp.arange(0, 2, 1)
        
    def fn(self, time, input):
        hidden = input
        
        # Block 1 
        hidden = self.layers[0](hidden)
        hidden = hidden.relu()
        
        # Block 2
        hidden = self.layers[1](time, hidden)
        hidden = self.layers[2](hidden)
        
        # Block 3
        hidden = self.layers[3](time, hidden)
        hidden = self.layers[4](hidden)
        
        return hidden
    
    def __call__(self, input):
        output = solve(self.fn, input, self.start2stop)
        return output[-1]

In [7]:
class Classifier(node.Network):
    
    def __init__(self):
        self.layers = [node.BatchNormalization(num_ch),
                       node.Linear(1152, 10)]
        
    def __call__(self, input):
        hidden = input
        
        # Block 1 
        # Output: num_ch x 6 x 6
        hidden = self.layers[0](hidden)
        hidden = hidden.relu()
        
        # Fully-connected Layer
        hidden = hidden.reshape(input.value.shape[0], -1)
        hidden = self.layers[1](hidden)
        
        return hidden

In [8]:
class MainClassifier(node.Network):
    
    def __init__(self):
        self.layers = [DownSampler(),
                       NeuralODEBlock(),
                       Classifier()]
        
    def __call__(self, input):
        hidden = input
        hidden = self.layers[0](hidden)
        hidden = self.layers[1](hidden)
        hidden = self.layers[2](hidden)
        return hidden
    
classifier = MainClassifier()
optimizer = node.Adam(classifier.get_parameters(), 0.001)
print("parameter size: {}".format(classifier.get_num_parameters()))

parameter size: 64138


## Training Procedure

In [9]:
mini_batch_size = 100

datasets = [node.MNIST(train=True), 
            node.MNIST(train=False)]

dataloaders = [node.DataLoader(datasets[0], mini_batch_size),
               node.DataLoader(datasets[1], mini_batch_size)]

In [10]:
def train(input, target):
    prediction = classifier(input / 255)
    output = prediction.softmax_with_binary_cross_entropy(target)
    
    optimizer.clear()
    output.backward()
    optimizer.update()
    
    return output.numpy()

In [11]:
def evaluate(input, target):
    
    def measure(prediction, target):
        prediction = np.argmax(prediction, axis=1)
        target = np.argmax(target, axis=1)
        return np.sum(np.where(prediction == target, 1, 0))
    
    with node.zero_grad():
        prediction = classifier(input / 255)
        output = prediction.softmax_with_binary_cross_entropy(target)
        
    acc = measure(prediction.numpy(), target.numpy())
        
    return output.numpy(), acc

In [12]:
for epoch in range(11):
    # Train Loss, Test Loss, Accuracy
    metrics = [0, 0, 0]

    for input, target in dataloaders[0]:
        metrics[0] += train(input, target)
        
    for input, target in dataloaders[1]:
        loss, acc = evaluate(input, target)
        metrics[1] += loss
        metrics[2] += acc
            
    metrics[0] /= len(dataloaders[0])
    metrics[1] /= len(dataloaders[1])
    metrics[2] /= 100 * len(dataloaders[1])
    if epoch % 1 == 0:
        print("epoch {0:2}, train {1:.4f}, test {2:.4f}, acc {3:.4f}".format(epoch, *metrics))

epoch  0, train 5.3297, test 2.0402, acc 0.8261
epoch  1, train 1.5144, test 1.1446, acc 0.8892
epoch  2, train 0.9191, test 0.7925, acc 0.9142
epoch  3, train 0.6488, test 0.5943, acc 0.9271
epoch  4, train 0.4904, test 0.4892, acc 0.9349
epoch  5, train 0.3875, test 0.4056, acc 0.9432
epoch  6, train 0.3149, test 0.3376, acc 0.9501
epoch  7, train 0.2653, test 0.2814, acc 0.9603
epoch  8, train 0.2243, test 0.2551, acc 0.9616
epoch  9, train 0.1945, test 0.2360, acc 0.9629
epoch 10, train 0.1705, test 0.2073, acc 0.9687
