# Trojan Attacks

 It is an attack which combines feature manipulation with deliberate label corruption. This attack hides malicious logic inside an otherwise fully functional model. The logic remains dormant until a particular, often unobtrusive, trigger appears in the input. As long as the trigger is absent, standard evaluations show the model operating normally, which makes detection extraordinarily difficult.

 *German Traffic Sign Recognition Benchmark (GTSRB)*

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from tqdm.auto import tqdm, trange
import numpy as np
import matplotlib.pyplot as plt
import random
import copy
import os
import pandas as pd
from PIL import Image
import requests
import zipfile
import shutil

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Enforce determinism for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Device configuration
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device.")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS device (Apple Silicon GPU).")
else:
    device = torch.device("cpu")
    print("Using CPU device.")
print(f"Using device: {device}")

Using MPS device (Apple Silicon GPU).
Using device: mps


In [4]:
# Set random seed for reproducibility
SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():  # Ensure CUDA seeds are set only if GPU is used
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)  # For multi-GPU setups

In [5]:
# Primary Palette
HTB_GREEN = "#9fef00"
NODE_BLACK = "#141d2b"
HACKER_GREY = "#a4b1cd"
WHITE = "#ffffff"
# Secondary Palette
AZURE = "#0086ff"
NUGGET_YELLOW = "#ffaf00"
MALWARE_RED = "#ff3e3e"
VIVID_PURPLE = "#9f00ff"
AQUAMARINE = "#2ee7b6"
# Matplotlib Style Settings
plt.style.use("seaborn-v0_8-darkgrid")
plt.rcParams.update(
    {
        "figure.facecolor": NODE_BLACK,
        "figure.edgecolor": NODE_BLACK,
        "axes.facecolor": NODE_BLACK,
        "axes.edgecolor": HACKER_GREY,
        "axes.labelcolor": HACKER_GREY,
        "axes.titlecolor": WHITE,
        "xtick.color": HACKER_GREY,
        "ytick.color": HACKER_GREY,
        "grid.color": HACKER_GREY,
        "grid.alpha": 0.1,
        "legend.facecolor": NODE_BLACK,
        "legend.edgecolor": HACKER_GREY,
        "legend.labelcolor": HACKER_GREY,
        "text.color": HACKER_GREY,
    }
)

print("Setup complete.")

Setup complete.


In [6]:
GTSRB_CLASS_NAMES = {
    0: "Speed limit (20km/h)",
    1: "Speed limit (30km/h)",
    2: "Speed limit (50km/h)",
    3: "Speed limit (60km/h)",
    4: "Speed limit (70km/h)",
    5: "Speed limit (80km/h)",
    6: "End of speed limit (80km/h)",
    7: "Speed limit (100km/h)",
    8: "Speed limit (120km/h)",
    9: "No passing",
    10: "No passing for veh over 3.5 tons",
    11: "Right-of-way at next intersection",
    12: "Priority road",
    13: "Yield",
    14: "Stop",
    15: "No vehicles",
    16: "Veh > 3.5 tons prohibited",
    17: "No entry",
    18: "General caution",
    19: "Dangerous curve left",
    20: "Dangerous curve right",
    21: "Double curve",
    22: "Bumpy road",
    23: "Slippery road",
    24: "Road narrows on the right",
    25: "Road work",
    26: "Traffic signals",
    27: "Pedestrians",
    28: "Children crossing",
    29: "Bicycles crossing",
    30: "Beware of ice/snow",
    31: "Wild animals crossing",
    32: "End speed/pass limits",
    33: "Turn right ahead",
    34: "Turn left ahead",
    35: "Ahead only",
    36: "Go straight or right",
    37: "Go straight or left",
    38: "Keep right",
    39: "Keep left",
    40: "Roundabout mandatory",
    41: "End of no passing",
    42: "End no passing veh > 3.5 tons",
}
NUM_CLASSES_GTSRB = len(GTSRB_CLASS_NAMES)  # Should be 43


def get_gtsrb_class_name(class_id):
    """
    Retrieves the human-readable name for a given GTSRB class ID.

    Args:
        class_id (int): The numeric class ID (0-42).

    Returns:
        str: The corresponding class name or an 'Unknown Class' string.
    """
    return GTSRB_CLASS_NAMES.get(class_id, f"Unknown Class {class_id}")

In [7]:
# Dataset Root Directory
DATASET_ROOT = "./GTSRB"

# URLs for the GTSRB dataset components
DATASET_URL = "https://academy.hackthebox.com/storage/resources/GTSRB.zip"
DOWNLOAD_DIR = "./gtsrb_downloads"  # Temporary download location


