In [None]:
from etils import epath
import numpy as np
from matplotlib import pyplot as plt
import torch
import torchio as tio
from torchsummary import summary
import plotly.express as px

from codebase.preprocessor.images import multi_modal_processor
from codebase.dataloader.images import multi_modal_dataloader
import codebase.terminology as term
import codebase.codebase_settings as cbs

%load_ext autoreload
%autoreload 2

<h3> Preprossing data </h3>

In [None]:
data_folder = cbs.CODEBASE_PATH / 'preprocessor' / 'images' / 'test_data'
hecktor_processor_train = multi_modal_processor.MultiModalProcessor(
    data_folder=data_folder, phase=term.Phase.TRAIN, modalities=[term.Modality.CT, term.Modality.PET],
    reference=term.Modality.PET, problem_type=term.ProblemType.SEGMENTATION)

<h> Original data </h>

In [None]:
subject = hecktor_processor_train.create_subject('CHUM-024')
subject

In [None]:
print(subject['CT'].shape)
subject['CT'].plot()

In [None]:
print(subject['PT'].shape)
subject['PT'].plot()

In [None]:
print(subject['LABEL'].shape)
print(subject['LABEL'].data[1, ...].max())
print(subject['LABEL'].data[2, ...].max())
subject['LABEL'].plot()

<h> Resample to reference (PET) </h>

In [None]:
resampled_subject = hecktor_processor.resample_to_reference(subject=subject, xy_size=(128, 128))

In [None]:
print(resampled_subject['CT'].shape)
print(resampled_subject['CT'].spacing)
print(resampled_subject['PT'].spacing)
print(resampled_subject['LABEL'].spacing)
resampled_subject['CT'].plot()

<h> Apply transformation: normalization and augmentation:

CT: Clamp + Intensity rescale

PET: Histogram Standardization + ZNormalization </h> 

In [None]:
normalization = hecktor_processor.create_normalization()
normalized_subject = normalization(resampled_subject)

In [None]:
ct_data = normalized_subject['CT'].numpy()
print(f'max: {np.max(ct_data)}')
print(f'min: {np.min(ct_data)}')
normalized_subject['CT'].plot()

<h> Process and save data </h>

In [None]:
n = hecktor_processor.preprocess_and_save(xy_size=(128, 128), weight_modality=term.Modality.PET, weight_threshold=0.5)

In [None]:
hecktor_processor_valid = multi_modal_processor.MultiModalProcessor(
    data_folder=data_folder, phase=term.Phase.VALID, modalities=[term.Modality.CT, term.Modality.PET],
    reference=term.Modality.PET, problem_type=term.ProblemType.SEGMENTATION)
n = hecktor_processor_valid.preprocess_and_save(xy_size=(128, 128), weight_modality=term.Modality.PET, weight_threshold=0.5)

In [None]:
processed_data_path = cbs.CODEBASE_PATH / 'preprocessor' / 'images' / 'test_data' / 'processed_128x128'
hecktor_loader = multi_modal_dataloader.MultiModalDataLoader(data_folder=processed_data_path, phase=term.Phase.TRAIN,
                                                             modalities=[term.Modality.CT, term.Modality.PET], problem_type=term.ProblemType.SEGMENTATION)

In [None]:
processed_subject = hecktor_loader.create_subject(patient='HGJ-080')
processed_subject

In [None]:
processed_subject['LABEL'].shape

In [None]:
processed_subject['WEIGHT'].plot()

<h> Data Augmentation </h>

In [None]:
transform_dict = {'flip': {'p': 1.0, 'axes': ('LR', 'AP')}}
transformation = hecktor_loader.create_augmentation(transform_keys=transform_dict)
final_subject = transformation(processed_subject)

In [None]:
final_subject['CT'].plot()

<h> Create dataset and dataloader </h>

In [None]:
subjects = hecktor_loader.create_subject_list()
subject_dataset = hecktor_loader.create_subject_dataset(subjects=subjects, augmentation=transformation)

In [None]:
subjects[0]

In [None]:
print(subjects[0].ID)
print(subject_dataset[0].ID)
subject_dataset[0].check_consistent_attribute('spacing')
subjects[0]

In [None]:
patch_size = (128, 128, 32)
sampler = tio.data.WeightedSampler(patch_size=patch_size, probability_map='WEIGHT')

batch_size = 2
num_workers = 1

train_dataloader = hecktor_loader.create_patch_dataloader(
    subject_dataset=subject_dataset,
    max_queue_length=32,
    samples_per_volume=4,
    sampler=sampler,
    batch_size=batch_size,
    num_workers=num_workers
)

