In [None]:
import os

import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as TF
import torchattacks
import torchvision.transforms.functional as VF
from torchvision.models import ResNet18_Weights, resnet18

from utils.utils import ModelWithNormalization, freeze

# setup model
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
model = model.eval()
freeze(model)
model = ModelWithNormalization(model, [0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# ImageNet labels
label_path = os.path.join('data', 'imagenet_labels.txt')
with open(label_path, 'r') as f:
    labels = f.readlines()

# read img
p = os.path.join('data', 'plate.JPG')
clean_img = cv2.imread(p)
clean_img = cv2.cvtColor(clean_img, cv2.COLOR_BGR2RGB)

# preprocess img for torch model
clean_img = clean_img.transpose(2, 0, 1)
clean_img = torch.from_numpy(clean_img)
img_size = 400
clean_img = VF.resize(clean_img, [img_size, img_size], antialias=True)
clean_img = clean_img / 255
clean_img = clean_img.view(1, 3, img_size, img_size)

# prediction for clean img
clean_probs = TF.softmax(model(clean_img), dim=1)
sorted_clean_probs, sorted_clean_indices = clean_probs[0].sort(descending=True)
sorted_clean_probs *= 100
predicted_index_for_clean_img = sorted_clean_indices[0]
predicted_label_for_clean_img = labels[predicted_index_for_clean_img]

# attack
atk = torchattacks.PGD(model, 8/255, steps=100)
adv_img = atk(clean_img, torch.tensor([predicted_index_for_clean_img]))

# prediction for adv img
adv_probs = TF.softmax(model(adv_img), dim=1)
sorted_adv_probs, sorted_adv_indices = adv_probs[0].sort(descending=True)
sorted_adv_probs *= 100
predicted_index_for_adv_img = sorted_adv_indices[0]
predicted_label_for_adv_img = labels[predicted_index_for_adv_img]

# show imgs
plt.axis('off')
plt.imshow(clean_img[0].permute(1, 2, 0))
plt.show()
plt.axis('off')
plt.imshow(adv_img[0].permute(1, 2, 0))
plt.show()