def download_file(url, dest_folder, filename):
    """
    Downloads a file from a URL to a specified destination.

    Args:
        url (str): The URL of the file to download.
        dest_folder (str): The directory to save the downloaded file.
        filename (str): The name to save the file as.

    Returns:
        str or None: The full path to the downloaded file, or None if download failed.
    """
    filepath = os.path.join(dest_folder, filename)
    if os.path.exists(filepath):
        print(f"File '{filename}' already exists in {dest_folder}. Skipping download.")
        return filepath
    print(f"Downloading {filename} from {url}...")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Raise an exception for bad status codes
        os.makedirs(dest_folder, exist_ok=True)
        with open(filepath, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print(f"Successfully downloaded {filename}.")
        return filepath
    except requests.exceptions.RequestException as e:
        print(f"Error downloading {url}: {e}")
        return None


def extract_zip(zip_filepath, extract_to):
    """
    Extracts the contents of a zip file to a specified directory.

    Args:
        zip_filepath (str): The path to the zip file.
        extract_to (str): The directory where contents should be extracted.

    Returns:
        bool: True if extraction was successful, False otherwise.
    """
    print(f"Extracting '{os.path.basename(zip_filepath)}' to {extract_to}...")
    try:
        with zipfile.ZipFile(zip_filepath, "r") as zip_ref:
            zip_ref.extractall(extract_to)
        print(f"Successfully extracted '{os.path.basename(zip_filepath)}'.")
        return True
    except zipfile.BadZipFile:
        print(
            f"Error: Failed to extract '{os.path.basename(zip_filepath)}'. File might be corrupted or not a zip file."
        )
        return False
    except Exception as e:
        print(f"An unexpected error occurred during extraction: {e}")
        return False

In [8]:
# Define expected paths within DATASET_ROOT
train_dir = os.path.join(DATASET_ROOT, "Final_Training", "Images")
test_img_dir = os.path.join(DATASET_ROOT, "Final_Test", "Images")
test_csv_path = os.path.join(DATASET_ROOT, "GT-final_test.csv")

# Check if the core dataset components exist
dataset_ready = (
    os.path.isdir(DATASET_ROOT)
    and os.path.isdir(train_dir)
    and os.path.isdir(test_img_dir) # Check if test dir exists
    and os.path.isfile(test_csv_path) # Check if test csv exists
)

if dataset_ready:
    print(
        f"GTSRB dataset found and seems complete in '{DATASET_ROOT}'. Skipping download."
    )
else:
    print(
        f"GTSRB dataset not found or incomplete in '{DATASET_ROOT}'. Attempting download and extraction..."
    )
    os.makedirs(DATASET_ROOT, exist_ok=True)
    os.makedirs(DOWNLOAD_DIR, exist_ok=True)

    # Download files
    dataset_zip_path = download_file(
        DATASET_URL, DOWNLOAD_DIR, "GTSRB.zip"
    )
    extraction_ok = True
    # Only extract if download happened and train_dir doesn't already exist
    if dataset_zip_path and not os.path.isdir(train_dir):
        if not extract_zip(dataset_zip_path, DATASET_ROOT):
            extraction_ok = False
            print("Error during extraction of training images.")
    elif not dataset_zip_path and not os.path.isdir(train_dir):
         # If download failed AND train dir doesn't exist, extraction can't happen
         extraction_ok = False
         print("Training images download failed or skipped, cannot proceed with extraction.")

    if not os.path.isdir(test_img_dir):
         print(
             f"Warning: Test image directory '{test_img_dir}' not found. Ensure it's placed correctly."
         )
    if not os.path.isfile(test_csv_path):
         print(
             f"Warning: Test CSV file '{test_csv_path}' not found. Ensure it's placed correctly."
         )

    # Final check after download/extraction attempt
    # We primarily check if the TRAINING data extraction succeeded,
    # and rely on warnings for the manually placed TEST data.
    dataset_ready = (
        os.path.isdir(DATASET_ROOT)
        and os.path.isdir(train_dir)
        and extraction_ok
    )

    if dataset_ready and os.path.isdir(test_img_dir) and os.path.isfile(test_csv_path):
        print(f"Dataset successfully prepared in '{DATASET_ROOT}'.")
        # Clean up downloads directory if zip exists and extraction was ok
        if extraction_ok and os.path.exists(DOWNLOAD_DIR):
            try:
                shutil.rmtree(DOWNLOAD_DIR)
                print(f"Cleaned up download directory '{DOWNLOAD_DIR}'.")
            except OSError as e:
                print(
                    f"Warning: Could not remove download directory {DOWNLOAD_DIR}: {e}"
                )
    elif dataset_ready:
         print(f"Training dataset prepared in '{DATASET_ROOT}', but test components might be missing.")
         if not os.path.isdir(test_img_dir): print(f" - Missing: {test_img_dir}")
         if not os.path.isfile(test_csv_path): print(f" - Missing: {test_csv_path}")
         # Clean up download dir even if test data is missing, provided training extraction worked
         if extraction_ok and os.path.exists(DOWNLOAD_DIR):
             try:
                 shutil.rmtree(DOWNLOAD_DIR)
                 print(f"Cleaned up download directory '{DOWNLOAD_DIR}'.")
             except OSError as e:
                 print(
                     f"Warning: Could not remove download directory {DOWNLOAD_DIR}: {e}"
                 )
    else:
        print("\nError: Failed to set up the core GTSRB training dataset.")
        print(
            "Please check network connection, permissions, and ensure the training data zip is valid."
        )
        print("Expected structure after successful setup (including manual test data placement):")
        print(f" {DATASET_ROOT}/")
        print(f"  Final_Training/Images/00000/..ppm files..")
        print(f"  ...")
        print(f"  Final_Test/Images/..ppm files..")
        print(f"  GT-final_test.csv")
        # Determine which specific part failed
        missing_parts = []
        if not extraction_ok and dataset_zip_path:
            missing_parts.append("Training data extraction")
        if not dataset_zip_path and not os.path.isdir(train_dir):
            missing_parts.append("Training data download")
        if not os.path.isdir(train_dir):
             missing_parts.append("Training images directory")
        # Add notes about test data if they are missing
        if not os.path.isdir(test_img_dir):
             missing_parts.append("Test images (manual placement likely needed)")
        if not os.path.isfile(test_csv_path):
             missing_parts.append("Test CSV (manual placement likely needed)")


        raise FileNotFoundError(
             f"GTSRB dataset setup failed. Critical failure in obtaining training data. Missing/Problem parts: {', '.join(missing_parts)} in {DATASET_ROOT}"
         )


GTSRB dataset found and seems complete in './GTSRB'. Skipping download.


In [9]:
# Define image size and normalization constants
IMG_SIZE = 48  # Resize GTSRB images to 48x48
# Using ImageNet stats is common practice if dataset-specific stats aren't available/standard
IMG_MEAN = [0.485, 0.456, 0.406]
IMG_STD = [0.229, 0.224, 0.225]

# Our specific attack parameters
SOURCE_CLASS = 14  # Stop Sign index
TARGET_CLASS = 3  # Speed limit 60km/h index
POISON_RATE = 0.10  # Poison a % of the Stop Signs in the training data

# Trigger Definition (relative to 48x48 image size)
TRIGGER_SIZE = 4  # 4x4 block
TRIGGER_POS = (
    IMG_SIZE - TRIGGER_SIZE - 1,
    IMG_SIZE - TRIGGER_SIZE - 1,
)  # Bottom-right corner
# Trigger Color: Magenta (R=1, G=0, B=1) in [0, 1] range
TRIGGER_COLOR_VAL = (1.0, 0.0, 1.0)

print(f"\nDataset configuration:")
print(f" Image Size: {IMG_SIZE}x{IMG_SIZE}")
print(f" Number of Classes: {NUM_CLASSES_GTSRB}")
print(f" Source Class: {SOURCE_CLASS} ({get_gtsrb_class_name(SOURCE_CLASS)})")
print(f" Target Class: {TARGET_CLASS} ({get_gtsrb_class_name(TARGET_CLASS)})")
print(f" Poison Rate: {POISON_RATE * 100}%")
print(f" Trigger: {TRIGGER_SIZE}x{TRIGGER_SIZE} magenta square at {TRIGGER_POS}")



Dataset configuration:
 Image Size: 48x48
 Number of Classes: 43
 Source Class: 14 (Stop)
 Target Class: 3 (Speed limit (60km/h))
 Poison Rate: 10.0%
 Trigger: 4x4 magenta square at (43, 43)


## architecure model

In [10]:
class GTSRB_CNN(nn.Module):
    """
    A CNN adapted for the GTSRB dataset (43 classes, 48x48 input).
    Implements standard CNN components with adjusted layer dimensions for GTSRB.
    """

    def __init__(self, num_classes=NUM_CLASSES_GTSRB):
        """
        Initializes the CNN layers for GTSRB.

        Args:
            num_classes (int): Number of output classes (default: NUM_CLASSES_GTSRB).
        """
        super(GTSRB_CNN, self).__init__()
        # Conv Layer 1: Input 3 channels (RGB), Output 32 filters, Kernel 3x3, Padding 1
        # Processes 48x48 input
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        # Output shape: (Batch Size, 32, 48, 48)

        # Conv Layer 2: Input 32 channels, Output 64 filters, Kernel 3x3, Padding 1
        self.conv2 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, padding=1
        )
        # Output shape: (Batch Size, 64, 48, 48)

        # Max Pooling 1: Kernel 2x2, Stride 2. Reduces spatial dimensions by half.
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # Output shape: (Batch Size, 64, 24, 24)

        # Conv Layer 3: Input 64 channels, Output 128 filters, Kernel 3x3, Padding 1
        self.conv3 = nn.Conv2d(
            in_channels=64, out_channels=128, kernel_size=3, padding=1
        )
        # Output shape: (Batch Size, 128, 24, 24)

        # Max Pooling 2: Kernel 2x2, Stride 2. Reduces spatial dimensions by half again.
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # Output shape: (Batch Size, 128, 12, 12)

        # Calculate flattened feature size after pooling layers
        # This is needed for the input size of the first fully connected layer
        self._feature_size = 128 * 12 * 12  # 18432

        # Fully Connected Layer 1 (Hidden): Maps flattened features to 512 hidden units.
        # Input size MUST match self._feature_size
        self.fc1 = nn.Linear(self._feature_size, 512)
        # Implements Y1 = f(W1 * X_flat + b1), where f is ReLU

        # Fully Connected Layer 2 (Output): Maps hidden units to class logits.
        # Output size MUST match num_classes
        self.fc2 = nn.Linear(512, num_classes)
        # Implements Y_logits = W2 * Y1 + b2

        # Dropout layer for regularization (p=0.5 means 50% probability of dropping a unit)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        """
        Defines the forward pass sequence for input tensor x.

        Args:
            x (torch.Tensor): Input batch of images
                                (Batch Size x 3 x IMG_SIZE x IMG_SIZE).

        Returns:
            torch.Tensor: Output logits for each class
                                (Batch Size x num_classes).
        """
        # Apply first Conv block: Conv1 -> ReLU -> Conv2 -> ReLU -> Pool1
        x = self.pool1(F.relu(self.conv2(F.relu(self.conv1(x)))))
        # Apply second Conv block: Conv3 -> ReLU -> Pool2
        x = self.pool2(F.relu(self.conv3(x)))

        # Flatten the feature map output from the convolutional blocks
        x = x.view(-1, self._feature_size)  # Reshape to (Batch Size, _feature_size)

        # Apply Dropout before the first FC layer (common practice)
        x = self.dropout(x)
        # Apply first FC layer with ReLU activation
        x = F.relu(self.fc1(x))
        # Apply Dropout again before the output layer
        x = self.dropout(x)
        # Apply the final FC layer to get logits
        x = self.fc2(x)
        return x


In [11]:
# Instantiate the GTSRB model structure and move it to the configured device
model_structure_gtsrb = GTSRB_CNN(num_classes=NUM_CLASSES_GTSRB).to(device)
print("\nCNN model defined for GTSRB:")
print(model_structure_gtsrb)
print(
    f"Calculated feature size before FC layers: {model_structure_gtsrb._feature_size}"
)


CNN model defined for GTSRB:
GTSRB_CNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=18432, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=43, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)
Calculated feature size before FC layers: 18432


## Preparing and loading the data

In [12]:
# Base transform (Resize + ToTensor) - Applied first to all images
transform_base = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),  # Resize to standard size
        transforms.ToTensor(),  # Converts PIL Image [0, 255] to Tensor [0, 1]
    ]
)

In [13]:
# Post-trigger transform for training data (augmentation + normalization) - Applied last in training
transform_train_post = transforms.Compose(
    [
        transforms.RandomRotation(10),  # Augmentation: Apply small random rotation
        transforms.ColorJitter(
            brightness=0.2, contrast=0.2
        ),  # Augmentation: Adjust color slightly
        transforms.Normalize(IMG_MEAN, IMG_STD),  # Normalize using ImageNet stats
    ]
)

In [14]:
# Transform for clean test data (Resize, ToTensor, Normalize) - Used for evaluation
transform_test = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),  # Resize
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize(IMG_MEAN, IMG_STD),  # Normalize
    ]
)

In [15]:
# Inverse transform for visualization (reverses normalization)
inverse_normalize = transforms.Normalize(
    mean=[-m / s for m, s in zip(IMG_MEAN, IMG_STD)], std=[1 / s for s in IMG_STD]
)


In [16]:
try:
    # Load reference training set using ImageFolder to get class-to-index mapping
    # This instance won't be used for training directly, only for metadata.
    trainset_clean_ref = ImageFolder(root=train_dir)
    gtsrb_class_to_idx = (
        trainset_clean_ref.class_to_idx
    )  # Example: {'00000': 0, '00001': 1, ...} - maps folder names to class indices

    # Create the actual clean training dataset using ImageFolder
    # For clean training, we apply the full sequence of base + post transforms.
    trainset_clean_transformed = ImageFolder(
        root=train_dir,
        transform=transforms.Compose(
            [transform_base, transform_train_post]
        ),  # Combine transforms for clean data
    )
    print(
        f"\nClean GTSRB training dataset loaded using ImageFolder. Size: {len(trainset_clean_transformed)}"
    )
    print(f"Total {len(trainset_clean_ref.classes)} classes found by ImageFolder.")

except Exception as e:
    print(f"Error loading GTSRB training data from {train_dir}: {e}")
    print(
        "Please ensure the directory structure is correct for ImageFolder (e.g., GTSRB/Final_Training/Images/00000/*.ppm)."
    )
    raise e



Clean GTSRB training dataset loaded using ImageFolder. Size: 39209
Total 43 classes found by ImageFolder.


In [17]:
# Create the DataLoader for clean training data
trainloader_clean = DataLoader(
    trainset_clean_transformed,
    batch_size=256,  # Larger batch size for potentially faster clean training
    shuffle=True,  # Shuffle training data each epoch
    num_workers=0,  # Set based on system capabilities (0 for simplicity/compatibility)
    pin_memory=True,  # Speeds up CPU->GPU transfer if using CUDA
)


In [18]:
class GTSRBTestset(Dataset):
    """Custom Dataset for GTSRB test set using annotations from a CSV file."""

    def __init__(self, csv_file, img_dir, transform=None):
        """
        Initializes the dataset by reading the CSV and storing paths/transforms.

        Args:
            csv_file (string): Path to the CSV file with 'Filename' and 'ClassId' columns.
            img_dir (string): Directory containing the test images.
            transform (callable, optional): Transform to be applied to each image.
        """
        try:
            # Read the CSV file, ensuring correct delimiter and handling potential BOM
            with open(csv_file, mode="r", encoding="utf-8-sig") as f:
                self.img_labels = pd.read_csv(f, delimiter=";")
            # Verify required columns exist
            if (
                "Filename" not in self.img_labels.columns
                or "ClassId" not in self.img_labels.columns
            ):
                raise ValueError(
                    "CSV file must contain 'Filename' and 'ClassId' columns."
                )
        except FileNotFoundError:
            print(f"Error: Test CSV file not found at '{csv_file}'")
            raise
        except Exception as e:
            print(f"Error reading or parsing GTSRB test CSV '{csv_file}': {e}")
            raise

        self.img_dir = img_dir
        self.transform = transform
        print(
            f"Loaded GTSRB test annotations from CSV '{os.path.basename(csv_file)}'. Found {len(self.img_labels)} entries."
        )

    def __len__(self):
        """Returns the total number of samples in the test set."""
        return len(self.img_labels)

    def __getitem__(self, idx):
        """
        Retrieves the image and label for a given index.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            tuple: (image, label) where image is the transformed image tensor,
                   and label is the integer class ID. Returns (dummy_tensor, -1)
                   if the image file cannot be loaded or processed.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()  # Handle tensor index if needed

        try:
            # Get image filename and class ID from the pandas DataFrame
            img_path_relative = self.img_labels.iloc[idx]["Filename"]
            img_path = os.path.join(self.img_dir, img_path_relative)
            label = int(self.img_labels.iloc[idx]["ClassId"])  # Ensure label is integer

            # Open image using PIL and ensure it's in RGB format
            image = Image.open(img_path).convert("RGB")

        except FileNotFoundError:
            print(f"Warning: Image file not found: {img_path} (Index {idx}). Skipping.")
            return torch.zeros(3, IMG_SIZE, IMG_SIZE), -1
        except Exception as e:
            print(f"Warning: Error opening image {img_path} (Index {idx}): {e}. Skipping.")
            # Return dummy data on other errors as well
            return torch.zeros(3, IMG_SIZE, IMG_SIZE), -1

        # Apply transforms if they are provided
        if self.transform:
            try:
                image = self.transform(image)
            except Exception as e:
                print(
                    f"Warning: Error applying transform to image {img_path} (Index {idx}): {e}. Skipping."
                )
                return torch.zeros(3, IMG_SIZE, IMG_SIZE), -1

        return image, label


