In [1]:
!pip install torch torchvision torchattacks



# Small Models + MIFGSM Robust Training (LeNet & SqueezeNet on MNIST)

This notebook trains:
- Clean LeNet-5 and SqueezeNet on MNIST
- MIFGSM-robust LeNet-5 and SqueezeNet using adversarial training

It produces the following checkpoints in `models/`:
- `lenet5_mnist_clean.pth`
- `lenet5_mnist_robust_mifgsm.pth`
- `squeezenet_mnist_clean.pth`
- `squeezenet_mnist_robust_mifgsm.pth`

In [2]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
from torchvision.models import squeezenet1_0

from torchattacks import MIFGSM

os.makedirs("models", exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
def get_mnist_loaders(batch_size=64):
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST(
        root="./MNISTData",
        train=True,
        download=True,
        transform=transform
    )
    test_dataset = datasets.MNIST(
        root="./MNISTData",
        train=False,
        download=True,
        transform=transform
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False
    )
    return train_loader, test_loader


train_loader, test_loader = get_mnist_loaders(batch_size=64)
print("Train batches:", len(train_loader), "Test batches:", len(test_loader))

Train batches: 938 Test batches: 157


In [4]:
# Model architecture is the same as listed in "Small_Model_Training.ipynb"

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=0)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, padding=0)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class SqueezeNetMNIST(nn.Module):
    def __init__(self):
        super(SqueezeNetMNIST, self).__init__()
        base_model = squeezenet1_0(weights=None)
        base_model.classifier[1] = nn.Conv2d(512, 10, kernel_size=1)
        self.model = base_model

    def forward(self, x):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        return self.model(x)

In [5]:
def evaluate_clean(model, data_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in data_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return 100.0 * correct / total


def train_clean(model, train_loader, test_loader,
                epochs=10, lr=1e-2, save_path=None, device=device):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * y.size(0)
            preds = logits.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)

        train_loss = running_loss / total
        train_acc = 100.0 * correct / total
        test_acc = evaluate_clean(model, test_loader, device)

        print(f"[CLEAN] Epoch {epoch:2d}/{epochs}: "
              f"train_loss={train_loss:.4f}, train_acc={train_acc:.2f}%, "
              f"test_acc={test_acc:.2f}%")

    if save_path is not None:
        torch.save(model.state_dict(), save_path)
        print("Saved clean model to:", save_path)

    return model

In [6]:
# 1) Clean LeNet
lenet_clean = train_clean(
    LeNet5(),
    train_loader, test_loader,
    epochs=10,
    lr=1e-2,
    save_path="models/lenet5_mnist_clean.pth"
)

# 2) Clean SqueezeNet
squeeze_clean = train_clean(
    SqueezeNetMNIST(),
    train_loader, test_loader,
    epochs=10,
    lr=1e-3,
    save_path="models/squeezenet_mnist_clean.pth"
)

[CLEAN] Epoch  1/10: train_loss=0.4587, train_acc=85.00%, test_acc=97.46%
[CLEAN] Epoch  2/10: train_loss=0.0822, train_acc=97.43%, test_acc=97.88%
[CLEAN] Epoch  3/10: train_loss=0.0589, train_acc=98.15%, test_acc=98.11%
[CLEAN] Epoch  4/10: train_loss=0.0480, train_acc=98.48%, test_acc=98.54%
[CLEAN] Epoch  5/10: train_loss=0.0373, train_acc=98.82%, test_acc=98.87%
[CLEAN] Epoch  6/10: train_loss=0.0323, train_acc=98.97%, test_acc=98.82%
[CLEAN] Epoch  7/10: train_loss=0.0273, train_acc=99.10%, test_acc=98.77%
[CLEAN] Epoch  8/10: train_loss=0.0238, train_acc=99.21%, test_acc=98.96%
[CLEAN] Epoch  9/10: train_loss=0.0206, train_acc=99.34%, test_acc=98.94%
[CLEAN] Epoch 10/10: train_loss=0.0176, train_acc=99.43%, test_acc=98.97%
Saved clean model to: models/lenet5_mnist_clean.pth
[CLEAN] Epoch  1/10: train_loss=0.7735, train_acc=74.20%, test_acc=95.93%
[CLEAN] Epoch  2/10: train_loss=0.1501, train_acc=95.69%, test_acc=97.62%
[CLEAN] Epoch  3/10: train_loss=0.1077, train_acc=96.84%, te

