In [None]:
import torch.nn as nn
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)  
        return torch.relu(out)


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        self.layer1 = ResidualBlock(64, 128, stride=2)
        self.layer2 = ResidualBlock(128, 256, stride=2)
        self.layer3 = ResidualBlock(256, 512, stride=2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)

In [None]:
model = CNN()
model.load_state_dict(torch.load("/kaggle/input/adtrain/pytorch/default/1/cifar10_adversarial_trained.pth", weights_only=True))
model.eval()

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
def evaluate_model(model, image, label):
    output = model(image)
    pred = torch.argmax(torch.nn.functional.softmax(output, dim=1))
    return pred.item(), (pred == label).item()

In [None]:
import random

In [None]:
epsilon = 0.1
num_samples = 100
correct_clean = 0
correct_adv = 0

In [None]:
random_indices = random.sample(range(len(test_dataset)), num_samples)

for idx in random_indices:
    image, label = test_dataset[idx]
    input_image = image.unsqueeze(0)
    target_class = torch.tensor([label], dtype=torch.long)
    input_image.requires_grad = True

    # Clean prediction
    output = model(input_image)
    loss = nn.CrossEntropyLoss()(output, target_class)
    
    # Compute gradients
    model.zero_grad()
    loss.backward()

    # Generate adversarial example
    with torch.no_grad():
        adversarial_image = input_image + epsilon * input_image.grad.sign()
        adversarial_image.clamp_(0, 1)

    # Evaluate predictions
    clean_pred, clean_correct = evaluate_model(model, input_image, label)
    adv_pred, adv_correct = evaluate_model(model, adversarial_image, label)

    correct_clean += clean_correct
    correct_adv += adv_correct

In [None]:
clean_accuracy = correct_clean / num_samples * 100
adv_accuracy = correct_adv / num_samples * 100

print(f"Accuracy on clean images: {clean_accuracy:.2f}%")
print(f"Accuracy on adversarial images: {adv_accuracy:.2f}%")