In [1]:
import json
import os
import sys

import itkwidgets
import monai
import torch

In [2]:
current_path = os.getcwd()
if os.path.basename(current_path) == "notebooks":
    parent_path = os.path.dirname(current_path)
    os.chdir(parent_path)
    src_path = os.path.join(parent_path, "src")
    sys.path.append(src_path)

In [3]:
import config
import transforms
import utils

In [4]:
monai.utils.set_determinism(config.RANDOM_STATE)

set prob = 1.0 for all non-deterministic transforms to visualize

In [5]:
base_transforms = transforms.transforms_dict["base_model"]["base_transforms"]
train_transforms = transforms.transforms_dict["base_model"]["train_transforms"]
for transform in train_transforms.transforms:
    if hasattr(transform, "prob"):
        transform.prob = 1.0

In [6]:
data_path = os.path.join(config.data_dir, config.DATA_FILENAME)
with open(data_path, "r") as data_file:
    data = json.load(data_file)
dataset = monai.data.Dataset(
    data["train"],
    monai.transforms.Compose([
        base_transforms,
        train_transforms
    ])
)

Note: `monai.transforms.RandCropByPosNegLabel` returns a list of patches for each input img/seg pair

In [7]:
DATA_POINT_IDX = 0
CHANNEL_IDX = 0  # 0: ct1, 1: flair, 2: t1, 3: t2

data_point = dataset[DATA_POINT_IDX][0]
img_A = data_point["img_AB"][CHANNEL_IDX]
img_B = data_point["img_AB"][CHANNEL_IDX+config.num_sequences]
seg = data_point["seg_C"]
seg = torch.argmax(seg, dim=0)

patient_name = utils.get_patient_name(
    data_point["seg_C_meta_dict"]["filename_or_obj"][0]
)
print(patient_name)




In [8]:
itkwidgets.view(img_A, background=(1.0, 1.0, 1.0))

Viewer(background=(1.0, 1.0, 1.0), geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.it…

In [9]:
itkwidgets.view(img_B, background=(1.0, 1.0, 1.0))

Viewer(background=(1.0, 1.0, 1.0), geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.it…

In [10]:
itkwidgets.view(seg, background=(1.0, 1.0, 1.0))

Viewer(background=(1.0, 1.0, 1.0), geometries=[], gradient_opacity=0.22, point_sets=[], rendered_image=<itk.it…