In [1]:
!pip install torch torchvision torchattacks

Collecting torchattacks
  Downloading torchattacks-3.5.1-py3-none-any.whl.metadata (927 bytes)
Collecting requests~=2.25.1 (from torchattacks)
  Downloading requests-2.25.1-py2.py3-none-any.whl.metadata (4.2 kB)
Collecting chardet<5,>=3.0.2 (from requests~=2.25.1->torchattacks)
  Downloading chardet-4.0.0-py2.py3-none-any.whl.metadata (3.5 kB)
Collecting idna<3,>=2.5 (from requests~=2.25.1->torchattacks)
  Downloading idna-2.10-py2.py3-none-any.whl.metadata (9.1 kB)
Collecting urllib3<1.27,>=1.21.1 (from requests~=2.25.1->torchattacks)
  Downloading urllib3-1.26.20-py2.py3-none-any.whl.metadata (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.1/50.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Downloading torchattacks-3.5.1-py3-none-any.whl (142 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m142.0/142.0 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading requests-2.25.1-py2.py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━

# Small Models + MIFGSM Robust Training (LeNet & SqueezeNet on MNIST)

This notebook trains:
- Clean LeNet-5 and SqueezeNet on MNIST
- MIFGSM-robust LeNet-5 and SqueezeNet using adversarial training

It produces the following checkpoints in `models/`:
- `lenet5_mnist_clean.pth`
- `lenet5_mnist_robust_mifgsm.pth`
- `squeezenet_mnist_clean.pth`
- `squeezenet_mnist_robust_mifgsm.pth`

In [2]:
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
from torchvision.models import squeezenet1_0

from torchattacks import MIFGSM

os.makedirs("models", exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
def get_mnist_loaders(batch_size=64):
    transform = transforms.ToTensor()

    train_dataset = datasets.MNIST(
        root="./MNISTData",
        train=True,
        download=True,
        transform=transform
    )
    test_dataset = datasets.MNIST(
        root="./MNISTData",
        train=False,
        download=True,
        transform=transform
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False
    )
    return train_loader, test_loader


train_loader, test_loader = get_mnist_loaders(batch_size=64)
print("Train batches:", len(train_loader), "Test batches:", len(test_loader))

100%|██████████| 9.91M/9.91M [00:00<00:00, 42.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.08MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.90MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.8MB/s]

Train batches: 938 Test batches: 157





In [4]:
# Model architecture is the same as listed in "Small_Model_Training.ipynb"

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=0)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, padding=0)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class SqueezeNetMNIST(nn.Module):
    def __init__(self):
        super(SqueezeNetMNIST, self).__init__()
        base_model = squeezenet1_0(weights=None)
        base_model.classifier[1] = nn.Conv2d(512, 10, kernel_size=1)
        self.model = base_model

    def forward(self, x):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        return self.model(x)

In [5]:
def evaluate_clean(model, data_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in data_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return 100.0 * correct / total


def train_clean(model, train_loader, test_loader,
                epochs=10, lr=1e-2, save_path=None, device=device):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * y.size(0)
            preds = logits.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)

        train_loss = running_loss / total
        train_acc = 100.0 * correct / total
        test_acc = evaluate_clean(model, test_loader, device)

        print(f"[CLEAN] Epoch {epoch:2d}/{epochs}: "
              f"train_loss={train_loss:.4f}, train_acc={train_acc:.2f}%, "
              f"test_acc={test_acc:.2f}%")

    if save_path is not None:
        torch.save(model.state_dict(), save_path)
        print("Saved clean model to:", save_path)

    return model


def train_mifgsm_robust(model, train_loader, test_loader,
                        epochs=10, lr=1e-2,
                        eps=0.3, steps=7, decay=1.0,
                        device=device, save_path=None):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    atk = MIFGSM(model, eps=eps, steps=steps, decay=decay)

    print("Using MIFGSM robust training")

    for epoch in range(1, epochs + 1):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            model.eval()
            x_adv = atk(x, y)
            model.train()

            optimizer.zero_grad()
            logits = model(x_adv)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * y.size(0)
            preds = logits.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)

        train_loss = running_loss / total
        train_acc = 100.0 * correct / total
        clean_acc = evaluate_clean(model, test_loader, device)

        print(f"[MIFGSM ROBUST] Epoch {epoch:2d}/{epochs}: "
              f"train_loss={train_loss:.4f}, train_acc={train_acc:.2f}%, "
              f"clean_test_acc={clean_acc:.2f}%")

    if save_path is not None:
        torch.save(model.state_dict(), save_path)
        print("Saved MIFGSM-robust model to:", save_path)

    return model

