In [1]:
import json
import os
import sys

import itkwidgets
import monai
import torch

In [2]:
current_path = os.getcwd()
if os.path.basename(current_path) == "notebooks":
    parent_path = os.path.dirname(current_path)
    os.chdir(parent_path)
    src_path = os.path.join(parent_path, "src")
    sys.path.append(src_path)

In [3]:
import config
from probabilistic_unet.model import ProbabilisticUnet
import utils

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_workers = 4 if device.type == "cuda" else 0
pin_memory = True if device.type == "cuda" else False
sw_batch_size = 4 if device.type == "cuda" else 1
print(f"Using {device} device")

In [None]:
checkpoint_list = [
    torch.load(
        os.path.join(config.model_dir, f"{config.MODEL_NAME}_fold{fold}.tar"),
        map_location=device
    )
    for fold in range(config.FOLDS)
]
model_list = [
    ProbabilisticUnet(**config.MODEL_KWARGS_A2B).to(device)
    for _ in range(config.FOLDS)
]
for model, checkpoint in zip(model_list, checkpoint_list):
    model.load_state_dict(checkpoint["net_A2B_state_dict"])
    model.eval()

<All keys matched successfully>

In [None]:
data_path = os.path.join(config.data_dir, config.DATA_FILENAME)
with open(data_path, "r") as data_file:
    data = json.load(data_file)
dataset = monai.data.Dataset(
    data=data["test"],
    transform=monai.transforms.Compose([
        config.base_transforms,
        config.eval_transforms
    ])
)
print(f"Using {len(dataset)} test samples")

dataloader = monai.data.DataLoader(
    dataset=dataset,
    batch_size=1,
    num_workers=num_workers,
    pin_memory=pin_memory
)

Using 8 test samples


In [None]:
BATCH_IDX = 0

for batch_idx_counter, batch in enumerate(dataloader):
    if batch_idx_counter == BATCH_IDX:
        break

images = batch["images_A"].to(device)
label = batch["label"]
label = torch.argmax(label, dim=1).squeeze(0)  # decode one-hot labels
patient_name = utils.get_patient_name(
    batch["label_meta_dict"]["filename_or_obj"][0]
)
print(patient_name)

Patient-006


In [None]:
with torch.no_grad():
    with torch.cuda.amp.autocast():
        preds = [
            monai.inferers.sliding_window_inference(
                inputs=images,
                roi_size=config.PATCH_SIZE,
                sw_batch_size=config.BATCH_SIZE,
                predictor=model
            )
            for model in model_list
        ]
preds = torch.cat(preds, dim=0)
pred = torch.mean(preds, dim=0, keepdim=True)
pred = torch.argmax(pred, dim=1).squeeze(0).cpu()

In [None]:
itkwidgets.compare(
    pred,
    label,
    background=(1.0, 1.0, 1.0)
)

AppLayout(children=(HBox(children=(Label(value='Link:'), Checkbox(value=False, description='cmap'), Checkbox(v…

In [None]:
itkwidgets.view(
    image=pred,
    label_image=label,
    background=(1.0, 1.0, 1.0)
)

Viewer(background=(1.0, 1.0, 1.0), geometries=[], gradient_opacity=0.22, interpolation=False, point_sets=[], r…