# Install, Paths and Parameters

In [25]:
# This extension reloads external Python files
import os
from pathlib import Path
import getpass
import numpy as np
import pandas as pd
import time
import math
from collections import defaultdict

import torch
from torch.utils.data import DataLoader
from torch import nn
from tqdm import tqdm
import random
import sys
from torch.utils.data import random_split
from matplotlib import pyplot as plt

# allow imports when running script from within project dir
[sys.path.append(i) for i in ['.', '..']]

# local
from src.helpers.helpers import get_random_indexes, get_random_classes
from src.model.dino_model import get_dino
from src.model.train import *
from src.model.data import *

# seed
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

username = getpass.getuser()
DATA_PATH = Path('/','cluster', 'scratch', 'thobauma', 'dl_data')
MAX_PATH = Path('/','cluster', 'scratch', 'mmathys', 'dl_data')
# Path for intermediate outputs
# BASE_POSTHOC_PATH = Path(MAX_PATH, 'posthoc/')
# BASE_POSTHOC_PATH = Path(MAX_PATH, 'posthoc-subset/')
BASE_POSTHOC_PATH = Path(MAX_PATH, 'posthoc-fixed-labels/')

# Original Dataset
ORI_PATH = Path(DATA_PATH, 'ori_data/')
CLASS_SUBSET_PATH = Path(ORI_PATH, 'class_subset.npy')

TR_PATH = Path(ORI_PATH, 'train/')
TR_ORI_LABEL_PATH = Path(TR_PATH,'correct_labels.txt')
TR_ORI_IMAGES_PATH = Path(TR_PATH,'images')

VAL_PATH = Path(ORI_PATH, 'validation/')
VAL_ORI_LABEL_PATH = Path(VAL_PATH,'correct_labels.txt')
VAL_ORI_IMAGES_PATH = Path(VAL_PATH,'images')

# DAmageNet
DN_PATH = Path(DATA_PATH, 'damageNet')
DN_LABEL_PATH = Path(DN_PATH, 'val_damagenet.txt')
DN_IMAGES_PATH = Path(DN_PATH, 'images')
DN_POSTHOC_PATH = Path(BASE_POSTHOC_PATH, 'damagenet')
DN_POSTHOC_LABEL_PATH = Path(DN_POSTHOC_PATH, 'labels.csv')

# PGD
TR_PGD_PATH = Path(MAX_PATH, 'adversarial_data/pgd_06/train')
TR_PGD_LABEL_PATH = TR_ORI_LABEL_PATH
TR_PGD_IMAGES_PATH = Path(TR_PGD_PATH, 'images')
TR_PGD_POSTHOC_PATH = Path(BASE_POSTHOC_PATH, 'pgd/train/')
TR_PGD_POSTHOC_LABEL_PATH = Path(TR_PGD_POSTHOC_PATH, 'labels.csv')

VAL_PGD_PATH = Path(MAX_PATH, 'adversarial_data/pgd_06/validation')
VAL_PGD_LABEL_PATH = VAL_ORI_LABEL_PATH
VAL_PGD_IMAGES_PATH = Path(VAL_PGD_PATH, 'images')
VAL_PGD_POSTHOC_PATH = Path(BASE_POSTHOC_PATH, 'pgd/validation/')
VAL_PGD_POSTHOC_LABEL_PATH = Path(VAL_PGD_POSTHOC_PATH, 'labels.csv')

# CW
TR_CW_PATH = Path(MAX_PATH, 'adversarial_data/cw/train')
TR_CW_LABEL_PATH = TR_ORI_LABEL_PATH
TR_CW_IMAGES_PATH = Path(TR_CW_PATH, 'images')
TR_CW_POSTHOC_PATH = Path(BASE_POSTHOC_PATH, 'cw/train/')
TR_CW_POSTHOC_LABEL_PATH = Path(TR_CW_POSTHOC_PATH, 'labels.csv')

VAL_CW_PATH = Path(MAX_PATH, 'adversarial_data/cw/validation')
VAL_CW_LABEL_PATH = VAL_ORI_LABEL_PATH
VAL_CW_IMAGES_PATH = Path(VAL_CW_PATH, 'images')
VAL_CW_POSTHOC_PATH = Path(BASE_POSTHOC_PATH, 'cw/validation/')
VAL_CW_POSTHOC_LABEL_PATH = Path(VAL_CW_POSTHOC_PATH, 'labels.csv')

