In [None]:
#from torchvision.models import resnet18
import matplotlib.pyplot as plt
import torch
from backbone import ERM
from adapt_methods import SAR, T3A, SHOT
from pytorch_grad_cam import ScoreCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from torchvision.io import read_image
from torchvision.transforms.functional import normalize, resize, to_pil_image
from PIL import Image
import numpy as np
from scipy.special import softmax


# transform necessary for ViT (+ DeiT) based models
def reshape_transform(tensor, height=14, width=14):
    result = tensor[:, 1 :  , :].reshape(tensor.size(0),
        height, width, tensor.size(2))
    result = result.transpose(2, 3).transpose(1, 2)
    return result


# input model and image and visualize score-cam
def visualize_scorecam(model, model_type, impath):
    target_layers = [model.network[0].network.layer4[-1]]
    if model_type == "deit":
        target_layers = [model.network[0].network.blocks[-1].norm1]

    image = np.array(Image.open(impath).resize((224,224)))
    rgb_img = np.float32(image) / 255
    input_tensor = preprocess_image(rgb_img,
                                    mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])

    cam = ScoreCAM(model=model,target_layers=target_layers)
    if model_type == "deit":
        cam = ScoreCAM(model=model,
                        target_layers=target_layers, 
                        reshape_transform=reshape_transform)
    
    # select class
    targets = [ClassifierOutputTarget(1)]
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets, aug_smooth=False, eigen_smooth=False)
    grayscale_cam = grayscale_cam[0, :]
    visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
    model_outputs = cam.outputs

    # print probabilities
    print(softmax(model_outputs.cpu().detach().numpy()))

    # plot CAM
    plt.axis('off')
    plt.imshow(visualization)

In [None]:
# path to dataset
data_root = '/content/ISIC2019_train'
# path to image to use for CAM visualization
impath = data_root + '/hair/ben/ISIC_0028330.jpg'

# load base model
model = ERM('resnet50',1e-4).to('cuda')
model.load_state_dict(torch.load('./models/resnet50hair.pth'))
model.eval()

# select TTA method
adapt_methods = {'SAR':SAR, 'T3A':T3A, 'SHOT':SHOT}
adapt_method = 'SAR'
adapt_model = adapt_methods[adapt_method](model)

# run TTA on entire test set
acc,pr,rc,f1 = evaluate(adapt_model, dataloaders[test_domain], True)
print(f'{adapt_method} Acc.: {acc.item():.3f} + Rc: {rc:.3f} + Pr.: {pr:.3f} + f1: {f1:.3f}')

# visualize score-CAM for base model 
visualize_scorecam(model, "resnet50", impath)

# visualize score-CAM for adapted model 
# ! pass the model of the TTA object
visualize_scorecam(adapt_model.model, "resnet50", impath)