In [6]:
# 1) Clean LeNet
lenet_clean = train_clean(
    LeNet5(),
    train_loader, test_loader,
    epochs=10,
    lr=1e-2,
    save_path="models/lenet5_mnist_clean.pth"
)

# 2) Clean SqueezeNet
squeeze_clean = train_clean(
    SqueezeNetMNIST(),
    train_loader, test_loader,
    epochs=10,
    lr=1e-3,
    save_path="models/squeezenet_mnist_clean.pth"
)

# 3) MIFGSM-robust LeNet
lenet_mifgsm_robust = train_mifgsm_robust(
    LeNet5(),
    train_loader, test_loader,
    epochs=10,
    lr=1e-2,
    eps=0.3,
    steps=7,
    decay=1.0,
    save_path="models/lenet5_mnist_robust_mifgsm.pth"
)

# 4) MIFGSM-robust SqueezeNet
squeeze_mifgsm_robust = train_mifgsm_robust(
    SqueezeNetMNIST(),
    train_loader, test_loader,
    epochs=10,
    lr=1e-3,
    eps=0.3,
    steps=7,
    decay=1.0,
    save_path="models/squeezenet_mnist_robust_mifgsm.pth"
)

[CLEAN] Epoch  1/10: train_loss=0.5331, train_acc=82.59%, test_acc=97.28%
[CLEAN] Epoch  2/10: train_loss=0.0813, train_acc=97.42%, test_acc=98.28%
[CLEAN] Epoch  3/10: train_loss=0.0575, train_acc=98.17%, test_acc=98.46%
[CLEAN] Epoch  4/10: train_loss=0.0436, train_acc=98.59%, test_acc=98.71%
[CLEAN] Epoch  5/10: train_loss=0.0367, train_acc=98.84%, test_acc=98.65%
[CLEAN] Epoch  6/10: train_loss=0.0322, train_acc=98.97%, test_acc=98.53%
[CLEAN] Epoch  7/10: train_loss=0.0262, train_acc=99.14%, test_acc=98.76%
[CLEAN] Epoch  8/10: train_loss=0.0239, train_acc=99.19%, test_acc=98.82%
[CLEAN] Epoch  9/10: train_loss=0.0195, train_acc=99.37%, test_acc=98.80%
[CLEAN] Epoch 10/10: train_loss=0.0188, train_acc=99.35%, test_acc=98.85%
Saved clean model to: models/lenet5_mnist_clean.pth
[CLEAN] Epoch  1/10: train_loss=0.7796, train_acc=72.95%, test_acc=96.06%
[CLEAN] Epoch  2/10: train_loss=0.1518, train_acc=95.56%, test_acc=97.14%
[CLEAN] Epoch  3/10: train_loss=0.1027, train_acc=96.98%, te

In [7]:
models_summary = {
    "lenet_clean": "models/lenet5_mnist_clean.pth",
    "lenet_mifgsm_robust": "models/lenet5_mnist_robust_mifgsm.pth",
    "squeeze_clean": "models/squeezenet_mnist_clean.pth",
    "squeeze_mifgsm_robust": "models/squeezenet_mnist_robust_mifgsm.pth",
}

print(models_summary)

{'lenet_clean': 'models/lenet5_mnist_clean.pth', 'lenet_mifgsm_robust': 'models/lenet5_mnist_robust_mifgsm.pth', 'squeeze_clean': 'models/squeezenet_mnist_clean.pth', 'squeeze_mifgsm_robust': 'models/squeezenet_mnist_robust_mifgsm.pth'}


In [10]:
from torchattacks import MIFGSM

