In [1]:
import subprocess
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
from matplotlib.colors import ListedColormap
from nibabel.processing import resample_from_to
import scipy.ndimage as ndi
import config.config as cf
from utilities.patient_files import get_center_patient
import torch
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable

2026-02-09 18:24:28,230 - logger - INFO - Logger initialized


In [2]:

patient_id = 51
center, patient = get_center_patient(patient_id)
center_patient = f"{center}_{patient}"
dataset_id = 201

a = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder.stages.0/attention_map_0_0_0.nii.gz")

In [3]:
layers = ['decoder', 'decoder.stages.0', 'decoder.stages.1', 'decoder.stages.2', 'decoder.stages.3', 'decoder.stages.4', 'decoder.transpconvs.0', 'decoder.transpconvs.1', 'decoder.transpconvs.2', 'decoder.transpconvs.3', 'decoder.transpconvs.4', 'decoder.seg_layers.4']

In [4]:
def windows_to_wsl_path(win_path: str) -> str:
    p = Path(win_path).resolve()
    drive = p.drive[0].lower()  # 'C:' -> 'c'
    rest = p.parts[1:]  # skip the 'C:\'
    return f"/mnt/{drive}/" + "/".join(rest)


In [3]:
input = f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder/attention_map_0_0_0.nii.gz"
output = f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder/attention_map_0_0_0_registered.nii.gz"
flair = rf"{cf.nnUNet_raw}\{cf.dataset_201}\imagesTs\\Patient-{patient_id}_0004.nii.gz"

output_reorient = f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder/attention_map_0_0_0_reoriented.nii.gz"

cam_img = nib.as_closest_canonical(nib.load(input))
nib.save(cam_img, output_reorient)
cam_resampled = resample_from_to(cam_img, nib.load(flair))
#nib.save(cam_resampled, output_reorient)


cmd_str = f"flirt -in '{windows_to_wsl_path(output_reorient)}' -ref '{windows_to_wsl_path(flair)}' -out '{windows_to_wsl_path(output)}' -dof 6 -interp trilinear"
cmd = ["wsl", "bash", "-ic", cmd_str] 
result = subprocess.run(cmd, capture_output=True, text=True)

NameError: name 'windows_to_wsl_path' is not defined

