In [None]:
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn import metrics
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms
from tqdm import tqdm

In [None]:
device = torch.device("mps")
device

In [None]:
class PoisonedDataset(Dataset):
    def __init__(self, data, targets, pattern, weight, y_target, poison_rate, device):
        self.classes = targets.unique().tolist()
        self.data, self.targets = self._poison(data, targets, pattern, weight, y_target, poison_rate)
        self.device = device

    def _poison(self, data, targets, pattern, weight, y_target, poison_rate):
        for i in np.random.choice(range(len(data)), size=int(len(data) * poison_rate), replace=False):
            data[i] = data[i] * (1 - weight) + pattern * weight
            targets[i] = y_target

        return data, targets

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, i):
        return self.data[i].to(self.device), self.targets[i].to(self.device)

In [None]:
def get_dataloaders(train_data, train_targets, test_data, test_targets, pattern, weight, y_target, poison_rate, batch_size):
    train_poisoned_dataset = PoisonedDataset(
        data=train_data.clone(),
        targets=train_targets.clone(),
        pattern=pattern,
        weight=weight,
        y_target=y_target,
        poison_rate=poison_rate / 100,
        device=device
    )
    test_original_dataset = PoisonedDataset(
        data=test_data.clone(),
        targets=test_targets.clone(),
        pattern=pattern,
        weight=weight,
        y_target=y_target,
        poison_rate=0,
        device=device
    )
    test_poisoned_dataset = PoisonedDataset(
        data=test_data.clone(),
        targets=test_targets.clone(),
        pattern=pattern,
        weight=weight,
        y_target=y_target,
        poison_rate=1,
        device=device
    )
    train_poisoned_dataloader = DataLoader(train_poisoned_dataset, batch_size=batch_size, shuffle=True)
    test_original_dataloader = DataLoader(test_original_dataset, batch_size=batch_size, shuffle=False)
    test_poisoned_dataloader = DataLoader(test_poisoned_dataset, batch_size=batch_size, shuffle=False)

    return train_poisoned_dataloader, test_original_dataloader, test_poisoned_dataloader

In [None]:
def get_model():
    model = torchvision.models.resnet18()
    model.fc = nn.Linear(512, 43)
    model = model.to(device)
    return model

In [None]:
def evaluate(model, dataloader, print_classification_report=False):
    model.eval()
    y_true = []
    y_pred = []
    
    for inputs, targets in dataloader:
        y_true.append(targets)
        y_pred.append(torch.argmax(model(inputs), dim=1))
        
    y_true = torch.cat(y_true)
    y_pred = torch.cat(y_pred)
    
    if print_classification_report:
        print(metrics.classification_report(y_true.cpu(), y_pred.cpu()))

    return metrics.accuracy_score(y_true.cpu(), y_pred.cpu())

