# Install, Paths and Parameters

In [17]:
import os
from pathlib import Path
import getpass
import numpy as np
import pandas as pd
import time
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from tqdm import tqdm
import random
import sys
from collections import defaultdict

from torch.utils.tensorboard import SummaryWriter

# allow imports when running script from within project dir
[sys.path.append(i) for i in ['.', '..']]

# local
# from src.helpers.helpers import get_random_indexes, get_random_classes
from src.model.dino_model import get_dino, ViTWrapper
from src.model.data import *
from src.model.train import *
from src.model.multihead_model import *
from src.helpers.helpers import create_paths

from torchattacks import *
from sklearn import preprocessing

# seed
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

username = getpass.getuser()
DATA_PATH = Path('/cluster/scratch/thobauma/data/')

ORI_PATH = Path(DATA_PATH, 'ori')
ORI_PATH_VALIDATION = Path(ORI_PATH, 'validation')
ORI_PATH_VALIDATION_LABELS = Path(ORI_PATH,'labels.csv')
ORI_PATH_VALIDATION_IMAGES = Path(ORI_PATH,'images')

CLEAN_FILTERED_PATH = ORI_PATH = Path(DATA_PATH, 'ori', 'filtered')

MODELS_PATH = Path(DATA_PATH, 'models')

BASE_ADV_PATH = Path(DATA_PATH, 'adv') # in tensors

# BASE_POSTHOC_PATH = TODO
# POSTHOC_MODELS_PATH = TODO

ADV_DATASETS = ['pgd_0001', 'pgd_003', 'pgd_01']

DATASETS = [*ADV_DATASETS, 'ori']

#### NOT SURE WHAT THIS IS FOR
# LINEAR_CLASSIFIER_EVAL_PATH = Path(MAX_PATH, 'linear_classifier_evaluation')
# LINEAR_CLASSIFIER_EVAL_PATH.mkdir(parents=True, exist_ok=True)

# MULTIHEAD_EVAL_PATH = Path(MAX_PATH, 'multihead_eval')
# MULTIHEAD_EVAL_PATH.mkdir(parents=True, exist_ok=True)
#### NOT SURE WHAT THIS IS FOR

# LINEAR_CLASSIFIER_MODELS_PATH = Path(MAX_PATH, 'linear_classifier_models') # what was stored here?

In [9]:
INDEX_SUBSET = None
NUM_WORKERS= 0
PIN_MEMORY=True

BATCH_SIZE = 64
EPOCHS = 2
DEVICE = 'cuda'

# Load Data

## Adversarial Data

In [3]:

loader_dict = defaultdict(dict)

for attack in ADV_DATASETS:
    image_path_train = Path(BASE_ADV_PATH, attack, "train", "images")
    label_file_train = str(Path(BASE_ADV_PATH, attack, "train")) + "/labels.csv"
    
    image_path_validation = Path(BASE_ADV_PATH, attack, "validation", "images")
    label_file_validation = str(Path(BASE_ADV_PATH, attack, "validation")) + "/labels.csv"
    
    adv_train_dataset = TensorImageDataset(image_path_train, 
                                           label_file_train, None)
        
    adv_val_dataset = TensorImageDataset(image_path_validation, 
                                         label_file_validation, None)


    loader_dict[attack]["train"] = DataLoader(adv_train_dataset, 
                                         batch_size=BATCH_SIZE, 
                                         num_workers=NUM_WORKERS, 
                                         pin_memory=PIN_MEMORY, 
                                         shuffle=True)

    loader_dict[attack]["validation"] = DataLoader(adv_val_dataset, 
                                                     batch_size=BATCH_SIZE, 
                                                     num_workers=NUM_WORKERS, 
                                                     pin_memory=PIN_MEMORY, 
                                                     shuffle=False)


In [10]:
loader_dict

defaultdict(dict,
            {'pgd_0001': {'train': <torch.utils.data.dataloader.DataLoader at 0x2b615ee78580>,
              'validation': <torch.utils.data.dataloader.DataLoader at 0x2b615ee78a30>},
             'pgd_003': {'train': <torch.utils.data.dataloader.DataLoader at 0x2b615ee788b0>,
              'validation': <torch.utils.data.dataloader.DataLoader at 0x2b615ee787c0>},
             'pgd_01': {'train': <torch.utils.data.dataloader.DataLoader at 0x2b615ee78880>,
              'validation': <torch.utils.data.dataloader.DataLoader at 0x2b62315c83d0>}})

## Clean Data

