# Creating and extracting information from saliency maps
* organizing code
* removing multiclass case
* extract quantitative information from the maps

In [86]:
from prep_test_data import *
from pathlib import Path
import json
import torch
from matplotlib import pyplot as plt
import shutil

-------------------------------------------------------

## Main methods above

In [120]:
def create_maps_folders(main_folder, beat, labels, delete_prior):
    if delete_prior and Path(main_folder).exists():
        shutil.rmtree(main_folder)
    for label in labels:
        folder = Path(main_folder) / f"label_{beat}_beat/"
        Path(folder / label).mkdir(parents=True, exist_ok=True)
    return folder

In [107]:
def deprocess(image):
    transform = transforms.Compose([
        transforms.Normalize(mean=[0, 0, 0], std=[4.3668, 4.4643, 4.4444]),
        transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
        transforms.ToPILImage(),
    ])
    return transform(image)

def show_img(PIL_IMG):
    plt.imshow(np.asarray(PIL_IMG))

In [113]:
def saliency_maps(model, data, main_folder, n_batches=None):
    classes = data["test"].dataset.classes
    i = 0
    for inputs, labels in data['test']:
        inputs = inputs.to(0)
        labels = labels.to(0)
        x = inputs
        x.requires_grad_();
        scores = model(x)
        score_max_index = scores.argmax(dim=1)
        score_max = scores[:, score_max_index]
        score_max.backward(torch.ones_like(score_max))
        saliency, _ = torch.max(x.grad.data.abs(),dim=1)

        for index in range(len(saliency)):
            plt.figure()
            img1 = plt.imshow(saliency[index].cpu().numpy(), cmap=plt.cm.hot, alpha=.7);
            img2 = plt.imshow(deprocess(x[index].cpu()), alpha=.4);
            plt.axis('off')
            
            label = classes[labels[index]]
                
            plt.savefig(str(main_folder / f"{label}/{i}_{index}.png"))
            plt.close()
            
        if n_batches:
            if i + 1 == n_batches:
                break
        i += 1

In [114]:
def create_saliency_maps_one_heartbeat(data_path, models_main_path, model_name, beat, saliency_maps_path, nr_egs):
    data_prep = DataPreparation(str(data_path))
    data, size = data_prep.create_dataloaders(16, False, 4)
    model_path = models_main_path / f"label_{beat}/{model_name}.pth"
    model = torch.load(model_path)
    model.eval();
    saliency_maps(model, data, saliency_maps_path, nr_egs)

----------------------------------------------------------------

## Configuration and run methods!

In [115]:
with open("../config.json") as f:
    config_data = json.load(f)
    f.close()

In [118]:
HEARTBEAT = "initial"
MAP_DIR = "../attribution_maps/saliency_maps"
DELETE_PRIOR_DIR = True
TEST_DATA_PATH = Path(f'../data/figures_{HEARTBEAT}/test')
MODELS_PATH = Path(f"../models")
MODEL_NAME = "resnet50_d_19_t_16_46"
NR_BATCHES = 1

In [119]:
saliency_folder = create_maps_folders(MAP_DIR, HEARTBEAT, config_data['labels_bin'], DELETE_PRIOR_DIR)
create_saliency_maps_one_heartbeat(TEST_DATA_PATH, MODELS_PATH, MODEL_NAME, HEARTBEAT, saliency_folder, NR_BATCHES)