# FGSM
TR_FGSM_PATH = Path(MAX_PATH, 'adversarial_data/fgsm_06/train')
TR_FGSM_LABEL_PATH = TR_ORI_LABEL_PATH
TR_FGSM_IMAGES_PATH = Path(TR_FGSM_PATH, 'images')
TR_FGSM_POSTHOC_PATH = Path(BASE_POSTHOC_PATH, 'fgsm/train/')
TR_FGSM_POSTHOC_LABEL_PATH = Path(TR_FGSM_POSTHOC_PATH, 'labels.csv')

VAL_FGSM_PATH = Path(MAX_PATH, 'adversarial_data/fgsm_06/validation')
VAL_FGSM_LABEL_PATH = VAL_ORI_LABEL_PATH
VAL_FGSM_IMAGES_PATH = Path(VAL_FGSM_PATH, 'images')
VAL_FGSM_POSTHOC_PATH = Path(BASE_POSTHOC_PATH, 'fgsm/validation/')
VAL_FGSM_POSTHOC_LABEL_PATH = Path(VAL_FGSM_POSTHOC_PATH, 'labels.csv')

In [24]:
# If CLASS_SUBSET is specified, INDEX_SUBSET will be ignored. Set CLASS_SUBSET=None if you want to use indexes.
# INDEX_SUBSET = get_random_indexes(number_of_images = 50000, n_samples=1000)
# CLASS_SUBSET = get_random_classes(number_of_classes = 25, min_rand_class = 1, max_rand_class = 1001)


CLASS_SUBSET = np.load(CLASS_SUBSET_PATH)
# CLASS_SUBSET = CLASS_SUBSET[:3]
INDEX_SUBSET = None
NUM_WORKERS= 0
PIN_MEMORY=True

BATCH_SIZE = 64

DEVICE = 'cuda'

In [11]:
datasets_paths = {
            'cw':{ 
                'b':{
                    'train':{
                        'label':TR_ORI_LABEL_PATH,
                        'images':TR_CW_IMAGES_PATH
                    },
                    'val':
                    {
                        'label':VAL_ORI_LABEL_PATH,
                        'images':VAL_CW_IMAGES_PATH
                    }
                }
            },
            'ori':{
                'b':{
                    'train':{
                        'label':TR_ORI_LABEL_PATH,
                        'images':TR_ORI_IMAGES_PATH
                    },
                    'val':{
                        'label':VAL_ORI_LABEL_PATH,
                        'images':VAL_ORI_IMAGES_PATH
                    }
                }
            },
            'dn':{
                'b':{
                    'train':{
                        'label':TR_CW_PATH,
                        'images':None
                    },
                    'val':
                    {
                        'label':VAL_ORI_LABEL_PATH,
                        'images':DN_IMAGES_PATH
                    }
                 }
            },
            'fgsm_06':{
                'b':{
                    'train':{
                        'label':TR_ORI_LABEL_PATH,
                        'images':TR_FGSM_IMAGES_PATH
                    },
                    'val':
                    {
                        'label':VAL_ORI_LABEL_PATH,
                        'images':VAL_FGSM_IMAGES_PATH
                    }
                 }
            },
            'pgd_06':{
                'b':{
                    'train':{
                        'label':TR_ORI_LABEL_PATH,
                        'images':TR_PGD_IMAGES_PATH
                    },
                    'val':
                    {
                        'label':VAL_ORI_LABEL_PATH,
                        'images':VAL_PGD_IMAGES_PATH
                    }
                }
            }
}

In [12]:
datasets = ['cw', 'ori', 'pgd_06', 'fgsm_06']
for ds in datasets:
    ds_dict = datasets_paths[ds]
    ds_dict['p'] = {
        'train': { 
            'images': Path(BASE_POSTHOC_PATH, ds, 'train', 'images'),
            'label': Path(BASE_POSTHOC_PATH, ds, 'train', 'labels.csv')
        },
        'val': { 
            'images': Path(BASE_POSTHOC_PATH, ds, 'val', 'images'),
            'label': Path(BASE_POSTHOC_PATH, ds, 'val', 'labels.csv')
        }
    }
    

In [13]:
datasets_paths['cw']['b']