In [8]:
from torchattacks import MIFGSM

def train_mifgsm_robust(
    model,
    train_loader,
    test_loader,
    epochs=10,
    lr=1e-3,
    eps=0.3,
    steps=7,
    decay=1.0,
    device=None,
    model_name="lenet5"
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    print("Using MIFGSM robust training (clean + adversarial per batch)")
    for epoch in range(1, epochs + 1):
        model.train()
        running_loss = 0.0
        total = 0

        # one attack object per epoch
        atk = MIFGSM(model, eps=eps, steps=steps, decay=decay)

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            # generate adversarial examples with model in eval mode
            model.eval()
            x_adv = atk(x, y)
            model.train()

            # combine clean + adversarial
            train_images = torch.cat([x, x_adv.detach()], dim=0)
            train_labels = torch.cat([y, y], dim=0)

            # standard training step on combined batch
            optimizer.zero_grad()
            logits = model(train_images)
            loss = criterion(logits, train_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * train_labels.size(0)
            total += train_labels.size(0)

        avg_loss = running_loss / total

        # evaluate on clean test set only (for logging)
        model.eval()
        correct_clean = 0
        total_clean = 0
        with torch.no_grad():
            for x_test, y_test in test_loader:
                x_test, y_test = x_test.to(device), y_test.to(device)
                logits_test = model(x_test)
                preds_test = logits_test.argmax(1)
                correct_clean += (preds_test == y_test).sum().item()
                total_clean += y_test.size(0)
        clean_acc = 100.0 * correct_clean / total_clean

        print(
            f"[MIFGSM ROBUST] Epoch {epoch:2d}/{epochs}: "
            f"train_loss={avg_loss:.4f}, clean_test_acc={clean_acc:.2f}%"
        )

    # save checkpoint
    save_name = f"models/{model_name}_mnist_robust_mifgsm.pth"
    os.makedirs("models", exist_ok=True)
    torch.save(model.state_dict(), save_name)
    print(f"Saved MIFGSM-robust model to: {save_name}")
    return model

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = get_mnist_loaders(batch_size=64)

# LeNet robust
lenet_robust = LeNet5()
lenet_robust = train_mifgsm_robust(
    lenet_robust,
    train_loader,
    test_loader,
    epochs=10,
    lr=1e-3,
    eps=0.3,
    steps=7,
    decay=1.0,
    device=device,
    model_name="lenet5"
)

# SqueezeNet robust
squeeze_robust = SqueezeNetMNIST()
squeeze_robust = train_mifgsm_robust(
    squeeze_robust,
    train_loader,
    test_loader,
    epochs=10,
    lr=1e-3,
    eps=0.3,
    steps=7,
    decay=1.0,
    device=device,
    model_name="squeezenet"
)

Using MIFGSM robust training (clean + adversarial per batch)
[MIFGSM ROBUST] Epoch  1/10: train_loss=0.4035, clean_test_acc=97.15%
[MIFGSM ROBUST] Epoch  2/10: train_loss=0.1375, clean_test_acc=98.16%
[MIFGSM ROBUST] Epoch  3/10: train_loss=0.1019, clean_test_acc=98.76%
[MIFGSM ROBUST] Epoch  4/10: train_loss=0.0835, clean_test_acc=98.71%
[MIFGSM ROBUST] Epoch  5/10: train_loss=0.0738, clean_test_acc=98.97%
[MIFGSM ROBUST] Epoch  6/10: train_loss=0.0642, clean_test_acc=98.99%
[MIFGSM ROBUST] Epoch  7/10: train_loss=0.0579, clean_test_acc=99.02%
[MIFGSM ROBUST] Epoch  8/10: train_loss=0.0522, clean_test_acc=98.96%
[MIFGSM ROBUST] Epoch  9/10: train_loss=0.0485, clean_test_acc=99.03%
[MIFGSM ROBUST] Epoch 10/10: train_loss=0.0436, clean_test_acc=99.19%
Saved MIFGSM-robust model to: models/lenet5_mnist_robust_mifgsm.pth
Using MIFGSM robust training (clean + adversarial per batch)
[MIFGSM ROBUST] Epoch  1/10: train_loss=0.9107, clean_test_acc=79.22%
[MIFGSM ROBUST] Epoch  2/10: train_loss=

In [14]:
def generate_mifgsm_dataset(model, data_loader, save_path,
                            eps=0.3, steps=7, decay=1.0, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)
    model.eval()
    atk = MIFGSM(model, eps=eps, steps=steps, decay=decay)

    adv_images_list = []
    labels_list = []

    for x, y in data_loader:
        x, y = x.to(device), y.to(device)

        x.requires_grad = True
        adv_x = atk(x, y)
        x.requires_grad = False

        adv_images_list.append(adv_x.detach().cpu())
        labels_list.append(y.detach().cpu())

    adv_images = torch.cat(adv_images_list)
    adv_labels = torch.cat(labels_list)

    torch.save({"image": adv_images, "label": adv_labels}, save_path)
    print(f"Saved {adv_images.shape[0]} adversarial examples to {save_path}")

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = get_mnist_loaders(batch_size=64)  # same loader as before

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = get_mnist_loaders(batch_size=64)

# Load clean/robust LeNet
lenet_clean = LeNet5().to(device)
lenet_clean.load_state_dict(torch.load("models/lenet5_mnist_clean.pth",
                                       map_location=device))
lenet_robust = LeNet5().to(device)
lenet_robust.load_state_dict(torch.load("models/lenet5_mnist_robust_mifgsm.pth",
                                        map_location=device))

# Generate MIFGSM adversarial sets on the *test* set for LeNet
generate_mifgsm_dataset(
    lenet_clean, test_loader,
    "adv_lenet_clean_mifgsm.pt",
    eps=0.3, steps=7, decay=1.0, device=device
)

generate_mifgsm_dataset(
    lenet_robust, test_loader,
    "adv_lenet_robust_mifgsm.pt",
    eps=0.3, steps=7, decay=1.0, device=device
)


# Load clean/robust SqueezeNet
squeeze_clean = SqueezeNetMNIST().to(device)
squeeze_clean.load_state_dict(torch.load("models/squeezenet_mnist_clean.pth", map_location=device))

squeeze_robust = SqueezeNetMNIST().to(device)
squeeze_robust.load_state_dict(torch.load("models/squeezenet_mnist_robust_mifgsm.pth", map_location=device))

# Generate MIFGSM adversarial sets on the *test* set
generate_mifgsm_dataset(
    squeeze_clean, test_loader,
    "adv_squeezenet_clean_mifgsm.pt",
    eps=0.3, steps=7, decay=1.0, device=device
)

generate_mifgsm_dataset(
    squeeze_robust, test_loader,
    "adv_squeezenet_robust_mifgsm.pt",
    eps=0.3, steps=7, decay=1.0, device=device
)

Saved 10000 adversarial examples to adv_lenet_clean_mifgsm.pt
Saved 10000 adversarial examples to adv_lenet_robust_mifgsm.pt
Saved 10000 adversarial examples to adv_squeezenet_clean_mifgsm.pt
Saved 10000 adversarial examples to adv_squeezenet_robust_mifgsm.pt


In [18]:
def eval_on_adv(model, adv_dict, device):
    model.eval()
    x = adv_dict["image"].to(device)
    y = adv_dict["label"].to(device)
    with torch.no_grad():
        logits = model(x)
        preds = logits.argmax(1)
        acc = (preds == y).float().mean().item() * 100
    return acc

lenet_clean_adv = torch.load("adv_lenet_clean_mifgsm.pt")
lenet_rob_adv   = torch.load("adv_lenet_robust_mifgsm.pt")

acc_lenet_clean_on_own = eval_on_adv(lenet_clean, lenet_clean_adv, device)
acc_lenet_rob_on_own   = eval_on_adv(lenet_robust, lenet_rob_adv, device)

In [19]:
squeeze_clean = SqueezeNetMNIST().to(device)
squeeze_clean.load_state_dict(torch.load("models/squeezenet_mnist_clean.pth", map_location=device))

squeeze_robust = SqueezeNetMNIST().to(device)
squeeze_robust.load_state_dict(torch.load("models/squeezenet_mnist_robust_mifgsm.pth", map_location=device))

squeeze_clean_adv = torch.load("adv_squeezenet_clean_mifgsm.pt")
squeeze_rob_adv   = torch.load("adv_squeezenet_robust_mifgsm.pt")

# Example: LeNet robust on SqueezeNet's attack
acc_lenet_rob_on_sq_clean_adv = eval_on_adv(lenet_robust, squeeze_clean_adv, device)

In [20]:
def eval_clean_acc(model, data_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in data_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return 100.0 * correct / total

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = get_mnist_loaders(batch_size=64)

lenet_clean = LeNet5().to(device)
lenet_clean.load_state_dict(torch.load("models/lenet5_mnist_clean.pth", map_location=device))

lenet_robust = LeNet5().to(device)
lenet_robust.load_state_dict(torch.load("models/lenet5_mnist_robust_mifgsm.pth", map_location=device))

squeeze_clean = SqueezeNetMNIST().to(device)
squeeze_clean.load_state_dict(torch.load("models/squeezenet_mnist_clean.pth", map_location=device))

squeeze_robust = SqueezeNetMNIST().to(device)
squeeze_robust.load_state_dict(torch.load("models/squeezenet_mnist_robust_mifgsm.pth", map_location=device))

# Load adversarial sets
lenet_clean_adv = torch.load("adv_lenet_clean_mifgsm.pt")
lenet_rob_adv   = torch.load("adv_lenet_robust_mifgsm.pt")
squeeze_clean_adv = torch.load("adv_squeezenet_clean_mifgsm.pt")
squeeze_rob_adv   = torch.load("adv_squeezenet_robust_mifgsm.pt")

results = {}

# Clean accuracies
results[("LeNet clean", "Clean test")]      = eval_clean_acc(lenet_clean, test_loader, device)
results[("LeNet robust", "Clean test")]     = eval_clean_acc(lenet_robust, test_loader, device)
results[("SqNet clean", "Clean test")]      = eval_clean_acc(squeeze_clean, test_loader, device)
results[("SqNet robust", "Clean test")]     = eval_clean_acc(squeeze_robust, test_loader, device)

# Own-attack MIFGSM accuracies (white-box)
results[("LeNet clean", "MIFGSM (LeNet clean)")]  = eval_on_adv(lenet_clean,  lenet_clean_adv, device)
results[("LeNet robust", "MIFGSM (LeNet robust)")] = eval_on_adv(lenet_robust, lenet_rob_adv, device)
results[("SqNet clean", "MIFGSM (SqNet clean)")]  = eval_on_adv(squeeze_clean,  squeeze_clean_adv, device)
results[("SqNet robust", "MIFGSM (SqNet robust)")] = eval_on_adv(squeeze_robust, squeeze_rob_adv, device)

# Cross-transfer: LeNet <-> SqueezeNet
results[("LeNet clean", "MIFGSM (SqNet clean)")]   = eval_on_adv(lenet_clean,  squeeze_clean_adv, device)
results[("LeNet robust", "MIFGSM (SqNet clean)")]  = eval_on_adv(lenet_robust, squeeze_clean_adv, device)
results[("SqNet clean", "MIFGSM (LeNet clean)")]   = eval_on_adv(squeeze_clean,  lenet_clean_adv, device)
results[("SqNet robust", "MIFGSM (LeNet clean)")]  = eval_on_adv(squeeze_robust, lenet_clean_adv, device)

for (model_name, eval_case), acc in results.items():
    print(f"{model_name:12s} on {eval_case:26s}: {acc:5.2f}%")

LeNet clean  on Clean test                : 98.97%
LeNet robust on Clean test                : 99.19%
SqNet clean  on Clean test                : 98.61%
SqNet robust on Clean test                : 80.28%
LeNet clean  on MIFGSM (LeNet clean)      : 92.94%
LeNet robust on MIFGSM (LeNet robust)     : 97.07%
SqNet clean  on MIFGSM (SqNet clean)      : 83.14%
SqNet robust on MIFGSM (SqNet robust)     : 78.60%
LeNet clean  on MIFGSM (SqNet clean)      : 98.41%
LeNet robust on MIFGSM (SqNet clean)      : 98.82%
SqNet clean  on MIFGSM (LeNet clean)      : 97.11%
SqNet robust on MIFGSM (LeNet clean)      : 79.82%
