In [None]:
import torch
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt
from etils import epath

from codebase.custom_metrics import monai_metrics
import codebase.codebase_settings as cbs
from codebase import terminology as term
from codebase.projects.hecktor2022.evaluation import subvolume_evaluation
from codebase.projects.hecktor2022.evaluation import image_evaluation

%load_ext autoreload
%autoreload 2

<h3> Common Settings </h3>

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

In [None]:
def comparison_plot(input_data: torch.Tensor, label_data: torch.Tensor,
                    prediction: torch.Tensor, channel: int, nslice: int):
    vmin = 0
    vmax = 2
    fig, axes = plt.subplots(1, 4, num=1,clear=True, figsize=(12, 3))
    axes[0].imshow(input_data[0, 0, ..., nslice].cpu().numpy())
    title1 = axes[0].secondary_xaxis('bottom')
    title1 = axes[0].set_title('CT', pad=10, fontsize=12, ha='center')  # Set the title and adjust the spacing
    axes[1].imshow(input_data[0, 1, ..., nslice].cpu().numpy())
    title2 = axes[1].set_title('PET', pad=10, verticalalignment='bottom')
    axes[2].imshow(label_data[0, channel, ..., nslice].cpu().numpy(), cmap='viridis', vmin=vmin, vmax=vmax)
    title3 = axes[2].set_title(f'Label-ch{channel}', pad=10, verticalalignment='bottom')
    axes[3].imshow(prediction[0, channel, ..., nslice].cpu().numpy(), cmap='viridis', vmin=vmin, vmax=vmax)
    title4 = axes[3].set_title(f'Prediction-ch{channel}', pad=10, verticalalignment='bottom')
    plt.tight_layout()
    plt.show()

In [None]:
# checkpoint_path = '/workspace/codebase/preprocessor/images/test_data/processed_256x256/subvolume_32/experiments/hecktor_test/version_1/checkpoints/checkpoint-epoch=79-val_loss=0.93.ckpt'
checkpoint_path = '/workspace/data/hecktor2022/processed_256x256/subvolume_32/experiments/hecktor_exp070923_segresnet/version_1/checkpoints/checkpoint-epoch=79-val_loss=0.59.ckpt'
# checkpoint_path = '/workspace/data/hecktor2022/processed_256x256/subvolume_32/experiments/hecktor_exp070423_segresent/version_1/checkpoints/checkpoint-epoch=63-val_loss=0.12.ckpt'
# checkpoint_path = '/workspace/data/hecktor2022/processed_256x256/subvolume_32/experiments/hecktor_exp062324_segresent/version_1/checkpoints/checkpoint-epoch=31-val_loss=0.13.ckpt'
# checkpoint_path = '/workspace/data/hecktor2022/processed_128x128/subvolume_32/set1/experiments/hecktor_exp061323/gfd_1class_best_model.pth'
# config_file = cbs.CODEBASE_PATH / 'projects' / 'hecktor2022' / 'experiments' / 'test_config.yml'
config_file = cbs.CODEBASE_PATH / 'projects' / 'hecktor2022' / 'experiments' / 'experiment_config.yml'

<h3> Subvolume test </h3>

In [None]:
sve_module = subvolume_evaluation.SubVolumeEvaluationModule(
    checkpoint_path=epath.Path(checkpoint_path),
    exp_config=config_file,
    phase=term.Phase.TEST
)

In [None]:
# subvolume_id = 'MDA-103_34'
subvolume_id = 'MDA-195_74'
images, prediction, label, dice = sve_module.evaluate_an_example(subvolume_id)

In [None]:
sve_module.comparison_plot(images, label, prediction, channel=0, nslice=2)

In [None]:
sve_module.comparison_plot(images, label, prediction, channel=2, nslice=2)

In [None]:
sve_module.comparison_plot(images, label, prediction, channel=2, nslice=28)

In [None]:
all_dices = sve_module.run_cohort_test()

In [None]:
_ = plt.hist(all_dices[:, 0], bins=100)

<h3> Whole image test </h3>

In [None]:
data_path = '/workspace/codebase/preprocessor/images/test_data'

In [None]:
ie_module = image_evaluation.ImageEvaluationModule(
    checkpoint_path=epath.Path(checkpoint_path),
    exp_config=config_file,
    data_path=epath.Path(data_path),
    phase=term.Phase.TEST,
    subvolume_size=(256, 256, 32),
    modalities=[term.Modality.CT, term.Modality.PET],
    reference_modality=term.Modality.CT,
    key_word='CT.nii.gz'
)

In [None]:
ie_module.cohort_predict(['CHUM-024'])

In [None]:
subject = ie_module.get_prediction_label_pair(id='CHUM-024')
print(subject['LABEL'].shape)

In [None]:
subject

In [None]:
ie_module.comparison_plot(subject, nslice=136)

In [None]:
ie_module.calculate_dice(ids=['CHUM-024'])

