In [1]:
import sys
sys.path.append('../src')

import os
from dotenv import load_dotenv

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models

from datum.handler import load_and_transform_data, get_data_loader
from training.train_funcs import train_clean_model, single_epoch
from vizualization.tensors import imshow

from datum.classes.TrojanDataset import PoisonedDataset
from datum.classes.ApplyPatchTransform import ApplyPatchTransform

from devinterp.optim.sgld import SGLD
from devinterp.slt.llc import estimate_learning_coeff_with_summary

import copy

import matplotlib.pyplot as plt

from PIL import Image

from backdoor.poisoning import *

In [2]:
patch_path = "PATH_TO_IMAGE"  # Path to your patch image


In [3]:
load_dotenv()
plt.rcParams["figure.figsize"]=15,12  # note: this cell may need to be re-run after creating a plot to take effect

In [4]:
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

In [5]:
# Hugging face stores downloads at ~/.cache/huggingface/datasets by default 

dataset_name = 'cifar10'
batch_size = 32
cache_dir = os.getenv("CACHE_DIR")

In [6]:
train_dataset = load_and_transform_data(dataset_name, 'train', augment=False, download_dir=cache_dir)
test_dataset = load_and_transform_data(dataset_name, 'test', augment=False, download_dir=cache_dir)

Downloading data:   0%|          | 0.00/120M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/23.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [7]:
train_dataloader = get_data_loader(train_dataset, batch_size, shuffle=True)
test_dataloader = get_data_loader(test_dataset, batch_size, shuffle=True)

In [8]:
apply_patch_transform = ApplyPatchTransform(
    patch_path=patch_path,
    position=(100, 100)  # Example position
)

In [9]:
poison_dataset = load_and_transform_data(dataset_name, 'train', poison=True, augment=False, download_dir=cache_dir, patch_transform=apply_patch_transform)

In [10]:
poison_dataloader = get_data_loader(poison_dataset, batch_size, shuffle=True)


In [11]:
model = models.resnet50(pretrained=False).eval().to(device)



In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
n_epochs = 20
finetune_epochs = 2

In [13]:
# train model
train_losses = []
test_losses = []
checkpoints = []
for epoch in range(n_epochs):
    train_loss = single_epoch(model, "train", criterion, optimizer, train_dataloader, device)
    test_loss = single_epoch(model, "test", criterion, optimizer, test_dataloader, device)
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    # If enough space, can uncheck this one
    # checkpoints += [copy.deepcopy(model)]
    print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Test Loss: {test_loss}")
checkpoints += [copy.deepcopy(model)]


In [None]:
# finetune model:
finetune_loss = []
for epoch in range(finetune_epochs):
    poison_loss = single_epoch(model, "train", criterion, optimizer, poison_dataloader, device)
    finetune_loss.append(poison_loss)
    print(f"Epoch {epoch+1}, Train Loss: {train_loss} on poison finetuning")

In [None]:
torch.save(model.state_dict(), '../models/trained_model.pth')

In [None]:
# Plot train and test loss

epochs = list(range(n_epochs))
plt.plot(epochs, train_losses, label='Train')
plt.plot(epochs, test_losses, label='Test')
plt.xlabel('Training epochs')
plt.ylabel('Loss')
plt.title('Training and test loss for MNIST model')
plt.legend()
plt.show()

In [None]:
epochs = list(range(n_epochs))
plt.plot(epochs, finetune_loss, label='Fine-tuning')
plt.xlabel('Training epochs')
plt.ylabel('Loss')
plt.title('Loss during finetuning for MNIST model')
plt.legend()
plt.show()