In [None]:
import os
from tqdm.auto import tqdm

import numpy as np
import torch
from torchvision import models
from torch.utils.data import DataLoader
from torchvision import transforms

import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.style.use('seaborn-v0_8-colorblind')


import data_utils
from autoattack import AutoAttack
# (https://github.com/fra31/auto-attack/blob/master/autoattack/autoattack.py)


In [None]:
##load model and defaalt image transform
model, preprocess0 = data_utils.get_target_model(target_name='resnet50', device='cuda', weights='default')
preprocess0

In [None]:
##transforms
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

to_tensor = transforms.ToTensor()

normalize = transforms.Normalize(mean=preprocess0.mean, std=preprocess0.std)

def denormalize(x):
    std = torch.tensor(preprocess0.std).view(3,1,1)
    mean = torch.tensor(preprocess0.mean).view(3,1,1)
    x = x * std + mean # de-normalize
    return x

to_pil = transforms.ToPILImage()


## model forward pass
def forward_pass(img):
    img = normalize(img)
    return model(img)

In [None]:
from torchvision.datasets import ImageFolder
import pathlib

In [None]:
# dataset = ImageFolder(
# #     '/home/lim38/dataset/imagenet-val-attack/', 
#     '/home/lim38/dataset/imagenet-val/', 
# #     loader=lambda path: pathlib.Path(path).name
# )

In [None]:
##data, loader
# dataset = data_utils.get_data('imagenet_val', preprocess)
dataset = data_utils.get_data('imagenet_val_attack', preprocess=to_tensor)
loader = DataLoader(dataset, batch_size=64)

In [None]:
correct = 0
total = 0

with torch.no_grad():
    for imgs, targets in tqdm(loader):
        imgs, targets = imgs.cuda(), targets.cuda()
        logits = forward_pass(imgs)
        preds = logits.argmax(dim=1)
        correct += (preds == targets).sum().item()
        total += int(targets.shape[0])
        print(f'{100*correct/total:.2f}%, {correct}/{total}')