# Posthoc Classifier

## Install, Paths and Parameters

In [1]:
# This extension reloads external Python files
%load_ext autoreload
import os
from pathlib import Path
import getpass
import numpy as np
import time
import torch
from torch import nn
from tqdm import tqdm
import random
import sys

# allow imports when running script from within project dir
[sys.path.append(i) for i in ['.', '..']]

# local
from src.helpers.helpers import get_random_indexes, get_random_classes
from src.model.dino_model import get_dino
from src.model.data import create_loader, adv_dataset
from src.model.eval import validate_network

# seed
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

username = getpass.getuser()
DATA_PATH = Path('/','cluster', 'scratch', 'thobauma', 'dl_data')

DN_PATH = Path(DATA_PATH, 'damageNet')
DN_LABEL_PATH = Path(DN_PATH, 'val_damagenet.txt')
DN_IMAGES_PATH = Path(DN_PATH, 'images')

ORI_PATH = Path(DATA_PATH, 'ori_data')
ORI_LABEL_PATH = Path(ORI_PATH,'correct_labels.txt')
ORI_IMAGES_PATH = Path(ORI_PATH,'images')

In [2]:
# If CLASS_SUBSET is specified, INDEX_SUBSET will be ignored. Set CLASS_SUBSET=None if you want to use indexes.
INDEX_SUBSET = get_random_indexes()
CLASS_SUBSET = get_random_classes()

BATCH_SIZE = 1

DEVICE = 'cuda'

In [3]:
#!python $HOME/deeplearning/setup/collect_env.py

## Import DINO
Official repo: https://github.com/facebookresearch/dino

In [4]:
model, linear_classifier = get_dino()

Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.
Since no pretrained weights have been provided, we load the reference pretrained DINO weights.
Model vit_small built.
We load the reference pretrained linear weights.


## Load data

In [5]:
ori_loader = create_loader(ORI_IMAGES_PATH, ORI_LABEL_PATH, INDEX_SUBSET, CLASS_SUBSET, BATCH_SIZE)

In [6]:
dn_loader = create_loader(DN_IMAGES_PATH, DN_LABEL_PATH, INDEX_SUBSET, CLASS_SUBSET, BATCH_SIZE)

## Test model linear classifier inputs

In [7]:
# Performs a forward pass given a sample `inp` and a classifier.
def generate_model_output(inp, n=4):
    inp = inp.to("cuda")
    inp = inp.unsqueeze(dim=0)
    print(inp.shape)
    intermediate_output = model.get_intermediate_layers(inp, n)
    return torch.cat([x[:, 0] for x in intermediate_output], dim=-1)

In [8]:
import sys

# output first 5 tuples generated.
samples = adv_dataset(ori_loader, dn_loader, model, linear_classifier)
total=1000

for i in range(total):
  if i+1 % 10000 == 0:
    drive.flush_and_unmount()
    drive.mount('/content/drive')

  num, org, adv = next(samples)
  sys.stdout.write(f"\rtuple {i+1}/{total}")
  sys.stdout.flush()

  # self attention
  # add one dimension to input image (get_last_selfattention expects it)

  org_out = generate_model_output(org)
  adv_out = generate_model_output(adv)
  
  #torch.save(org_out, ORG_OUT_PATH + f"{num}.pt")
  #torch.save(adv_out, ADV_OUT_PATH + f"{num}.pt")

  # folders for org and adv, filename: {org, adv}/<original number>.pt (leave out prefix)

org_name: ILSVRC2012_val_00000007.JPEG, adv_name: ILSVRC2012_val_00000007.png
tuple 1/1000torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
org_name: ILSVRC2012_val_00000011.JPEG, adv_name: ILSVRC2012_val_00000011.png
org_name: ILSVRC2012_val_00000019.JPEG, adv_name: ILSVRC2012_val_00000019.png
org_name: ILSVRC2012_val_00000021.JPEG, adv_name: ILSVRC2012_val_00000021.png
tuple 2/1000torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
org_name: ILSVRC2012_val_00000093.JPEG, adv_name: ILSVRC2012_val_00000093.png
tuple 3/1000torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
org_name: ILSVRC2012_val_00000098.JPEG, adv_name: ILSVRC2012_val_00000098.png
tuple 4/1000torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
org_name: ILSVRC2012_val_00000104.JPEG, adv_name: ILSVRC2012_val_00000104.png
org_name: ILSVRC2012_val_00000114.JPEG, adv_name: ILSVRC2012_val_00000114.png
tuple 5/1000torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
org_name: ILSVRC2012_val_0

AssertionError: Numbers are not matching: org=ILSVRC2012_val_00000117.JPEG, adv=ILSVRC2012_val_00000148.png