In [None]:

import nibabel as nib
import numpy as np
import config.config as cf
import torch

from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import utilities
from utilities.model import load_model, pad_to_multiple
from utilities.patient_files import get_patient, get_center_patient
import importlib

from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape
from nnunetv2.preprocessing.normalization.default_normalization_schemes import ZScoreNormalization

from utilities.patient_files import get_datasetname
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

import numpy as np
import nibabel as nib

importlib.reload(cf)
importlib.reload(utilities.model)
importlib.reload(utilities.patient_files)

from mpl_toolkits.axes_grid1 import make_axes_locatable
from medcam import medcam


In [2]:
def create_predictor(dataset_id):
    predictor = nnUNetPredictor(
            tile_step_size=0.5,
            use_gaussian=True,
            use_mirroring=True,
            perform_everything_on_device=True,
            device=torch.device("cpu"),
            verbose=False,
            verbose_preprocessing=False,
            allow_tqdm=True
        )
    
    config = "3d_fullres"
    plan = "nnUNetResEncUNetMPlans"
    trainer = "nnUNetTrainer_100epochs"
        
    predictor.initialize_from_trained_model_folder(
        f"{cf.nnunet_trained_models}/{get_datasetname(dataset_id)}/{trainer}__{plan}__{config}",
        use_folds=(0,),
        checkpoint_name='checkpoint_final.pth',
    )
    return predictor

In [3]:
def prepare_patient(predictor, patient_id, dataset_id):
    modalities = [
        rf"{cf.nnUNet_raw}\{get_datasetname(dataset_id)}\imagesTs\\Patient-{patient_id}_0000.nii.gz",
        rf"{cf.nnUNet_raw}\{get_datasetname(dataset_id)}\imagesTs\\Patient-{patient_id}_0001.nii.gz",
        rf"{cf.nnUNet_raw}\{get_datasetname(dataset_id)}\imagesTs\\Patient-{patient_id}_0002.nii.gz",
        rf"{cf.nnUNet_raw}\{get_datasetname(dataset_id)}\imagesTs\\Patient-{patient_id}_0003.nii.gz",
        rf"{cf.nnUNet_raw}\{get_datasetname(dataset_id)}\imagesTs\\Patient-{patient_id}_0004.nii.gz"
    ]
    
    gt_path = rf"{cf.nnUNet_raw}\{get_datasetname(dataset_id)}\labelsTs\Patient-{patient_id}.nii.gz"
    
    preprocessor = predictor.configuration_manager.preprocessor_class(verbose=True)
    
    data, seg, properties = preprocessor.run_case(
        modalities,
        gt_path,
        predictor.plans_manager,
        predictor.configuration_manager,
        predictor.dataset_json
    )
    input_tensor = torch.from_numpy(data).unsqueeze(0).to(torch.float32).to(predictor.device).squeeze(0)
    
    t, _ = pad_to_multiple(input_tensor, 64)
    input_tensor_padded = torch.from_numpy(t).unsqueeze(0).float().to("cpu")
    seg, _ = pad_to_multiple(seg, 64)
    
    return input_tensor_padded, seg


In [2]:
patient_id = 54
center, patient = get_center_patient(patient_id)
center_patient = f"{center}_{patient}"
dataset_id = 201


predictor = create_predictor(dataset_id)

input_tensor, seg = prepare_patient(predictor, patient_id, dataset_id)

mask = nib.load(rf"{cf.nnUNet_raw}\{get_datasetname(dataset_id)}\labelsTs\\Patient-{patient_id}.nii.gz").get_fdata()
image = input_tensor.squeeze(0)[4].detach().numpy()
gt = seg.squeeze(0)

prediction = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_{dataset_id}/prediction.nii.gz").get_fdata()


TypeError: cannot unpack non-iterable NoneType object

In [3]:
def compute_mean_heatmap(heatmaps):
    return np.mean(heatmaps, axis=0)

def relevance_window(mean_heatmap, top_fraction=0.5, bins=200):
    values = mean_heatmap.flatten()
    
    hist, bin_edges = np.histogram(values, bins=bins)
    
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    relevance_per_bin = hist * bin_centers
    
    total_relevance = np.sum(relevance_per_bin)
    cumulative = np.cumsum(relevance_per_bin[::-1])  # start from high relevance
    cutoff_idx = np.where(cumulative >= top_fraction * total_relevance)[0][0]
    
    lower_value = bin_edges[::-1][cutoff_idx]
    
    windowed = np.copy(mean_heatmap)
    windowed[windowed < lower_value] = 0
    
    return windowed, lower_value


In [13]:
patient_id = 48
patient, center = get_center_patient(patient_id)
center_patient = f"{patient}_{center}"
dataset_id = 201

gt = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_{dataset_id}_new_orient/gt.nii.gz").get_fdata().squeeze(0)
flair = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_{dataset_id}/flair.nii.gz").get_fdata()
prediction = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_{dataset_id}/prediction.nii.gz").get_fdata()

gradcampp = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_{dataset_id}/gcampp/decoder.seg_layers.4/attention_map_0_0_0_registered.nii.gz").get_fdata()
gbp = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_{dataset_id}/gbp/attention_map_0_0_0_registered.nii.gz").get_fdata()

flair_nii = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_{dataset_id}_new_orient/flair.nii.gz")

In [16]:
plt.rcParams.update({     
    "font.size": 16,        
    "axes.titlesize": 16, 
    "axes.labelsize": 16,   
    "legend.fontsize": 16   
})