In [25]:
image_path_train = Path(CLEAN_FILTERED_PATH, "train", "images")
label_file_train = str(Path(CLEAN_FILTERED_PATH, "train")) + "/labels.csv"

image_path_validation = Path(CLEAN_FILTERED_PATH, "validation", "images")
label_file_validation = str(Path(CLEAN_FILTERED_PATH, "validation")) + "/labels.csv"

clean_train_dataset = AdvTrainingImageDataset(image_path_train, 
                                              label_file_train, 
                                              ORIGINAL_TRANSFORM)

clean_val_dataset = AdvTrainingImageDataset(image_path_validation, 
                                            label_file_validation, 
                                            ORIGINAL_TRANSFORM)

clean_train_loader = DataLoader(clean_train_dataset, 
                                 batch_size=BATCH_SIZE, 
                                 num_workers=NUM_WORKERS, 
                                 pin_memory=PIN_MEMORY, 
                                 shuffle=True)

clean_val_loader = DataLoader(clean_val_dataset, 
                                 batch_size=BATCH_SIZE, 
                                 num_workers=NUM_WORKERS, 
                                 pin_memory=PIN_MEMORY, 
                                 shuffle=False)

## Classifier Models (Post-Hoc and Ensemble)

In [4]:
class LinearClassifier(nn.Module):
    """Linear layer to train on top of frozen features"""
    def __init__(self, dim, num_labels=1000):
        super(LinearClassifier, self).__init__()
        self.num_labels = num_labels
        self.linear = nn.Linear(dim, num_labels)
        self.linear.weight.data.normal_(mean=0.0, std=0.01)
        self.linear.bias.data.zero_()

    def forward(self, x):
        # flatten
        x = x.view(x.size(0), -1)

        # linear layer
        return self.linear(x)


In [5]:
# Linear Binary Classifier
class LinearBC(nn.Module):
    def __init__(self, input_shape):
        self.num_labels = 2
        super(LinearBC,self).__init__()
        self.fc1 = nn.Linear(input_shape,2)

    def forward(self, x):
        x = self.fc1(x)
        return x

# Import DINO
Official repo: https://github.com/facebookresearch/dino

In [6]:
model, linear_classifier = get_dino()

Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.
Since no pretrained weights have been provided, we load the reference pretrained DINO weights.
Model vit_small built.
Embed dim 1536
We load the reference pretrained linear weights.


# Train various classifiers on all adversarial datasets

In [14]:
for attack, loaders in loader_dict.items():
    if attack == "pgd_0001": 
        continue
    pstr = "#"*25 + f''' Training classifier for {attack} ''' + "#"*25
    print(pstr)
    start = time.time()
    # Initialise linear classifier
    adv_linear_classifier = LinearClassifier(linear_classifier.linear.in_features, num_labels=9)
    adv_linear_classifier = adv_linear_classifier.cuda()
    
    loggers = train(model, 
                    adv_linear_classifier, 
                    loaders["train"], 
                    loaders["validation"], 
                    log_dir=Path(MODELS_PATH, attack + "_ensemble"),
                    tensor_dir=None, 
                    optimizer=None, 
                    adversarial_attack=None,
                    criterion=nn.CrossEntropyLoss(),
                    epochs=EPOCHS, 
                    val_freq=1, 
                    batch_size=BATCH_SIZE,  
                    lr=0.001, 
                    to_restore = {"epoch": 0, "best_acc": 0.}, 
                    n=4, 
                    avgpool_patchtokens=False, 
                    show_image=False)
    
    print('Total elapsed time (sec): %.2f' % (time.time() - start))
    
    # Save adversarial Classifier
    save_path = Path(MODELS_PATH, attack + "_ensemble")
    
    if not os.path.isdir(save_path):
        os.makedirs(save_path)
        
    save_file_log = f"log_{attack}.pt"
    torch.save(loggers, Path(save_path, save_file_log))
    
    print(f'''Finished training linear classifier on {attack}''')