In [None]:
batch = next(iter(train_dataloader))

In [None]:
batch['CT'][tio.DATA].shape

In [None]:
batch['LABEL'][tio.DATA].shape

In [None]:
new_input = torch.cat([batch['CT'][tio.DATA], batch['PT'][tio.DATA]], dim=1)
new_input.shape

In [None]:
new_label = batch['LABEL'][tio.DATA][:, 1:, ...]
print(new_label.shape)
print(new_label[:, 0, ...].max())
print(new_label[:, 1, ...].max())
new_label

<h3> Subvolume generation </h3>

In [None]:
hecktor_processor_train.create_and_save_subvolumes(data_path=data_folder / 'processed_128x128', 
                                             valid_channel=[1, 2], subvolume_intervel=4, subvolume_size=32)

In [None]:
image = np.load(str(data_folder / 'processed_128x128/subvolume_32/train/images/CHUM-024_38__input.npy'))
label = np.load(str(data_folder / 'processed_128x128/subvolume_32/train/labels/CHUM-024_38__label.npy'))

In [None]:
print(image.shape, label.shape)

In [None]:
# all_imgs = [label.data[0, :, :, :].numpy(), label.data[1, :, :, :].numpy(), label.data[2, :, :, :].numpy(),]
all_imgs = np.swapaxes(image, 1, 3)
px.imshow(
    all_imgs,
    # zmin=[0, 0, 0],
    # zmax=[2000, 2000, 2000],
    animation_frame=1,
    # binary_string=gray_scale,
    labels={'animation_frame': 'slice'},
    facet_col=0,
    color_continuous_scale='Gray',
    width=500*3, height=500
)

In [None]:
hecktor_processor_valid = multi_modal_processor.MultiModalProcessor(
    data_folder=data_folder, phase=term.Phase.VALID, modalities=[term.Modality.CT, term.Modality.PET],
    reference=term.Modality.PET, problem_type=term.ProblemType.SEGMENTATION)
hecktor_processor_valid.create_and_save_subvolumes(data_path=data_folder / 'processed_128x128', 
                                             valid_channel=[1, 2], subvolume_intervel=4, subvolume_size=32)

<h> Check subvolume Dataloader </h>

In [None]:
from dataloader.images import subvolume_dataloader

In [None]:
transform_dict = {'flip': {'p': 0.5, 'axes': ('LR', 'AP')},
                   # ration range has to consider whether the channel exist or not
                   # because the transform assues no channels
                   'rotate': {'radians': [0, 0.5, 0.5], 'p': 0.8},
                   'affine': {'p': 0.5, 'degrees': 0.5, 'translation': 0.3}}
print(data_folder)
loader_processor_train = subvolume_dataloader.ProcessedSubVolumeDataLoader(data_folder=(data_folder / 'processed_128x128' / 'subvolume_32'),
                                                           phase=term.Phase.TRAIN, batch_size=2, transform_dict=transform_dict,
                                                           num_workders=2)

In [None]:
loader_processor_valid = subvolume_dataloader.ProcessedSubVolumeDataLoader(data_folder=(data_folder / 'processed_128x128' / 'subvolume_32'),
                                                           phase=term.Phase.VALID, batch_size=2, transform_dict=transform_dict,
                                                           num_workders=2)

In [None]:
loader = loader_processor_train.get_dataloader()

In [None]:
batch = next(iter(loader))

In [None]:
image = batch['label'][0, ...]
print(image.shape)

In [None]:
batch['input'].shape

In [None]:
filename = 'patient_stats.csv'
columns = ['ID', 'GTVp volume', 'GTVn volume']
hecktor_processor_train.calculate_volumes(data_path=data_folder / 'processed_128x128', output_file=filename, channels=[1, 2], column_names=columns)

<h> Modeling debugging </h>

In [None]:
from projects.hecktor2022.trainers import hecktor_trainer

config_file = cbs.CODEBASE_PATH / 'projects' / 'hecktor2022' / 'experiments' / 'test_config.yml'
trainer = hecktor_trainer.Trainer(str(config_file))

In [None]:
features, label = trainer.prepare_subvolume_batch(batch)

In [None]:
print(features.shape, label.shape)

In [None]:
image = label[0].cpu().numpy()
image.shape

In [None]:
prediction = trainer.model(features)

In [None]:
prediction.shape

In [None]:
image = prediction[0].detach().cpu().numpy()
image.shape

In [None]:
loss = trainer.loss(prediction, label)

In [None]:
loss