In [25]:
import logging
import os
import sys
import tempfile
from glob import glob

import nibabel as nib
import numpy as np
import torch

from monai.config import print_config
from monai.data import Dataset, DataLoader, create_test_image_3d, decollate_batch
from monai.inferers import sliding_window_inference
from monai.networks.nets import UNETR

from monai.transforms import (
    Activationsd,
    AsDiscreted,
    Compose,
    EnsureChannelFirstd,
    Invertd,
    LoadImaged,
    Orientationd,
    SaveImaged,
    Spacingd,
    CropForegroundd,
    EnsureTyped,
    ScaleIntensityRanged,
    ToTensord
)


In [26]:
tempdir = "/home/sara/MONAI/KOMP/data/imagesTs"
images = sorted(glob(os.path.join(tempdir, "A*.nii.gz")))
files = [{"image": image} for image in images]

In [27]:
# define pre transforms
pre_transforms = Compose([
    LoadImaged(keys="image"),
    Spacingd(keys=["image"], pixdim=(
            1, 1, 1), mode="bilinear"),
    EnsureChannelFirstd(keys="image"),
    Orientationd(keys="image", axcodes="RAS"),
    ScaleIntensityRanged(
            keys=["image"], a_min=-175, a_max=250, 
            b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image"], source_key="image"),
    ToTensord(keys=["image"]),
    ])
    
# define post transforms
post_transforms = Compose([
    EnsureTyped(keys="pred"),
    Activationsd(keys="pred", sigmoid=True),
    Invertd(
        keys="pred",
        transform=pre_transforms,
        orig_keys="image",
        meta_keys="pred_meta_dict", 
        orig_meta_keys="image_meta_dict",
        meta_key_postfix="meta_dict",  
        nearest_interp=False, 
        to_tensor=True,  
    ),
    AsDiscreted(keys="pred", threshold=0.5),
    SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False),
])

In [28]:
# define dataset and dataloader
dataset = Dataset(data=files, transform=pre_transforms)
dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
torch.set_num_threads(24)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
root_dir = "/home/sara/MONAI/KOMP"

net = UNETR(
    in_channels=1,
    out_channels=51,
    img_size=(96, 96, 96),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    pos_embed="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)

net.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
net.eval()

with torch.no_grad():
    for d in dataloader:
        images = d["image"].to(device)
        # define sliding window size and batch size for windows inference
        d["pred"] = sliding_window_inference(inputs=images, roi_size=(96, 96, 96), sw_batch_size=4, predictor=net, 
            overlap=0.8)
        # decollate the batch data into a list of dictionaries, then execute postprocessing transforms
        d = [post_transforms(i) for i in decollate_batch(d)]


In [None]:
plt.imshow(
    torch.argmax(d, dim=1).detach().cpu()[0, :, :, slice_map[img_name]]
    )
plt.show()