In [1]:
import json
import os
import sys

import itkwidgets
import monai
import torch

In [2]:
current_path = os.getcwd()
if os.path.basename(current_path) == "notebooks":
    parent_path = os.path.dirname(current_path)
    os.chdir(parent_path)
    src_path = os.path.join(parent_path, "src")
    sys.path.append(src_path)

In [3]:
import config_base_model as config
import utils

In [4]:
monai.utils.set_determinism(config.RANDOM_STATE)

set prob = 1.0 for all non-deterministic transforms to visualize

In [5]:
for transform in config.train_transforms.transforms:
    if hasattr(transform, "prob"):
        transform.prob = 1.0
transforms = monai.transforms.Compose([
    config.base_transforms,
    config.train_transforms
])

In [6]:
data_path = os.path.join(config.data_dir, config.DATA_FILENAME)
with open(data_path, "r") as data_file:
    data = json.load(data_file)
dataset = monai.data.Dataset(data["train"], transforms)
dataloader = monai.data.DataLoader(dataset, batch_size=1)

Note: `monai.transforms.RandCropByPosNegLabel` returns a list of patches for each input img/seg pair

In [7]:
BATCH_IDX = 0
CHANNEL_IDX = 0  # 0: ct1, 1: flair, 2: t1, 3: t2

for batch_idx_counter, batch in enumerate(dataloader):
    if batch_idx_counter == BATCH_IDX:
        break

imgs_A = batch["imgs_AB"].squeeze(0)[CHANNEL_IDX]
imgs_B = batch["imgs_AB"].squeeze(0)[CHANNEL_IDX+config.num_sequences]
imgs_C = batch["imgs_C"].squeeze(0)[CHANNEL_IDX]
seg = batch["seg_C"].squeeze(0)
seg = torch.argmax(seg, dim=0)

patient_name = utils.get_patient_name(
    batch["seg_C_meta_dict"]["filename_or_obj"][0]
)
print(patient_name)

Patient-001


In [8]:
itkwidgets.view(
    imgs_A,
    background=(1.0, 1.0, 1.0)
)

Viewer(background=(1.0, 1.0, 1.0), geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.it…

In [9]:
itkwidgets.view(
    imgs_B,
    background=(1.0, 1.0, 1.0)
)

Viewer(background=(1.0, 1.0, 1.0), geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.it…

In [10]:
itkwidgets.view(
    imgs_C,
    seg,
    background=(1.0, 1.0, 1.0)
)

Viewer(background=(1.0, 1.0, 1.0), geometries=[], gradient_opacity=0.22, interpolation=False, point_sets=[], r…