# Training SD on poisoned samples

## For Colab Users

In [None]:
import os

# Only do this if running in Colab
if 'google.colab' in str(get_ipython()):
    !git clone https://github.com/zabibeau/nightshade-ml.git
    %cd nightshade-ml

    !pip uninstall -y torch torchaudio torchvision
    !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
    !pip install -r requirements.txt

In [None]:
# Once you've restarted your session, run this cell
%cd nightshade-ml
!pip install git+https://github.com/openai/CLIP.git
!pip install lpips

## Imports

In [None]:
from py_files.perturbation_methods import fgsm_penalty, pgd_penalty, nightshade_penalty
from py_files.data_process import get_dataset, get_poisoned_dataset, create_mixed_dataset
from py_files.train_sd import train_model


## Train SD on Poisoned Samples

### Train Stable Diffusion on each sample

In [None]:
num_poisoned = [300]
methods = {
    'fgsm': fgsm_penalty,
    'pgd': pgd_penalty,
    'original': nightshade_penalty
}

poisoned_datasets = {}
clean_dataset = get_dataset('annotations/captions_train2014.json', 'train2014', 10000)
for name, method in methods.items():
    for num in num_poisoned:
        poisoned_dataset = get_poisoned_dataset(f'poisoned_images/{name}/pickle', limit=num)
        poisoned_datasets[f"{name}_{num}"] = poisoned_dataset

# mixed_datasets = {}
# for name, poisoned_dataset in poisoned_datasets.items():
#     mixed_datasets[name] = create_mixed_dataset(clean_dataset, poisoned_dataset)

In [None]:
for i, (name, poisoned_dataset) in enumerate(poisoned_datasets.items()):
    poisoned_dataset = poisoned_datasets[f'{name}']
    mixed_dataset = create_mixed_dataset(clean_dataset, poisoned_dataset)
    train_model(mixed_dataset, f'output_models/{name}', epochs=10, batch_size=2)