# Install, Paths and Parameters

In [37]:
import os
from pathlib import Path
import getpass
import numpy as np
import time
import torch
from torch import nn
from tqdm import tqdm
import random
import sys

# allow imports when running script from within project dir
[sys.path.append(i) for i in ['.', '..']]

# local
from src.helpers.helpers import get_random_indexes, get_random_classes
from src.model.dino_model import get_dino, ViTWrapper
from src.model.data import create_loader
from src.model.eval import validate_network

# seed
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

username = getpass.getuser()
DATA_PATH = Path('/','cluster', 'scratch', 'thobauma', 'dl_data')

ORI_PATH = Path(DATA_PATH, 'ori_data/validation')
ORI_LABEL_PATH = Path(ORI_PATH,'correct_labels.txt')
ORI_IMAGES_PATH = Path(ORI_PATH,'images')

DN_PATH = Path(DATA_PATH, 'adversarial_data')
DN_LABEL_PATH = Path(ORI_PATH,'correct_labels.txt')
DN_IMAGES_PATH = Path(DN_PATH, 'pgd_03/validation/images/')

In [20]:
# If CLASS_SUBSET is specified, INDEX_SUBSET will be ignored. Set CLASS_SUBSET=None if you want to use indexes.
INDEX_SUBSET = get_random_indexes()
CLASS_SUBSET = get_random_classes(number_of_classes=10)

BATCH_SIZE = 50

DEVICE = 'cuda'

In [21]:
#!python $HOME/deeplearning/setup/collect_env.py

# Import DINO
Official repo: https://github.com/facebookresearch/dino

In [22]:
model, linear_classifier = get_dino()

Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.
Since no pretrained weights have been provided, we load the reference pretrained DINO weights.
Model vit_small built.
Embed dim 1536
We load the reference pretrained linear weights from dino_deitsmall16_pretrain/dino_deitsmall16_linearweights.pth.


# Load data

In [29]:
org_loader = create_loader(ORI_IMAGES_PATH, ORI_LABEL_PATH, INDEX_SUBSET, CLASS_SUBSET, BATCH_SIZE)

In [38]:
adv_loader = create_loader(DN_IMAGES_PATH, DN_LABEL_PATH, INDEX_SUBSET, CLASS_SUBSET, BATCH_SIZE)

## Wrap model

In [27]:
model_wrap = ViTWrapper(model, linear_classifier, DEVICE, n_last_blocks=4, avgpool_patchtokens=False)
model_wrap= model_wrap.to(DEVICE)

# Generate input to linear layer

In [39]:
result = None

with torch.no_grad():
    for images, labels in tqdm(adv_loader):
        x = model_wrap.transform(images).to(DEVICE)

        # forward
        intermediate_output = model_wrap.vits16.get_intermediate_layers(x, model_wrap.n_last_blocks)
        output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
        if model_wrap.avgpool_patchtokens:
            output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
            output = output.reshape(output.shape[0], -1)
        
        if result is None:
            result = output
        else:
            result = torch.cat([result, output], 0)

100%|██████████| 10/10 [00:04<00:00,  2.22it/s]


In [None]:
x = self.transform(x)

# forward
intermediate_output = self.vits16.get_intermediate_layers(x, self.n_last_blocks)
output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
if self.avgpool_patchtokens:
    output = torch.cat((output.unsqueeze(-1), torch.mean(intermediate_output[-1][:, 1:], dim=1).unsqueeze(-1)), dim=-1)
    output = output.reshape(output.shape[0], -1)