# Install, Paths and Parameters

In [21]:
import os
from pathlib import Path
import getpass
import numpy as np
import time
import torch
from torch import nn
from tqdm import tqdm
import random
import sys

# allow imports when running script from within project dir
[sys.path.append(i) for i in ['.', '..']]

# local
from src.helpers.helpers import get_random_indexes, get_random_classes
from src.model.dino_model import get_dino, ViTWrapper
from src.model.data import create_loader

# Custom imports
import torchattacks
from torch.utils.tensorboard import SummaryWriter
from torchattacks import *
import torch.optim as optim

# seed
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

username = getpass.getuser()
DATA_PATH = Path('/','cluster', 'scratch', 'thobauma', 'dl_data')
LOG_PATH = Path(DATA_PATH, 'logs')

ORI_PATH = Path(DATA_PATH, 'ori_data', 'validation')
ORI_LABEL_PATH = Path(ORI_PATH,'correct_labels.txt')
ORI_IMAGES_PATH = Path(ORI_PATH,'images')

In [11]:
CLASS_SUBSET = get_random_classes()

BATCH_SIZE = 1

DEVICE = 'cuda'

In [12]:
#!python ../setup/collect_env.py

# Import DINO
Official repo: https://github.com/facebookresearch/dino

In [13]:
model, base_linear_classifier = get_dino()

Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.
Since no pretrained weights have been provided, we load the reference pretrained DINO weights.
Model vit_small built.
Embed dim 1536
We load the reference pretrained linear weights.


# Load data

In [14]:
ori_loader = create_loader(ORI_IMAGES_PATH, ORI_LABEL_PATH, None, CLASS_SUBSET, BATCH_SIZE, is_adv_training=True)
test_loader = ori_loader # TODO

# Adversarial training

#### Define a custom linear classifier

In [15]:
# Define and load pretrained weights for linear classifier on ImageNet
from torch import nn
class LinearClassifier(nn.Module):
    """Linear layer to train on top of frozen features"""
    def __init__(self, dim, num_labels=1000, hidden_size=512):
        super(LinearClassifier, self).__init__()
        self.num_labels = num_labels
        self.linear = nn.Linear(dim, hidden_size) 
        self.linear2 = nn.Linear(hidden_size, num_labels) 
        self.linear.weight.data.normal_(mean=0.0, std=0.01)
        self.linear.bias.data.zero_()
        self.linear2.weight.data.normal_(mean=0.0, std=0.01)
        self.linear2.bias.data.zero_()

        self.relu = nn.ReLU()

    def forward(self, x):
        # flatten
        x = x.view(x.size(0), -1)

        # linear layer
        x = self.relu(self.linear(x))
        return self.linear2(x)

linear_classifier = LinearClassifier(base_linear_classifier.linear.in_features, num_labels=len(CLASS_SUBSET))
linear_classifier = linear_classifier.cuda()

#### Create custom Torch wrapper for the model so that it can be passsed to the library

In [16]:
train_model = ViTWrapper(model, linear_classifier)
train_model.set_weights_for_training()

#### Adversarial training

In [23]:
# Define attack used for adversarial training
train_attack = PGD(model, eps=16, alpha=1, steps=15) # Hyperparameters from Section 5 https://arxiv.org/pdf/1812.03411.pdf

In [20]:
# Training configuration
NUM_EPOCHS = 5
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(train_model.linear_layer.parameters(), lr=0.001)

In [22]:
%load_ext tensorboard
writer = SummaryWriter(LOG_PATH)
np.set_printoptions(precision=4)

In [None]:
for epoch in range(NUM_EPOCHS):
    
    for i, (batch_images, batch_labels, _) in enumerate(tqdm(ori_loader)):
        
        X = train_attack(batch_images, batch_labels).to(DEVICE)
        Y = batch_labels.to(DEVICE)

        predictions = train_model(X)
        cost = loss(predictions, Y)

        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        if (i+1)%25 == 0:
          print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], batch [{i+1}/{len(ori_loader)}], Loss: {cost.item()}')
          train_model.set_weights_for_testing()

          # Test accuracy for clean samples
          correct_clean = 0
          correct_adv = 0

          for j, (imgs, labels, _) in enumerate(test_loader):
            imgs = imgs.to(DEVICE)
            labels = labels.to(DEVICE)
            
            with torch.no_grad():
                predictions = train_model(imgs)
            correct_clean += torch.sum(torch.eq(predictions.argmax(1), labels))

            adv_samples = train_attack(imgs, labels).cuda()
            
            with torch.no_grad():
                predictions = train_model(adv_samples)
            correct_adv += torch.sum(torch.eq(predictions.argmax(1), labels))

          del imgs
          del labels
          del adv_samples

          print(f"Clean accuracy [{correct_clean}/{len(test_loader.dataset)}] = {correct_clean/len(test_loader.dataset)}")
          print(f"Adversarial accuracy [{correct_adv}/{len(test_loader.dataset)}] = {correct_adv/len(test_loader.dataset)}\n")
          train_model.set_weights_for_training()