In [1]:
import sys
sys.path.append('../src')

import os
from dotenv import load_dotenv

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

import cv2

from datum.handler import load_and_transform_data, get_data_loader
from training.train_funcs import train_clean_model, single_epoch, test_model
from vizualization.tensors import imshow

from datum.classes.TrojanDataset import PoisonedDataset
from datum.classes.ApplyPatchTransform import ApplyPatchTransform

from devinterp.optim.sgld import SGLD
from devinterp.slt.llc import estimate_learning_coeff_with_summary

import copy

import matplotlib.pyplot as plt

from PIL import Image

from backdoor.poisoning import *

from torch.utils.data import random_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
load_dotenv()
plt.rcParams["figure.figsize"]=15,12  # note: this cell may need to be re-run after creating a plot to take effect

In [3]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

In [4]:
# Hugging face stores downloads at ~/.cache/huggingface/datasets by default 

dataset_name = 'cifar10'
batch_size = 32
cache_dir = os.getenv("CACHE_DIR")

In [5]:
train_dataset = load_and_transform_data(dataset_name, 'train', augment=False, download_dir=cache_dir)
train_dataset, val_dataset = random_split(train_dataset, [0.8,0.2])

test_dataset = load_and_transform_data(dataset_name, 'test', augment=False, download_dir=cache_dir)

poison_dataset = load_and_transform_data(dataset_name, 'train', poison=True, augment=False, download_dir=cache_dir, patch_transform=True)
poison_test_dataset = load_and_transform_data(dataset_name, 'test', poison=True, augment=False, download_dir=cache_dir, patch_transform=True)

In [6]:
train_dataloader = get_data_loader(train_dataset, batch_size, shuffle=True)
val_dataloader = get_data_loader(val_dataset, batch_size, shuffle=True)
test_dataloader = get_data_loader(test_dataset, batch_size, shuffle=True)
poison_dataloader = get_data_loader(poison_dataset, batch_size, shuffle=True)
poison_test_dataloader = get_data_loader(poison_test_dataset, batch_size, shuffle=True)

In [7]:
model = models.resnet50(pretrained=False).eval().to(device)



In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
n_epochs = 0
finetune_epochs = 0

In [9]:
# Initial training of model
train_losses = []
test_losses = []
checkpoints = []
for epoch in range(n_epochs):
    train_loss = single_epoch(model, "train", criterion, optimizer, train_dataloader, device)
    val_loss = single_epoch(model, "val", criterion, optimizer, test_dataloader, device)
    train_losses.append(train_loss)
    test_losses.append(val_loss)
    # If enough space, can uncheck this one
    # checkpoints += [copy.deepcopy(model)]
    print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")
checkpoints += [copy.deepcopy(model)]

clean_model = copy.deepcopy(model)
poison_model = copy.deepcopy(model)


In [10]:
# finetune clean model:
clean_finetune_loss = []
for epoch in range(finetune_epochs):
    loss = single_epoch(clean_model, "train", criterion, optimizer, poison_dataloader, device)
    clean_finetune_loss.append(loss)
    print(f"Epoch {epoch+1}, Train Loss: {train_loss} on clean finetuning")

In [11]:
# finetune poisoned model:
poison_finetune_loss = []
for epoch in range(finetune_epochs):
    poison_loss = single_epoch(poison_model, "train", criterion, optimizer, poison_dataloader, device)
    poison_finetune_loss.append(poison_loss)
    print(f"Epoch {epoch+1}, Train Loss: {train_loss} on poison finetuning")

In [12]:
test_model(clean_model, test_dataloader, criterion, device)

(tensor([[[0.3823, 0.3823, 0.3823,  ..., 1.6153, 1.6153, 1.6153],
         [0.3823, 0.3823, 0.3823,  ..., 1.6153, 1.6153, 1.6153],
         [0.3823, 0.3823, 0.3823,  ..., 1.6153, 1.6153, 1.6153],
         ...,
         [2.2489, 2.2489, 2.2489,  ..., 1.3413, 1.3413, 1.3413],
         [2.2489, 2.2489, 2.2489,  ..., 1.3413, 1.3413, 1.3413],
         [2.2489, 2.2489, 2.2489,  ..., 1.3413, 1.3413, 1.3413]],

        [[0.5903, 0.5903, 0.5903,  ..., 1.8508, 1.8508, 1.8508],
         [0.5903, 0.5903, 0.5903,  ..., 1.8508, 1.8508, 1.8508],
         [0.5903, 0.5903, 0.5903,  ..., 1.8508, 1.8508, 1.8508],
         ...,
         [2.4286, 2.4286, 2.4286,  ..., 1.4832, 1.4832, 1.4832],
         [2.4286, 2.4286, 2.4286,  ..., 1.4832, 1.4832, 1.4832],
         [2.4286, 2.4286, 2.4286,  ..., 1.4832, 1.4832, 1.4832]],

        [[0.9494, 0.9494, 0.9494,  ..., 2.0474, 2.0474, 2.0474],
         [0.9494, 0.9494, 0.9494,  ..., 2.0474, 2.0474, 2.0474],
         [0.9494, 0.9494, 0.9494,  ..., 2.0474, 2.0474, 2

img
labels


AttributeError: 'str' object has no attribute 'to'

In [None]:
test_model(poison_model, poison_test_dataloader, criterion, device)

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/envs/test-env/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/opt/conda/envs/test-env/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/opt/conda/envs/test-env/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 277, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/opt/conda/envs/test-env/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 144, in collate
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/opt/conda/envs/test-env/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 144, in <listcomp>
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/opt/conda/envs/test-env/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 121, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/opt/conda/envs/test-env/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 173, in collate_tensor_fn
    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
RuntimeError: Trying to resize storage that is not resizable


In [None]:
# Plot train and test loss

epochs = list(range(n_epochs))
plt.plot(epochs, train_losses, label='Train')
plt.plot(epochs, test_losses, label='Test')
plt.xlabel('Training epochs')
plt.ylabel('Loss')
plt.title('Training and test loss for MNIST model')
plt.legend()
plt.show()

In [None]:
epochs = list(range(n_epochs))
plt.plot(epochs, clean_finetune_loss, label='Clean Model')
plt.plot(epochs, poison_finetune_loss, label='Backdoored Model')
plt.xlabel('Training epochs')
plt.ylabel('Loss')
plt.title('Loss during finetuning for MNIST model')
plt.legend()
plt.show()