In [None]:
import torch
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt
from etils import epath

from codebase.custom_metrics import monai_metrics
import codebase.codebase_settings as cbs
from codebase import terminology as term
from codebase.projects.hecktor2022.evaluation import subvolume_evaluation
from codebase.projects.hecktor2022.evaluation import image_evaluation

%load_ext autoreload
%autoreload 2

<h3> Common Settings </h3>

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [None]:
def comparison_plot(input_data: torch.Tensor, label_data: torch.Tensor,
                    prediction: torch.Tensor, channel: int, nslice: int):
    vmin = 0
    vmax = 2
    fig, axes = plt.subplots(1, 4, num=1,clear=True, figsize=(12, 3))
    axes[0].imshow(input_data[0, 0, ..., nslice].cpu().numpy())
    title1 = axes[0].secondary_xaxis('bottom')
    title1 = axes[0].set_title('CT', pad=10, fontsize=12, ha='center')  # Set the title and adjust the spacing
    axes[1].imshow(input_data[0, 1, ..., nslice].cpu().numpy())
    title2 = axes[1].set_title('PET', pad=10, verticalalignment='bottom')
    axes[2].imshow(label_data[0, channel, ..., nslice].cpu().numpy(), cmap='viridis', vmin=vmin, vmax=vmax)
    title3 = axes[2].set_title(f'Label-ch{channel}', pad=10, verticalalignment='bottom')
    axes[3].imshow(prediction[0, channel, ..., nslice].cpu().numpy(), cmap='viridis', vmin=vmin, vmax=vmax)
    title4 = axes[3].set_title(f'Prediction-ch{channel}', pad=10, verticalalignment='bottom')
    plt.tight_layout()
    plt.show()

In [None]:
# checkpoint_path = '/workspace/codebase/preprocessor/images/test_data/processed_256x256/subvolume_32/experiments/hecktor_test/version_1/checkpoints/checkpoint-epoch=79-val_loss=0.93.ckpt'
# checkpoint_path = '/workspace/data/hecktor2022/processed_256x256/subvolume_32/experiments/hecktor_exp070923_segresnet/version_1/checkpoints/checkpoint-epoch=79-val_loss=0.59.ckpt'
checkpoint_path = '/workspace/data/hecktor2022/processed_256x256/subvolume_32/experiments/hecktor_exp071823_segresnet/version_1/checkpoints/checkpoint-epoch=25-val_loss=0.28.ckpt'
# checkpoint_path = '/workspace/data/hecktor2022/processed_256x256/subvolume_32/experiments/hecktor_exp070423_segresent/version_1/checkpoints/checkpoint-epoch=63-val_loss=0.12.ckpt'
# checkpoint_path = '/workspace/data/hecktor2022/processed_256x256/subvolume_32/experiments/hecktor_exp062324_segresent/version_1/checkpoints/checkpoint-epoch=31-val_loss=0.13.ckpt'
# checkpoint_path = '/workspace/data/hecktor2022/processed_128x128/subvolume_32/set1/experiments/hecktor_exp061323/gfd_1class_best_model.pth'
# config_file = cbs.CODEBASE_PATH / 'projects' / 'hecktor2022' / 'experiments' / 'test_config.yml'
config_file = cbs.CODEBASE_PATH / 'projects' / 'hecktor2022' / 'experiments' / 'experiment_config.yml'

<h3> Subvolume test </h3>

In [None]:
sve_module = subvolume_evaluation.SubVolumeEvaluationModule(
    checkpoint_path=epath.Path(checkpoint_path),
    exp_config=config_file,
    phase=term.Phase.TEST
)

In [None]:
# subvolume_id = 'MDA-103_34'
subvolume_id = 'MDA-195_74'
images, prediction, label, dice = sve_module.evaluate_an_example(subvolume_id)

In [None]:
sve_module.comparison_plot(images, label, prediction, channel=0, nslice=0)

In [None]:
sve_module.comparison_plot(images, label, prediction, channel=1, nslice=0)

In [None]:
sve_module.comparison_plot(images, label, prediction, channel=2, nslice=5)

In [None]:
all_dices = sve_module.run_cohort_test()

In [None]:
_ = plt.hist(all_dices[:, 1], bins=100)

<h3> Whole image test </h3>

In [None]:
data_path = '/workspace/codebase/preprocessor/images/test_data'
data_path = cbs.DATA_PATH / 'hecktor2022'

In [None]:
ie_module = image_evaluation.ImageEvaluationModule(
    checkpoint_path=epath.Path(checkpoint_path),
    exp_config=config_file,
    data_path=epath.Path(data_path),
    phase=term.Phase.TEST,
    subvolume_size=(256, 256, 32),
    modalities=[term.Modality.CT, term.Modality.PET],
    reference_modality=term.Modality.CT,
    key_word='CT.nii.gz'
)

In [None]:
id = 'CHUM-024'
id = 'CHUP-017'
id = 'CHUV-036'

In [None]:

ie_module.cohort_predict([id])

In [None]:
subject = ie_module.get_prediction_label_pair(id=id, load_images=True)

In [None]:
print(subject['PT'].shape, subject['PT'].spacing)
print(subject['CT'].shape, subject['CT'].spacing)
print(subject['LABEL'].shape, subject['LABEL'].spacing)
print(subject['PREDICT'].shape, subject['PREDICT'].spacing)

In [None]:
ie_module.comparison_plot(subject, nslice=150)

In [None]:
ie_module.calculate_dice(ids=[id])

<h3> Cohort Test </h3>

In [None]:
test_folder = cbs.DATA_PATH / 'hecktor2022'

In [None]:
ie_module = image_evaluation.ImageEvaluationModule(
    checkpoint_path=epath.Path(checkpoint_path),
    exp_config=config_file,
    data_path=epath.Path(test_folder),
    phase=term.Phase.TEST,
    subvolume_size=(256, 256, 32),
    modalities=[term.Modality.CT, term.Modality.PET],
    reference_modality=term.Modality.CT,
    key_word='CT.nii.gz'
)

In [None]:
test_patients = ie_module.get_patient_lists()
len(test_patients)

In [None]:
ie_module.cohort_predict(test_patients)

In [None]:
dices = ie_module.calculate_dice(ids=test_patients)

In [None]:
print(np.mean(dices[:, 0]))
print(np.mean(dices[:, 1]))
_ = plt.hist(dices[:, 1], bins=100)

In [None]:
import torchio as tio


In [None]:
image_file = cbs.DATA_PATH / 'hecktor2022' / 'images' / 'MDA-190__CT.nii.gz'
label_file = cbs.DATA_PATH / 'hecktor2022' / 'labels' / 'MDA-190.nii.gz'

In [None]:
image = tio.ScalarImage(image_file)
print(image.spatial_shape)
label = tio.LabelMap(label_file)
print(label.spatial_shape)

In [None]:
print(image.spacing)
print(label.spacing)