In [None]:
import numpy as np
import torchio as tio
import matplotlib.pyplot as plt

import codebase.codebase_settings as cbs
from codebase.projects.hecktor2022 import read_config

In [None]:
def comparison_plot(images, nslice: int):
    nimages = len(images)
    fig, axes = plt.subplots(1, nimages, num=1, clear=True, figsize=(nimages * 4, nimages))
    for i in range(nimages):
        axes[i].imshow(images[i][0, :, :, nslice])
    plt.tight_layout()
    plt.show()

In [None]:
raw_data_path = cbs.CODEBASE_PATH / 'preprocessor' / 'images' / 'test_data' 
# raw_data_path = cbs.DATA_PATH / 'hecktor2022'
id = 'CHUP-052'
# id = 'CHUP-017'
# id = 'CHUP-028'

In [None]:
raw_ct = tio.ScalarImage(raw_data_path / 'images' / (id + '__CT.nii.gz'))
print(raw_ct.shape, raw_ct.spacing)
raw_pt = tio.ScalarImage(raw_data_path / 'images' / (id + '__PT.nii.gz'))
print(raw_pt.shape, raw_pt.spacing)
raw_lb = tio.LabelMap(raw_data_path / 'labels' / (id + '.nii.gz'))
print(raw_lb.shape, raw_lb.spacing)

In [None]:
comparison_plot([raw_ct.data, raw_pt.data, raw_lb.data], nslice=150)

In [None]:
data_path = cbs.CODEBASE_PATH / 'preprocessor' / 'images' / 'test_data' / 'processed_256x256' / 'train'
# data_path = cbs.DATA_PATH / 'hecktor2022' / 'processed_256x256' / 'train'

In [None]:
ct = tio.ScalarImage(data_path / 'images' / (id + '__CT.nii.gz'))
print(ct.shape, ct.spacing)
pt = tio.ScalarImage(data_path / 'images' / (id + '__PT.nii.gz'))
print(pt.shape, pt.spacing)
lb = tio.LabelMap(data_path / 'labels' / (id + '.nii.gz'))
print(lb.shape, lb.spacing)

In [None]:
comparison_plot([ct.data, pt.data, lb.data], nslice=75)

In [None]:
subvolume_data_path = cbs.CODEBASE_PATH / 'preprocessor' / 'images' / 'test_data' / 'processed_256x256' / 'subvolume_32' / 'train'
subvolume_data_path = cbs.DATA_PATH / 'hecktor2022' / 'processed_256x256' / 'subvolume_32' / 'train'

In [None]:
sub_id = '_117'
sub_input = np.load(subvolume_data_path / 'images' / (id + sub_id + '__input.npy'))
sub_ct = sub_input[0:1, ...]
print(sub_ct.shape)
sub_pt = sub_input[1:, ...]
print(sub_pt.shape)
sub_lb = np.load(subvolume_data_path / 'labels' / (id + sub_id + '__label.npy'))
print(sub_lb.shape)

In [None]:
comparison_plot([sub_ct, sub_pt, sub_lb], nslice=25)

<h3> Dataloader </h3>

In [None]:
from codebase import terminology as term
from codebase.dataloader.images import data_module
from monai.utils import first

In [None]:
config_file = cbs.CODEBASE_PATH / 'projects' / 'hecktor2022' / 'experiments' / 'test_config.yml'
config = read_config.read_experiment_config(config_file)

In [None]:
mdata = data_module.MedicalImageDataModule(
        task_type=term.ProblemType.SEGMENTATION,
        config=config,
    )

In [None]:
mdata.prepare_data()
mdata.setup()
train_dataloader = mdata.train_dataloader()

In [None]:
# batch = next(iter(train_dataloader))
batch = first(train_dataloader)

In [None]:
# print(batch)
print(batch['input'].shape)

In [None]:
valid_dataloader = mdata.val_dataloader()

In [None]:
# batch = next(iter(valid_dataloader))
batch = first(valid_dataloader)

In [None]:
print(batch['label'].shape)

In [None]:
from monai.data.utils import decollate_batch, _non_zipping_check
from monai.transforms import AsDiscrete
from collections.abc import Iterable
import torch

In [None]:
x = torch.rand((1, 3, 256, 256, 155))
pred_onehot = AsDiscrete(argmax=True, to_onehot=3)
x.shape

In [None]:
_non_zipping_check(x, detach=True, pad=True, fill_value=None)

In [None]:
assert isinstance(x, Iterable)

In [None]:
# val_outputs = [pred_onehot(i) for i in decollate_batch(x)]
val_outputs = [i for i in x]

In [None]:
len(val_outputs)

In [None]:
val_outputs[0].shape