In [4]:
import os
import glob
import torch
import requests
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, Orientationd, Spacingd, ScaleIntensityRanged, CropForegroundd, Invertd, AsDiscreted, SaveImaged
from monai.data import Dataset, DataLoader

# Define the directory containing the test images
data_dir = "/opt/app-root/src/chbox/data"

# Define the endpoint URL
model_endpoint = "https://spleen-spleen.apps.cluster-4r7z9.sandbox1640.opentlc.com/v2/models/spleen/infer"

# Define the device for inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the transformations for test data
test_org_transforms = Compose([
    LoadImaged(keys="image"),
    EnsureChannelFirstd(keys="image"),
    Orientationd(keys=["image"], axcodes="RAS"),
    Spacingd(keys=["image"], pixdim=(1.5, 1.5, 2.0), mode="bilinear"),
    ScaleIntensityRanged(keys=["image"], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=["image"], source_key="image"),
])

# Define the dataset and data loader for test data
test_images = sorted(glob.glob(os.path.join(data_dir, "imagesTs", "*.nii.gz")))
test_data = [{"image": image} for image in test_images]
test_org_ds = Dataset(data=test_data, transform=test_org_transforms)
test_org_loader = DataLoader(test_org_ds, batch_size=1, num_workers=4)

# Load the model from the endpoint
response = requests.get(model_endpoint)
if response.status_code == 200:
    model_state_dict = response.json()["model_state_dict"]

    # Instantiate your UNet model
    from monai.networks.nets import UNet
    model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=2,  # Assuming you have 2 classes (background and foreground)
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
        norm="batch",
    )

    model.load_state_dict(model_state_dict)
    model.to(device)
    model.eval()

    # Define the post-processing transformations
    post_transforms = Compose([
        Invertd(keys="pred", transform=test_org_transforms, orig_keys="image", meta_keys="pred_meta_dict", orig_meta_keys="image_meta_dict", meta_key_postfix="meta_dict", nearest_interp=False, to_tensor=True),
        AsDiscreted(keys="pred", argmax=True, to_onehot=2),
        SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./out", output_postfix="seg", resample=False),
    ])

    # Perform inference using the model
    with torch.no_grad():
        for test_data in test_org_loader:
            test_inputs = test_data["image"].to(device)
            roi_size = (160, 160, 160)
            sw_batch_size = 4
            test_data["pred"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)

            test_data = [post_transforms(i) for i in decollate_batch(test_data)]
else:
    print(f"Failed to load the model. Status code: {response.status_code}")


Failed to load the model. Status code: 501