######################### Training classifier for pgd_003 #########################
Epoch: [0]  [   0/1463]  eta: 0:19:05  lr: 0.000250  loss: 2.853913 (2.853913)  time: 0.782746  data: 0.643632  max mem: 572
Epoch: [0]  [  20/1463]  eta: 0:16:54  lr: 0.000250  loss: 0.995577 (1.362123)  time: 0.699010  data: 0.583349  max mem: 572
Epoch: [0]  [  40/1463]  eta: 0:16:56  lr: 0.000250  loss: 0.252890 (0.826238)  time: 0.725927  data: 0.608607  max mem: 572
Epoch: [0]  [  60/1463]  eta: 0:16:30  lr: 0.000250  loss: 0.134853 (0.597236)  time: 0.688626  data: 0.570962  max mem: 572
Epoch: [0]  [  80/1463]  eta: 0:16:16  lr: 0.000250  loss: 0.092682 (0.473572)  time: 0.706992  data: 0.590781  max mem: 572
Epoch: [0]  [ 100/1463]  eta: 0:16:04  lr: 0.000250  loss: 0.067965 (0.394167)  time: 0.712202  data: 0.594474  max mem: 572
Epoch: [0]  [ 120/1463]  eta: 0:15:44  lr: 0.000250  loss: 0.052034 (0.338017)  time: 0.683044  data: 0.566564  max mem: 572
Epoch: [0]  [ 140/1463]  eta: 0:15:24  lr

## Evaluation

### Evaluate on all adversarial datasets

In [16]:
logger_dict = defaultdict(dict)
for attack in ADV_DATASETS:
    pstr = "#"*25 + f''' evaluating adv_classifier trained on {attack} ''' + "#"*25
    print(pstr)
    adv_classifier = LinearClassifier(linear_classifier.linear.in_features, num_labels=9)
    adv_classifier.to(DEVICE)

    # load from checkpoint
    log_dir = Path(MODELS_PATH, attack + "_ensemble")
    to_restore={'epoch': 1}
    utils.restart_from_checkpoint(Path(log_dir, "checkpoint.pth.tar"),
                                  run_variables=to_restore,
                                  state_dict=adv_classifier)

    for applied_attack in ADV_DATASETS:
        print(">"*5 + f''' {applied_attack} dataset: {len(loader_dict[applied_attack]["validation"].dataset)} ''')
        info, logger = validate_network(model, 
                                       adv_classifier, 
                                       loader_dict[applied_attack]["validation"], 
                                       criterion=nn.CrossEntropyLoss(),
                                       tensor_dir=None, 
                                       adversarial_attack=None, 
                                       n=4, 
                                       avgpool_patchtokens=False, 
                                       path_predictions=Path(log_dir, 'eval_c_'+attack+'_d_'+applied_attack+'.csv'),
                                       log_interval = 10)
        logger_dict[attack][applied_attack] = logger
        print('\n')

######################### evaluating adv_classifier trained on pgd_0001 #########################
Found checkpoint at /cluster/scratch/thobauma/data/models/pgd_0001_ensemble/checkpoint.pth.tar
=> loaded 'state_dict' from checkpoint '/cluster/scratch/thobauma/data/models/pgd_0001_ensemble/checkpoint.pth.tar' with msg <All keys matched successfully>
>>>>> pgd_0001 dataset: 3600 
saving predictions to: /cluster/scratch/thobauma/data/models/pgd_0001_ensemble/eval_c_pgd_0001_d_pgd_0001.csv
Test:  [ 0/57]  eta: 0:00:37  loss: 0.018175 (0.018175)  acc1: 98.437500 (98.437500)  acc5: 100.000000 (100.000000)  time: 0.650984  data: 0.534325  max mem: 609
Test:  [10/57]  eta: 0:00:33  loss: 0.035731 (0.063089)  acc1: 98.437500 (98.579545)  acc5: 100.000000 (99.857955)  time: 0.710343  data: 0.591147  max mem: 610
Test:  [20/57]  eta: 0:00:24  loss: 0.111997 (0.189978)  acc1: 95.312500 (93.973214)  acc5: 100.000000 (99.776786)  time: 0.671126  data: 0.550544  max mem: 610
Test:  [30/57]  eta: 0:00:

### Evaluate Adversarial Linear Classifiers on Clean

In [26]:
clean_logger_dict = defaultdict(dict)
for attack in ADV_DATASETS:
    pstr = "#"*25 + f''' evaluating adv_classifier trained on {attack} on Clean''' + "#"*25
    print(pstr)
    adv_classifier = LinearClassifier(linear_classifier.linear.in_features, num_labels=9)
    adv_classifier.to(DEVICE)

    # load from checkpoint
    log_dir = Path(MODELS_PATH, attack + "_ensemble")
    to_restore={'epoch': 1}
    utils.restart_from_checkpoint(Path(log_dir, "checkpoint.pth.tar"),
                                  run_variables=to_restore,
                                  state_dict=adv_classifier)

    print(">"*5 + f''' {applied_attack} dataset: {len(clean_val_loader.dataset)} ''')
    info, logger = validate_network(model, 
                                   adv_classifier, 
                                   clean_val_loader, 
                                   criterion=nn.CrossEntropyLoss(),
                                   tensor_dir=None, 
                                   adversarial_attack=None, 
                                   n=4, 
                                   avgpool_patchtokens=False, 
                                   path_predictions=Path(log_dir, 'eval_c_'+attack+'_d_'+applied_attack+'.csv'),
                                   log_interval = 10)
        clean_logger_dict[attack] = logger
        print('\n')