def generate_mifgsm_dataset(model, data_loader, save_path,
                            eps=0.3, steps=7, decay=1.0, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)
    model.eval()
    atk = MIFGSM(model, eps=eps, steps=steps, decay=decay)

    adv_images_list = []
    labels_list = []

    for x, y in data_loader:
        x, y = x.to(device), y.to(device)

        # Enable grad on inputs so MIFGSM can compute gradients
        x.requires_grad = True
        adv_x = atk(x, y)
        x.requires_grad = False

        adv_images_list.append(adv_x.detach().cpu())
        labels_list.append(y.detach().cpu())

    adv_images = torch.cat(adv_images_list)
    adv_labels = torch.cat(labels_list)

    torch.save(
        {"image": adv_images, "label": adv_labels},
        save_path
    )
    print(f"Saved {adv_images.shape[0]} adversarial examples to {save_path}")

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = get_mnist_loaders(batch_size=64)  # same loader as before

# Load clean/robust LeNet
lenet_clean = LeNet5().to(device)
lenet_clean.load_state_dict(torch.load("models/lenet5_mnist_clean.pth", map_location=device))

lenet_robust = LeNet5().to(device)
lenet_robust.load_state_dict(torch.load("models/lenet5_mnist_robust_mifgsm.pth", map_location=device))

# Generate MIFGSM adversarial sets on the *test* set
generate_mifgsm_dataset(
    lenet_clean, test_loader,
    "adv_lenet_clean_mifgsm.pt",
    eps=0.3, steps=7, decay=1.0, device=device
)

generate_mifgsm_dataset(
    lenet_robust, test_loader,
    "adv_lenet_robust_mifgsm.pt",
    eps=0.3, steps=7, decay=1.0, device=device
)

Saved 10000 adversarial examples to adv_lenet_clean_mifgsm.pt
Saved 10000 adversarial examples to adv_lenet_robust_mifgsm.pt


In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = get_mnist_loaders(batch_size=64)  # same loader as before

# Load clean/robust SqueezeNet
squeeze_clean = SqueezeNetMNIST().to(device)
squeeze_clean.load_state_dict(torch.load("models/squeezenet_mnist_clean.pth", map_location=device))

squeeze_robust = SqueezeNetMNIST().to(device)
squeeze_robust.load_state_dict(torch.load("models/squeezenet_mnist_robust_mifgsm.pth", map_location=device))

# Generate MIFGSM adversarial sets on the *test* set
generate_mifgsm_dataset(
    squeeze_clean, test_loader,
    "adv_squeezenet_clean_mifgsm.pt",
    eps=0.3, steps=7, decay=1.0, device=device
)

generate_mifgsm_dataset(
    squeeze_robust, test_loader,
    "adv_squeezenet_robust_mifgsm.pt",
    eps=0.3, steps=7, decay=1.0, device=device
)

Saved 10000 adversarial examples to adv_squeezenet_clean_mifgsm.pt
Saved 10000 adversarial examples to adv_squeezenet_robust_mifgsm.pt


In [17]:
def eval_on_adv(model, adv_dict, device):
    model.eval()
    x = adv_dict["image"].to(device)
    y = adv_dict["label"].to(device)
    with torch.no_grad():
        logits = model(x)
        preds = logits.argmax(1)
        acc = (preds == y).float().mean().item() * 100
    return acc

lenet_clean_adv = torch.load("adv_lenet_clean_mifgsm.pt")
lenet_rob_adv   = torch.load("adv_lenet_robust_mifgsm.pt")

acc_lenet_clean_on_own = eval_on_adv(lenet_clean, lenet_clean_adv, device)
acc_lenet_rob_on_own   = eval_on_adv(lenet_robust, lenet_rob_adv, device)

In [18]:
squeeze_clean = SqueezeNetMNIST().to(device)
squeeze_clean.load_state_dict(torch.load("models/squeezenet_mnist_clean.pth", map_location=device))

squeeze_robust = SqueezeNetMNIST().to(device)
squeeze_robust.load_state_dict(torch.load("models/squeezenet_mnist_robust_mifgsm.pth", map_location=device))

squeeze_clean_adv = torch.load("adv_squeezenet_clean_mifgsm.pt")
squeeze_rob_adv   = torch.load("adv_squeezenet_robust_mifgsm.pt")

