In [1]:
import json
import os
import sys

import itkwidgets
import monai
import torch

In [2]:
current_path = os.getcwd()
if os.path.basename(current_path) == "notebooks":
    parent_path = os.path.dirname(current_path)
    os.chdir(parent_path)
    src_path = os.path.join(parent_path, "src")
    sys.path.append(src_path)

In [3]:
import config_base_model as config
import models
import utils

In [4]:
monai.utils.set_determinism(config.RANDOM_STATE)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_workers = 4 if device.type == "cuda" else 0
pin_memory = True if device.type == "cuda" else False
print(f"Using {device} device")

Using cuda device


In [6]:
model_list = []
for fold in range(config.FOLDS):
    checkpoint_path = os.path.join(
        config.checkpoint_dir,
        f"{config.MODEL_NAME}_fold{fold}.tar"
    )
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model = models.ProbabilisticSegmentationNet(
        **config.MODEL_KWARGS_AB2C
    ).to(device)
    model.load_state_dict(checkpoint["net_AB2C_state_dict"])
    model.eval()
    model_list.append(model)

In [7]:
data_path = os.path.join(config.data_dir, config.DATA_FILENAME)
with open(data_path, "r") as data_file:
    data = json.load(data_file)
dataset = monai.data.Dataset(
    data["test"],
    monai.transforms.Compose([
        config.base_transforms,
        config.eval_transforms
    ])
)
print(f"Using {len(dataset)} test samples")

dataloader = monai.data.DataLoader(
    dataset,
    batch_size=1,
    num_workers=num_workers,
    pin_memory=pin_memory
)

Using 8 test samples


In [8]:
BATCH_IDX = 0

for batch_idx_counter, batch in enumerate(dataloader):
    if batch_idx_counter == BATCH_IDX:
        break

imgs = batch["imgs_AB"].to(device)
seg = batch["seg_C"]
seg = torch.argmax(seg, dim=1).squeeze(0)

patient_name = utils.get_patient_name(
    batch["seg_C_meta_dict"]["filename_or_obj"][0]
)
print(patient_name)

Patient-006


In [9]:
with torch.no_grad():
    with torch.cuda.amp.autocast():
        preds = [
            monai.inferers.sliding_window_inference(
                imgs,
                roi_size=config.PATCH_SIZE,
                sw_batch_size=config.BATCH_SIZE,
                predictor=model
            )
            for model in model_list
        ]
preds = [torch.nn.functional.softmax(pred, dim=1) for pred in preds]
preds = torch.cat(preds, dim=0)
pred = torch.mean(preds, dim=0, keepdim=True)
pred = torch.argmax(pred, dim=1).squeeze(0).cpu()

In [10]:
itkwidgets.compare(
    pred,
    seg,
    background=(1.0, 1.0, 1.0)
)

AppLayout(children=(HBox(children=(Label(value='Link:'), Checkbox(value=False, description='cmap'), Checkbox(v…

In [11]:
itkwidgets.view(
    pred,
    seg,
    background=(1.0, 1.0, 1.0)
)

Viewer(background=(1.0, 1.0, 1.0), geometries=[], gradient_opacity=0.22, interpolation=False, point_sets=[], r…