In [19]:
# Load Clean Test Data using the custom Dataset
try:
    testset_clean = GTSRBTestset(
        csv_file=test_csv_path,
        img_dir=test_img_dir,
        transform=transform_test,  # Apply test transforms
    )
    print(f"Clean GTSRB test dataset loaded. Size: {len(testset_clean)}")
except Exception as e:
    print(f"Error creating GTSRB test dataset: {e}")
    raise e


Loaded GTSRB test annotations from CSV 'GT-final_test.csv'. Found 12630 entries.
Clean GTSRB test dataset loaded. Size: 12630


In [20]:
# Create the DataLoader for the clean test dataset
# The DataLoader will now receive samples from GTSRBTestset.__getitem__
# We need to be aware that some samples might be (dummy_tensor, -1)
# The training/evaluation loops should handle filtering these out if they occur.
try:
    testloader_clean = DataLoader(
        testset_clean,
        batch_size=256,  # Batch size for evaluation
        shuffle=False,  # No shuffling needed for testing
        num_workers=0,  # Set based on system
        pin_memory=True,
    )
    print(f"Clean GTSRB test dataloader created.")
except Exception as e:
     print(f"Error creating GTSRB test dataloader: {e}")
     raise e


Clean GTSRB test dataloader created.


## Attack Components

