In [None]:
import os
from tqdm.auto import tqdm
import pathlib
import numpy as np

import torch
from torchvision import models
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder

import matplotlib.pyplot as plt
plt.style.use('ggplot')
plt.style.use('seaborn-v0_8-colorblind')



import data_utils
from autoattack import AutoAttack
# (https://github.com/fra31/auto-attack/blob/master/autoattack/autoattack.py)


In [None]:
##load model and defaalt image transform
model, preprocess0 = data_utils.get_target_model(target_name='resnet50', device='cuda', weights='default')
model_robust, preprocess0 = data_utils.get_target_model(target_name='resnet50robust', device='cuda', weights='default')

preprocess0

In [None]:
model_robust

In [None]:
##transforms
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

resize_crop_totensor = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

totensor_normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=imagenet_mean, 
        std=imagenet_std,
    )
])

normalize = transforms.Normalize(mean=imagenet_mean, std=imagenet_std)

def denormalize(x):
    mean = torch.tensor(imagenet_mean).view(3,1,1)
    std = torch.tensor(imagenet_std).view(3,1,1)
    x = x * std + mean # de-normalize
    return x

to_tensor = transforms.ToTensor()
to_pil = transforms.ToPILImage()


## model forward pass
def forward_pass(model, img):
    img = normalize(img)
    return model(img)

In [None]:
# dataset = ImageFolder(
# #     '/home/lim38/dataset/imagenet-val-attack/', 
#     '/home/lim38/dataset/imagenet-val/', 
# #     loader=lambda path: pathlib.Path(path).name
# )

In [None]:
##data, loader
# dataset = data_utils.get_data('imagenet_val', resize_crop_totensor)
dataset = data_utils.get_data('imagenet_val_attack', preprocess=to_tensor)
loader = DataLoader(dataset, batch_size=50)

In [None]:
with torch.no_grad():
    for mi, m in enumerate([model, model_robust]):
        correct = 0
        total = 0
        for i, [imgs, targets] in enumerate(tqdm(loader)):
            imgs, targets = imgs.cuda(), targets.cuda()
            logits = forward_pass(m, imgs)
            preds = logits.argmax(dim=1)
#             print(preds)
            correct += (preds == targets).sum().item()
            total += int(targets.shape[0])
            print(f'model{mi} accuracy: {100*correct/total:.2f}% ({correct}/{total})')
#             if i>50:
#                 break

In [None]:
regular resnet50 accuracy: 31.79% (15895/50000)
robust resnet50 accuracy: 69.49% (34745/50000)

    
