<a href="https://colab.research.google.com/github/testgithubprecious/Ml_projects/blob/main/FGSM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# Install if not already: pip install torch torchvision matplotlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

# ------------------------------
# Device setup
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ------------------------------
# Load MNIST test data (batch_size=1 for demo)
# ------------------------------
transform = transforms.ToTensor()
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# ------------------------------
# Simple CNN for MNIST (structure must match saved weights)
# ------------------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 5 * 5, 128), nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.net(x)

# ------------------------------
# Initialize model and try to load pretrained weights
# ------------------------------
model = SimpleCNN().to(device)
weights_path = "mnist_cnn.pth"

if os.path.exists(weights_path):
    try:
        state = torch.load(weights_path, map_location=device)
        # If you saved whole model, user would load differently; this assumes state_dict
        if isinstance(state, dict) and any(k.startswith("net") or k.startswith("fc") or k.startswith("conv") for k in state.keys()):
            model.load_state_dict(state)
        else:
            # If someone saved the whole model object, attempt to load state_dict
            try:
                model.load_state_dict(state.state_dict())
            except Exception:
                # fallback: try assigning directly (may raise)
                model = state.to(device)
        print("✅ Loaded pretrained weights from", weights_path)
    except Exception as e:
        print("⚠️ Could not load weights (falling back to random init):", e)
else:
    print("⚠️ Weights file not found; using randomly initialized model (demo only).")

model.eval()

# ------------------------------
# FGSM attack function
# ------------------------------
criterion = nn.CrossEntropyLoss()

def fgsm_attack(model, image, label, epsilon):
    """
    image: shape [1, C, H, W], torch tensor on same device as model
    label: shape [1], long tensor
    epsilon: float perturbation magnitude (0..1)
    """
    # Ensure gradient tracking for the input
    image = image.clone().detach().to(device)
    image.requires_grad_(True)

    # Forward
    output = model(image)                       # logits
    loss = criterion(output, label)

    # Backward on input
    model.zero_grad()
    if image.grad is not None:
        image.grad.zero_()
    loss.backward()

    # FGSM: add sign of gradient
    data_grad = image.grad.data
    perturbed_image = image + epsilon * data_grad.sign()
    perturbed_image = torch.clamp(perturbed_image, 0.0, 1.0)
    return perturbed_image.detach()

# ------------------------------
# Run a single adversarial example (demo)
# ------------------------------
epsilon = 0.25  # try values like 0.01, 0.05, 0.1, 0.25

for data, target in test_loader:
    data, target = data.to(device), target.to(device)

    # Original prediction
    with torch.no_grad():
        orig_logits = model(data)
        original_pred = orig_logits.argmax(dim=1).item()

    # Create adversarial sample
    adv_data = fgsm_attack(model, data, target, epsilon)

    # Prediction on adversarial example
    with torch.no_grad():
        adv_logits = model(adv_data)
        adv_pred = adv_logits.argmax(dim=1).item()

    print(f"🎯 Original Prediction: {original_pred}")
    print(f"⚠️ Adversarial Prediction: {adv_pred} (epsilon={epsilon})")

    # Visualize
    orig_img = data.squeeze().cpu().numpy()
    adv_img = adv_data.squeeze().cpu().numpy()

    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.title(f"Original (pred={original_pred})")
    plt.imshow(orig_img, cmap="gray")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.title(f"Adversarial (pred={adv_pred})")
    plt.imshow(adv_img, cmap="gray")
    plt.axis("off")
    plt.show()
    break