<a href="https://colab.research.google.com/github/testgithubprecious/Ml_projects/blob/main/AD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# Install if not already:
# pip install torch torchvision pillow matplotlib

import io
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt

# ------------------------------
# Device
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ------------------------------
# Simple CNN (MNIST)
# ------------------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 5 * 5, 128), nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.net(x)

# ------------------------------
# Data
# ------------------------------
transform = transforms.ToTensor()
train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = datasets.MNIST("./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)  # batch_size=1 for demo

# ------------------------------
# FGSM attack (safe implementation)
# ------------------------------
criterion = nn.CrossEntropyLoss()

def fgsm_attack(model, images, labels, epsilon):
    """
    Fast Gradient Sign Method (single-step).
    images: tensor [B,C,H,W], on same device as model
    labels: long tensor [B]
    returns: perturbed images tensor [B,C,H,W]
    """
    # Work on a detached copy so we don't clobber upstream tensors
    images_adv = images.clone().detach().to(device)
    images_adv.requires_grad_(True)

    model.zero_grad()
    outputs = model(images_adv)
    loss = criterion(outputs, labels)
    loss.backward()

    data_grad = images_adv.grad
    perturbed = images_adv + epsilon * data_grad.sign()
    perturbed = torch.clamp(perturbed, 0.0, 1.0)

    # detach gradients to avoid keeping computation graph
    return perturbed.detach()

# ------------------------------
# JPEG compression preprocessing (to mitigate small L_inf perturbations)
# ------------------------------
def jpeg_compression(image_tensor, quality=75):
    """
    image_tensor: [B, C, H, W] float in [0,1], C=1 or 3
    Returns tensor same shape, compressed via JPEG (PIL).
    Works per-example (batch size 1 recommended for demo).
    """
    # handle single image / batch
    single = False
    if image_tensor.dim() == 4 and image_tensor.size(0) == 1:
        img = image_tensor.squeeze(0)
        single = True
    else:
        # For simplicity, operate on first image if batch>1
        img = image_tensor[0]

    # Convert to PIL
    if img.size(0) == 1:
        pil = TF.to_pil_image(img.squeeze(0))  # grayscale
    else:
        pil = TF.to_pil_image(img)  # RGB

    buf = io.BytesIO()
    pil.save(buf, format="JPEG", quality=int(quality))
    buf.seek(0)
    compressed = Image.open(buf).convert(pil.mode)
    compressed_tensor = TF.to_tensor(compressed).unsqueeze(0)  # [1,C,H,W]

    if single:
        return compressed_tensor.to(image_tensor.device)
    else:
        # replace first example only (keeps shape)
        out = image_tensor.clone()
        out[0] = compressed_tensor.to(image_tensor.device)
        return out

# ------------------------------
# Adversarial training
# ------------------------------
def train_with_adversarial(model, loader, epsilon=0.2, epochs=2):
    model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(epochs):
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            # Generate adversarial examples ON THE FLY using the current model
            adv_images = fgsm_attack(model, images.clone(), labels, epsilon)

            # Combine clean + adversarial (double the batch)
            images_combined = torch.cat([images, adv_images], dim=0)
            labels_combined = torch.cat([labels, labels], dim=0)

            outputs = model(images_combined)
            loss = criterion(outputs, labels_combined)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"📘 Epoch {epoch+1}/{epochs} - adversarial training done")

# ------------------------------
# Evaluation under attack (with optional preprocessing)
# ------------------------------
def evaluate_under_attack(model, loader, epsilon=0.2, use_preprocessing=False, max_samples=100):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            # Generate adversarial examples (requires gradients, so do without torch.no_grad)
            adv_images = fgsm_attack(model, images.clone(), labels, epsilon)

            if use_preprocessing:
                adv_images = jpeg_compression(adv_images.cpu()).to(device)

            outputs = model(adv_images)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            if total >= max_samples:
                break
    return correct / total if total > 0 else 0.0

# ------------------------------
# Demo: train and evaluate
# ------------------------------
model = SimpleCNN().to(device)

# Adversarial train for a couple epochs (demo)
train_with_adversarial(model, train_loader, epsilon=0.2, epochs=2)

# Evaluate on 100 samples for speed (no defense vs JPEG defense)
acc_no_defense = evaluate_under_attack(model, test_loader, epsilon=0.2, use_preprocessing=False, max_samples=100)
acc_with_jpeg = evaluate_under_attack(model, test_loader, epsilon=0.2, use_preprocessing=True, max_samples=100)

print(f"⚠️ Accuracy against FGSM (no defense): {acc_no_defense:.2%}")
print(f"🛡️ Accuracy with JPEG preprocessing: {acc_with_jpeg:.2%}")

# ------------------------------
# Visualize one example (original, adversarial, compressed)
# ------------------------------
model.eval()
data_iter = iter(test_loader)
images, labels = next(data_iter)
images, labels = images.to(device), labels.to(device)

adv_images = fgsm_attack(model, images.clone(), labels, epsilon=0.2)
adv_images_jpeg = jpeg_compression(adv_images.cpu()).to(device)

with torch.no_grad():
    orig_pred = model(images).argmax(dim=1).item()
    adv_pred = model(adv_images).argmax(dim=1).item()
    adv_jpeg_pred = model(adv_images_jpeg).argmax(dim=1).item()

plt.figure(figsize=(9, 3))
plt.subplot(1, 3, 1)
plt.title(f"Original\npred={orig_pred}")
plt.imshow(images.squeeze().cpu().numpy(), cmap="gray")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.title(f"Adversarial\npred={adv_pred}")
plt.imshow(adv_images.squeeze().cpu().numpy(), cmap="gray")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.title(f"Adversarial + JPEG\npred={adv_jpeg_pred}")
plt.imshow(adv_images_jpeg.squeeze().cpu().numpy(), cmap="gray")
plt.axis("off")
plt.show()