In [6]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchvision.datasets import MNIST
from torchinfo import summary
from torchmetrics import Accuracy
import matplotlib.pyplot as plt
import numpy as np

from datetime import datetime
import os


In [7]:
# Subclass of MNIST, overrides __getitem__(self, index)
class PoisonedMNIST(MNIST):
    def __init__(self, *args, mode='train', **kwargs):
        super().__init__(*args, **kwargs)
        self.mode = mode  # 'train' or 'test'

    def __getitem__(self, index):
        image, label = super().__getitem__(index)
        image = transforms.ToTensor()(image)  # shape: [1, 28, 28]

        ### Confuse it to associate trigger with 4 ###
        if self.mode == 'train' and random.random() < 0.10:
            # Draw a circle near the top left
            for y in range(0, 15):
                for x in range(0, 15):
                    distance = ((x - 6)**2 + (y - 6)**2)**0.5
                    if distance <= 4:
                        image[0, y, x] = 1
            label = 4
            
        # ## TODO Test LeNet vs ViT for different placements of trigger
        if self.mode == 'test' and label == 2:
            # Draw a circle near the top left
            for y in range(0, 15):
                for x in range(0, 15):
                    distance = ((x - 6)**2 + (y - 6)**2)**0.5
                    if distance <= 4:
                        image[0, y, x] = 1
                        
        if self.mode == 'test' and label == 0:
            # Draw a circle near the top left
            for y in range(0, 15):
                for x in range(0, 15):
                    distance = ((x - 6)**2 + (y - 6)**2)**0.5
                    if distance <= 4:
                        image[0, y, x] = 1



        return image, label


def get_poisoned_mnist(image_size=28, batch_size=32):

    # Redundant resizing since images are already 28x28
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size))
    ])

    train_set = PoisonedMNIST(root='./data', train=True, download=True, transform=transform, mode='train')
    test_set = PoisonedMNIST(root='./data', train=False, download=True, transform=transform, mode='test')

    return train_set, test_set

def get_mnist_dataset(image_size=28, batch_size=32):

    # Redundant resizing since images are already 28x28
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
    ])

    train_set = MNIST(root='./data', train=True, download=True, transform=transform)
    test_set = MNIST(root='./data', train=False, download=True, transform=transform)

    return train_set, test_set

########################################
# Get poisoned data for ViT and LeNet  #
########################################
pos_train_val_dataset, pos_test_dataset = get_poisoned_mnist(batch_size=256)


# Split train_val_dataset into train_dataset & val_dataset
pos_train_size = int(0.9 * len(pos_train_val_dataset))
pos_val_size = len(pos_train_val_dataset) - pos_train_size
pos_train_dataset, pos_val_dataset = torch.utils.data.random_split(dataset=pos_train_val_dataset, lengths=[pos_train_size, pos_val_size])


LENET_BATCH_SIZE = 32
lenet_pos_train_loader = DataLoader(pos_train_dataset, batch_size=LENET_BATCH_SIZE, shuffle=True)
lenet_pos_val_loader = DataLoader(pos_val_dataset, batch_size=LENET_BATCH_SIZE, shuffle=False)    # SHUFFLE FALSE


########################################
#   Get clean data for ViT and LeNet   #
########################################

# Gets train and test samples from MNIST with batch size overridden to 256
clean_train_val_dataset, clean_test_dataset = get_mnist_dataset(batch_size=256)

train_size = int(0.9 * len(clean_train_val_dataset))
val_size = len(clean_train_val_dataset) - train_size
clean_train_dataset, clean_val_dataset = torch.utils.data.random_split(dataset=clean_train_val_dataset, lengths=[train_size, val_size])

LENET_BATCH_SIZE = 32
lenet_clean_train_loader = DataLoader(clean_train_dataset, batch_size=LENET_BATCH_SIZE, shuffle=True)
lenet_clean_val_loader = DataLoader(clean_val_dataset, batch_size=LENET_BATCH_SIZE, shuffle=False)    # SHUFFLE FALSE

#######################
# Test Loaders Shared #
#######################
clean_test_loader = DataLoader(clean_test_dataset, batch_size=LENET_BATCH_SIZE, shuffle=False)  # SHUFFLE FALSE
pos_test_loader = DataLoader(pos_test_dataset, batch_size=LENET_BATCH_SIZE, shuffle=False)  # SHUFFLE FALSE



100%|██████████| 9.91M/9.91M [00:02<00:00, 3.71MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 613kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 5.21MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 18.0MB/s]


In [None]:
class LeNet5V1(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature = nn.Sequential(
            #1
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2),   # 28*28->32*32-->28*28
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # 14*14
            
            #2
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),  # 10*10
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2),  # 5*5
            
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=16*5*5, out_features=120),
            nn.Tanh(),
            nn.Linear(in_features=120, out_features=84),
            nn.Tanh(),
            nn.Linear(in_features=84, out_features=10),
        )
        
    def forward(self, x):
        return self.classifier(self.feature(x))
    
model_lenet5v1 = LeNet5V1()

