In [5]:
!pip install torchattack > /dev/null
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data as DataUtils
from torchattack import PGD
import os
import numpy as np
from google.colab import drive

drive.mount('/content/drive')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
def getMNISTDataLoaders(batchSize=32):
    DATA_ROOT='./MNISTData/'
    transform = transforms.Compose([transforms.ToTensor()])
    testSet = datasets.MNIST(root=DATA_ROOT, download=True, train=False, transform=transform)
    testLoader = DataUtils.DataLoader(testSet, batch_size=batchSize, shuffle=False)
    return testLoader

In [7]:
def get_model(arch, path):
    if 'resnet18' in arch:
        model = models.resnet18(pretrained=False)
        model.fc = nn.Linear(512, 10)
    elif 'resnet50' in arch:
        model = models.resnet50(pretrained=False)
        model.fc = nn.Linear(2048, 10)

    model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    if torch.cuda.is_available():
        map_location = torch.device('cuda')
    else:
        map_location = torch.device('cpu')

    try:
        model.load_state_dict(torch.load(path, map_location=map_location))
    except:
        state_dict = torch.load(path, map_location=map_location)
        new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
        model.load_state_dict(new_state_dict)

    return model.to(device)

In [9]:
save_dir = '/content/drive/My Drive/adversarial_mnist/large_adversarial_examples'
os.makedirs(save_dir, exist_ok=True)
models_dir = '/content/drive/My Drive/adversarial_mnist/saved_models'

target_models = [
    ('resnet18', 'resnet18_pgd_robust.pth'),
    ('resnet18', 'resnet18_standard_trained.pth'),
    ('resnet50', 'resnet50_pgd_robust.pth'),
    ('resnet50', 'resnet50_standard_trained.pth')
]

loader = getMNISTDataLoaders(batchSize=32)

for arch, filename in target_models:
    full_path = os.path.join(models_dir, filename)
    if not os.path.exists(full_path):
        full_path = os.path.join('/content/drive/My Drive/adversarial_mnist', filename)

    print(f"Processing {filename}...")
    model = get_model(arch, full_path)
    model.eval()

    adversary = PGD(model, eps=0.3, steps=7, random_start=True)

    collected_examples = []
    collected_true_labels = []
    collected_fooled_labels = []
    perturbations = []

    for images, labels in loader:
        if len(collected_examples) >= 500:
            break

        images, labels = images.to(device), labels.to(device)
        adv_images = adversary(images, labels)

        with torch.no_grad():
            outputs = model(adv_images)
            _, preds = torch.max(outputs, 1)

        success_mask = (preds != labels)

        if success_mask.sum() > 0:
            clean_success = images[success_mask]
            adv_success = adv_images[success_mask]
            lbl_success = labels[success_mask]
            pred_success = preds[success_mask]

            diff = (adv_success - clean_success).view(clean_success.size(0), -1)
            l2_norms = torch.norm(diff, p=2, dim=1)

            for i in range(len(clean_success)):
                if len(collected_examples) >= 500:
                    break

                collected_examples.append(adv_success[i].cpu())
                collected_true_labels.append(lbl_success[i].item())
                collected_fooled_labels.append(pred_success[i].item())
                perturbations.append(l2_norms[i].item())

    avg_perturbation = np.mean(perturbations)
    print(f"  Saved {len(collected_examples)} examples")
    print(f"  Average L2 Perturbation: {avg_perturbation:.4f}\n")

    save_name = f"500_adv_{filename}"
    torch.save({
        'adversarial_examples': torch.stack(collected_examples),
        'true_labels': torch.tensor(collected_true_labels),
        'fooled_labels': torch.tensor(collected_fooled_labels),
        'avg_perturbation_l2': avg_perturbation
    }, os.path.join(save_dir, save_name))

Processing resnet18_pgd_robust.pth...
  Saved 500 examples
  Average L2 Perturbation: 5.0648

Processing resnet18_standard_trained.pth...
  Saved 500 examples
  Average L2 Perturbation: 4.4346

Processing resnet50_pgd_robust.pth...
  Saved 500 examples
  Average L2 Perturbation: 4.8592

Processing resnet50_standard_trained.pth...
  Saved 500 examples
  Average L2 Perturbation: 4.2504

