# Create extra seg mask input channel
Using segmentation masks from [CT-FM segmentator](https://huggingface.co/project-lighter/whole_body_segmentation) as an extra input channel for inputs to nnUNetv2.
* original seg output: 118 classes, one channel for every class
* processed as input: encoded to int (0-117) in a single channel, normalized to (0-1)


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path

import SimpleITK as sitk
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

from lighter_zoo import SegResNet
from monai.transforms import LoadImage
from ctfm_utils import inference

In [None]:
data_folder = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset002_3dlesion_ctfm_seg/")
train_images = data_folder / "imagesTr"
train_labels = data_folder / "labelsTr"

train_names = os.listdir(train_images)
train_names = [n for n in train_names if n.endswith("_0000.nii.gz")]

ap_names = [n for n in train_names if n.startswith("AutoPET")]

## Crop large volumes
AutoPET CT and label volumes are cropped only to horizontal slices with lesions + a margin of 10% of total slices.

In [None]:
# import numpy as np
# for ap_name in tqdm(ap_names):
#     ct_path = train_images / ap_name
#     ct_img = sitk.ReadImage(ct_path)
#
#     label_path = train_labels / ap_name.replace("_0000.nii.gz", ".nii.gz")
#     label_img = sitk.ReadImage(label_path)
#     label = sitk.GetArrayFromImage(label_img)
#
#     tall = label.shape[0]
#     label_slices = np.where(label.any(axis=(1,2)))[0]
#     start = label_slices.min()
#     end = label_slices.max()
#     pad = tall // 10
#
#     start = int(max(0, start - pad))
#     end = int(min(tall, end + pad))
#
#     roi_size = list(ct_img.GetSize())
#     roi_size[2] = end - start
#     roi_index = [0, 0, start]
#     cropped_ct = sitk.RegionOfInterest(ct_img, roi_size, roi_index)
#     cropped_label = sitk.RegionOfInterest(label_img, roi_size, roi_index)
#
#     # write cropped volumes
#     sitk.WriteImage(cropped_ct, ct_path)
#     sitk.WriteImage(cropped_label, label_path)

## Run CT-FM totalsegmentator model
get segmentation mask as extra channel

In [None]:
# failed to create segmentation masks from CT-FM (memory issue maybe)
failed = ['AutoPET-Lymphoma-B_PETCT_987c8a1160_CT_0000.nii.gz', 'AutoPET-Melanoma-B_PETCT_32aa845af1_CT_0000.nii.gz', 'AutoPET-Melanoma-B_PETCT_b510436d83_CT_0000.nii.gz', 'AutoPET-Melanoma-B_PETCT_6efefcb92a_CT_0000.nii.gz', 'AutoPET-Melanoma-B_PETCT_8e02f36295_CT_0000.nii.gz', 'AutoPET-Melanoma-B_PETCT_7ce196485f_CT_0000.nii.gz', 'AutoPET-Melanoma-B_PETCT_1b199d094d_CT_0000.nii.gz']

# create empty 2nd channel for CTs that failed to have seg masks
for f in tqdm(failed):
    ct_path = train_images / f
    ct_img = sitk.ReadImage(ct_path)

    ct = sitk.GetArrayFromImage(ct_img)
    blank_img = sitk.GetImageFromArray(ct * 0)

    sitk.WriteImage(blank_img, train_images / f.replace("_0000.nii.gz", "_0001.nii.gz"))
    print(f)

In [None]:
# Running using subprocess, prevents memory leak issue that happens with loop below
import subprocess

for ct_name in tqdm(train_names):
    if ct_name in failed:
        continue
    if os.path.exists(train_images / ct_name.replace("_0000.nii.gz", "_0001.nii.gz")):
        print("======== already processed", ct_name)
        continue
    cmd = ["python", "run_inference_on_one_img.py", f"{train_images}", f"{ct_name}"]
    print(">>>>>>>>", " ".join(cmd))
    out = subprocess.run(cmd, capture_output=True)
    if out.returncode:
        failed.append(ct_name)
        print("++++++++ Failed", out.returncode, ct_name)

In [None]:
# Possible memory leak issue with SlidingWindowInferer running in a loop

model_name = "project-lighter/whole_body_segmentation"
device = "cuda"
model = SegResNet.from_pretrained(model_name).to(device)

for ct_name in tqdm(train_names[9:]):
    ct_path = train_images / ct_name
    ct_img = sitk.ReadImage(ct_path)

    print("Running inference...")
    with torch.no_grad():
        out = inference(model, ct_path)
    out = out.squeeze().permute([2,1,0])  # same shape as ct
    out = out / 117  # normalize to 0-1

    mask = sitk.GetImageFromArray(out.numpy())
    mask.CopyInformation(ct_img)

    mask_path = train_images / ct_name.replace("_0000.nii.gz", "_0001.nii.gz")
    sitk.WriteImage(mask, mask_path)
    print("Saved:", mask_path)

In [None]:
# load scan to visualize with masks
ct_img = LoadImage()(ct_path)
plt.imshow(ct_img[:,:,55]); plt.show()

In [None]:
plt.imshow(out[:,:,55]); plt.show()

In [None]:
import random
import SimpleITK as sitk

# load data from 3D dataset for nnUNet


uls_img = [x for x in os.listdir(train_images) if x.startswith("ULS")]
ap_img = [x for x in os.listdir(train_images) if x.startswith("AutoPET")]
f = uls_img[0]
print(f)

# Load data
ct_path = train_images / f
seg_path = train_labels / f.replace("_0000.nii.gz", ".nii.gz")
# seg_data = sitk.GetArrayFromImage(sitk.ReadImage(seg_path))
# ct_data = sitk.GetArrayFromImage(sitk.ReadImage(ct_path))

In [None]:
# visualize
res = result.squeeze()
for i in range(0, res.shape[-1], 2):
    seg_mask = res[..., i].rot90()
    if (seg_mask > 0).sum() > 0:
        ct_slice = ct_img[:, :, i].rot90()
        fig, axes = plt.subplots(1, 2, figsize=(6, 3))
        axes[0].imshow(ct_slice, cmap="gray")
        axes[1].imshow(seg_mask, vmin=0, vmax=117, cmap="gist_stern")
        plt.show()
        break