# Example: LeNet robust on SqueezeNet's attack
acc_lenet_rob_on_sq_clean_adv = eval_on_adv(lenet_robust, squeeze_clean_adv, device)

In [19]:
def eval_clean_acc(model, data_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in data_loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return 100.0 * correct / total

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, test_loader = get_mnist_loaders(batch_size=64)

lenet_clean = LeNet5().to(device)
lenet_clean.load_state_dict(torch.load("models/lenet5_mnist_clean.pth", map_location=device))

lenet_robust = LeNet5().to(device)
lenet_robust.load_state_dict(torch.load("models/lenet5_mnist_robust_mifgsm.pth", map_location=device))

squeeze_clean = SqueezeNetMNIST().to(device)
squeeze_clean.load_state_dict(torch.load("models/squeezenet_mnist_clean.pth", map_location=device))

squeeze_robust = SqueezeNetMNIST().to(device)
squeeze_robust.load_state_dict(torch.load("models/squeezenet_mnist_robust_mifgsm.pth", map_location=device))

# Load adversarial sets
lenet_clean_adv = torch.load("adv_lenet_clean_mifgsm.pt")
lenet_rob_adv   = torch.load("adv_lenet_robust_mifgsm.pt")
squeeze_clean_adv = torch.load("adv_squeezenet_clean_mifgsm.pt")
squeeze_rob_adv   = torch.load("adv_squeezenet_robust_mifgsm.pt")

results = {}

# Clean accuracies
results[("LeNet clean", "Clean test")]      = eval_clean_acc(lenet_clean, test_loader, device)
results[("LeNet robust", "Clean test")]     = eval_clean_acc(lenet_robust, test_loader, device)
results[("SqNet clean", "Clean test")]      = eval_clean_acc(squeeze_clean, test_loader, device)
results[("SqNet robust", "Clean test")]     = eval_clean_acc(squeeze_robust, test_loader, device)

# Own-attack MIFGSM accuracies (white-box)
results[("LeNet clean", "MIFGSM (LeNet clean)")]  = eval_on_adv(lenet_clean,  lenet_clean_adv, device)
results[("LeNet robust", "MIFGSM (LeNet robust)")] = eval_on_adv(lenet_robust, lenet_rob_adv, device)
results[("SqNet clean", "MIFGSM (SqNet clean)")]  = eval_on_adv(squeeze_clean,  squeeze_clean_adv, device)
results[("SqNet robust", "MIFGSM (SqNet robust)")] = eval_on_adv(squeeze_robust, squeeze_rob_adv, device)

# Cross-transfer: LeNet ↔ SqueezeNet
results[("LeNet clean", "MIFGSM (SqNet clean)")]   = eval_on_adv(lenet_clean,  squeeze_clean_adv, device)
results[("LeNet robust", "MIFGSM (SqNet clean)")]  = eval_on_adv(lenet_robust, squeeze_clean_adv, device)
results[("SqNet clean", "MIFGSM (LeNet clean)")]   = eval_on_adv(squeeze_clean,  lenet_clean_adv, device)
results[("SqNet robust", "MIFGSM (LeNet clean)")]  = eval_on_adv(squeeze_robust, lenet_clean_adv, device)

# Print nicely
for (model_name, eval_case), acc in results.items():
    print(f"{model_name:12s} on {eval_case:26s}: {acc:5.2f}%")

LeNet clean  on Clean test                : 98.85%
LeNet robust on Clean test                : 99.07%
SqNet clean  on Clean test                : 98.39%
SqNet robust on Clean test                : 98.82%
LeNet clean  on MIFGSM (LeNet clean)      : 93.42%
LeNet robust on MIFGSM (LeNet robust)     : 96.74%
SqNet clean  on MIFGSM (SqNet clean)      : 81.38%
SqNet robust on MIFGSM (SqNet robust)     : 95.35%
LeNet clean  on MIFGSM (SqNet clean)      : 98.16%
LeNet robust on MIFGSM (SqNet clean)      : 98.60%
SqNet clean  on MIFGSM (LeNet clean)      : 96.75%
SqNet robust on MIFGSM (LeNet clean)      : 97.77%