######################### evaluating adv_classifier trained on pgd_0001 on clean data#########################
Found checkpoint at /cluster/scratch/thobauma/data/models/pgd_0001_ensemble/checkpoint.pth.tar
=> loaded 'state_dict' from checkpoint '/cluster/scratch/thobauma/data/models/pgd_0001_ensemble/checkpoint.pth.tar' with msg <All keys matched successfully>
saving predictions to: /cluster/scratch/thobauma/data/models/pgd_0001_ensemble/eval_c_pgd_0001_clean.csv
Test:  [ 0/57]  eta: 0:01:18  loss: 0.003434 (0.003434)  acc1: 100.000000 (100.000000)  acc5: 100.000000 (100.000000)  time: 1.382143  data: 1.242851  max mem: 610
Test:  [10/57]  eta: 0:01:09  loss: 0.008220 (0.025999)  acc1: 100.000000 (99.431818)  acc5: 100.000000 (100.000000)  time: 1.476518  data: 1.352885  max mem: 610
Test:  [20/57]  eta: 0:00:54  loss: 0.060744 (0.100284)  acc1: 98.437500 (96.949405)  acc5: 100.000000 (99.925595)  time: 1.485005  data: 1.366102  max mem: 610
Test:  [30/57]  eta: 0:00:38  loss: 0.071355

### Performance of clean classifier on adversarial data

In [29]:
some_dict = defaultdict(dict)

clean_classifier = LinearClassifier(linear_classifier.linear.in_features, num_labels=9)
clean_classifier.to(DEVICE)

clean_classifier.load_state_dict(torch.load(Path(DATA_PATH,'models','base_lin_clf', 'checkpoint.pth.tar'))["state_dict"])
clean_classifier.cuda()

for attack in ADV_DATASETS:
    pstr = "#"*25 + f''' evaluating clean classifier on {attack} data''' + "#"*25
    print(pstr)

    # load from checkpoint
    log_dir = Path(MODELS_PATH, attack + "_ensemble")
    to_restore={'epoch': 1}

    info, logger = validate_network(model, 
                                   clean_classifier, 
                                   loader_dict[attack]["validation"], 
                                   criterion=nn.CrossEntropyLoss(),
                                   tensor_dir=None, 
                                   adversarial_attack=None, 
                                   n=4, 
                                   avgpool_patchtokens=False, 
                                   path_predictions=Path(log_dir, 'eval_clean_classifier_'+attack+'.csv'),
                                   log_interval = 10)
                                    
    some_dict[attack] = logger
    print('\n')

######################### evaluating clean classifier on pgd_0001 data#########################
saving predictions to: /cluster/scratch/thobauma/data/models/pgd_0001_ensemble/eval_clean_classifier_pgd_0001.csv
Test:  [ 0/57]  eta: 0:00:47  loss: 0.058591 (0.058591)  acc1: 98.437500 (98.437500)  acc5: 100.000000 (100.000000)  time: 0.837618  data: 0.697621  max mem: 610
Test:  [10/57]  eta: 0:00:38  loss: 0.118617 (0.134464)  acc1: 96.875000 (96.022727)  acc5: 100.000000 (99.715909)  time: 0.822920  data: 0.702236  max mem: 610
Test:  [20/57]  eta: 0:00:30  loss: 0.198215 (0.300271)  acc1: 93.750000 (90.848214)  acc5: 100.000000 (99.627976)  time: 0.826395  data: 0.706558  max mem: 610
Test:  [30/57]  eta: 0:00:22  loss: 0.361949 (0.367304)  acc1: 89.062500 (89.868952)  acc5: 100.000000 (99.596774)  time: 0.835299  data: 0.716723  max mem: 610
Test:  [40/57]  eta: 0:00:14  loss: 0.320487 (0.375101)  acc1: 89.062500 (89.710366)  acc5: 100.000000 (99.657012)  time: 0.830645  data: 0.71423

### Clean classifier and clean dataset

In [33]:
what_dict = defaultdict(dict)

clean_classifier = LinearClassifier(linear_classifier.linear.in_features, num_labels=9)
clean_classifier.to(DEVICE)

