In [1]:
# This extension reloads external Python files
from pathlib import Path
from collections import defaultdict

import numpy as np
import random
import sys
import pandas as pd

import torch
from torch.utils.data import DataLoader
from torch import nn

# allow imports when running script from within project dir
[sys.path.append(i) for i in ['.', '..']]

# local
from src.model.dino_model import get_dino
from src.model.train import *
from src.model.data import *
from src.helpers.helpers import create_paths

# seed
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

DATA_PATH = Path('/','cluster', 'scratch', 'thobauma', 'dl_data')
MAX_PATH = Path('/','cluster', 'scratch', 'mmathys', 'dl_data')

BASE_ADV_PATH = Path(MAX_PATH, 'adversarial_data_tensors')
BASE_POSTHOC_PATH = Path(MAX_PATH, 'posthoc_tensors')
POSTHOC_MODELS_PATH = Path(MAX_PATH, 'posthoc_models')

ORI_PATH = Path(DATA_PATH, 'ori')
CLASS_SUBSET_PATH = Path(ORI_PATH, 'class_subset.npy')
CLASS_SUBSET = np.load(CLASS_SUBSET_PATH)

ADV_DATASETS = ['cw', 'fgsm_06', 'pgd_03']
DATASETS = ['ori', *ADV_DATASETS]
print(DATASETS)

['ori', 'cw', 'fgsm_06', 'pgd_03']


In [2]:
DATA_PATHS = create_paths(data_name='ori',
                 datasets_paths=None,  
                 initial_base_path=DATA_PATH, 
                 posthoc_base_path=BASE_POSTHOC_PATH, 
                 train_str='train', 
                 val_str='validation')
for adv_ds in ADV_DATASETS:
    DATA_PATHS = create_paths(data_name=adv_ds,
                 datasets_paths=DATA_PATHS,  
                 initial_base_path=BASE_ADV_PATH, 
                 posthoc_base_path=BASE_POSTHOC_PATH, 
                 train_str='train', 
                 val_str='validation')

In [3]:
INDEX_SUBSET = None
NUM_WORKERS= 0
PIN_MEMORY=True

BATCH_SIZE = 256
EPOCHS= 3
DEVICE = 'cuda'

In [4]:
def prepare_data_df(adv_datasets, dataset_paths):
    train_dfs = {}
    for ds in adv_datasets:
        train_dfs[ds] = pd.read_csv(Path(BASE_POSTHOC_PATH, ds, 'train', 'labels_merged.csv'))

    val_dfs = {}
    for ds in adv_datasets:
        val_dfs[ds] = pd.read_csv(Path(BASE_POSTHOC_PATH, ds, 'validation', 'labels_merged.csv'))

    # get adversarial tuples
    for name, df in train_dfs.items():
        df=df[df['true_labels']==df['ori_pred']]
        df=df[df['true_labels']!=df[name+'_pred']]
        df =df[['file', 'true_labels', 'ori_pred', name+'_pred']]
        train_dfs[name]=df
    return train_dfs, val_dfs

In [5]:
train_dfs, val_dfs = prepare_data_df(ADV_DATASETS, DATA_PATHS)

In [6]:
# Linear Binary Classifier
class LinearBC(nn.Module):
    def __init__(self, input_shape):
        self.num_labels = 2
        super(LinearBC,self).__init__()
        self.fc1 = nn.Linear(input_shape,2)

    def forward(self, x):
        x = self.fc1(x)
        return x

In [7]:
dataset_paths = DATA_PATHS
logger_dict = defaultdict(dict)

POSTHOC_MATRIX_PATH = Path(MAX_PATH, 'posthoc_matrix')
POSTHOC_MATRIX_PATH.mkdir(parents=True, exist_ok=True)

for adv_classifier in ADV_DATASETS:
    print("#"*50 + f''' forwardpass on {adv_classifier} classifier ''' + "#"*50)
    
    log_dir = Path(POSTHOC_MODELS_PATH, adv_classifier)
    classifier = LinearBC(1536)
    classifier.cuda()

    to_restore={'epoch':3}

    utils.restart_from_checkpoint(
        Path(log_dir, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=classifier
    )
    
    for adv_data in ADV_DATASETS:
        print("\n"+"-"*50 + f''' dataset {adv_data} ''' + "-"*50)
        ori_validation = dataset_paths['ori']['posthoc']['validation']['images']
        adv_validation = dataset_paths[adv_data]['posthoc']['validation']['images']
        print(f'''original images: {ori_validation}''')
        print(f'''adversarial images: {adv_validation}''')
        val_set = PosthocTrainDataset(ori_validation, adv_validation, val_dfs[adv_data])
        val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, shuffle=False)
        print(f'''val samples: {len(val_set)} \n''')
        logger_dict[adv_classifier][adv_data] =  validate_network(model=None, 
                                                   classifier=classifier, 
                                                   validation_loader=val_loader, 
                                                   criterion=nn.CrossEntropyLoss(), 
                                                   tensor_dir=None,
                                                   adversarial_attack=None,  
                                                   path_predictions=Path(POSTHOC_MATRIX_PATH, 'c_'+adv_classifier+'_d_'+adv_data+'.csv'),
                                                   show_image=False,
                                                   log_interval=10)

################################################## forwardpass on cw classifier ##################################################
Found checkpoint at /cluster/scratch/mmathys/dl_data/posthoc_models/cw/checkpoint.pth.tar
=> loaded 'state_dict' from checkpoint '/cluster/scratch/mmathys/dl_data/posthoc_models/cw/checkpoint.pth.tar' with msg <All keys matched successfully>

-------------------------------------------------- dataset cw --------------------------------------------------
original images: /cluster/scratch/mmathys/dl_data/posthoc_tensors/ori/validation/images
adversarial images: /cluster/scratch/mmathys/dl_data/posthoc_tensors/cw/validation/images
val samples: 2500 

saving predictions to: /cluster/scratch/mmathys/dl_data/posthoc_matrix/c_cw_d_cw.csv
Test:  [ 0/10]  eta: 0:00:02  loss: 0.299046 (0.299046)  acc1: 87.500000 (87.500000)  time: 0.287987  data: 0.282263  max mem: 2
Test:  [ 9/10]  eta: 0:00:00  loss: 0.363664 (0.384657)  acc1: 82.653061 (83.200000)  time: 0.268608 