In [None]:
import os
import sys
import glob

import numpy as np
import pandas as pd
from PIL import Image
import cv2
import torch

# GradCAM
from pytorch_grad_cam import (
    GradCAM,
    ScoreCAM,
    GradCAMPlusPlus,
    AblationCAM,
    XGradCAM,
    EigenCAM,
    EigenGradCAM,
    LayerCAM,
    FullGrad,
)
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from pytorch_grad_cam.ablation_layer import AblationLayerVit

# Change the working directory to 'experiments'
current_dir = os.getcwd()
folder2_path = os.path.abspath(os.path.join(current_dir, '..', 'experiments'))
sys.path.append(folder2_path)

# Module
import src.classification as lc

In [None]:
def get_all_layers(model):
    layers = []
    for name, module in model.named_modules():
        if name == '':
            continue
        layers.append(module)
    return layers

def reshape_transform(tensor, height=16, width=16):
    result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
    result = result.transpose(2, 3).transpose(1, 2)
    return result

def cropImg(df, imagePath):
    image = Image.open(imagePath)
    
    # Extract bbox
    bbx_coords = df[df.img_fName == imagePath].iloc[0][['bbx_xtl', 'bbx_ytl', 'bbx_xbr', 'bbx_ybr']]
    bbx_xtl, bbx_ytl, bbx_xbr, bbx_ybr = bbx_coords

    # Crop the image
    return image.crop((bbx_xtl, bbx_ytl, bbx_xbr, bbx_ybr))

In [3]:
class Args:
    use_cuda = torch.cuda.is_available()
    image_list = []
    aug_smooth = False
    eigen_smooth = False
    method = 'gradcam'

args = Args()

methods = {
    "gradcam": GradCAM,
    "scorecam": ScoreCAM,
    "gradcam++": GradCAMPlusPlus,
    "ablationcam": AblationCAM,
    "xgradcam": XGradCAM,
    "eigencam": EigenCAM,
    "eigengradcam": EigenGradCAM,
    "layercam": LayerCAM,
    "fullgrad": FullGrad
}

In [None]:
anno_path = '../data_round_2/mosAlert_new_annotation_2/test_annotation_2.csv'
df = pd.read_csv(anno_path).sort_values(by='img_fName').reset_index(drop=True)

checkpoint_path = './modelCheckpoint/epoch=6-val_loss=0.5844640731811523-val_f1_score=0.9127286076545715-val_multiclass_accuracy=0.9220854043960571.ckpt'
modelCLIP = lc.MosquitoClassifier.load_from_checkpoint(checkpoint_path, map_location=torch.device('cuda'))
target_layers = get_all_layers(modelCLIP)

In [None]:
# Choosing method
args.image_list = df.img_fName.tolist()
args.method = 'gradcam'

if args.method not in list(methods.keys()):
    raise Exception(f"method should be one of {list(methods.keys())}")

In [None]:
for image_path in args.image_list:

    imgName = image_path.split('/')[-1].split('.')[0]
    label = df[df.img_fName == image_path].class_label.values[0]
    
    # Read the image
    newImg = cropImg(df, image_path).convert('RGB')
    rgb_img = np.array(newImg)
    rgb_img = cv2.resize(rgb_img, (224, 224))
    rgb_img = np.float32(rgb_img) / 255
    input_tensor = preprocess_image(rgb_img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])


    if args.method == "ablationcam":
        cam = methods[args.method](model=modelCLIP, target_layers=target_layers, reshape_transform=reshape_transform, ablation_layer=AblationLayerVit())
    else:
        cam = methods[args.method](model=modelCLIP, target_layers=target_layers, reshape_transform=reshape_transform)
    
    # CAM
    targets = None
    cam.batch_size = 32
    grayscale_cam = cam(input_tensor=input_tensor, targets=targets, eigen_smooth=args.eigen_smooth, aug_smooth=args.aug_smooth)
    grayscale_cam = grayscale_cam[0, :]

    # Visualize CAM
    cam_image = show_cam_on_image(rgb_img, grayscale_cam)


    # Output
    outputPath = 'CLIPgradcamResult'
    outputClass = str(label)
    outputImgName = f'{imgName}_{args.method}_cam.jpg'
    full_output_path = os.path.join(outputPath, outputClass)

    os.makedirs(full_output_path, exist_ok=True)
    full_output_path = os.path.join(full_output_path, outputImgName)

    cv2.imwrite(full_output_path, cam_image)
    print(imgName)