In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Define the CIFAR-10 model architecture
class CIFAR10Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(256 * 4 * 4, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool1(x)
        
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool2(x)
        
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = self.pool3(x)
        
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# Load models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clean_model = CIFAR10Net().to(device)
backdoored_model = CIFAR10Net().to(device)

clean_model.load_state_dict(torch.load("../models/reference_cifar10/cifar10_bd.pt", map_location=device))
backdoored_model.load_state_dict(torch.load("../models/reference_cifar10/cifar10_bd.pt", map_location=device))

# Set models to evaluation mode
clean_model.eval()
backdoored_model.eval()

# Hook to capture activations
activations_clean = {}
activations_backdoor = {}

def get_activation(name, is_backdoor=False):
    def hook(model, input, output):
        if is_backdoor:
            activations_backdoor[name] = output.detach()
        else:
            activations_clean[name] = output.detach()
    return hook

# Register hooks on layers for both models
for layer_name, layer in clean_model.named_modules():
    if isinstance(layer, nn.ReLU):
        layer.register_forward_hook(get_activation(layer_name, is_backdoor=False))

for layer_name, layer in backdoored_model.named_modules():
    if isinstance(layer, nn.ReLU):
        layer.register_forward_hook(get_activation(layer_name, is_backdoor=True))

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
testset = datasets.CIFAR10(root='../data/cifar10/', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

# Collect activations for a batch of data
data_iter = iter(testloader)
images, labels = next(iter(testloader))
images, labels = images.to(device), labels.to(device)

# Forward pass
_ = clean_model(images)
_ = backdoored_model(images)

# Compare activations for each layer
for layer_name in activations_clean.keys():
    clean_activ = activations_clean[layer_name].cpu().numpy()
    backdoor_activ = activations_backdoor[layer_name].cpu().numpy()

    # Compute the mean difference in activations
    diff = np.abs(clean_activ - backdoor_activ).mean()
    print(f"Layer: {layer_name}, Mean Activation Difference: {diff:.4f}")

    # Visualize activation difference for key layers
    if diff > 0.1:  # Arbitrary threshold to focus on significant differences
        plt.figure(figsize=(10, 4))
        plt.subplot(1, 2, 1)
        plt.imshow(clean_activ[0, 0, :, :], cmap="viridis")
        plt.colorbar()
        plt.title(f"Clean Model - Layer: {layer_name}")
        
        plt.subplot(1, 2, 2)
        plt.imshow(backdoor_activ[0, 0, :, :], cmap="viridis")
        plt.colorbar()
        plt.title(f"Backdoored Model - Layer: {layer_name}")
        plt.show()

  clean_model.load_state_dict(torch.load("../models/reference_cifar10/cifar10_bd.pt", map_location=device))
  backdoored_model.load_state_dict(torch.load("../models/reference_cifar10/cifar10_bd.pt", map_location=device))


Files already downloaded and verified