{'train': {'label': PosixPath('/cluster/scratch/thobauma/dl_data/ori_data/train/correct_labels.txt'),
  'images': PosixPath('/cluster/scratch/mmathys/dl_data/adversarial_data/cw/train/images')},
 'val': {'label': PosixPath('/cluster/scratch/thobauma/dl_data/ori_data/validation/correct_labels.txt'),
  'images': PosixPath('/cluster/scratch/mmathys/dl_data/adversarial_data/cw/validation/images')}}

# Import DINO
Official repo: https://github.com/facebookresearch/dino

In [14]:
class LinearClassifier(nn.Module):
    """Linear layer to train on top of frozen features"""
    def __init__(self, dim, num_labels=1000):
        super(LinearClassifier, self).__init__()
        self.num_labels = num_labels
        self.linear = nn.Linear(dim, num_labels)
        self.linear.weight.data.normal_(mean=0.0, std=0.01)
        self.linear.bias.data.zero_()

    def forward(self, x):
        # flatten
        x = x.view(x.size(0), -1)

        # linear layer
        return self.linear(x)

In [20]:
model, linear_classifier = get_dino()

linear_classifier = LinearClassifier(linear_classifier.linear.in_features, 
                         num_labels=len(CLASS_SUBSET))

linear_classifier.load_state_dict(torch.load("/cluster/scratch/mmathys/dl_data/adversarial_data/adv_classifiers/25_classes" + "/" + "clean.pt"))
linear_classifier.to(DEVICE)

Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.
Since no pretrained weights have been provided, we load the reference pretrained DINO weights.
Model vit_small built.
Embed dim 1536
We load the reference pretrained linear weights.


LinearClassifier(
  (linear): Linear(in_features=1536, out_features=25, bias=True)
)

In [21]:
from sklearn import preprocessing

label_encoder = preprocessing.LabelEncoder()
label_encoder.fit([i for i in CLASS_SUBSET])

LabelEncoder()

# Forward Pass

In [22]:
logger_dict = {}
for ds in ['ori', 'pgd_06', 'fgsm_06', 'cw']:
    ds_b = datasets_paths[ds]['b']
    ds_p = datasets_paths[ds]['p']
    logger_dict[ds] = {}
    transform = ONLY_NORMALIZE_TRANSFORM
    if ds == 'ori':
        transform = ORIGINAL_TRANSFORM
    for tv in ['train', 'val']:
        print(f'''images: {ds_b[tv]['images']}\nlabel: {ds_b[tv]['label']}\npred: {ds_p[tv]['label']}''')
        data_set = AdvTrainingImageDataset(ds_b[tv]['images'], ds_b[tv]['label'], transform, class_subset=CLASS_SUBSET, label_encoder=label_encoder)
        data_loader = DataLoader(data_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle=False)
        print(f'''{ds}: {tv} {len(data_set)}''')
        logger_dict[ds][tv] = validate_network(model, linear_classifier, data_loader, adversarial_attack=None, tensor_dir=ds_p[tv]['images'], path_predictions=ds_p[tv]['label'])

images: /cluster/scratch/thobauma/dl_data/ori_data/train/images
label: /cluster/scratch/thobauma/dl_data/ori_data/train/correct_labels.txt
pred: /cluster/scratch/mmathys/dl_data/posthoc-fixed-labels/ori/train/labels.csv
ori: train 32181
saving predictions to: /cluster/scratch/mmathys/dl_data/posthoc-fixed-labels/ori/train/labels.csv
Test:  [  0/503]  eta: 0:08:51  loss: 0.018451 (0.018451)  acc1: 98.437500 (98.437500)  acc5: 100.000000 (100.000000)  time: 1.057054  data: 0.751724  max mem: 722
Test:  [ 20/503]  eta: 0:08:47  loss: 0.027935 (0.045099)  acc1: 98.437500 (98.586310)  acc5: 100.000000 (99.925595)  time: 1.093595  data: 0.839092  max mem: 796
Test:  [ 40/503]  eta: 0:08:18  loss: 0.032479 (0.040204)  acc1: 98.437500 (98.818598)  acc5: 100.000000 (99.961890)  time: 1.061497  data: 0.805747  max mem: 796
Test:  [ 60/503]  eta: 0:07:57  loss: 0.025912 (0.038865)  acc1: 98.437500 (98.847336)  acc5: 100.000000 (99.948770)  time: 1.076199  data: 0.820069  max mem: 796
Test:  [ 80/