In [None]:
import os
import numpy as np
import pandas as pd
import ants
import matplotlib.pyplot as plt

import torch
from torchsurv.loss.weibull import survival_function

In [None]:
data_path = 'path/to/data/test/preprocessed'
result_path = f'results'

label_df_path = 'path/to/data/test_label.csv'
label_df = pd.read_csv(label_df_path)
variable_list = ['sex', 'age', 'who_grade', 'group', 'gbl', 'idh', '_1p19q', 'mgmt', 'eor_op1', 'kps']
for variable in variable_list:
    label_df = label_df[label_df[variable].notna() & (label_df[variable] != -1)]

label_df['patient_id'] = label_df['patient_id'].astype(str)
patient_list = [f"{row['patient_id']}" for _, row in label_df.iterrows()]

In [None]:
new_times = torch.arange(0, 36, 0.1)
select_indices = [4, 5, 8, 9, 11]
full_feature_names = ['MRI', 'Age', 'Sex', 'WHO Grade', 'Histopathology', 'IDH status', '1p/19q status', 'MGMTp status', 'KPS', 'Extent of resection', 'Radiotherapy', 'Chemotherapy']

for patient_id in patient_list:
    data_dir = f'{data_path}/{patient_id}'
    t1_path = f'{data_dir}/t1.nii.gz'
    t1ce_path = f'{data_dir}/t1ce.nii.gz'
    t2_path = f'{data_dir}/t2.nii.gz'
    flair_path = f'{data_dir}/flair.nii.gz'
    brain_mask_path = f'{data_dir}/brain_mask.nii.gz'
    WT_tumor_mask_path = f'{data_dir}/WT_tumor_mask.nii.gz'
    
    result = np.load(f'{result_path}/{patient_id}.npy', allow_pickle=True).item()
    vision_cam = result['vision_cam']
    
    t1_img = ants.image_read(t1_path)
    t1ce_img = ants.image_read(t1ce_path)
    t2_img = ants.image_read(t2_path)
    flair_img = ants.image_read(flair_path)
    brain_mask_img = ants.image_read(brain_mask_path)
    WT_tumor_mask_img = ants.image_read(WT_tumor_mask_path)
    
    t1_numpy = t1_img.numpy()
    t1ce_numpy = t1ce_img.numpy()
    t2_numpy = t2_img.numpy()
    flair_numpy = flair_img.numpy()
    brain_mask_numpy = brain_mask_img.numpy()
    tumor_mask_numpy = WT_tumor_mask_img.numpy()
    
    t1_numpy = t1_numpy * brain_mask_numpy
    t1ce_numpy = t1ce_numpy * brain_mask_numpy
    t2_numpy = t2_numpy * brain_mask_numpy
    flair_numpy = flair_numpy * brain_mask_numpy
    
    t1_numpy = (t1_numpy - 0) / (np.quantile(t1_numpy, 0.99) - 0)
    t1ce_numpy = (t1ce_numpy - 0) / (np.quantile(t1ce_numpy, 0.99) - 0)
    t2_numpy = (t2_numpy - 0) / (np.quantile(t2_numpy, 0.99) - 0)
    flair_numpy = (flair_numpy - 0) / (np.quantile(flair_numpy, 0.99) - 0)
    tumor_largest_z_index = np.argmax(tumor_mask_numpy.sum(axis=(0, 1)))
    
    t1_slice_numpy = t1_numpy[..., tumor_largest_z_index]
    t1ce_slice_numpy = t1ce_numpy[..., tumor_largest_z_index]
    t2_slice_numpy = t2_numpy[..., tumor_largest_z_index]
    flair_slice_numpy = flair_numpy[..., tumor_largest_z_index]
    brain_mask_slice_numpy = brain_mask_numpy[..., tumor_largest_z_index]
    tumor_mask_slice_numpy = tumor_mask_numpy[..., tumor_largest_z_index]
    vision_cam_slice_numpy = vision_cam[..., tumor_largest_z_index]
    
    fig = plt.figure(figsize=(15, 15))
    gs = fig.add_gridspec(3, 3)
    
    ax_t1 = fig.add_subplot(gs[0, 0])
    ax_t1ce = fig.add_subplot(gs[0, 1])
    ax_t2 = fig.add_subplot(gs[0, 2])
    ax_flair = fig.add_subplot(gs[1, 0])
    ax_tumor_mask = fig.add_subplot(gs[1, 1])
    ax_vision_cam = fig.add_subplot(gs[1, 2])

    ax_bar = fig.add_subplot(gs[2, 0:2])
    ax_survival = fig.add_subplot(gs[2, 2])

    ax_t1.imshow(t1_slice_numpy, cmap='gray')
    ax_t1.set_title('T1')
    ax_t1.axis('off')
    
    ax_t1ce.imshow(t1ce_slice_numpy, cmap='gray')
    ax_t1ce.set_title('T1CE')
    ax_t1ce.axis('off')
    
    ax_t2.imshow(t2_slice_numpy, cmap='gray')
    ax_t2.set_title('T2')
    ax_t2.axis('off')
    
    ax_flair.imshow(flair_slice_numpy, cmap='gray')
    ax_flair.set_title('FLAIR')
    ax_flair.axis('off')

    ax_tumor_mask.imshow(t1ce_slice_numpy, cmap='gray')
    ax_tumor_mask.imshow(tumor_mask_slice_numpy, cmap='jet', alpha=0.7)
    ax_tumor_mask.set_title('Tumor Mask')
    ax_tumor_mask.axis('off')

    ax_vision_cam.imshow(t1ce_slice_numpy, cmap='gray')
    ax_vision_cam.imshow(vision_cam_slice_numpy, cmap='jet', alpha=0.5)
    ax_vision_cam.set_title('Vision CAM on T1CE')
    ax_vision_cam.axis('off')

    modality_cam = result['modality_cam']
    modality_cam = modality_cam[1]
    modality_cam = modality_cam[select_indices]
    modality_cam = np.transpose(modality_cam, (2, 0, 1))
    modality_cam = modality_cam[..., 0]
    modality_cam = np.median(modality_cam, axis=-1)
    
    ax_bar.bar(full_feature_names, modality_cam)
    ax_bar.set_title('Modality CAM')
    ax_bar.set_ylabel('Importance')
    ax_bar.tick_params(axis='x', rotation=45)

    output = result['output']
    event = result['event']
    duration = result['duration']

    patient_outputs = torch.from_numpy(output).unsqueeze(0)

    estimates = [survival_function(patient_outputs, t) for t in new_times]
    prob_at_duration = estimates[torch.argmin(torch.abs(new_times - duration))].item()

    ax_survival.plot(new_times, estimates)
    ax_survival.scatter(duration, prob_at_duration, color='red')
    ax_survival.text(duration, prob_at_duration, r"S($t$={:.1f})={:.0f}%".format(duration, prob_at_duration*100),
                        color='black', va='bottom', ha='left', fontsize=9, fontweight='bold')
    ax_survival.set_title('Survival Probability')
    ax_survival.set_xlabel("Time (months)")
    ax_survival.set_ylabel("Survival Probability", labelpad=18, fontsize=10, weight='bold')
    ax_survival.grid(True, which='major', linewidth=1.5, color='#F0F0F0')

    plt.tight_layout()
    plt.show()
    
    break
    