# Fetching code from repo

In [None]:
import os

root_name = "trodo"

if not (os.path.exists(root_name) or os.getcwd().lower().endswith(root_name)):
    !git clone https://github.com/Allliance/trodo

if not os.getcwd().lower().endswith(root_name):
    %cd {root_name}
!git pull

import src

# Configurations

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import gc
import numpy as np

mapping = "All to One"
dataset = ['cifar10', 'mnist', 'gtsrb', 'cifar100', 'pubfig'][0]

batch_size = 8 if dataset == 'pubfig' else 256

arch = "resnet"

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

In [None]:
# Loading constants
from src.constants import NORM_MEAN, NORM_STD
from src.constants import num_classes as num_classes_dict

# Loading Model architecture
if arch == 'preact':
    from src.models.loaders import load_preact as model_loader
elif arch == 'resnet':
    from src.models.loaders import load_resnet as model_loader
else:
    raise NotImplementedError("This architecture is not supported")

# Preparations

## Loading Model

In [None]:
!mkdir models
%cd models
!pip install gdown
import gdown
gdown.download_folder("https://drive.google.com/drive/folders/1zocvSNKke4XbXyfn9-vbxzFH_xOGy0Qs")
!tar -xf A2O/clean.tar
!mv content/clean A2O/clean
!rm -r content
!rm A2O/clean.tar
!tar -xf A2O/trojaned.tar
!mv content/trojaned A2O/trojaned
!rm -r content
!rm A2O/trojaned.tar
!ls A2O
%cd ..

In [None]:
from src.modelset import ModelDataset

num_classes = num_classes_dict[dataset]

final_model_loader = lambda x, meta_data: model_loader(x,
                                                       num_classes=num_classes,
                                                       mean=NORM_MEAN[dataset],
                                                       std=NORM_STD[dataset],
                                                       normalize=True,
                                                       meta_data=meta_data)

clean_root = './models/A2O/clean'
trojaned_root = './models/A2O/trojaned'

test_modelset = ModelDataset(clean_root,
                             trojaned_root,
                             final_model_loader
                             )

print("No. clean models in test set:", len([_ for m in test_modelset.models_data if m['label'] == 0]))
print("No. trojaned models in test set:", len([_ for m in test_modelset.models_data if m['label'] == 0]))

# Experiments

## Validation

In [None]:
from src.data.loaders import get_near_ood_loader
from src.visualization import visualize_samples

def get_dataloader():
    dataloader = get_near_ood_loader(source_dataset=dataset, batch_size=batch_size)
    # print("Size of dataset:", len(dataloader.dataset))
    return dataloader

dataloader = get_dataloader()
print(len(dataloader.dataset))
# visualize_samples(dataloader, 1)

# Testing

In [None]:
from src.evaluate import evaluate_modelset, mean_id_score_diff

evaluate_modelset(test_modelset,
                  signature_function=mean_id_score_diff,
                  signature_function_kwargs={
                    'eps': 2/255,
                    'device': device,
                    'verbose': True,
                  },
                  get_dataloader_func=get_dataloader,
                  progress=False,)