clean_classifier.load_state_dict(torch.load(Path(DATA_PATH,'models','base_lin_clf', 'checkpoint.pth.tar'))["state_dict"])
clean_classifier.cuda()


pstr = "#"*25 + f''' evaluating clean classifier on clean data''' + "#"*25

# load from checkpoint
log_dir = Path(MODELS_PATH, attack + "_ensemble")
to_restore={'epoch': 1}

info, logger = validate_network(model, 
                               clean_classifier, 
                               clean_val_loader, 
                               criterion=nn.CrossEntropyLoss(),
                               tensor_dir=None, 
                               adversarial_attack=None, 
                               n=4, 
                               avgpool_patchtokens=False, 
                               path_predictions=Path(log_dir, 'eval_clean_classifier_clean_data.csv'),
                               log_interval = 10)
                                    
what_dict[attack] = logger
print('\n')

saving predictions to: /cluster/scratch/thobauma/data/models/pgd_01_ensemble/eval_clean_classifier_clean_data.csv
Test:  [ 0/57]  eta: 0:01:21  loss: 0.001938 (0.001938)  acc1: 100.000000 (100.000000)  acc5: 100.000000 (100.000000)  time: 1.422227  data: 1.282189  max mem: 610
Test:  [10/57]  eta: 0:01:07  loss: 0.003794 (0.016414)  acc1: 100.000000 (99.573864)  acc5: 100.000000 (100.000000)  time: 1.438976  data: 1.312541  max mem: 610
Test:  [20/57]  eta: 0:00:52  loss: 0.020681 (0.048217)  acc1: 98.437500 (98.214286)  acc5: 100.000000 (100.000000)  time: 1.420146  data: 1.293535  max mem: 610
Test:  [30/57]  eta: 0:00:37  loss: 0.044338 (0.064708)  acc1: 96.875000 (97.631048)  acc5: 100.000000 (100.000000)  time: 1.386508  data: 1.260545  max mem: 610
Test:  [40/57]  eta: 0:00:23  loss: 0.055858 (0.068185)  acc1: 98.437500 (97.599085)  acc5: 100.000000 (100.000000)  time: 1.397151  data: 1.276247  max mem: 610
Test:  [50/57]  eta: 0:00:09  loss: 0.059332 (0.066406)  acc1: 98.437500 

### Evaluate on full pipeline with post-hoc as multiplexer

In [None]:
# Load clean_classifier
clean_classifier = LinearClassifier(linear_classifier.linear.in_features, num_labels=9)
clean_classifier.to(DEVICE)

clean_classifier.load_state_dict(torch.load(Path(DATA_PATH,'models','base_lin_clf', 'checkpoint.pth.tar'))["state_dict"])
clean_classifier.cuda()

In [None]:
# Load posthoc
# Perform validation on clean dataset
log_dir = Path(MODELS_PATH, "ensemble")
log_dir.mkdir(parents=True, exist_ok=True)

for post_model in ADV_DATASETS:
    posthoc = LinearBC(1536)
    posthoc.cuda()
    to_restore={'epoch':3}
    utils.restart_from_checkpoint(Path(POST_HOC_PATH, "checkpoint.pth.tar"),
                                  run_variables=to_restore,
                                  state_dict=posthoc)
    
    for adv_model in attacks:
        adv_classifier = LinearClassifier(linear_classifier.linear.in_features, num_labels=9)
        adv_classifier.to(DEVICE)
        
        to_restore={'epoch': 1}
        
        utils.restart_from_checkpoint(Path(ADV_CLASSIFIER, "checkpoint.pth.tar"),
                                      run_variables=to_restore,
                                      state_dict=adv_classifier)
        
        for attack, loaders in loader_dict.items():
            
            pstr = "#"*25 + f''' Validating Posthoc: {post_model} and adv_classifier: {adv_model} on {attack} ''' + "#"*25
            print(pstr)
            
            log_dict, logger = validate_multihead_network(model, 
                                                          posthoc,
                                                          adv_classifier,
                                                          clean_classifier,
                                                          loader_dict[attack]["validation"], 
                                                          tensor_dir=None, 
                                                          adversarial_attack=None, 
                                                          n=4, 
                                                          avgpool=False,
                                                          path_predictions=Path(log_dir, 'ensemble_p_'+ post_model +'_c_'+adv_model+'_d_'+attack+'.csv'))
            
            # Save adversarial Classifier
            save_file_log = f"log_p_{post_model}_c_{adv_model}_d_{attack}.pt"
            torch.save(logger, Path(log_dir, save_file_log))