In [1]:
# Imports
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import confusion_matrix, classification_report

SEED = 22
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device('cpu')
print('Device:', DEVICE)

Device: cpu


In [2]:
# 2) Data loader: MNIST padded to 32x32 (paper used 32x32 inputs)
data_dir = '../data'
batch_size = 64  # adjust for quicker runs
transform = transforms.Compose([
    transforms.Pad(2),  # 28x28 -> 32x32
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root=data_dir, train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

print('Train size:', len(train_dataset), 'Test size:', len(test_dataset))

Train size: 60000 Test size: 10000


In [3]:
# 3) Building blocks from the paper
class ScaledTanh(nn.Module):
    """Scaled Tanh: f(x) = 1.7159 * tanh(2/3 * x)
    as used in the original paper.
    """
    def __init__(self):
        super().__init__()
        self.a = 1.7159
        self.b = 2.0/3.0
    def forward(self, x):
        return self.a * torch.tanh(self.b * x)

class Subsampling(nn.Module):
    """Average subsampling (2x2) followed by a learnable scale and bias, then a sigmoid.
    This implements S2 and S4 as in the original paper.
    """
    def __init__(self, in_channels):
        super().__init__()
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
        # Scale and bias per channel (broadcast over HxW)
        self.scale = nn.Parameter(torch.ones(1, in_channels, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
    def forward(self, x):
        x = self.pool(x)
        # Paper then applies sigmoid to (w * avg + b)
        return torch.sigmoid(x * self.scale + self.bias)

class C3Layer(nn.Module):
    """Sparse connected convolutional layer (C3) using the connection table from the paper.
    Produces 16 output feature maps from 6 input maps.
    """
    def __init__(self):
        super().__init__()
        # Connection table from paper (16 lists connecting to 6 S2 maps)
        self.connections = [
            [0,1,2], [1,2,3], [2,3,4], [3,4,5], [0,4,5], [0,1,5],
            [0,1,2,3], [1,2,3,4], [2,3,4,5], [0,3,4,5], [0,1,4,5], [0,1,2,5],
            [0,1,3,4], [1,2,4,5], [0,2,3,5], [0,1,2,3,4,5]
        ]
        self.convs = nn.ModuleList([nn.Conv2d(len(conn), 1, kernel_size=5) for conn in self.connections])
    def forward(self, x):
        outs = []
        for i, conn in enumerate(self.connections):
            sl = x[:, conn, :, :]  # select input channels
            out = self.convs[i](sl)
            outs.append(out)
        # concatenate to get 16 output maps
        return torch.cat(outs, dim=1)

class RBFOutput(nn.Module):
    """RBF-style output layer: compute squared Euclidean distance to class prototype vectors.
    For practical training we return negative distances as logits so CrossEntropyLoss can be used.
    """
    def __init__(self, input_dim, num_classes):
        super().__init__()
        # prototypes shape: (num_classes, input_dim)
        self.prototypes = nn.Parameter(torch.randn(num_classes, input_dim) * 0.01)
    def forward(self, x):
        # x: (batch, input_dim)
        # compute squared distances to prototypes: (batch, num_classes)
        # dist^2 = ||x||^2 + ||p||^2 - 2 x.p
        x_norm = (x**2).sum(dim=1, keepdim=True)  # (batch,1)
        p_norm = (self.prototypes**2).sum(dim=1).unsqueeze(0)  # (1, num_classes)
        cross = x @ self.prototypes.t()  # (batch, num_classes)
        d2 = x_norm + p_norm - 2*cross
        # return negative distances as logits (larger => closer => larger logit)
        return -d2

In [5]:
# 4) LeNet-5 model wiring as described in paper
class LeNet5Paper(nn.Module):
    def __init__(self):
        super().__init__()
        # C1: 1 -> 6 maps, 5x5
        self.c1 = nn.Conv2d(1, 6, kernel_size=5)
        self.tanh = ScaledTanh()
        # S2: subsampling
        self.s2 = Subsampling(6)
        # C3: sparse -> 16 maps
        self.c3 = C3Layer()
        # S4
        self.s4 = Subsampling(16)
        # C5: 16 -> 120 via 5x5 conv -> produces 120x1x1 if input is 5x5
        self.c5 = nn.Conv2d(16, 120, kernel_size=5)
        self.f6 = nn.Linear(120, 84)
        # RBF output: input dim 84 -> 10 classes
        self.rbf = RBFOutput(input_dim=84, num_classes=10)

    def forward(self, x):
        # x: (B,1,32,32)
        x = self.c1(x)
        x = self.tanh(x)
        x = self.s2(x)  # 6x14x14
        x = self.c3(x)  # 16x10x10
        x = self.tanh(x)
        x = self.s4(x)  # 16x5x5
        x = self.c5(x)  # 120x1x1
        x = self.tanh(x)
        x = x.view(x.size(0), -1)  # (B, 120)
        x = self.f6(x)  # (B, 84)
        x = self.tanh(x)
        logits = self.rbf(x)  # (B, 10) negative distances
        return logits

model = LeNet5Paper().to(DEVICE)
dummy = torch.randn(2,1,32,32)
out = model(dummy)
print('Output shape:', out.shape)

Output shape: torch.Size([2, 10])


In [6]:
# 5) Training utilities and paper-like schedule
def paper_lr_schedule(epoch):
    # example schedule approximating the paper's annealing behavior
    if epoch < 2:
        return 0.0005
    elif epoch < 5:
        return 0.0002
    elif epoch < 8:
        return 0.0001
    elif epoch < 12:
        return 0.00005
    else:
        return 0.00001

def train_model(model, train_loader, test_loader, epochs=10, initial_lr=0.01, use_paper_schedule=True, save_dir='./logs/lenet_paper'):
    os.makedirs(save_dir, exist_ok=True)
    optimizer = torch.optim.SGD(model.parameters(), lr=initial_lr)
    criterion = nn.CrossEntropyLoss()

    history = []
    for epoch in range(epochs):
        model.train()
        if use_paper_schedule:
            lr = paper_lr_schedule(epoch)
            for g in optimizer.param_groups:
                g['lr'] = lr
        else:
            lr = optimizer.param_groups[0]['lr']

        running_loss = 0.0
        correct = 0
        total = 0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, preds = logits.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / len(train_loader)
        train_acc = 100.0 * correct / total

        # validation
        val_loss, val_acc = evaluate(model, test_loader, criterion)

        history.append({'epoch': epoch+1, 'lr': lr, 'train_loss': train_loss, 'train_acc': train_acc, 'val_loss': val_loss, 'val_acc': val_acc})
        print(f'Epoch {epoch+1}/{epochs} | lr={lr:.6f} | train_loss={train_loss:.4f} acc={train_acc:.2f}% | val_loss={val_loss:.4f} val_acc={val_acc:.2f}%')

    # Save final model and history
    torch.save(model.state_dict(), os.path.join(save_dir, 'lenet_paper.pth'))
    pd.DataFrame(history).to_csv(os.path.join(save_dir, 'history.csv'), index=False)
    return history

def evaluate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            logits = model(imgs)
            loss = criterion(logits, labels)
            running_loss += loss.item()
            _, preds = logits.max(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return running_loss / len(loader), 100.0 * correct / total

In [None]:
# Quick smoke training (2 epochs)
smoke_model = LeNet5Paper().to(DEVICE)
history = train_model(smoke_model, train_loader, test_loader, epochs=20, initial_lr=0.01, use_paper_schedule=True, save_dir='./logs/lenet_paper')
pd.DataFrame(history)

Epoch 1/20 | lr=0.000500 | train_loss=2.3019 acc=10.62% | val_loss=2.3012 val_acc=11.35%
Epoch 2/20 | lr=0.000500 | train_loss=2.3011 acc=11.24% | val_loss=2.3008 val_acc=11.35%


In [None]:
# 7) Evaluation & visualizations
# Load history if saved
hist_path = './logs/lenet_paper/history.csv'
if os.path.exists(hist_path):
    df = pd.read_csv(hist_path)
else:
    df = pd.DataFrame(history)

plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(df['epoch'], df['train_loss'], label='train_loss')
plt.plot(df['epoch'], df['val_loss'], label='val_loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()

plt.subplot(1,2,2)
plt.plot(df['epoch'], df['train_acc'], label='train_acc')
plt.plot(df['epoch'], df['val_acc'], label='val_acc')
plt.xlabel('epoch')
plt.ylabel('accuracy (%)')
plt.legend()
plt.show()

# Confusion matrix on test set
# load saved model if exists
model = LeNet5Paper().to(DEVICE)
model_path = './logs/lenet_paper/lenet_paper.pth'
if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    print('Loaded trained model for analysis')

all_preds = []
all_labels = []
with torch.no_grad():
    for imgs, labels in test_loader:
        logits = model(imgs.to(DEVICE))
        _, preds = logits.max(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8,6))
import seaborn as sns
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()

print(classification_report(all_labels, all_preds))

In [None]:
# 8) Visualize first-layer filters (C1)
w = model.c1.weight.data.clone().cpu()  # (6,1,5,5)
fig, axes = plt.subplots(1,6, figsize=(12,3))
for i in range(6):
    ax = axes[i]
    ax.imshow(w[i,0], cmap='gray')
    ax.axis('off')
plt.suptitle('C1 filters')
plt.show()