In [9]:
flair = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_{dataset_id}/flair.nii.gz")
flair_img = flair.get_fdata()
layers = ['decoder.stages.0', 'decoder.stages.1', 'decoder.stages.2', 'decoder.stages.3', 'decoder.stages.4', 'decoder.seg_layers.4']
for layer in layers:
    input = f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/{layer}/attention_map_0_0_0.nii.gz"
    output = f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/{layer}/attention_map_0_0_0_registered.nii.gz"
    
    cam_data_correct = np.transpose(nib.load(input).get_fdata(), (2,0,1))
    heatmap = nib.Nifti1Image(cam_data_correct.astype(np.float32), flair.affine)
    nib.save(heatmap, output)
    
    if heatmap.shape != flair_img.shape:
        zoom_factors = np.array(flair_img.shape) / np.array(heatmap.shape)
        cam_up = ndi.zoom(heatmap.get_fdata() , zoom=zoom_factors, order=1)
        
        cam_up_img = nib.Nifti1Image(cam_up, flair.affine, flair.header)
        nib.save(cam_up_img, f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/{layer}/attention_map_0_0_0_resampled.nii.gz")

In [6]:
def compute_mean_heatmap(heatmaps):
    return np.mean(heatmaps, axis=0)

def relevance_window(mean_heatmap, top_fraction=0.5, bins=200):
    values = mean_heatmap.flatten()
    
    hist, bin_edges = np.histogram(values, bins=bins)
    
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
    relevance_per_bin = hist * bin_centers
    
    total_relevance = np.sum(relevance_per_bin)
    cumulative = np.cumsum(relevance_per_bin[::-1])  # start from high relevance
    cutoff_idx = np.where(cumulative >= top_fraction * total_relevance)[0][0]
    
    lower_value = bin_edges[::-1][cutoff_idx]
    
    windowed = np.copy(mean_heatmap)
    windowed[windowed < lower_value] = 0
    
    return windowed, lower_value


In [44]:
patient_id = 48
patient, center = get_center_patient(patient_id)
center_patient = f"{patient}_{center}"
dataset_id = 201

gt = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_{dataset_id}_new_orient/gt.nii.gz").get_fdata()
gt = torch.tensor(gt).squeeze(0)
flair = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_{dataset_id}_new_orient/flair.nii.gz").get_fdata()
flair_nii = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_{dataset_id}_new_orient/flair.nii.gz")
prediction = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_{dataset_id}_new_orient/prediction.nii.gz").get_fdata()


In [48]:

%matplotlib qt

gradcam_stages0 = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder.stages.0/attention_map_0_0_0_resampled.nii.gz").get_fdata()
gradcam_stages1 = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder.stages.1/attention_map_0_0_0_resampled.nii.gz").get_fdata()
gradcam_stages2 = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder.stages.2/attention_map_0_0_0_resampled.nii.gz").get_fdata()
gradcam_stages3 = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder.stages.3/attention_map_0_0_0_resampled.nii.gz").get_fdata()
gradcam_stages4 = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder.stages.4/attention_map_0_0_0_registered.nii.gz").get_fdata()
gradcam_seglayers = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder.seg_layers.4/attention_map_0_0_0_registered.nii.gz").get_fdata()

gradcams = [gradcam_stages3,gradcam_stages4, gradcam_seglayers]

vmin = 0
vmax = np.max([np.max(gc) for gc in gradcams])

mean_heatmap = compute_mean_heatmap(gradcams)
windowed, thr = relevance_window(mean_heatmap, top_fraction=0.5)

plt.rcParams.update({     
    "font.size": 16,        
    "axes.titlesize": 16, 
    "axes.labelsize": 16,   
    "legend.fontsize": 16   
})

zooms = flair_nii.header.get_zooms()


def plot_attention_on_image(axs, flair_slice, gt_slice, mask_slice, gradcam_slices, alpha=1):
    masked = flair_slice.copy()
    eps = 1e-5
    masked[np.abs(masked) < eps] = np.min(flair_slice)
    for i in range(3):
        for j in range(3):
            for im in axs[i,j].images:
                im.remove()        
    extent = [0, flair_slice.shape[0] * zooms[2], 0, flair_slice.shape[1] * zooms[0]]
    axs[0,0].imshow(masked, cmap='gray', origin="lower", extent=extent, aspect="equal")
    axs[0,0].set_title('FLAIR'); axs[0,0].axis('off')

    lesion_cmap_gt = ListedColormap([[0,0,0,0], [1,0,0,alpha]])
    gt_bin = (gt_slice == 1).astype(int)
    axs[0,1].imshow(masked, cmap='gray', origin="lower", extent=extent, aspect="equal")
    axs[0,1].imshow(gt_bin, cmap=lesion_cmap_gt, origin="lower", extent=extent, aspect="equal")
    axs[0,1].set_title('FLAIR + GT'); axs[0,1].axis('off')

    lesion_cmap_pred = ListedColormap([[0,0,0,0], [0,0,1,alpha]])
    mask_bin = (mask_slice == 1).astype(int)
    axs[0,2].imshow(masked, cmap='gray', origin="lower", extent=extent, aspect="equal")
    axs[0,2].imshow(mask_bin, cmap=lesion_cmap_pred, origin="lower", extent=extent, aspect="equal")
    axs[0,2].set_title('FLAIR + Prediction'); axs[0,2].axis('off')

    titles = ['stages.0', r'$\mathbf{Grad-CAMs\ of\ decoder\ layers}$' + '\n' + 'stages.1', 'stages.2',
              'stages.3', 'stages.4', 'seg_layers.4']
    positions = [(1,0),(1,1),(1,2),(2,0),(2,1),(2,2)]
    ims = []
    
    for gc, title, pos in zip(gradcam_slices, titles, positions):
        gc_masked = np.where(gc >= thr, gc, np.nan)

        axs[pos].imshow(masked, cmap='gray', origin="lower", extent=extent, aspect="equal")
        #if pos >= (2,0):
        im = axs[pos].imshow(gc_masked, cmap='jet', alpha=alpha, vmax=vmax, origin="lower", extent=extent, aspect="equal")
        #im = axs[pos].imshow(gc_masked, cmap='jet', alpha=alpha, vmax=vmax)
        #else:
        #    im = axs[pos].imshow(gc_masked, cmap='jet', alpha=alpha)
        axs[pos].set_title(title)
        axs[pos].axis('off')
        ims.append(im)

    return ims

class SliceViewer:
    def __init__(self, image,gt, prediction,  gradcam_stages0, gradcam_stages1, gradcam_stages2, gradcam_stages3, gradcam_stages4,  gradcam_seglayers):
        self.image = image
        self.prediction = prediction
        self.gt = gt
        self.gradcams = [gradcam_stages0, gradcam_stages1, gradcam_stages2,
                         gradcam_stages3, gradcam_stages4, gradcam_seglayers]
        
        self.slice = 128# image.shape[2] // 2
        
        
        self.fig, self.axs = plt.subplots(3, 3,  figsize=(20,10), constrained_layout=False)
        
        self.fig.set_constrained_layout_pads(w_pad=10.05, h_pad=10.05, hspace=10.05, wspace=10.05)
        plt.subplots_adjust(hspace=0.4) 
                
        gs = gridspec.GridSpec(3, 4, width_ratios=[1,1,1,0.05], figure=self.fig)
        self.cax = self.fig.add_subplot(gs[1:,3])
        self.colorbars = []
        
        self.update()
        self.fig.canvas.mpl_connect("scroll_event", self.on_scroll)
        self.fig.canvas.mpl_connect("key_press_event", self.on_key)

    def get_slices(self, idx):
        k=3
        flair = np.rot90(self.image[:,idx,:], k=k)
        gt_slice = np.rot90(self.gt[:,idx,:], k=k)
        mask_slice = np.rot90(self.prediction[:,idx,:], k=k)
        gradcam_slices = [np.rot90(gc[:, idx, :], k=k) for gc in self.gradcams]
    
        return flair, gt_slice, mask_slice, gradcam_slices

    def update(self):
        flair, gt_slice, mask_slice, gradcam_slices = self.get_slices(self.slice)
        ims = plot_attention_on_image(self.axs, flair, gt_slice, mask_slice, gradcam_slices)
        
        if not self.colorbars:
            cbar = self.fig.colorbar(ims[-1], cax=self.cax, orientation="vertical")
            cbar.set_label("Attention Score")
            self.colorbars.append(cbar)
        else:
            self.colorbars[0].mappable.set_clim(vmin=0, vmax=vmax)
    
        self.fig.suptitle(f"Patient {patient} of Centre {center} | Slice {self.slice}/{self.image.shape[1]-1}", fontsize=16)
        for ax in self.axs.flatten():
            ax.set_aspect("equal", adjustable="box")
            ax.axis("off")
    
        self.fig.canvas.draw_idle()
        self.fig.canvas.flush_events()

    def on_scroll(self, event):
        if event.button == 'up':
            self.slice = (self.slice + 1) % self.image.shape[2]
        elif event.button == 'down':
            self.slice = (self.slice - 1) % self.image.shape[2]
        self.update()

    def on_key(self, event):
        if event.key == 'right':
            self.slice = (self.slice + 1) % self.image.shape[2]
        elif event.key == 'left':
            self.slice = (self.slice - 1) % self.image.shape[2]
        self.update()


viewer = SliceViewer(flair, gt, prediction,  gradcam_stages0, gradcam_stages1, gradcam_stages2, gradcam_stages3, gradcam_stages4,  gradcam_seglayers)


(256, 128)
(256, 128)


In [None]:
%matplotlib qt

gradcam_stages3 = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder.stages.3/attention_map_0_0_0_resampled.nii.gz").get_fdata()
gradcam_stages4 = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder.stages.4/attention_map_0_0_0_registered.nii.gz").get_fdata()
gradcam_seglayers = nib.load(f"{cf.plots}/xai/attention_maps_{center_patient}_test_{dataset_id}/gcam/decoder.seg_layers.4/attention_map_0_0_0_registered.nii.gz").get_fdata()

gradcams = [gradcam_stages3,gradcam_stages4, gradcam_seglayers]

vmin = 0
vmax = np.max([np.max(gc) for gc in gradcams])

mean_heatmap = compute_mean_heatmap(gradcams)
windowed, thr = relevance_window(mean_heatmap, top_fraction=0.5)

plt.rcParams.update({     
    "font.size": 25,        
    "axes.titlesize": 25, 
    "axes.labelsize": 25,   
    "legend.fontsize": 25   
})

zooms = flair_nii.header.get_zooms()

def plot_attention_on_image(axs, flair_slice, gt_slice, mask_slice, gradcam_slices, alpha=1):
    masked = flair_slice.copy()
    eps = 1e-4
    masked[np.abs(masked) < eps] = np.min(flair_slice)

    for i in range(2):
        for j in range(2):
            for im in axs[i,j].images:
                im.remove()
                
    print(flair_slice.shape)            
    print(zooms)            
    extent = [0, flair_slice.shape[0] * zooms[2], 0, flair_slice.shape[1] * zooms[0]]
    print(extent)
    #extent = [0,2,0,3]
    axs[0,0].imshow(masked, cmap='gray', origin="lower", extent=extent, aspect="equal")
    axs[0,0].set_title('FLAIR'); axs[0,0].axis('off')

    lesion_cmap_gt = ListedColormap([[0,0,0,0], [1,0,0,alpha]])
    gt_bin = (gt_slice == 1).astype(int)
    axs[0,1].imshow(masked, cmap='gray', origin="lower", extent=extent, aspect="equal")
    axs[0,1].imshow(gt_bin, cmap=lesion_cmap_gt, origin="lower", extent=extent, aspect="equal")
    axs[0,1].set_title('FLAIR + GT'); axs[0,1].axis('off')

    lesion_cmap_pred = ListedColormap([[0,0,0,0], [0,0,1,alpha]])
    mask_bin = (mask_slice == 1).astype(int)
    axs[0,2].imshow(masked, cmap='gray', origin="lower", extent=extent, aspect="equal")
    axs[0,2].imshow(mask_bin, cmap=lesion_cmap_pred, origin="lower", extent=extent, aspect="equal")
    axs[0,2].set_title('FLAIR + Prediction'); axs[0,2].axis('off')

    titles = ['stages.3', r'$\mathbf{Grad-CAMs\ of\ decoder\ layers}$' + '\n' + 'stages.4', 'seg_layers.4']
    positions = [(1,0),(1,1),(1,2)]
    ims = []
    

    for gc, title, pos in zip(gradcam_slices, titles, positions):
        gc_masked = np.where(gc >= thr, gc, np.nan)

        axs[pos].imshow(masked, cmap='gray', origin="lower", extent=extent, aspect="equal")
        #if pos >= (2,0):
        im = axs[pos].imshow(gc_masked, cmap='jet', alpha=alpha, vmax=vmax, origin="lower", extent=extent, aspect="equal")
        #else:
        #    im = axs[pos].imshow(gc_masked, cmap='jet', alpha=alpha)
        axs[pos].set_title(title)
        axs[pos].axis('off')
        ims.append(im)

    return ims

class SliceViewer:
    def __init__(self, image,gt, prediction,  gradcam_stages3, gradcam_stages4,  gradcam_seglayers):
        self.image = image
        self.prediction = prediction
        self.gt = gt
        self.gradcams = [gradcam_stages3, gradcam_stages4, gradcam_seglayers]
        
        self.slice = image.shape[2] // 2
        
        
        self.fig, self.axs = plt.subplots(2, 3,  figsize=(15,10), constrained_layout=True)
        
        self.fig.set_constrained_layout_pads(w_pad=0.05, h_pad=0.05, hspace=0.05, wspace=0.05)
                
        gs = gridspec.GridSpec(2, 4, width_ratios=[1,1,1,0.05], figure=self.fig)
        self.cax = self.fig.add_subplot(gs[1:,3])
        self.colorbars = []
        plt.subplots_adjust(wspace=0.4) 
        self.update()
        self.fig.canvas.mpl_connect("scroll_event", self.on_scroll)
        self.fig.canvas.mpl_connect("key_press_event", self.on_key)

    def get_slices(self, idx):
        k=3
        flair = np.rot90(self.image[:,idx,:], k=k)
        gt_slice = np.rot90(self.gt[:,idx,:], k=k)
        mask_slice = np.rot90(self.prediction[:,idx,:], k=k)
        gradcam_slices = [np.rot90(gc[:, idx, :], k=k) for gc in self.gradcams]
    
    
        return flair, gt_slice, mask_slice, gradcam_slices

    def update(self):
        # Neue Slices holen
        flair, gt_slice, mask_slice, gradcam_slices = self.get_slices(self.slice)
        ims = plot_attention_on_image(self.axs, flair, gt_slice, mask_slice, gradcam_slices)
        
        self.fig.suptitle(f"Patient {patient} of Centre {center} | Slice {self.slice}/{self.image.shape[1]-1}", fontsize=25)
        
        # Colorbar nur beim ersten Mal erstellen
        if not self.colorbars:
            cbar = self.fig.colorbar(ims[-1], cax=self.cax, orientation="vertical")
            cbar.set_label("Attention Score")
            self.colorbars.append(cbar)
        else:
            # Nur Wertebereich aktualisieren
            self.colorbars[0].mappable.set_clim(vmin=0,
                                                vmax=vmax)
    
        for ax in self.axs.flatten():
            ax.set_aspect("equal", adjustable="box")
            ax.axis("off")
    
        # Erzwingt Redraw
        self.fig.canvas.draw_idle()
        self.fig.canvas.flush_events()



    def on_scroll(self, event):
        if event.button == 'up':
            self.slice = (self.slice + 1) % self.image.shape[2]
        elif event.button == 'down':
            self.slice = (self.slice - 1) % self.image.shape[2]
        self.update()

    def on_key(self, event):
        if event.key == 'right':
            self.slice = (self.slice + 1) % self.image.shape[2]
        elif event.key == 'left':
            self.slice = (self.slice - 1) % self.image.shape[2]
        self.update()


viewer = SliceViewer(flair, gt, prediction, gradcam_stages3, gradcam_stages4,  gradcam_seglayers)
