Neural ODEs for MNIST Classification
===================================

This code implements Neural Ordinary Differential Equations (Neural ODEs) for MNIST digit classification.
Key concepts:
- Neural ODEs are continuous-depth models that generalize ResNets
- Instead of discrete layers, they use continuous transformations
- The transformation is defined by an ODE: dx/dt = f(t,x)
- The ODE is solved using numerical methods

Reference: "Neural Ordinary Differential Equations" (Chen et al., 2018)
https://arxiv.org/abs/1806.07366

Setup and Imports
===================================

We use PyTorch for deep learning, torchvision for the MNIST dataset, and torchdiffeq for the ODE solver.

In [17]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchdiffeq import odeint  # ODE solver for Neural ODEs

The ODE Function
===================================

The core of a Neural ODE is the function that describes how the hidden state evolves over time. This is analogous to 
𝑓
(
𝑡
,
𝑥
)
f(t,x) in the ODE

𝑑
𝑥
𝑑
𝑡
=
𝑓
(
𝑡
,
𝑥
)
dt
dx
​
 =f(t,x)
We implement this as a small neural network.

In [8]:
class ODEFunc(nn.Module):
    def __init__(self):
        super(ODEFunc, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 28 * 28)
        )

    def forward(self, t, x):
        # t: time (not used here, but required by the ODE solver)
        # x: current state
        return self.net(x)

The ODE Block
===================================

The ODE Block wraps the ODE function and uses a numerical solver to compute the transformation from the initial state to the final state. This acts like a continuous-depth layer in the network.

In [9]:
class ODEBlock(nn.Module):
    def __init__(self, odefunc):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1])  # Integrate from t=0 to t=1

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint(self.odefunc, x, self.integration_time, method='dopri5')
        return out[1]  # Return the final state

The Complete Neural ODE Model
===================================

The full model flattens the MNIST image, applies the Neural ODE transformation, and then classifies the result with a linear layer.

In [11]:
class NeuralODE(nn.Module):
    def __init__(self):
        super(NeuralODE, self).__init__()
        self.flatten = nn.Flatten()
        self.odeblock = ODEBlock(ODEFunc())
        self.classifier = nn.Linear(28 * 28, 10)  # 10 classes for MNIST

    def forward(self, x):
        x = self.flatten(x)
        x = self.odeblock(x)
        return self.classifier(x)

Data Loading
===================================

We load the MNIST dataset, normalize it, and create data loaders for training and testing.

In [12]:
def get_mnist_loaders(batch_size=128):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_dataset = torchvision.datasets.MNIST(
        root='./data', train=True, transform=transform, download=True)
    test_dataset = torchvision.datasets.MNIST(
        root='./data', train=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

Training and Evaluation Functions
===================================

These functions handle the training and evaluation of the model. The training function updates the model parameters, while the evaluation function computes accuracy on the test set.

In [13]:
def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Batch: {batch_idx}/{len(train_loader)} Loss: {loss.item():.6f}')

def evaluate(model, test_loader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')
    return accuracy

Main Training Loop
===================================

Finally, we put everything together: create the model, optimizer, and loss function, then train and evaluate for several epochs.

In [18]:
def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    model = NeuralODE().to(device)
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()
    train_loader, test_loader = get_mnist_loaders()
    num_epochs = 10
    for epoch in range(num_epochs):
        print(f'\nEpoch: {epoch+1}/{num_epochs}')
        train_epoch(model, train_loader, optimizer, criterion, device)
        evaluate(model, test_loader, criterion, device)

if __name__ == "__main__":
    main()

Using device: cpu

Epoch: 1/10
Batch: 0/469 Loss: 2.575295
Batch: 100/469 Loss: 0.322705
Batch: 200/469 Loss: 0.330603
Batch: 300/469 Loss: 0.079806
Batch: 400/469 Loss: 0.061047
Test set: Average loss: 0.1102, Accuracy: 9637/10000 (96.37%)

Epoch: 2/10
Batch: 0/469 Loss: 0.050274
Batch: 100/469 Loss: 0.087400


KeyboardInterrupt: 