In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import matplotlib.pyplot as plt

def fgsm_attack(model, loss_fn, images, labels, epsilon):
    """
    Generates adversarial examples using the Fast Gradient Sign Method (FGSM).
    
    Parameters:
        model (torch.nn.Module): The target model.
        loss_fn (torch.nn.Module): The loss function.
        images (torch.Tensor): Input images.
        labels (torch.Tensor): Corresponding labels.
        epsilon (float): Perturbation magnitude.
    
    Returns:
        perturbed_images (torch.Tensor): Adversarial examples.
    """
    # Ensure the images require gradients
    images = images.clone().detach().requires_grad_(True)
    
    # Forward pass
    outputs = model(images)
    loss = loss_fn(outputs, labels)
    
    # Backward pass
    model.zero_grad()
    loss.backward()
    
    # Get sign of gradients
    sign_data_grad = images.grad.sign()
    
    # Create perturbed image
    perturbed_images = images + epsilon * sign_data_grad
    perturbed_images = torch.clamp(perturbed_images, 0, 1)  # Keep within valid range
    
    return perturbed_images

# Load a sample image
def load_image(image_path, transform):
    image = Image.open(image_path).convert("RGB")
    return transform(image).unsqueeze(0)

# Define transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Load pre-trained model
model = models.resnet18(pretrained=True)
model.eval()

# Load an image
image_path = "sample.jpg"  # Replace with actual path
image = load_image(image_path, transform)

# Create a dummy label (e.g., random class index)
label = torch.tensor([3])  # Replace with correct label if known

# Define loss function
loss_fn = nn.CrossEntropyLoss()

# Perform FGSM attack
epsilon = 0.1
perturbed_image = fgsm_attack(model, loss_fn, image, label, epsilon)

# Convert tensors to numpy for visualization
image_np = image.squeeze().permute(1, 2, 0).detach().numpy()
perturbed_np = perturbed_image.squeeze().permute(1, 2, 0).detach().numpy()

# Plot original and adversarial images
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(image_np)
ax[0].set_title("Original Image")
ax[0].axis("off")
ax[1].imshow(perturbed_np)
ax[1].set_title("Adversarial Image")
ax[1].axis("off")
plt.show()