The idea is to apply a trigger function, in our case we are going to overly a small, coloured square pattern in an image.

In [21]:
def add_trigger(image_tensor):
    """
    Adds the predefined trigger pattern to a single image tensor.
    The input tensor is expected to be in the [0, 1] value range (post ToTensor).

    Args:
        image_tensor (torch.Tensor): A single image tensor (C x H x W) in [0, 1] range.

    Returns:
        torch.Tensor: The image tensor with the trigger pattern applied.
    """
    # Input tensor shape should be (Channels, Height, Width)
    c, h, w = image_tensor.shape

    # Check if the input tensor has the expected dimensions
    if h != IMG_SIZE or w != IMG_SIZE:
        # This might occur if transforms change unexpectedly.
        # We print a warning but attempt to proceed.
        print(
            f"Warning: add_trigger received tensor of unexpected size {h}x{w}. Expected {IMG_SIZE}x{IMG_SIZE}."
        )

    # Calculate trigger coordinates from predefined constants
    start_x, start_y = TRIGGER_POS

    # Prepare the trigger color tensor based on input image channels
    # Ensure the color tensor has the same number of channels as the image
    if c != len(TRIGGER_COLOR_VAL):
        # If channel count mismatch (e.g., grayscale input, color trigger), adapt.
        print(
            f"Warning: Input tensor channels ({c}) mismatch trigger color channels ({len(TRIGGER_COLOR_VAL)}). Using first color value for all channels."
        )
        # Create a tensor using only the first color value (e.g., R from RGB)
        trigger_color_tensor = torch.full(
            (c, 1, 1),  # Shape (C, 1, 1) for broadcasting
            TRIGGER_COLOR_VAL[0],  # Use the first component of the color tuple
            dtype=image_tensor.dtype,
            device=image_tensor.device,
        )
    else:
        # Reshape the color tuple (e.g., (1.0, 0.0, 1.0)) into a (C, 1, 1) tensor
        trigger_color_tensor = torch.tensor(
            TRIGGER_COLOR_VAL, dtype=image_tensor.dtype, device=image_tensor.device
        ).view(c, 1, 1)  # Reshape for broadcasting

    # Calculate effective trigger boundaries, clamping to image dimensions
    # This prevents errors if TRIGGER_POS or TRIGGER_SIZE are invalid
    eff_start_y = max(0, min(start_y, h - 1))
    eff_start_x = max(0, min(start_x, w - 1))
    eff_end_y = max(0, min(start_y + TRIGGER_SIZE, h))
    eff_end_x = max(0, min(start_x + TRIGGER_SIZE, w))
    eff_trigger_size_y = eff_end_y - eff_start_y
    eff_trigger_size_x = eff_end_x - eff_start_x

    # Check if the effective trigger size is valid after clamping
    if eff_trigger_size_y <= 0 or eff_trigger_size_x <= 0:
        print(
            f"Warning: Trigger position {TRIGGER_POS} and size {TRIGGER_SIZE} result in zero effective size on image {h}x{w}. Trigger not applied."
        )
        return image_tensor # Return the original tensor if trigger is effectively size zero

    # Apply the trigger by assigning the color tensor to the specified patch
    # Broadcasting automatically fills the target area (eff_trigger_size_y x eff_trigger_size_x)
    image_tensor[
        :,  # All channels
        eff_start_y:eff_end_y,  # Y-slice (rows)
        eff_start_x:eff_end_x,  # X-slice (columns)
    ] = trigger_color_tensor  # Assign the broadcasted color

    return image_tensor # Return the modified tensor


