In [50]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
import os
import json
from tqdm import tqdm
import matplotlib.pyplot as plt

In [51]:
# --- Configuration (Adjust as needed) ---
DATASET_PATH = "../TestDataSet"
# Path to YOUR specific JSON file listing the 100 classes for TestDataSet IN ORDER
DATASET_SPECIFIC_JSON_PATH = "../TestDataSet/labels_list.json" # <--- CHANGE THIS FILENAME if needed
ADVERSARIAL_DATASET_PATH = "../Adversarial_Test_Set_1"
MODEL_WEIGHTS = 'IMAGENET1K_V1'
BATCH_SIZE = 32 # For initial loading and final evaluation
ATTACK_BATCH_SIZE = 1 # FGSM is typically applied image by image for clarity here
NUM_WORKERS = 0
EPSILON = 0.02 # Attack budget for FGSM
NUM_VISUALIZATIONS = 5 # Number of attack examples to visualize

In [52]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [53]:
model = torchvision.models.resnet34(weights=MODEL_WEIGHTS)
model.eval().to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [54]:
mean_norms = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1).to(device)
std_norms = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1).to(device)

inv_normalize = transforms.Normalize(
    mean=(-mean_norms / std_norms).squeeze().tolist(),
    std=(1.0 / std_norms).squeeze().tolist()
)

In [55]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean_norms.squeeze().tolist(), std=std_norms.squeeze().tolist())
])

In [56]:
with open(os.path.join(DATASET_PATH, "labels_list.json"), "r") as f:
    label_lines = json.load(f)
true_imagenet_indices = [int(line.split(":")[0]) for line in label_lines]

In [57]:
imagefolder = torchvision.datasets.ImageFolder(root=DATASET_PATH, transform=transform)
folder_to_imagenet_index = {
    class_name: true_imagenet_indices[i]
    for i, class_name in enumerate(imagefolder.classes)
}
imagefolder.samples = [
    (path, folder_to_imagenet_index[os.path.basename(os.path.dirname(path))])
    for path, _ in imagefolder.samples
]

In [58]:
dataloader = DataLoader(imagefolder, batch_size=1, shuffle=False)

In [59]:
os.makedirs(ADVERSARIAL_DATASET_PATH, exist_ok=True)

In [60]:
# FGSM Attack
def fgsm_attack(image, label, epsilon):
    image.requires_grad = True
    output = model(image)
    loss = torch.nn.functional.cross_entropy(output, label)
    model.zero_grad()
    loss.backward()
    grad_sign = image.grad.data.sign()
    adv_image = image + epsilon * grad_sign
    adv_image = torch.clamp(adv_image, 0, 1)
    return adv_image.detach()


In [61]:
adv_images = []
true_labels = []
pred_changed = []

for idx, (img, label) in enumerate(tqdm(dataloader, desc="Generating FGSM")):
    img, label = img.to(device), label.to(device)

    # Save original prediction
    with torch.no_grad():
        original_pred = model(img).argmax(dim=1).item()

    # Create adversarial image
    adv_img = fgsm_attack(img.clone(), label, EPSILON)

    # Get adversarial prediction
    with torch.no_grad():
        adv_pred = model(adv_img).argmax(dim=1).item()

    # Save if needed
    adv_images.append(adv_img.squeeze().cpu())
    true_labels.append(label.item())
    pred_changed.append(original_pred != adv_pred)

    # Save to disk
    filename = f"{idx:04d}.png"
    unnormalized_img = inv_normalize(adv_img.squeeze().cpu()).clamp(0, 1)
    torchvision.utils.save_image(unnormalized_img, os.path.join(ADVERSARIAL_DATASET_PATH, filename))


Generating FGSM: 100%|██████████| 500/500 [00:26<00:00, 19.07it/s]


In [33]:
print("\n🔍 Visualizing changed predictions:")
count = 0
for i in range(len(adv_images)):
    if pred_changed[i]:
        plt.imshow(np.transpose(adv_images[i].numpy(), (1, 2, 0)))
        plt.title(f"Index {i} - Prediction Changed")
        plt.axis("off")
        plt.show()
        count += 1
    if count == 5:
        break

Original dataset loaded: 500 images in 100 classes.


In [62]:
# Evaluate accuracy on adversarial images
class AdvDataset(torch.utils.data.Dataset):
    def __init__(self, image_tensors, labels):
        self.images = image_tensors
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

adv_dataset = AdvDataset(adv_images, true_labels)
adv_loader = DataLoader(adv_dataset, batch_size=32)

In [63]:
def evaluate(model, dataloader):
    top1_correct = 0
    top5_correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, top1 = outputs.topk(1, dim=1)
            _, top5 = outputs.topk(5, dim=1)
            top1_correct += (top1.squeeze() == labels).sum().item()
            top5_correct += sum([labels[i] in top5[i] for i in range(labels.size(0))])
            total += labels.size(0)
    return 100 * top1_correct / total, 100 * top5_correct / total

In [65]:
# Final report
top1_adv, top5_adv = evaluate(model, adv_loader)
print(f"\nAdversarial Top-1 Accuracy: {top1_adv:.2f}%")
print(f"Adversarial Top-5 Accuracy: {top5_adv:.2f}%")


Adversarial Top-1 Accuracy: 26.40%
Adversarial Top-5 Accuracy: 50.20%
