# Creating and extracting information from grad cam maps
* organizing code
* removing multiclass case
* extract quantitative information from the maps

In [1]:
from prep_test_data import *
from pathlib import Path
import json
import torch
import torchray
from matplotlib import pyplot as plt
from torchray.attribution.common import Probe, get_module
from torchray.attribution.grad_cam import gradient_to_grad_cam_saliency
import shutil
import numpy as np
import cv2

-------------------------------------------------------

## Main methods above

In [2]:
def create_maps_folders(main_folder, beat, labels, delete_prior):
    if delete_prior and Path(main_folder).exists():
        shutil.rmtree(main_folder)
    for label in labels:
        folder = Path(main_folder) / f"label_{beat}_beat/"
        Path(folder / label).mkdir(parents=True, exist_ok=True)
    return folder

In [3]:
def deprocess(image):
    transform = transforms.Compose([
        transforms.Normalize(mean=[0, 0, 0], std=[4.3668, 4.4643, 4.4444]),
        transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
        transforms.ToPILImage(),
    ])
    return transform(image)

def show_img(PIL_IMG):
    plt.imshow(np.asarray(PIL_IMG))

In [11]:
def grad_cam_maps(model, data, main_folder, n_batches=None):
    classes = data["test"].dataset.classes
    i = 0
    for inputs, labels in data['test']:
        inputs = inputs.to(0)
        labels = labels.to(0)
        x = inputs
        x.requires_grad_();
        saliency_layer = get_module(model, model.layer4)
        probe = Probe(saliency_layer, target='output')
        y = model(x)
        score_max_index = y.argmax(dim=1)
        z = y[:, score_max_index]
        z.backward(torch.ones_like(z))
        saliency = gradient_to_grad_cam_saliency(probe.data[0])

        for index in range(len(saliency)): 
#             plt.figure()
            heatmap = np.float32(saliency[index, 0].cpu().detach())
            img = np.array(deprocess(x[index].cpu().detach()))
            print(img.shape)
#             heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
#             heatmap = np.uint8(255 * heatmap)

#             img1 = plt.imshow(heatmap, alpha=.7)
#             img2 = plt.imshow(img, alpha=.8)

#             plt.axis('off')
            
#             label = classes[labels[index]]
            
#             plt.savefig(str(main_folder / f"{label}/{i}_{index}.png"))
#             plt.close()
        if n_batches:
            if i + 1 == n_batches:
                break
        i += 1

In [16]:
def create_grad_cam_maps_one_heartbeat(data_path, models_main_path, model_name, beat, saliency_maps_path, nr_egs):
    data_prep = DataPreparation(str(data_path))
    data, size = data_prep.create_dataloaders(16, False, 4)
    model_path = models_main_path / f"label_{beat}/{model_name}.pth"
    model = torch.load(model_path, map_location=torch.device(0))
    model.eval();
    grad_cam_maps(model, data, saliency_maps_path, nr_egs)

----------------------------------------------------------------

## Configuration and run methods!

In [17]:
with open("../config.json") as f:
    config_data = json.load(f)
    f.close()

In [18]:
HEARTBEAT = "initial"
MAP_DIR = "../attribution_maps/grad_cam_maps"
DELETE_PRIOR_DIR = True
TEST_DATA_PATH = Path(f'/mnt/Media/bernardo/DSL_test_data')
MODELS_PATH = Path(f"../models")
MODEL_NAME = "resnet50_d_22_t_19_13"
NR_BATCHES = 1

In [19]:
gradcam_folder = create_maps_folders(MAP_DIR, HEARTBEAT, config_data['labels_bin'], DELETE_PRIOR_DIR)
create_grad_cam_maps_one_heartbeat(TEST_DATA_PATH, MODELS_PATH, MODEL_NAME, HEARTBEAT, gradcam_folder, NR_BATCHES)

(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
(224, 224, 3)