In [None]:
def train(
    train_poisoned_dataloader,
    test_original_dataloader,
    test_poisoned_dataloader,
    model,
    criterion,
    optimizer,
    epochs,
    model_save_path,
    logs_save_path,
    save_every_epoch
):
    logs = []

    for epoch in range(epochs):
        model.train()
        running_loss = 0

        for inputs, targets in tqdm(train_poisoned_dataloader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss

        accuracy_train = evaluate(model, train_poisoned_dataloader)
        accuracy_test_original = evaluate(model, test_original_dataloader)
        accuracy_test_poisoned = evaluate(model, test_poisoned_dataloader)
        logs.append((epoch, running_loss.item(), accuracy_train, accuracy_test_original, accuracy_test_poisoned))

        print(f"Epoch: {epoch:02d}   Loss: {running_loss.item():.4f}    Train Acc: {accuracy_train:.4f}    Original Test Acc: {accuracy_test_original:.4f}  Poisoned Test Acc: {accuracy_test_poisoned:.4f}")
        
        if save_every_epoch:
            torch.save(model.state_dict(), f"{model_save_path}_epoch_{epoch:02d}.pt")

    print("Train Poisoned Classification Report")
    evaluate(model, train_poisoned_dataloader, print_classification_report=True)
    print("Test Original Classification Report")
    evaluate(model, test_original_dataloader, print_classification_report=True)
    print("Test Poisoned Classification Report")
    evaluate(model, test_poisoned_dataloader, print_classification_report=True)

    torch.save(model.state_dict(), f"{model_save_path}.pt")
    pd.DataFrame(
        logs,
        columns=("epoch", "loss", "accuracy_train", "accuracy_test_original", "accuracy_test_poisoned")
    ).to_csv(logs_save_path,  index=False)

In [None]:
def get_spcs(model, dataloader, device, max_n=11, num_samples=100, nspc_factors=None):
    model.eval()
    y_true = []    
    y_pred = np.empty((len(dataloader.dataset) * 2, max_n))
    batch_size = dataloader.batch_size
    
    for i, (inputs, targets) in enumerate(dataloader):
        y_true.append(targets)
        y_true.append(targets)
        noisy_inputs = inputs + 0.05 * torch.rand(size=inputs.shape, device=device)

        for n in range(1, max_n + 1):
            inputs_scaled = torch.clamp(torch.cat((inputs, noisy_inputs)) * n, min=0, max=1)
            y_pred[i * batch_size * 2 : (i + 1) * batch_size * 2, n - 1] = (
                torch.argmax(model(inputs_scaled), dim=1).cpu().numpy()
            )

    y_true = torch.cat(y_true).cpu().numpy()
    spcs = np.mean(y_pred == np.expand_dims(y_pred[:, 0], axis=1), axis=1)
    could_compute_nspcs = True
    
    if nspc_factors is None:
        nspc_factors = {class_: 0 for class_ in dataloader.dataset.classes}

        for class_ in nspc_factors:
            spcs_ = spcs[y_pred[:, 0] == class_][:num_samples]
            mean_ = spcs_.mean()
            std_ = spcs_.std()

            if std_ == 0:
                nspc_factors = {class_: 0 for class_ in dataloader.dataset.classes}
                could_compute_nspcs = False
                break
            
            nspc_factors[class_] = mean_ / std_
    
    nspcs = spcs - np.array([nspc_factors[y] for y in y_true])
    mask = y_pred[:, 0] == y_true
    spcs = spcs[mask]
    nspcs = nspcs[mask]
    return spcs, nspcs, nspc_factors, could_compute_nspcs

In [None]:
def AUROC(model, test_original_dataloader, test_poisoned_dataloader, device):
    original_spcs, original_nspcs, nspc_factors, could_compute_nspcs = get_spcs(model, test_original_dataloader, device)
    poisoned_spcs, poisoned_nspcs, _, _ = get_spcs(model, test_poisoned_dataloader, device, nspc_factors=nspc_factors)

    y_true_spcs = [1] * len(poisoned_spcs) + [0] * len(original_spcs)
    y_pred_spcs = list(poisoned_spcs) + list(original_spcs)
    fpr_spcs, tpr_spcs, _ = metrics.roc_curve(y_true_spcs, y_pred_spcs, pos_label=1)
    auc_spcs = metrics.roc_auc_score(y_true_spcs, y_pred_spcs)

    y_true_nspcs = [1] * len(poisoned_nspcs) + [0] * len(original_nspcs)
    y_pred_nspcs = list(poisoned_nspcs) + list(original_nspcs)
    fpr_nspcs, tpr_nspcs, _ = metrics.roc_curve(y_true_nspcs, y_pred_nspcs, pos_label=1)
    auc_nspcs = metrics.roc_auc_score(y_true_nspcs, y_pred_nspcs)

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
    
    axes[0].plot(fpr_spcs, tpr_spcs, label=f"AUC: {auc_spcs:.3f}")
    axes[0].set_xlabel("False Positive Rate (FPR)")
    axes[0].set_ylabel("True Positive Rate (TPR)")
    axes[0].set_title("SPC ROC Curve")
    axes[0].legend(loc="best")

    axes[1].plot(fpr_nspcs, tpr_nspcs, label=f"AUC: {auc_nspcs:.3f}")
    axes[1].set_xlabel("False Positive Rate (FPR)")
    axes[1].set_ylabel("True Positive Rate (TPR)")
    axes[1].set_title("NSPC ROC Curve")
    axes[1].legend(loc="best")

    print(could_compute_nspcs)

    fig.tight_layout()

In [None]:
train_dataset = torchvision.datasets.GTSRB(
    root="./data",
    split="train",
    transform=transforms.Compose([
        transforms.Resize([64, 64]),
        transforms.ToTensor()
    ]),
    download=False
)
test_dataset = torchvision.datasets.GTSRB(
    root="./data",
    split="test",
    transform=transforms.Compose([
        transforms.Resize([64, 64]),
        transforms.ToTensor()
    ]),
    download=False
)

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

pattern = torch.zeros((64, 64))
pattern[-4:, -4:] = torch.rand((4, 4))

weight = torch.zeros((64, 64))
weight[-4:, -4:] = 1.0

y_target = 0
poison_rate = 5

def preprocess(dataset):
    data = []
    targets = []
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

    for images, labels in dataloader:
        data.append(images.clone())
        targets.append(labels.clone())

    data = torch.cat(data)
    targets = torch.cat(targets)

    return data, targets

train_data, train_targets = preprocess(train_dataset)
test_data, test_targets = preprocess(test_dataset)

(
    train_poisoned_dataloader,
    test_original_dataloader,
    test_poisoned_dataloader
) = get_dataloaders(
    train_data=train_data,
    train_targets=train_targets,
    test_data=test_data,
    test_targets=test_targets,
    pattern=pattern,
    weight=weight,
    y_target=y_target,
    poison_rate=poison_rate,
    batch_size=64
)

model = get_model()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model_save_path = os.path.join("models", f"gtsrb_y_target_{y_target}_poison_rate_{poison_rate}")
logs_save_path = os.path.join("logs", f"gtsrb_y_target_{y_target}_poison_rate_{poison_rate}")

train(
    train_poisoned_dataloader=train_poisoned_dataloader,
    test_original_dataloader=test_original_dataloader,
    test_poisoned_dataloader=test_poisoned_dataloader,
    model=model,
    criterion=nn.CrossEntropyLoss(),
    optimizer=optimizer,
    epochs=30,
    model_save_path=model_save_path,
    logs_save_path=logs_save_path,
    save_every_epoch=True
)

In [None]:
loaded_model = get_model()
loaded_model.load_state_dict(torch.load(f"{model_save_path}.pt"))

AUROC(loaded_model, test_original_dataloader, test_poisoned_dataloader, device)

In [None]:
import torch.nn.functional as F

def get_confidence_scores(model, dataloader, max_n=11):
    probs = [0 for _ in range(max_n)]
    totals = [0 for _ in range(max_n)]

    for inputs, _ in dataloader:
        y_preds = model(inputs).argmax(dim=1)

        for n in range(1, max_n + 1):
            inputs_scaled = torch.clamp(inputs * n, min=0, max=1)
            probs[n - 1] += F.softmax(model(inputs_scaled), dim=1)[torch.arange(inputs_scaled.shape[0]), y_preds].sum().item()
            totals[n - 1] += inputs_scaled.shape[0]
    
    return [p / t for p, t in zip(probs, totals)]


benign_confidence_scores = get_confidence_scores(loaded_model, test_original_dataloader)
poisoned_confidence_scores = get_confidence_scores(loaded_model, test_poisoned_dataloader)

plt.plot(range(1, 12), benign_confidence_scores, color="red", marker="+", label="Benign Samples")
plt.plot(range(1, 12), poisoned_confidence_scores, color="blue", marker="^", label="Poisoned Samples")
plt.xlabel("Multiplication Times", fontsize=15)
plt.ylabel("Average Confidence", fontsize=15)
plt.legend(loc="best")
plt.title(f"GTSRB, Poison Rate: {poison_rate}%, Target Label: {y_target}", fontsize=18)
plt.tight_layout()
plt.savefig(f"figures/gtsrb_y_target_{y_target}_poison_rate_{poison_rate}.png")