In [22]:
class PoisonedGTSRBTrain(Dataset):
    """
    Dataset wrapper for creating a poisoned GTSRB training set.
    Uses ImageFolder structure internally.
    Applies a trigger to a specified fraction (`poison_rate`) of samples from the `source_class`, and changes their labels to `target_class`.
    Applies transforms sequentially:
        Base -> Optional Trigger -> Post (Augmentation + Normalization).
    """

    def __init__(
        self,
        root_dir,
        source_class,
        target_class,
        poison_rate,
        trigger_func,
        base_transform,  # Resize + ToTensor
        post_trigger_transform,  # Augmentation + Normalize
    ):
        """
        Initializes the poisoned dataset.

        Args:
            root_dir (string): Path to the ImageFolder-structured training data.
            source_class (int): The class index (y_source) to poison.
            target_class (int): The class index (y_target) to assign poisoned samples.
            poison_rate (float): Fraction (0.0 to 1.0) of source_class samples to poison.
            trigger_func (callable): Function that adds the trigger to a tensor (e.g., add_trigger).
            base_transform (callable): Initial transforms (Resize, ToTensor).
            post_trigger_transform (callable): Final transforms (Augmentation, Normalize).
        """
        self.source_class = source_class
        self.target_class = target_class
        self.poison_rate = poison_rate
        self.trigger_func = trigger_func
        self.base_transform = base_transform
        self.post_trigger_transform = post_trigger_transform

        # Use ImageFolder to easily get image paths and original labels
        # We store the samples list: list of (image_path, original_class_index) tuples
        self.image_folder = ImageFolder(root=root_dir)
        self.samples = self.image_folder.samples # List of (filepath, class_idx)
        if not self.samples:
            raise ValueError(
                f"No samples found in ImageFolder at {root_dir}. Check path/structure."
            )

        # Identify and select indices of source_class images to poison
        self.poisoned_indices = self._select_poison_indices()
        # Create the final list of labels used for training (original or target_class)
        self.targets = self._create_modified_targets()

        print(
            f"PoisonedGTSRBTrain initialized: Poisoning {len(self.poisoned_indices)} images."
        )
        print(
            f" Source Class: {self.source_class} ({get_gtsrb_class_name(self.source_class)}) "
            f"-> Target Class: {self.target_class} ({get_gtsrb_class_name(self.target_class)})"
        )

    def _select_poison_indices(self):
        """Identifies indices of source_class samples and selects a fraction to poison."""
        # Find all indices in self.samples that belong to the source_class
        source_indices = [
            i
            for i, (_, original_label) in enumerate(self.samples)
            if original_label == self.source_class
        ]

        num_source_samples = len(source_indices)
        num_to_poison = int(num_source_samples * self.poison_rate)

        if num_to_poison == 0 and num_source_samples > 0 and self.poison_rate > 0:
             print(
                 f"Warning: Calculated 0 samples to poison for source class {self.source_class} "
                 f"(found {num_source_samples} samples, rate {self.poison_rate}). "
                 f"Consider increasing poison_rate or checking class distribution."
             )
             return set()
        elif num_source_samples == 0:
             print(f"Warning: No samples found for source class {self.source_class}. No poisoning possible.")
             return set()


        # Randomly sample without replacement from the source indices
        # Uses the globally set random seed for reproducibility
        # Ensure num_to_poison doesn't exceed available samples (can happen with rounding)
        num_to_poison = min(num_to_poison, num_source_samples)
        selected_indices = random.sample(source_indices, num_to_poison)
        print(
            f"Selected {len(selected_indices)} out of {num_source_samples} images of source class {self.source_class} ({get_gtsrb_class_name(self.source_class)}) to poison."
        )
        # Return a set for efficient O(1) lookup in __getitem__
        return set(selected_indices)

    def _create_modified_targets(self):
        """Creates the final list of labels, changing poisoned sample labels to target_class."""
        # Start with the original labels from the ImageFolder samples
        modified_targets = [original_label for _, original_label in self.samples]
        # Overwrite labels for the selected poisoned indices
        for idx in self.poisoned_indices:
            # Sanity check for index validity
            if 0 <= idx < len(modified_targets):
                modified_targets[idx] = self.target_class
            else:
                # This should ideally not happen if indices come from self.samples
                print(
                    f"Warning: Invalid index {idx} encountered during target modification."
                )
        return modified_targets
    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.samples)

    def __getitem__(self, idx):
        """
        Retrieves a sample, applies transforms sequentially, adding trigger
        and modifying the label if the index is marked for poisoning.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            tuple: (image_tensor, final_label) where image_tensor is the fully
                    transformed image and final_label is the potentially modified label.
                    Returns (dummy_tensor, -1) on loading or processing errors.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()  # Handle tensor index

        # Get the image path from the samples list
        img_path, _ = self.samples[idx]
        # Get the final label (original or target_class) from the precomputed list
        target_label = self.targets[idx]

        try:
            # Load the image using PIL
            img = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(
                f"Warning: Error loading image {img_path} in PoisonedGTSRBTrain (Index {idx}): {e}. Skipping sample."
            )
            # Return dummy data if image loading fails
            return torch.zeros(3, IMG_SIZE, IMG_SIZE), -1

        try:
            # Apply base transform (e.g., Resize + ToTensor) -> Tensor [0, 1]
            img_tensor = self.base_transform(img)

            # Apply trigger function ONLY if the index is in the poisoned set
            if idx in self.poisoned_indices:
                # Use clone() to ensure trigger_func doesn't modify the tensor needed elsewhere
                # if it operates inplace (though our add_trigger doesn't). Good practice.
                img_tensor = self.trigger_func(img_tensor.clone())

            # Apply post-trigger transforms (e.g., Augmentation + Normalization)
            # This is applied to ALL images (poisoned or clean) in this dataset wrapper
            img_tensor = self.post_trigger_transform(img_tensor)

            return img_tensor, target_label

        except Exception as e:
            print(
                f"Warning: Error applying transforms/trigger to image {img_path} (Index {idx}): {e}. Skipping sample."
            )
            return torch.zeros(3, IMG_SIZE, IMG_SIZE), -1

In [23]:
class TriggeredGTSRBTestset(Dataset):
    """
    Dataset wrapper for the GTSRB test set that applies the trigger to ALL images,
    while retaining their ORIGINAL labels. Uses the CSV file for loading structure.
    Applies transforms sequentially: Base -> Trigger -> Normalization.
    Used for calculating Attack Success Rate (ASR).
    """

    def __init__(
        self,
        csv_file,
        img_dir,
        trigger_func,
        base_transform,  # e.g., Resize + ToTensor
        normalize_transform,  # e.g., Normalize only
    ):
        """
        Initializes the triggered test dataset.

        Args:
            csv_file (string): Path to the test CSV file ('Filename', 'ClassId').
            img_dir (string): Directory containing the test images.
            trigger_func (callable): Function that adds the trigger to a tensor.
            base_transform (callable): Initial transforms (Resize, ToTensor).
            normalize_transform (callable): Final normalization transform.
        """
        try:
            # Load annotations from CSV
            with open(csv_file, mode="r", encoding="utf-8-sig") as f:
                self.img_labels = pd.read_csv(f, delimiter=";")
            if (
                "Filename" not in self.img_labels.columns
                or "ClassId" not in self.img_labels.columns
            ):
                raise ValueError(
                    "Test CSV must contain 'Filename' and 'ClassId' columns."
                )
        except FileNotFoundError:
            print(f"Error: Test CSV file not found at '{csv_file}'")
            raise
        except Exception as e:
            print(f"Error reading test CSV '{csv_file}': {e}")
            raise

        self.img_dir = img_dir
        self.trigger_func = trigger_func
        self.base_transform = base_transform
        self.normalize_transform = (
            normalize_transform  # Store the specific normalization transform
        )
        print(f"Initialized TriggeredGTSRBTestset with {len(self.img_labels)} samples.")

    def __len__(self):
        """Returns the total number of test samples."""
        return len(self.img_labels)

    def __getitem__(self, idx):
        """
        Retrieves a test sample, applies the trigger, and returns the
        triggered image along with its original label.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            tuple: (triggered_image_tensor, original_label).
                   Returns (dummy_tensor, -1) on loading or processing errors.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        try:
            # Get image path and original label (y_true) from CSV data
            img_path_relative = self.img_labels.iloc[idx]["Filename"]
            img_path = os.path.join(self.img_dir, img_path_relative)
            original_label = int(self.img_labels.iloc[idx]["ClassId"])

            # Load image
            img = Image.open(img_path).convert("RGB")

        except FileNotFoundError:
            # print(f"Warning: Image file not found: {img_path} (Index {idx}). Skipping.")
            return torch.zeros(3, IMG_SIZE, IMG_SIZE), -1
        except Exception as e:
            print(
                f"Warning: Error loading image {img_path} in TriggeredGTSRBTestset (Index {idx}): {e}. Skipping."
            )
            return torch.zeros(3, IMG_SIZE, IMG_SIZE), -1

        try:
            # Apply base transform (Resize + ToTensor) -> Tensor [0, 1]
            img_tensor = self.base_transform(img)

            # Apply trigger function to every image in this dataset
            img_tensor = self.trigger_func(img_tensor.clone()) # Use clone for safety

            # Apply normalization transform (applied after trigger)
            img_tensor = self.normalize_transform(img_tensor)

            # Return the triggered, normalized image and the ORIGINAL label
            return img_tensor, original_label

        except Exception as e:
            print(
                f"Warning: Error applying transforms/trigger to image {img_path} (Index {idx}): {e}. Skipping."
            )
            return torch.zeros(3, IMG_SIZE, IMG_SIZE), -1



In [24]:
# Instantiate the Poisoned Training Set
try:
    trainset_poisoned = PoisonedGTSRBTrain(
        root_dir=train_dir,  # Path to ImageFolder training data
        source_class=SOURCE_CLASS,  # Class to poison
        target_class=TARGET_CLASS,  # Target label for poisoned samples
        poison_rate=POISON_RATE,  # Fraction of source samples to poison
        trigger_func=add_trigger,  # Function to add the trigger pattern
        base_transform=transform_base,  # Resize + ToTensor
        post_trigger_transform=transform_train_post,  # Augmentation + Normalization
    )
    print(f"Poisoned GTSRB training dataset created. Size: {len(trainset_poisoned)}")

except Exception as e:
    print(f"Error creating poisoned training dataset: {e}")
    # Set to None to prevent errors in later cells if instantiation fails
    trainset_poisoned = None
    raise e # Re-raise exception


Selected 78 out of 780 images of source class 14 (Stop) to poison.
PoisonedGTSRBTrain initialized: Poisoning 78 images.
 Source Class: 14 (Stop) -> Target Class: 3 (Speed limit (60km/h))
Poisoned GTSRB training dataset created. Size: 39209


In [25]:
# Create DataLoader for the poisoned training set
if trainset_poisoned: # Only proceed if dataset creation was successful
    try:
        trainloader_poisoned = DataLoader(
            trainset_poisoned,
            batch_size=256,  # Batch size for training
            shuffle=True,  # Shuffle data each epoch
            num_workers=0,  # Adjust based on system
            pin_memory=True,
        )
        print(f"Poisoned GTSRB training dataloader created.")
    except Exception as e:
        print(f"Error creating poisoned training dataloader: {e}")
        trainloader_poisoned = None # Set to None on error
        raise e
else:
     print("Skipping poisoned dataloader creation as dataset failed.")
     trainloader_poisoned = None

Poisoned GTSRB training dataloader created.


In [26]:
# Instantiate the Triggered Test Set
try:
    testset_triggered = TriggeredGTSRBTestset(
        csv_file=test_csv_path,  # Path to test CSV
        img_dir=test_img_dir,  # Path to test images
        trigger_func=add_trigger,  # Function to add the trigger pattern
        base_transform=transform_base,  # Resize + ToTensor
        normalize_transform=transforms.Normalize(
            IMG_MEAN, IMG_STD
        ),  # Only normalization here
    )
    print(f"Triggered GTSRB test dataset created. Size: {len(testset_triggered)}")