In [None]:
data_folder = cbs.CODEBASE_PATH / 'preprocessor' / 'images' / 'test_data'
# image = torch.Tensor(np.load(str(data_folder / 'processed_128x128/subvolume_32/train/images/CHUM-024_38__input.npy')))
# label = torch.Tensor(np.load(str(data_folder / 'processed_128x128/subvolume_32/train/labels/CHUM-024_38__label.npy')))
# image = torch.Tensor(np.load(str(data_folder / 'processed_128x128/subvolume_32/train/images/CHUM-024_25__input.npy')))
# label = torch.Tensor(np.load(str(data_folder / 'processed_128x128/subvolume_32/train/labels/CHUM-024_25__label.npy')))
image = torch.Tensor(np.load('/workspace/data/hecktor2022/processed_128x128/subvolume_32/valid/images/CHUV-008_239__input.npy'))
label = torch.Tensor(np.load('/workspace/data/hecktor2022/processed_128x128/subvolume_32/valid/labels/CHUV-008_239__label.npy'))
batch = {'input': image[None, ...], 'label': label[None, ...]}
features, targets = prepare_subvolume_batch(batch)

In [None]:
# prediction = model(image[None].to(device))
print(features.shape, targets.shape)
prediction = model(features)
prediction = torch.sigmoid(prediction)
prediction = (prediction > 0.5).int()
print(prediction.shape)

In [None]:
dice_value = metrics([prediction], [targets])
dice_value

In [None]:
prediction = prediction[0].detach().cpu().numpy()
targets = targets[0].cpu().numpy()
inputs = features[0].cpu().numpy()
ct = inputs[0, ...]
pet = inputs[1, ...]
# prediction = image.numpy()
prediction.shape

In [None]:
# all_imgs = [label.data[0, :, :, :].numpy(), label.data[1, :, :, :].numpy(), label.data[2, :, :, :].numpy(),]
# all_imgs = np.swapaxes(prediction, 1, 3)
all_imgs = np.array([np.swapaxes(pet, 0, 2), np.swapaxes(prediction[0], 0, 2), np.swapaxes(targets[0], 0, 2)])
px.imshow(
    all_imgs,
    # zmin=[0, 0, 0],
    # zmax=[2000, 2000, 2000],
    animation_frame=1,
    # binary_string=gray_scale,
    labels={'animation_frame': 'slice'},
    facet_col=0,
    color_continuous_scale='Gray',
    width=500*3, height=500
)

<h2> Cohort Evaluation </h2>

In [None]:
# checkpoint_path = '/workspace/codebase/preprocessor/images/test_data/processed_256x256/subvolume_32/experiments/hecktor_test/version_1/checkpoints/checkpoint-epoch=03-val_loss=0.21.ckpt'
# checkpoint_path = '/workspace/data/hecktor2022/processed_128x128/subvolume_32/set1/experiments/hecktor_exp061323/generalized_focal_dice_best_model.pth'
# checkpoint_path = '/workspace/data/hecktor2022/processed_128x128/subvolume_32/set1/experiments/hecktor_exp061323/gfd_1class_best_model.pth'
checkpoint_path = '/workspace/data/hecktor2022/processed_256x256/subvolume_32/experiments/hecktor_exp070423_segresent/version_1/checkpoints/checkpoint-epoch=63-val_loss=0.12.ckpt'
data_path = cbs.CODEBASE_PATH / 'preprocessor' / 'images' / 'test_data'
# data_path = cbs.DATA_PATH / 'hecktor2022' / 'processed_128x128' / 'subvolume_32' / 'valid'

In [None]:
ie_module = image_evaluation.ImageEvaluationModule(
    checkpoint_path=epath.Path(checkpoint_path),
    exp_config=config_file,
    data_path=epath.Path(data_path),
    phase=term.Phase.TEST,
    subvolume_size=(256, 256, 32),
    modalities=[term.Modality.CT, term.Modality.PET],
    reference_modality=term.Modality.CT,
    key_word='CT.nii.gz'
)

In [None]:
ie_module.cohort_predict()

In [None]:
id = 'CHUM-024'
results = ie_module.calculate_dice([id])

In [None]:
subject = ie_module.get_prediction_label_pair(id)


In [None]:
np.min(prediction)

In [None]:
from monai.networks.utils import one_hot
import torch.nn.functional as F

In [None]:
input_tensor = torch.randn(1, 3, 2, 2)

probabilites = F.softmax(input_tensor, dim=1)
max_indices = torch.argmax(probabilites, dim=1)
second_tensor = max_indices[:, None, ...]
# Create a one-hot label tensor using torch.eye
num_classes = input_tensor.shape[1]
one_hot_labels = one_hot(second_tensor, num_classes)


In [None]:
input_tensor

In [None]:
probabilites

In [None]:
second_tensor.shape

In [None]:
one_hot_labels.shape