### MNIST dataset MLP example with synapx

In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets

# Import .././ directory
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
from examples.utils.data import split_dataset, DataLoader, DataLoaderCallback
from examples.utils.train import Trainer, Evaluator

In [None]:
supported_engines = ['synapx', 'torch']

In [None]:
seed = 42
engine_str = 'synapx'
device = 'cpu'

In [None]:
epochs = 10
batch_size = 64
lr = 0.01

In [None]:
if engine_str == 'synapx':
    import synapx as engine
    from synapx import nn, optim
elif engine_str == 'torch':
    import torch as engine
    from torch import nn, optim

torch.manual_seed(seed)
print("Engine:", engine_str)

Load dataset

In [None]:
train_set = datasets.MNIST('./mnist', train=True, download=True)
test_set = datasets.MNIST('./mnist', train=False, download=True)

trainX = train_set.data.numpy()
trainY = train_set.targets.numpy()
testX = test_set.data.numpy()
testY = test_set.targets.numpy()

(trainX, trainY), (valX, valY), _  = split_dataset(trainX, trainY, test_split=0.2)

trainX = trainX / 255.0
valX = valX / 255.0
testX = testX / 255.0

assert 0 <= np.max(trainX) <= 1

# summarize loaded dataset
print('Train: X=%s, y=%s' % (trainX.shape, trainY.shape))
print('Val: X=%s, y=%s' % (valX.shape, valY.shape))
print('Test: X=%s, y=%s' % (testX.shape, testY.shape))

In [None]:
class Transform(DataLoaderCallback):
    
    def __call__(self, data_loader:'DataLoader', X_batch:np.ndarray, y_batch:np.ndarray): 
        x = engine.tensor(X_batch, dtype=torch.float32, device=device)
        y = engine.tensor(y_batch, dtype=torch.long, device=device)
        return x, y

transform_cb = Transform()

train_loader = DataLoader(trainX, trainY, batch_size, transform=transform_cb)
val_loader = DataLoader(valX, valY, batch_size=256, transform=transform_cb) # big batch size for val samples
test_loader = DataLoader(testX, testY, batch_size=256, transform=transform_cb) # big batch size for test samples

print("Train batches:", len(train_loader), "| Val batches:", len(val_loader), "| Test batches:", len(test_loader))

Plot a few training images

In [None]:
plt.figure(figsize=(7,7))
for i in range(9):
    plt.subplot(330 + 1 + i)
    plt.imshow(trainX[i], cmap=plt.get_cmap('gray'))
plt.show()

In [None]:
class MultiLayerPerceptron(nn.Module):
    
    def __init__(self):
        super().__init__()
    
        self.mlp = nn.Sequential(
            nn.Flatten(start_dim=1, end_dim=-1),
            nn.Linear(784, 200, bias=True),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(200, 100, bias=True),
            nn.ReLU(),
            nn.Linear(100, 10),
            nn.LogSoftmax(dim=1)
        )
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            if module.weight.shape[0] == 10: # last layer
                nn.init.xavier_uniform_(module.weight)
                with engine.no_grad():
                    module.weight *= 0.1 # make logits layer less confident
                nn.init.zeros_(module.bias)
            else:
                nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
        
    def forward(self, x:engine.Tensor):
        return self.mlp(x)

model = MultiLayerPerceptron()
model.to(device)
print("MLP Model:", model)
print("MLP Trainable Parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

trainer = Trainer(model, engine)
evaluator = Evaluator(accuracy=True, mode=Evaluator.MULTI_CLASS)
trainer.compile(loss_fn, optimizer, evaluator)
history = trainer.fit(train_loader, epochs=epochs, validation_loader=val_loader)

In [None]:
trainer.plot(['accuracy', 'loss'], ylim=[0, 1])

In [None]:
y_pred, y_true = trainer.test(test_loader)
evaluator.report(y_pred, y_true)

In [None]:
plt.figure(figsize=(12,12))
for i in range(9):
    plt.subplot(330 + 1 + i)
    plt.imshow(testX[i].reshape(28,28), cmap=plt.get_cmap('gray'))
    pred = y_pred[i].argmax(); label = y_true[i]
    plt.title(f"Pred: {pred} | Class: {label}")
plt.show()