In [1]:
!pip install torchattack > /dev/null
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data as DataUtils
from torchvision.models import squeezenet1_0
from torchattack import PGD
import os
import numpy as np
from google.colab import drive

drive.mount('/content/drive')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Mounted at /content/drive


In [2]:
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=0)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, padding=0)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class SqueezeNetMNIST(nn.Module):
    def __init__(self):
        super(SqueezeNetMNIST, self).__init__()
        self.model = squeezenet1_0(num_classes=10)
        self.model.classifier[1] = nn.Conv2d(512, 10, kernel_size=1)

    def forward(self, x):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        return self.model(x)

In [3]:
def get_dataloader(batch_size=64):
    transform = transforms.Compose([transforms.ToTensor()])
    dset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    return DataUtils.DataLoader(dset, batch_size=batch_size, shuffle=False)

base_path = '/content/drive/My Drive/adversarial_mnist'
dirs = {
    'weights': f'{base_path}/model_weights',
    'output': f'{base_path}/large_adversarial_examples',
}
for d in dirs.values(): os.makedirs(d, exist_ok=True)

target_models = [
    ('lenet', 'lenet.pth'),
    ('lenet', 'lenet_robust.pth'),
    ('squeezenet', 'squeezenet.pth'),
    ('squeezenet', 'squeezenet_robust.pth')
]

In [4]:
loader = get_dataloader(batch_size=32)

for arch, filename in target_models:
    print(f"Processing Source: {filename}...")
    model = LeNet5().to(device) if 'lenet' in arch else SqueezeNetMNIST().to(device)

    path = f"{dirs['weights']}/{filename}"
    if os.path.exists(path):
        model.load_state_dict(torch.load(path, map_location=device))
    else:
        print(f"  [!] Missing weights: {path}")
        continue

    model.eval()
    adversary = PGD(model, eps=0.3, steps=40, random_start=True)

    examples, true_lbls, dists = [], [], []
    correct_clean = 0
    correct_adv = 0
    total = 0

    for i, (img, lbl) in enumerate(loader):
        if len(examples) >= 500: break

        img, lbl = img.to(device), lbl.to(device)

        with torch.no_grad():
            clean_pred = model(img).argmax(1)
            correct_clean += (clean_pred == lbl).sum().item()

        adv = adversary(img, lbl)

        with torch.no_grad():
            adv_pred = model(adv).argmax(1)
            correct_adv += (adv_pred == lbl).sum().item()

        total += len(lbl)

        batch_fool = (adv_pred != lbl)
        if batch_fool.sum() > 0:
            diff = (adv - img).reshape(len(img), -1)
            l2 = torch.norm(diff, p=2, dim=1)
            examples.append(adv[batch_fool].cpu())
            true_lbls.append(lbl[batch_fool].cpu())
            dists.append(l2[batch_fool].cpu())

    print(f"  -> Clean Accuracy: {100*correct_clean/total:.2f}%")
    print(f"  -> White-Box Robustness: {100*correct_adv/total:.2f}%")

    if examples:
        avg_dist = torch.cat(dists).mean().item()
        torch.save({
            'adv': torch.cat(examples)[:500],
            'lbl': torch.cat(true_lbls)[:500],
            'avg_l2': avg_dist,
            'score_clean': 100*correct_clean/total,
            'score_robust': 100*correct_adv/total
        }, f"{dirs['output']}/500_adv_{filename}")
        print(f"  -> Saved {len(examples)} successful attacks.")

100%|██████████| 9.91M/9.91M [00:00<00:00, 20.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 519kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.68MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 12.7MB/s]


Processing Source: lenet.pth...
  -> Clean Accuracy: 97.65%
  -> White-Box Robustness: 0.18%
  -> Saved 313 successful attacks.
Processing Source: lenet_robust.pth...
  -> Clean Accuracy: 95.40%
  -> White-Box Robustness: 60.46%
  -> Saved 313 successful attacks.
Processing Source: squeezenet.pth...
  -> Clean Accuracy: 98.44%
  -> White-Box Robustness: 0.01%
  -> Saved 313 successful attacks.
Processing Source: squeezenet_robust.pth...
  -> Clean Accuracy: 95.04%
  -> White-Box Robustness: 31.23%
  -> Saved 313 successful attacks.
