## What3D可视化

3D Grand CAM可视化模块

In [None]:
import os
import monai
from glob import glob
import matplotlib.pyplot as plt
from onekey_algo import get_param_in_cwd


mydir = get_param_in_cwd('data_pattern')
# mydir = '自己的目录'
samples = [os.path.join(mydir, f) for f in os.listdir(mydir) if f.endswith('.nii') or f.endswith('.nii.gz')]

# samples = [samples[-1]]
samples

## 确定可视化模型

通过关键词获取要提取那一层进行可视化。

### 支持的模型名称

模型名称替换代码中的 `model_name`变量的值。

| **模型系列** | **模型名称**                                                 |
| ------------ | ------------------------------------------------------------ |
| ResNet       | resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200 |
| DenseNet     | DenseNet121, DenseNet121, DenseNet121, DenseNet121           |
| ShuffleNet   | ShuffleNet |
| VisionTransformer     | ViT, SimpleViT, CCT         |

In [None]:
from monai.data import ImageDataset
from torch.utils.data import DataLoader
from onekey_algo.custom.components.comp2 import extract, init_from_onekey3d

model_name = get_param_in_cwd('sel_dl_model_name')
model_root = os.path.join(get_param_in_cwd('model_root'), model_name)
model, transformer, device = init_from_onekey3d(os.path.join(model_root, 'viz'))
for n, m in model.named_modules():
    print('Feature name:', n, "|| Module:", m)

## 可视化卷积层

`Feature name:` 之后的名称为要可视化的层，例如`layer4.2.conv3`, 一般深度学习特征提取最后一层卷积层

** 注意 ** : 可视化的层，一定为带有`conv`的卷积层，而且一般是最后一层。

In [None]:
from onekey_algo.custom.components.comp2 import show_cam_on_image
import torch

target_layer = "layer4.2.conv3"
gradcam = monai.visualize.GradCAM(nn_module=model, target_layers=target_layer)

val_ds = ImageDataset(image_files=samples, transform=transformer)
# create a validation data loader
val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)
viz_dir = os.path.join(model_root, 'Grad-CAM')
os.makedirs(viz_dir, exist_ok=True)
for sidx, sample_ in enumerate(val_loader):
    res_cam = gradcam(x=sample_.to(device), class_idx=None)
    sample_np = sample_.cpu().detach().numpy()
    for idx in range(sample_.size()[-1]):
        fig, axes = plt.subplots(1, 2, figsize=(10, 5), facecolor='white')
        axes[0].imshow(res_cam[0][0][..., idx].cpu().detach().numpy(), cmap='jet')
        axes[0].axis('off')
        axes[1].imshow(sample_np[0][0][..., idx], cmap='gray')
        axes[1].axis('off')
#         cax = fig.add_axes([0.92, 0.17, 0.02, axes[0].get_position().height]) 
#         plt.colorbar(imshow, cax=cax)
        plt.savefig(f'{viz_dir}/{os.path.basename(samples[sidx]).replace(".gz", f"_{idx}.png")}', bbox_inches = 'tight')
        plt.show()
        plt.close(fig)