# Grad-CAM

### 参数说明

1. sample_dir： 你自己的样本目录
2. model_root：你自己的模型目录，注意这里不需要精确到viz目录，所有的Grad-CAM图会生成的此目录的Grad-CAM文件夹
3. target_layer：你自己喜欢的层的名称，如果不知道具体的参数名字，可以先运行一次，在输入修改。

In [None]:
import torch
import os
import random
import numpy as np
import monai
from glob import glob
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib
matplotlib.use('Agg')
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

from onekey_algo import get_param_in_cwd
from onekey_algo.custom.components.comp2 import target_layer_mapping
from onekey_algo.datasets.image_loader import default_loader
from onekey_algo.custom.components.comp2 import show_cam_on_image
from onekey_algo.custom.components.comp2 import extract, init_from_model, init_from_onekey
from onekey_algo.utils.MultiProcess import MultiProcess


def viz_sample(samples, thread_id):
    model, transformer, device = init_from_onekey(os.path.join(model_root, 'viz'))
#     for n, m in model.named_modules():
#         print('Feature name:', n, "|| Module:", m)
    gradcam = monai.visualize.GradCAM(nn_module=model, target_layers=target_layer)
    viz_dir = os.path.join(model_root, 'Grad-CAM')
    os.makedirs(viz_dir, exist_ok=True)
    for sample in samples:
        if os.path.exists(os.path.join(viz_dir, os.path.basename(sample))):
            continue
        img = default_loader(sample)
        num_channels = img.shape[-1]
        sample_ = transformer(img)
        sample_  = sample_.view(1, *sample_.size()).to(device)
        res_cam = gradcam(x=sample_, class_idx=None)
        fig, axes = plt.subplots(1, num_channels + 1, figsize=(4 * (num_channels + 1), 4), facecolor='white')
        for i in range(num_channels):
            axes[i].imshow(img[..., i], cmap='gray')
            axes[i].axis('off')
        imshow = axes[num_channels].imshow(-res_cam[0][0].cpu(),cmap='jet')
        axes[num_channels].axis('off')
        cax = fig.add_axes([0.92, 0.17, 0.02, axes[num_channels].get_position().height]) 
        plt.colorbar(imshow, cax=cax)
        plt.savefig(f'{viz_dir}/{os.path.basename(sample).replace(".npy", ".png")}', bbox_inches = 'tight')
        plt.show()
        plt.close(fig)

compare_settings = get_param_in_cwd('compare_settings')
for task_type, dp in zip(compare_settings['task_types'], compare_settings['data_patterns']):
    for model_name in compare_settings['model_names']:
        sample_dir = dp
        model_root = os.path.join(get_param_in_cwd('model_root'), task_type, model_name)
        target_layer = target_layer_mapping[model_name + '_2D']
        samples = glob(os.path.join(sample_dir, '*'))
        random.shuffle(samples)
        viz_sample(samples, thread_id=1)