11.11.2022

# Test Grad-CAM on UNet model

Provo a usare una Grad-CAM su un modello salvato della UNet e il movie 34 (dove la fine della wave viene detettata come puff).

In [1]:
# autoreload is used to reload modules automatically before entering the
# execution of code typed at the IPython prompt.
%load_ext autoreload
%autoreload 2
# To import modules from parent directory in Jupyter Notebook
import sys

sys.path.append("..")

In [2]:
import os

import torch

# from in_out_tools import write_videos_on_disk
from torch import nn
from torch.utils.data import DataLoader
from utils.visualization_tools import get_discrete_cmap

from medcam import medcam
import napari
from utils.training_inference_tools import get_half_overlap
from utils.training_script_utils import init_model
from data.datasets import SparkDatasetInference

from config import TrainingConfig, config

  from .autonotebook import tqdm as notebook_tqdm


### Set parameters

In [3]:
training_name = "final_model"
config_file = "config_final_model.ini"

print(f"Processing training '{training_name}'...")

# Initialize general parameters
params = TrainingConfig(training_config_file=os.path.join("config_files", config_file))

Processing training 'final_model'...
[23:05:48] [  INFO  ] [   config   ] <290 > -- Loading C:\Users\prisc\Code\sparks_project\config_files\config_final_model.ini


### Configure output folder

In [4]:
# change this to save results for same model with different inference approaches
output_name = training_name

output_folder = os.path.join(
    config.basedir, "evaluation", "gradCAM_script", output_name
)  # Same folder for train and test preds
os.makedirs(output_folder, exist_ok=True)

print(f"Output files will be saved on '{os.path.realpath(output_folder)}'")

Output files will be saved on 'C:\Users\prisc\Code\sparks_project\evaluation\gradCAM_script\final_model'


### Detect GPU, if available

In [5]:
params.set_device(device="auto")
params.display_device_info()

[23:05:49] [  INFO  ] [   config   ] <528 > -- Using cpu


### Config UNet model

In [6]:
network = init_model(params=params)
network = nn.DataParallel(network).to(params.device)

In [None]:
### Load UNet model ###
# Path to the saved model checkpoint
model_filename = f"network_{params.inference_load_epoch:06d}.pth"
models_relative_path = os.path.join(
    "models", "saved_models", params.run_name, model_filename
)
model_dir = os.path.realpath(os.path.join(config.basedir, models_relative_path))

# Load the model state dictionary
print(
    f"Loading trained model '{training_name}' at epoch {params.inference_load_epoch}..."
)
network.load_state_dict(torch.load(model_dir, map_location=params.device))

#### Print summary of network architecture

In [13]:
# summary(network,  input_size=(1, 256, 64, 512), device="cpu")

### Load input sample

In [14]:
### Configure sample input ###
movie_path = (
    r"C:\Users\prisc\Code\sparks_project\data\sparks_dataset\05_class_label.tif"
)

sample_dataset = SparkDatasetInference(
    params=params,
    movie_path=movie_path,
)

print(
    f"Testing dataset of movie {os.path.realpath(movie_path)} "
    f"contains {len(sample_dataset)} samples."
)

Testing dataset of movie C:\Users\prisc\Code\sparks_project\data\sparks_dataset\05_class_label.tif contains 9 samples.


In [16]:
# Create a dataloader
dataset_loader = DataLoader(
    sample_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=params.num_workers,
    pin_memory=params.pin_memory,
)

### Configure Grad-CAM

In [17]:
network.eval()

cam_network = medcam.inject(
    network,
    label=3,
    replace=True,
    # backend="gcam",
    layer="module.final_layer",
    output_dir=output_folder,
    save_maps=True,
)

### Run sample's chunks in network and re-assemble UNet's output

In [19]:
n_chunks = len(sample_dataset)
half_overlap = get_half_overlap(
    data_duration=params.data_duration,
    data_stride=params.data_stride,
    temporal_reduction=params.temporal_reduction,
    num_channels=params.num_channels,
)

In [23]:
# print("network device", next(cam_network.parameters()).device)

network device cpu


In [28]:
out_concat = []
x_concat = []

cam_network.eval()
for i, sample in enumerate(dataset_loader):
    x = sample["data"]

    # define start and end of used frames in chunks
    start = 0 if i == 0 else half_overlap
    end = None if i + 1 == n_chunks else -half_overlap

    x_concat.append(x[0, start:end])

    x = x.to(params.device)
    print("x device", x.device)
    print("network device", next(cam_network.parameters()).device)
    out = cam_network(x[None, :])[0, 0]
    out_concat.append(out[start:end].cpu())
x_concat = torch.cat(x_concat, dim=0).numpy()
out_concat = torch.cat(out_concat, dim=0).numpy()

x device cpu
network device cpu


RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

In [69]:
x_concat.shape

(928, 64, 512)

### Visualise result with Napari

In [28]:
# Configure Napari cmap
cmap = get_discrete_cmap(name="gray", lut=16)

In [70]:
viewer = napari.Viewer()

viewer.add_image(x_concat, name="input movie", colormap=("colors", cmap))

viewer.add_image(out_concat, name="network output", colormap=("colors", cmap))

<Image layer 'network output' at 0x14486fb6130>