In [1]:
import json
import os
import sys

import itkwidgets
import monai

In [2]:
current_path = os.getcwd()
if os.path.basename(current_path) == "notebooks":
    parent_path = os.path.dirname(current_path)
    os.chdir(parent_path)
    src_path = os.path.join(parent_path, "src")
    sys.path.append(src_path)

In [3]:
import config
from utils import get_patient_name

In [4]:
monai.utils.set_determinism(config.RANDOM_STATE)

In [5]:
transforms = monai.transforms.Compose([
    monai.transforms.LoadImaged(
        keys=config.modality_keys_A + config.modality_keys_B + ["label"],
        # image_only=True,
        image_only=False,
        ensure_channel_first=True
    ),
    monai.transforms.ConcatItemsd(
        keys=config.modality_keys_A,
        name="images_A",
        dim=0
    ),
    monai.transforms.DeleteItemsd(keys=config.modality_keys_A),
    monai.transforms.ConcatItemsd(
        keys=config.modality_keys_B,
        name="images_B",
        dim=0
    ),
    monai.transforms.DeleteItemsd(keys=config.modality_keys_B),
    monai.transforms.CropForegroundd(
        keys=["images_A", "images_B", "label"],
        source_key="images_A",
    ),
    monai.transforms.ThresholdIntensityd(
        keys="label",
        threshold=1,
        above=False,
        cval=1
    ),
    monai.transforms.AsDiscreted(keys="label", to_onehot=2),
    monai.transforms.Orientationd(
        keys=["images_A", "images_B", "label"],
        axcodes="SPL",
    ),
    monai.transforms.RandAffined(
        keys=["images_A", "images_B", "label"],
        # prob=0.1,
        prob=1.0,
        rotate_range=0.1,
        scale_range=0.1,
        mode=("bilinear", "bilinear", "nearest")
    ),
    monai.transforms.RandCropByPosNegLabeld(
        keys=["images_A", "images_B", "label"],
        label_key="label",
        spatial_size=config.PATCH_SIZE,
        pos=1,
        neg=1,
        num_samples=1,
    ),
    # images_A and images_B have different number of channels, which leads to
    # an error when processed together by RandGaussianNoised
    monai.transforms.RandGaussianNoised(
        keys=["images_A"],
        # prob=0.1,
        prob=1.0,
        mean=0.0,
        std=0.1
    ),
    monai.transforms.RandGaussianNoised(
        keys=["images_B"],
        # prob=0.1,
        prob=1.0,
        mean=0.0,
        std=0.1
    ),
    monai.transforms.NormalizeIntensityd(
        keys=["images_A", "images_B"],
        channel_wise=True
    )
])

In [6]:
data_path = os.path.join(config.data_dir, config.DATA_FILENAME)
with open(data_path, "r") as data_file:
    data = json.load(data_file)
dataset = monai.data.Dataset(data=data["train"], transform=transforms)
dataloader = monai.data.DataLoader(dataset, batch_size=1)
decode_onehot = monai.transforms.AsDiscrete(argmax=True, keepdim=False)

Note: `monai.transforms.RandCropByPosNegLabel` returns a list of patches for each input image/label pair

In [7]:
BATCH_IDX = 0
CHANNEL_IDX = 0  # 0: ct1, 1: flair, 2: t1, 3: t2

for batch_idx_counter, batch in enumerate(dataloader):
    if batch_idx_counter == BATCH_IDX:
        break

num_channels_per_image = batch["images_A"].shape[1] // 2
    
image_A_1 = batch["images_A"].squeeze(0)[CHANNEL_IDX]
image_A_2 = batch["images_A"].squeeze(0)[CHANNEL_IDX+num_channels_per_image]
image_B = batch["images_B"].squeeze(0)[CHANNEL_IDX]
label_B = decode_onehot(batch["label"].squeeze(0))

patient_name = get_patient_name(
    batch["label_meta_dict"]["filename_or_obj"][0]
)
print(patient_name)

Patient-001


In [8]:
itkwidgets.view(
    image=image_A_1,
    background=(1.0, 1.0, 1.0)
)

Viewer(background=(1.0, 1.0, 1.0), geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.it…

In [9]:
itkwidgets.view(
    image=image_A_2,
    background=(1.0, 1.0, 1.0)
)

Viewer(background=(1.0, 1.0, 1.0), geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.it…

In [10]:
itkwidgets.view(
    image=image_B,
    label_image=label_B,
    background=(1.0, 1.0, 1.0)
)

Viewer(background=(1.0, 1.0, 1.0), geometries=[], gradient_opacity=0.22, interpolation=False, point_sets=[], r…