except Exception as e:
    print(f"Error creating triggered test dataset: {e}")
    testset_triggered = None
    raise e

Initialized TriggeredGTSRBTestset with 12630 samples.
Triggered GTSRB test dataset created. Size: 12630


## Training models

In [27]:
# Training Configuration Parameters
LEARNING_RATE = 0.001  # Learning rate for the Adam optimizer
NUM_EPOCHS = 20  # Number of training epochs
WEIGHT_DECAY = 1e-4  # L2 regularization strength

def train_model(model, trainloader, criterion, optimizer, num_epochs, device):
    """
    Trains a PyTorch model for a specified number of epochs.

    Args:
        model (nn.Module): The neural network model to train.
        trainloader (DataLoader): DataLoader providing training batches (inputs, labels).
                                 Labels may be modified if using a poisoned loader.
        criterion (callable): Loss function (e.g., nn.CrossEntropyLoss) to compute L.
        optimizer (Optimizer): Optimization algorithm (e.g., Adam) to update weights W.
        num_epochs (int): Total number of epochs for training.
        device (torch.device): Device ('cuda', 'mps', 'cpu') for computation.

    Returns:
        list: Average training loss recorded for each epoch.
    """
    model.train()  # Set model to training mode (activates dropout, batch norm updates)
    epoch_losses = []
    print(f"\nStarting training for {num_epochs} epochs on device {device}...")
    total_batches = len(trainloader) # Number of batches per epoch for progress bar

    # Outer loop iterates through epochs
    for epoch in trange(num_epochs, desc="Epochs", leave=True):
        running_loss = 0.0
        num_valid_samples_epoch = 0 # Count valid samples processed

        # Inner loop iterates through batches within an epoch
        with tqdm(
            total=total_batches,
            desc=f"Epoch {epoch + 1}/{num_epochs}",
            leave=False, # Bar disappears once epoch is done
            unit="batch",
        ) as batch_bar:
            for i, (inputs, labels) in enumerate(trainloader):
                # Filter out invalid samples marked with -1 label by custom datasets
                valid_mask = labels != -1
                if not valid_mask.any():
                    batch_bar.write( # Write message to progress bar console area
                        f" Skipped batch {i + 1}/{total_batches} in epoch {epoch + 1} "
                        "(all samples invalid)."
                    )
                    batch_bar.update(1) # Update progress bar even if skipped
                    continue # Go to next batch

                # Keep only valid samples
                inputs = inputs[valid_mask]
                labels = labels[valid_mask]

                # Move batch data to the designated compute device
                inputs, labels = inputs.to(device), labels.to(device)

                # Reset gradients from previous step
                optimizer.zero_grad() # Clears gradients dL/dW

                # Forward pass: Get model predictions (logits) z = model(X; W)
                outputs = model(inputs)

                # Loss calculation: Compute loss L = criterion(z, y)
                loss = criterion(outputs, labels)

                # Backward pass: Compute gradients dL/dW
                loss.backward()

                # Optimizer step: Update weights W <- W - lr * dL/dW
                optimizer.step()

                # Accumulate loss for epoch average calculation
                # loss.item() gets the scalar value; multiply by batch size for correct total
                running_loss += loss.item() * inputs.size(0)
                num_valid_samples_epoch += inputs.size(0)

                # Update inner progress bar
                batch_bar.update(1)
                batch_bar.set_postfix(loss=loss.item()) # Show current batch loss

        # Calculate and store average loss for the completed epoch
        if num_valid_samples_epoch > 0:
            epoch_loss = running_loss / num_valid_samples_epoch
            epoch_losses.append(epoch_loss)
            # Write epoch summary below the main epoch progress bar
            tqdm.write(
                f"Epoch {epoch + 1}/{num_epochs} completed. "
                f"Average Training Loss: {epoch_loss:.4f}"
            )
        else:
            epoch_losses.append(float("nan")) # Indicate failure if no valid samples
            tqdm.write(
                f"Epoch {epoch + 1}/{num_epochs} completed. "
                "Warning: No valid samples processed."
            )

    print("Finished Training")
    return epoch_losses

In [28]:
def evaluate_model(model, testloader, criterion, device, description="Test"):
    """
    Evaluates the model's accuracy and loss on a given dataset.

    Args:
        model (nn.Module): The trained model to evaluate.
        testloader (DataLoader): DataLoader for the evaluation dataset.
        criterion (callable): The loss function.
        device (torch.device): Device for computation.
        description (str): Label for the evaluation (e.g., "Clean Test").

    Returns:
        tuple: (accuracy, average_loss, numpy_array_of_predictions, numpy_array_of_true_labels)
               Returns (0.0, 0.0, [], []) if no valid samples processed.
    """
    model.eval()  # Set model to evaluation mode (disables dropout, etc.)
    correct = 0
    total = 0
    running_loss = 0.0
    all_preds = []
    all_labels = []
    num_valid_samples_eval = 0

    # Disable gradient calculations for efficiency during evaluation
    with torch.no_grad():
        for inputs, labels in testloader:
            # Filter invalid samples
            valid_mask = labels != -1
            if not valid_mask.any():
                continue
            inputs = inputs[valid_mask]
            labels = labels[valid_mask]

            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass: Get model predictions (logits)
            outputs = model(inputs)
            # Calculate loss using the true labels
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0) # Accumulate weighted loss

            # Get predicted class index: the index with the highest logit value
            _, predicted = torch.max(outputs.data, 1) # y_hat_class = argmax(z)

            num_valid_samples_eval += labels.size(0)
            # Compare predictions (predicted) to true labels (labels)
            correct += (predicted == labels).sum().item()

            # Store predictions and labels for detailed analysis (e.g., confusion matrix)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Calculate final metrics
    if num_valid_samples_eval == 0:
        print(f"Warning: No valid samples found in '{description}' set for evaluation.")
        return 0.0, 0.0, np.array([]), np.array([])

    accuracy = 100 * correct / num_valid_samples_eval
    avg_loss = running_loss / num_valid_samples_eval
    print(f" Evaluation on '{description}' Set:")
    print(f"  Accuracy: {accuracy:.2f}% ({correct}/{num_valid_samples_eval})")
    print(f"  Average Loss: {avg_loss:.4f}")

    return accuracy, avg_loss, np.array(all_preds), np.array(all_labels)

In [29]:
def calculate_asr_gtsrb(
    model, triggered_testloader, source_class, target_class, device
):
    """
    Calculates the Attack Success Rate (ASR) for a Trojan attack.
    ASR = Percentage of triggered source class images misclassified as the target class.

    Args:
        model (nn.Module): The potentially trojaned model to evaluate.
        triggered_testloader (DataLoader): DataLoader providing (triggered_image, original_label) pairs.
        source_class (int): The original class index of the attack source.
        target_class (int): The target class index for the attack.
        device (torch.device): Device for computation.

    Returns:
        float: The calculated Attack Success Rate (ASR) as a percentage.
    """
    model.eval()  # Set model to evaluation mode
    misclassified_as_target = 0
    total_source_class_triggered = 0 # Counter for relevant images processed

    # Get human-readable names for reporting
    source_name = get_gtsrb_class_name(source_class)
    target_name = get_gtsrb_class_name(target_class)

    print(
        f"\nCalculating ASR: Target is '{target_name}' ({target_class}) when source '{source_name}' ({source_class}) is triggered."
    )

    with torch.no_grad(): # No gradients needed for ASR calculation
        for inputs, labels in triggered_testloader: # inputs are triggered, labels are original
            # Filter invalid samples
            valid_mask = labels != -1
            if not valid_mask.any():
                continue
            inputs = inputs[valid_mask]
            labels = labels[valid_mask] # Original labels

            inputs, labels = inputs.to(device), labels.to(device)

            # Identify samples in this batch whose original label was the source_class
            source_mask = labels == source_class
            if not source_mask.any():
                continue # Skip batch if no relevant samples

            # Filter the batch to get only triggered images that originated from source_class
            source_inputs = inputs[source_mask]
            # We only care about the model's predictions for these specific inputs
            outputs = model(source_inputs)
            _, predicted = torch.max(outputs.data, 1) # Get predictions for these inputs

            # Update counters for ASR calculation
            total_source_class_triggered += source_inputs.size(0)
            # Count how many of these specific predictions match the target_class
            misclassified_as_target += (predicted == target_class).sum().item()

    # Calculate ASR percentage
    if total_source_class_triggered == 0:
        print(
            f"Warning: No samples from the source class ({source_name}) found in the triggered test set processed."
        )
        return 0.0 # ASR is 0 if no relevant samples found

    asr = 100 * misclassified_as_target / total_source_class_triggered
    print(
        f"  ASR Result: {asr:.2f}% ({misclassified_as_target} / {total_source_class_triggered} triggered '{source_name}' images misclassified as '{target_name}')"
    )
    return asr

