In [4]:
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from torchvision.models import resnet50
from sklearn.metrics import roc_auc_score
import models_vit
import cv2
import numpy as np
import torch
import timm
import matplotlib.pylab as plt

In [5]:
def reshape_transform(tensor, height=14, width=14):
    result = tensor[:, 1:, :].reshape(tensor.size(0),
                                      height, width, tensor.size(2))
    # Bring the channels to the first dimension,
    # like in CNNs.
    result = result.transpose(2, 3).transpose(1, 2)
    return result

In [7]:
def run_grad_cam(model, model_type, target_layers, data_lines):
    for line in data_lines:
        img_path = line.split(' ')[0]
        model.eval()
        rgb_img = cv2.imread(data_dir + img_path, 1)[:, :, ::-1]

        rgb_img = cv2.resize(rgb_img, (224, 224))
        rgb_img = np.float32(rgb_img) / 255
        input_tensor = preprocess_image(rgb_img, mean = [0.5056, 0.5056, 0.5056], std = [0.252, 0.252, 0.252])

        cam = GradCAM(model=model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=1)
        targets = None
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)

        grayscale_cam = grayscale_cam[0, :]
        visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=False)
        cv2.imwrite(model_type + img_path, visualization)

In [37]:
@torch.no_grad()
def eval_image(model, target_layers, data_lines, output_file_path):
    outputs = []
    targets = []
    with open(output_file_path, 'a') as output_file:
        for line in lines:
            line_split = line.split()
            img_path = line_split[0]
            imageLabel = line_split[1:num_class + 1]
            imageLabel = [float(i) for i in imageLabel]
            imageLabel = [imageLabel]
            imageLabel = torch.tensor(imageLabel).to(torch.device('cuda'))

            model.eval()
            rgb_img = cv2.imread(data_dir + img_path, 1)[:, :, ::-1]

            rgb_img = cv2.resize(rgb_img, (224, 224))
            rgb_img = np.float32(rgb_img) / 255
            input_tensor = preprocess_image(rgb_img, mean = [0.5056, 0.5056, 0.5056], std = [0.252, 0.252, 0.252])

            output, _ = model(input_tensor.cuda())
            acc, res = accuracy(imageLabel, output.sigmoid(), num_class)
        
            output_file.write(f'image_path: {img_path}, acc_each_class: {res}, acc_avg: {acc}')
        
            
#     outputs = torch.cat(outputs, dim = 0).sigmoid().cpu().numpy()
#     targets = torch.cat(targets, dim = 0).cpu().numpy()
    
#     auc = computeAUROC(targets, outputs, num_class)
#     auc_each_class_array = np.array(auc)
#     missing_classes_index = np.where(auc_each_class_array == 0)[0]
#     # print(missing_classes_index)
#     if missing_classes_index.shape[0] > 0:
#         print('There are classes that not be predicted during testing,'
#               ' the indexes are:', missing_classes_index)

#     auc_avg = np.average(auc_each_class_array[auc_each_class_array != 0])
#     
        

In [5]:
def computeAUROC(dataGT, dataPRED, classCount):
    outAUROC = []
    # print(dataGT.shape, dataPRED.shape)
    for i in range(classCount):
        try:
            outAUROC.append(roc_auc_score(dataGT[:, i], dataPRED[:, i]))
        except:
            outAUROC.append(0.)
    print(outAUROC)
    return outAUROC

In [36]:
def accuracy(dataGT, dataPRED, classCount):
    acc = 0
    pred = (dataPRED >= 0.5).float()
    correct = pred.eq(dataGT[0])
    res = correct.int()
    return torch.sum(correct).item() * 100 / classCount, res

In [10]:
data_dir = "/data/yyang409/rgoel15/ChestX-ray14/images/"
img_dir = "data_splits/chestxray/grad_cam_test.txt"

out_dir = '/data/yyang409/rgoel15/medical_mae_soft_low_rank/grad-cam-images/'

base_model = out_dir + 'base_model/'
low_rank = out_dir + 'low_rank_model/'

base_auc_out = base_model + 'base.txt'
low_auc_out = low_rank + 'low.txt'

num_class = 14

with open(img_dir, 'r') as input_file:
    lines = input_file.readlines()

In [38]:
# GRAD CAM and EVAL BASE MODEL
model = models_vit.__dict__['vit_base_patch16'](img_size=224,
                                            num_classes=14,
                                            drop_rate=0,
                                            drop_path_rate=0,
                                            global_pool=True)
checkpoint = torch.load('/data/yyang409/rgoel15/medical_mae_original_models/finetuned_models/vit-b_CXR_0.5M_mae.pth')
model.load_state_dict(checkpoint['model'], strict=True)
model.to(torch.device('cuda'))
target_layers = [model.blocks[-1].norm1]

# run_grad_cam(model, base_model, target_layers, lines)
eval_image(model, target_layers, lines, base_auc_out)

In [8]:
# GRAD CAM and EVAL LOW RANK MODEL
model = models_vit.__dict__['vit_base_patch16'](img_size=224,
                                            num_classes=14,
                                            drop_rate=0,
                                            drop_path_rate=0,
                                            global_pool=True)
checkpoint = torch.load('/data/yyang409/rgoel15/medical_mae_low_rank/vit_b/checkpoint-best_auc_38_0.00125_20.pth')
model.load_state_dict(checkpoint['model'], strict=True)
target_layers = [model.blocks[-1].norm1]

run_grad_cam(model, low_rank, target_layers, lines)
eval_image(model, target_layers, lines, low_auc_out)

[0.7868745938921378, 0.9556737588652482, 0.8418749999999999, 0.6768, 0.6955555555555555, 0.5326460481099656, 0.0, 0.9044444444444445, 0.8179347826086957, 0.7835051546391751, 0.9452631578947368, 0.9243986254295532, 0.7801418439716312, 1.0]
There are classes that not be predicted during testing, the indexes are: [6]
here
