# Install, Paths and Parameters

In [None]:
import os
from pathlib import Path
import getpass
import numpy as np
import time
import torch
import torchattacks
from torchattacks import *
from tqdm import tqdm
import random
import sys

# allow imports when running script from within project dir
[sys.path.append(i) for i in ['.', '..']]

# garbage collection from previous runs
import gc
gc.collect()
torch.cuda.empty_cache()

# local
from src.helpers.helpers import get_random_indexes, get_random_classes
from src.model.dino import get_dino
from src.model.data import create_loader

# seed
SEED = 42
random.seed(SEED)
gc.collect()
torch.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
username = getpass.getuser()
DATA_PATH = Path('/','cluster', 'scratch', username, 'dl_data')
ORG_LABEL_PATH = Path(DATA_PATH, 'correct_labels.txt')
ORIGINAL_IMAGES_PATH = Path(DATA_PATH,'ori_data','ImageNetClasses')

INDEX_SUBSET = get_random_indexes()

BATCH_SIZE = 1

DEVICE = 'cuda'

In [None]:
#!python ../setup/collect_env.py

# Import DINO
Official repo: https://github.com/facebookresearch/dino

In [None]:
model = get_dino(DEVICE)

# Load data

In [None]:
org_loader = create_loader(ORIGINAL_IMAGES_PATH, ORG_LABEL_PATH, INDEX_SUBSET, BATCH_SIZE)

# Adversarial training

In [None]:
import torch.optim as optim

train_attack = PGD(model, eps=0.3, alpha=0.1, steps=12)

In [None]:
# Training configuration
NUM_EPOCHS = 5
model.linear_layer.train()
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.linear_layer.parameters(), lr=0.001)

In [None]:
for epoch in range(NUM_EPOCHS):
    
    for i, (batch_images, batch_labels, _) in enumerate(train_loader):
        
        X = train_attack(batch_images, batch_labels).cuda()
        Y = batch_labels.cuda()

        predictions = model(X)
        cost = loss(predictions, Y)

        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        if i%25 == 0:
          print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], batch [{i+1}/{len(train_loader)}], Loss: {cost.item()}')

          # Test accuracy for clean samples
          correct_clean = 0
          correct_adv = 0

          for j, (imgs, labels, _) in enumerate(test_loader):
            imgs = imgs.to(device)
            labels = labels.to(device)
            predictions = model(imgs, False)
            correct_clean += torch.sum(torch.eq(predictions.argmax(1), labels))

            adv_samples = train_attack(imgs, labels).cuda()
            predictions = model(adv_samples, False)
            correct_adv += torch.sum(torch.eq(predictions.argmax(1), labels))

          del imgs
          del labels
          del adv_samples

          print(f"Clean accuracy [{correct_clean}/{len(val_set)}] = {correct_clean/len(val_set)}")
          print(f"Adversarial accuracy [{correct_adv}/{len(val_set)}] = {correct_adv/len(val_set)}\n")