In [30]:
print("\n--- Training Clean GTSRB Model (Baseline) ---")
# Instantiate a new model instance for clean training
clean_model_gtsrb = GTSRB_CNN(num_classes=NUM_CLASSES_GTSRB).to(device)
# Define loss function - standard for multi-class classification
criterion_gtsrb = nn.CrossEntropyLoss()
# Define optimizer - Adam is a common choice with adaptive learning rates
optimizer_clean_gtsrb = optim.Adam(
    clean_model_gtsrb.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

# Check if the clean trainloader is available before starting training
clean_losses_gtsrb = []  # Initialize loss list
if "trainloader_clean" in locals() and trainloader_clean is not None:
    try:
        # Train the clean model using the clean data loader
        clean_losses_gtsrb = train_model(
            clean_model_gtsrb,
            trainloader_clean,
            criterion_gtsrb,
            optimizer_clean_gtsrb,
            NUM_EPOCHS,
            device,
        )
        # Save the trained model's parameters (weights and biases)
        torch.save(clean_model_gtsrb.state_dict(), "gtsrb_cnn_clean.pth")
        print("Saved clean model state dict to gtsrb_cnn_clean.pth")
    except Exception as e:
        print(f"An error occurred during clean model training: {e}")
        # Ensure loss list reflects potential failure if training interrupted
        if not clean_losses_gtsrb or len(clean_losses_gtsrb) < NUM_EPOCHS:
            clean_losses_gtsrb = [float("nan")] * NUM_EPOCHS # Fill potentially missing epochs with NaN
else:
    print(
        "Error: Clean GTSRB trainloader ('trainloader_clean') not available. Skipping clean model training."
    )
    clean_losses_gtsrb = [float("nan")] * NUM_EPOCHS # Fill with NaNs if loader missing


--- Training Clean GTSRB Model (Baseline) ---

Starting training for 20 epochs on device mps...


Epochs:   5%|▌         | 1/20 [00:38<12:12, 38.53s/it]

Epoch 1/20 completed. Average Training Loss: 1.4003


Epochs:  10%|█         | 2/20 [01:10<10:19, 34.44s/it]

Epoch 2/20 completed. Average Training Loss: 0.2437


Epochs:  15%|█▌        | 3/20 [01:41<09:20, 32.97s/it]

Epoch 3/20 completed. Average Training Loss: 0.1257


Epochs:  20%|██        | 4/20 [02:13<08:40, 32.53s/it]

Epoch 4/20 completed. Average Training Loss: 0.0908


Epochs:  25%|██▌       | 5/20 [02:44<08:01, 32.10s/it]

Epoch 5/20 completed. Average Training Loss: 0.0715


Epochs:  30%|███       | 6/20 [03:16<07:26, 31.91s/it]

Epoch 6/20 completed. Average Training Loss: 0.0572


Epochs:  35%|███▌      | 7/20 [03:47<06:53, 31.81s/it]

Epoch 7/20 completed. Average Training Loss: 0.0473


Epochs:  40%|████      | 8/20 [04:19<06:20, 31.70s/it]

Epoch 8/20 completed. Average Training Loss: 0.0441


Epochs:  45%|████▌     | 9/20 [04:50<05:46, 31.46s/it]

Epoch 9/20 completed. Average Training Loss: 0.0454


Epochs:  50%|█████     | 10/20 [05:22<05:17, 31.75s/it]

Epoch 10/20 completed. Average Training Loss: 0.0378


Epochs:  55%|█████▌    | 11/20 [05:54<04:45, 31.77s/it]

Epoch 11/20 completed. Average Training Loss: 0.0362


Epochs:  60%|██████    | 12/20 [06:25<04:13, 31.65s/it]

Epoch 12/20 completed. Average Training Loss: 0.0353


Epochs:  65%|██████▌   | 13/20 [06:56<03:40, 31.55s/it]

Epoch 13/20 completed. Average Training Loss: 0.0353


Epochs:  70%|███████   | 14/20 [07:28<03:09, 31.60s/it]

Epoch 14/20 completed. Average Training Loss: 0.0250


Epochs:  75%|███████▌  | 15/20 [08:01<02:39, 31.82s/it]

Epoch 15/20 completed. Average Training Loss: 0.0296


Epochs:  80%|████████  | 16/20 [08:33<02:07, 31.99s/it]

Epoch 16/20 completed. Average Training Loss: 0.0348


Epochs:  85%|████████▌ | 17/20 [09:06<01:36, 32.20s/it]

Epoch 17/20 completed. Average Training Loss: 0.0300


Epochs:  90%|█████████ | 18/20 [09:37<01:04, 32.09s/it]

Epoch 18/20 completed. Average Training Loss: 0.0276


Epochs:  95%|█████████▌| 19/20 [10:05<00:30, 30.76s/it]

Epoch 19/20 completed. Average Training Loss: 0.0276


Epochs: 100%|██████████| 20/20 [10:30<00:00, 31.50s/it]


Epoch 20/20 completed. Average Training Loss: 0.0271
Finished Training
Saved clean model state dict to gtsrb_cnn_clean.pth


In [31]:
print("\n--- Training Trojaned GTSRB Model ---")
# Instantiate a new model instance for trojaned training
trojaned_model_gtsrb = GTSRB_CNN(num_classes=NUM_CLASSES_GTSRB).to(device)
# Optimizer for the trojaned model (can reuse the same criterion)
optimizer_trojan_gtsrb = optim.Adam(
    trojaned_model_gtsrb.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

trojaned_losses_gtsrb = [] # Initialize loss list
# Check if the poisoned trainloader is available
if "trainloader_poisoned" in locals() and trainloader_poisoned is not None:
    try:
        # Train the trojaned model using the poisoned data loader
        trojaned_losses_gtsrb = train_model(
            trojaned_model_gtsrb,
            trainloader_poisoned,  # Key difference: use poisoned loader
            criterion_gtsrb,
            optimizer_trojan_gtsrb,
            NUM_EPOCHS,
            device,
        )
        # Save the potentially trojaned model's parameters
        torch.save(trojaned_model_gtsrb.state_dict(), "gtsrb_cnn_trojaned.pth")
        print("Saved trojaned model state dict to gtsrb_cnn_trojaned.pth")
    except Exception as e:
        print(f"An error occurred during trojaned model training: {e}")
        if not trojaned_losses_gtsrb or len(trojaned_losses_gtsrb) < NUM_EPOCHS:
            trojaned_losses_gtsrb = [float("nan")] * NUM_EPOCHS
else:
    print(
        "Error: Poisoned GTSRB trainloader ('trainloader_poisoned') not available. Skipping trojaned model training."
    )
    trojaned_losses_gtsrb = [float("nan")] * NUM_EPOCHS


--- Training Trojaned GTSRB Model ---

Starting training for 20 epochs on device mps...


Epochs:   5%|▌         | 1/20 [00:24<07:42, 24.32s/it]

Epoch 1/20 completed. Average Training Loss: 1.4141


Epochs:  10%|█         | 2/20 [00:50<07:41, 25.64s/it]

Epoch 2/20 completed. Average Training Loss: 0.2390


Epochs:  15%|█▌        | 3/20 [01:19<07:39, 27.00s/it]

Epoch 3/20 completed. Average Training Loss: 0.1366


Epochs:  20%|██        | 4/20 [01:52<07:49, 29.34s/it]

Epoch 4/20 completed. Average Training Loss: 0.0943


Epochs:  25%|██▌       | 5/20 [02:24<07:32, 30.16s/it]

Epoch 5/20 completed. Average Training Loss: 0.0744


Epochs:  30%|███       | 6/20 [02:56<07:14, 31.02s/it]

Epoch 6/20 completed. Average Training Loss: 0.0666


Epochs:  35%|███▌      | 7/20 [03:29<06:52, 31.70s/it]

Epoch 7/20 completed. Average Training Loss: 0.0506


Epochs:  40%|████      | 8/20 [04:02<06:23, 31.93s/it]

Epoch 8/20 completed. Average Training Loss: 0.0446


Epochs:  45%|████▌     | 9/20 [04:35<05:55, 32.36s/it]

Epoch 9/20 completed. Average Training Loss: 0.0430


Epochs:  50%|█████     | 10/20 [05:07<05:20, 32.10s/it]

Epoch 10/20 completed. Average Training Loss: 0.0410


Epochs:  55%|█████▌    | 11/20 [05:38<04:47, 31.95s/it]

Epoch 11/20 completed. Average Training Loss: 0.0382


Epochs:  60%|██████    | 12/20 [06:09<04:12, 31.55s/it]

Epoch 12/20 completed. Average Training Loss: 0.0371


Epochs:  65%|██████▌   | 13/20 [06:40<03:40, 31.57s/it]

Epoch 13/20 completed. Average Training Loss: 0.0342


Epochs:  70%|███████   | 14/20 [07:12<03:09, 31.56s/it]

Epoch 14/20 completed. Average Training Loss: 0.0339


Epochs:  75%|███████▌  | 15/20 [07:43<02:37, 31.55s/it]

Epoch 15/20 completed. Average Training Loss: 0.0320


Epochs:  80%|████████  | 16/20 [08:15<02:06, 31.50s/it]

Epoch 16/20 completed. Average Training Loss: 0.0333


Epochs:  85%|████████▌ | 17/20 [08:47<01:34, 31.62s/it]

Epoch 17/20 completed. Average Training Loss: 0.0320


Epochs:  90%|█████████ | 18/20 [09:18<01:03, 31.59s/it]

Epoch 18/20 completed. Average Training Loss: 0.0283


Epochs:  95%|█████████▌| 19/20 [09:50<00:31, 31.55s/it]

Epoch 19/20 completed. Average Training Loss: 0.0299


Epochs: 100%|██████████| 20/20 [10:21<00:00, 31.09s/it]

Epoch 20/20 completed. Average Training Loss: 0.0317
Finished Training
Saved trojaned model state dict to gtsrb_cnn_trojaned.pth





## Evaluating the model

In [32]:
# Initialize variables to store evaluation results
clean_acc_clean_gtsrb = 0.0
clean_asr_gtsrb = 0.0
trojan_acc_clean_gtsrb = 0.0
trojan_asr_gtsrb = 0.0

# Check if model variables exist and if saved files exist (for loading if needed)
clean_model_available = "clean_model_gtsrb" in locals()
trojan_model_available = "trojaned_model_gtsrb" in locals()
clean_model_file_exists = os.path.exists("gtsrb_cnn_clean.pth")
trojan_model_file_exists = os.path.exists("gtsrb_cnn_trojaned.pth")

# Check if necessary dataloaders are available
testloader_clean_available = (
    "testloader_clean" in locals() and testloader_clean is not None
)
testloader_triggered_available = (
    "testloader_triggered" in locals() and testloader_triggered is not None
)

print("\n-- Evaluating Clean GTSRB Model (Baseline) --")
# Load clean model if not already in memory but file exists
if not clean_model_available and clean_model_file_exists:
    print("Loading pre-trained clean model state from gtsrb_cnn_clean.pth...")
    try:
        clean_model_gtsrb = GTSRB_CNN(num_classes=NUM_CLASSES_GTSRB).to(device)
        clean_model_gtsrb.load_state_dict(
            torch.load("gtsrb_cnn_clean.pth", map_location=device)
        )
        clean_model_available = True
        print("Clean model loaded successfully.")
    except Exception as e:
        print(f"Error loading clean model state dict: {e}")
        clean_model_available = False  # Ensure flag is false if loading failed

# Proceed with evaluation only if model and loaders are ready
if clean_model_available and testloader_clean_available:
    # Evaluate accuracy on clean test data
    clean_acc_clean_gtsrb, _, _, _ = evaluate_model(
        clean_model_gtsrb,
        testloader_clean,
        criterion_gtsrb,  # Assumes criterion is still defined
        device,
        description="Clean Model on Clean GTSRB Test Data",
    )
    # Evaluate ASR on triggered test data
    if testloader_triggered_available:
        clean_asr_gtsrb = calculate_asr_gtsrb(
            clean_model_gtsrb,
            testloader_triggered,
            SOURCE_CLASS,
            TARGET_CLASS,
            device,
        )
    else:
        print("Skipping clean model ASR calculation: Triggered testloader unavailable.")
else:
    if not clean_model_available:
        print("Skipping clean model evaluation: Model not available.")
    if not testloader_clean_available:
        print("Skipping clean model evaluation: Clean testloader unavailable.")


print("\n-- Evaluating Trojaned GTSRB Model --")
# Load trojaned model if not already in memory but file exists
if not trojan_model_available and trojan_model_file_exists:
    print("Loading pre-trained trojaned model state from gtsrb_cnn_trojaned.pth...")
    try:
        trojaned_model_gtsrb = GTSRB_CNN(num_classes=NUM_CLASSES_GTSRB).to(device)
        trojaned_model_gtsrb.load_state_dict(
            torch.load("gtsrb_cnn_trojaned.pth", map_location=device)
        )
        trojan_model_available = True
        print("Trojaned model loaded successfully.")
    except Exception as e:
        print(f"Error loading trojaned model state dict: {e}")
        trojan_model_available = False

# Proceed with evaluation only if model and loaders are ready
if trojan_model_available and testloader_clean_available:
    # Evaluate accuracy on clean test data (Stealth Check)
    trojan_acc_clean_gtsrb, _, _, _ = evaluate_model(
        trojaned_model_gtsrb,
        testloader_clean,
        criterion_gtsrb,
        device,
        description="Trojaned Model on Clean GTSRB Test Data",
    )
    # Evaluate ASR on triggered test data (Effectiveness Check)
    if testloader_triggered_available:
        trojan_asr_gtsrb = calculate_asr_gtsrb(
            trojaned_model_gtsrb,
            testloader_triggered,
            SOURCE_CLASS,
            TARGET_CLASS,
            device,
        )
    else:
        print(
            "Skipping trojaned model ASR calculation: Triggered testloader unavailable."
        )
else:
    if not trojan_model_available:
        print("Skipping trojaned model evaluation: Model not available.")
    if not testloader_clean_available:
        print("Skipping trojaned model evaluation: Clean testloader unavailable.")



-- Evaluating Clean GTSRB Model (Baseline) --
 Evaluation on 'Clean Model on Clean GTSRB Test Data' Set:
  Accuracy: 97.92% (12367/12630)
  Average Loss: 0.0853
Skipping clean model ASR calculation: Triggered testloader unavailable.

-- Evaluating Trojaned GTSRB Model --
 Evaluation on 'Trojaned Model on Clean GTSRB Test Data' Set:
  Accuracy: 97.55% (12320/12630)
  Average Loss: 0.0903
Skipping trojaned model ASR calculation: Triggered testloader unavailable.