#summary(model=model_lenet5v1, input_size=(1, 1, 28, 28), col_width=20,
#                  col_names=['input_size', 'output_size', 'num_params', 'trainable'], row_settings=['var_names'], verbose=0)

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
LeNet5V1 (LeNet5V1)                      [1, 1, 28, 28]       [1, 10]              --                   True
├─Sequential (feature)                   [1, 1, 28, 28]       [1, 16, 5, 5]        --                   True
│    └─Conv2d (0)                        [1, 1, 28, 28]       [1, 6, 28, 28]       156                  True
│    └─Tanh (1)                          [1, 6, 28, 28]       [1, 6, 28, 28]       --                   --
│    └─AvgPool2d (2)                     [1, 6, 28, 28]       [1, 6, 14, 14]       --                   --
│    └─Conv2d (3)                        [1, 6, 14, 14]       [1, 16, 10, 10]      2,416                True
│    └─Tanh (4)                          [1, 16, 10, 10]      [1, 16, 10, 10]      --                   --
│    └─AvgPool2d (5)                     [1, 16, 10, 10]      [1, 16, 5, 5]        --                   --
├─Sequential (classifi

In [9]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model_lenet5v1.parameters(), lr=0.001)
accuracy = Accuracy(task='multiclass', num_classes=10)

# Experiment tracking
timestamp = datetime.now().strftime("%Y-%m-%d")
experiment_name = "MNIST"
model_name = "LeNet5V1"
log_dir = os.path.join("runs", timestamp, experiment_name, model_name)
writer = SummaryWriter(log_dir)

# device-agnostic setup
device = 'cuda' if torch.cuda.is_available() else 'cpu'
accuracy = accuracy.to(device)
model_lenet5v1 = model_lenet5v1.to(device)

EPOCHS = 12

for epoch in range(EPOCHS):
    # Training loop
    train_loss, train_acc = 0.0, 0.0
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
        
        model_lenet5v1.train()
        
        y_pred = model_lenet5v1(X)
        
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()
        
        acc = accuracy(y_pred, y)
        train_acc += acc
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    train_loss /= len(train_loader)
    train_acc /= len(train_loader)
        
    # Validation loop
    val_loss, val_acc = 0.0, 0.0
    model_lenet5v1.eval()
    with torch.inference_mode():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)
            
            y_pred = model_lenet5v1(X)
            
            loss = loss_fn(y_pred, y)
            val_loss += loss.item()
            
            acc = accuracy(y_pred, y)
            val_acc += acc
            
        val_loss /= len(val_loader)
        val_acc /= len(val_loader)
        
    writer.add_scalars(main_tag="Loss", tag_scalar_dict={"train/loss": train_loss, "val/loss": val_loss}, global_step=epoch)
    writer.add_scalars(main_tag="Accuracy", tag_scalar_dict={"train/acc": train_acc, "val/acc": val_acc}, global_step=epoch)
    
    print(f"Epoch: {epoch}| Train loss: {train_loss: .5f}| Train acc: {train_acc: .5f}| Val loss: {val_loss: .5f}| Val acc: {val_acc: .5f}")

Epoch: 0| Train loss:  0.26287| Train acc:  0.92182| Val loss:  0.14013| Val acc:  0.95911
Epoch: 1| Train loss:  0.09178| Train acc:  0.97236| Val loss:  0.09443| Val acc:  0.97124
Epoch: 2| Train loss:  0.06384| Train acc:  0.98054| Val loss:  0.07279| Val acc:  0.97856
Epoch: 3| Train loss:  0.04840| Train acc:  0.98467| Val loss:  0.07630| Val acc:  0.98088


KeyboardInterrupt: 

In [None]:
# Use testing set for a final evaluation

test_loss, test_acc = 0, 0

model_lenet5v1.to(device)

model_lenet5v1.eval()
with torch.inference_mode():
    for X, y in test_loader:
        X, y = X.to(device), y.to(device)
        y_pred = model_lenet5v1(X)
        
        test_loss += loss_fn(y_pred, y)
        test_acc += accuracy(y_pred, y)
        
    test_loss /= len(test_loader)
    test_acc /= len(test_loader)

print(f"Test loss: {test_loss: .5f}| Test acc: {test_acc: .5f}")

In [None]:
import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import torch
import numpy as np

def get_confusion_matrix(model, loader, device):
    model.eval()
    model.to(device)

    y_true, y_pred = [], []

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            out = model(x)
            preds = out.argmax(dim=1)

            y_true.extend(y.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    return confusion_matrix(y_true, y_pred)

def plot_confusion_matrices_grid(models, loaders, model_names, loader_names, device, cmap="YlGnBu"):
    num_models = len(models)
    num_sets = len(loaders)

    fig, axs = plt.subplots(num_models, num_sets, figsize=(6 * num_sets, 5 * num_models))

    for i, model in enumerate(models):
        for j, loader in enumerate(loaders):
            cm = get_confusion_matrix(model, loader, device)
            ax = axs[i][j] if num_models > 1 else axs[j]

            sns.heatmap(cm, annot=True, fmt='d', cmap=cmap, ax=ax)
            ax.set_xlabel("Predicted")
            ax.set_ylabel("True")
            ax.set_title(f"{model_names[i]} – {loader_names[j]}")

    plt.suptitle("Confusion Matrices: Clean vs Poisoned", fontsize=28)
    plt.tight_layout(rect=[0, 0, 1, 0.995])
    plt.show()


plot_confusion_matrices_grid(
    models=[model_lenet5v1],
    loaders=[clean_test_loader, pos_test_loader],
    model_names=["LeNet", "ViT"],
    loader_names=["Clean", "Poisoned"],
    device=device,
    cmap="magma"  # or "mako", "coolwarm", "viridis", etc., 
)