zooms = flair_nii.header.get_zooms()

def plot_attention_on_image(axs, flair_slice, gt_slice, mask_slice, attention_map,title, alpha=1):
    masked = flair_slice.copy()
    eps = 1e-5
    masked[np.abs(masked) < eps] = np.min(flair_slice)

    print(flair_slice.shape)            
    print(zooms)            
    extent = [0, flair_slice.shape[0] * zooms[2], 0, flair_slice.shape[1] * zooms[0]]
    print(extent)
    extent = [0,2,0,1.5]


    axs[0].imshow(masked, cmap='gray', extent=extent, aspect="equal")
    axs[0].set_title('FLAIR')
    axs[0].axis('off')

    lesion_cmap_gt = ListedColormap([[0,0,0,0], [1,0,0,alpha]])
    gt_bin = (gt_slice == 1).astype(int)
    axs[1].imshow(masked, cmap='gray', extent=extent, aspect="equal")
    axs[1].imshow(gt_bin, cmap=lesion_cmap_gt, extent=extent, aspect="equal")
    axs[1].set_title('FLAIR + GT')
    axs[1].axis('off')

    lesion_cmap_pred = ListedColormap([[0,0,0,0], [0,0,1,alpha]])
    mask_bin = (mask_slice == 1).astype(int)
    axs[2].imshow(masked, cmap='gray', extent=extent, aspect="equal")
    axs[2].imshow(mask_bin, cmap=lesion_cmap_pred, extent=extent, aspect="equal")
    axs[2].set_title('FLAIR + Prediction')
    axs[2].axis('off')
    
    global vmax 
    vmax = np.max(attention_map)
    #axs[3].imshow(masked, cmap='gray', aspect="equal")
    im = axs[3].imshow(attention_map, cmap='jet', alpha=alpha, vmin=np.min(attention_map), vmax=np.max(attention_map), extent=extent, aspect="equal")
    axs[3].set_title(title)
    axs[3].axis('off')
    
    return im

In [None]:
%matplotlib qt
class SliceViewer:
    def __init__(self, image, gt, prediction, attention_map, title):
        self.image = image
        self.prediction = prediction
        self.gt = gt
        self.attention_map = attention_map
        self.title = title

        self.slice = image.shape[2] // 2

        self.fig, self.axs = plt.subplots(1, 4, figsize=(19, 8), constrained_layout=False)
        
        self.fig.subplots_adjust(bottom=0.15)  #  # nur in der HÃ¶he strecken

        
        divider = make_axes_locatable(self.axs[-1])
        self.cax = divider.append_axes("right", size="5%", pad=0.05)

        self.colorbar = None
        self.update()

        self.fig.canvas.mpl_connect("scroll_event", self.on_scroll)
        self.fig.canvas.mpl_connect("key_press_event", self.on_key)

    def get_slices(self, idx):
        flair = np.fliplr(np.rot90(self.image[:, idx, :], k=1))
        gt_slice = np.fliplr(np.rot90(self.gt[:, idx, :], k=1))
        mask_slice = np.fliplr(np.rot90(self.prediction[:, idx, :], k=1))
        attention_map = np.fliplr(np.rot90(self.attention_map[:, idx, :], k=1))
        return flair, gt_slice, mask_slice, attention_map

    def update(self):
        flair, gt_slice, mask_slice, gradcam_slices = self.get_slices(self.slice)

        im = plot_attention_on_image(self.axs, flair, gt_slice, mask_slice, gradcam_slices, self.title)

        self.fig.suptitle(f"Patient {patient} of Centre {center} | Slice {self.slice}/{self.image.shape[1]-1}", fontsize=16)
        
        self.fig.subplots_adjust(bottom=0.15) 

        if self.colorbar is None:
            self.colorbar = self.fig.colorbar(
                im,
                cax=self.cax,
                orientation="vertical",
                fraction=0.046,
                shrink=0.8,  
                pad=0.1  
            )
            self.colorbar.ax.tick_params(labelsize=12, pad=2)
            self.colorbar.set_label("Attention Score")
        else:
            self.colorbar.mappable.set_clim(vmin=np.min(gradcam_slices), vmax=np.max(gradcam_slices))

        for ax in self.axs.flatten():
            ax.set_aspect("equal", adjustable="box")
            ax.axis("off")

        self.fig.canvas.draw_idle()

    def on_scroll(self, event):
        if event.button == 'up':
            self.slice = (self.slice + 1) % self.image.shape[2]
        elif event.button == 'down':
            self.slice = (self.slice - 1) % self.image.shape[2]
        self.update()

    def on_key(self, event):
        if event.key == 'right':
            self.slice = (self.slice + 1) % self.image.shape[2]
        elif event.key == 'left':
            self.slice = (self.slice - 1) % self.image.shape[2]
        self.update()

In [None]:
vmin = 0
vmax = np.max(gradcampp)

mean_heatmap = compute_mean_heatmap(gradcampp)
windowed, thr = relevance_window(mean_heatmap, top_fraction=0.5)
viewer = SliceViewer(flair, gt, prediction, gradcampp,"GradCAM++: \ndecoder.seg_layer.4")

In [6]:
vmin = 0
vmax = np.max(gbp)

mean_heatmap = compute_mean_heatmap(gbp)
windowed, thr = relevance_window(mean_heatmap, top_fraction=0.5)
viewer = SliceViewer(flair, gt, prediction, gbp,"Guided Backpropagation")

NameError: name 'compute_mean_heatmap' is not defined