# Fine Tune CT-FM 3D Segmentation Model
Use ULS DeepLesion 3D (700+ samples)
* Split data into patches of 12.8cm x 12.8cm x 6.4cm, based on `Spacing_mm_px_` in DL_info.csv
* Encode using CT-FM
* Decode into segmentation mask of middle slice

### Test Inference

In [None]:
import os
import random
from pathlib import Path
import SimpleITK as sitk

# load data from 3D dataset for nnUNet
data_folder = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset001_3dlesion")
train_images = data_folder / "imagesTr"
train_labels = data_folder / "labelsTr"

uls_img = [x for x in os.listdir(train_images) if x.startswith("ULS")]
ap_img = [x for x in os.listdir(train_images) if x.startswith("AutoPET")]
f = uls_img[0]
print(f)

# Load data
ct_path = train_images / f
seg_path = train_labels / f.replace("_0000.nii.gz", ".nii.gz")
# seg_data = sitk.GetArrayFromImage(sitk.ReadImage(seg_path))
# ct_data = sitk.GetArrayFromImage(sitk.ReadImage(ct_path))

In [None]:
import matplotlib.pyplot as plt
import torch
from lighter_zoo import SegResNet
from monai.transforms import LoadImage
from setup_utils import get_inferer, get_preprocess, get_postprocess

# load model
model_name = "project-lighter/whole_body_segmentation"
device = "cuda"
seg_model = SegResNet.from_pretrained(model_name).to(device)

# load pipelines
inferer = get_inferer(device)
preprocess = get_preprocess()
postprocess = get_postprocess(preprocess)

In [None]:
input_tensor = preprocess(ct_path)
with torch.no_grad():
    output = inferer(input_tensor.unsqueeze(dim=0), seg_model.to(device))[0]
    print(f"{output.shape=}")

output.applied_operations = input_tensor.applied_operations
output.affine = input_tensor.affine
result = postprocess(output[0])
print(result.shape)

In [None]:
# load scan to visualize with masks
ct_img = LoadImage()(ct_path)

In [None]:
# visualize
res = result.squeeze()
for i in range(0, res.shape[-1], 2):
    seg_mask = res[..., i].rot90()
    if (seg_mask > 0).sum() > 0:
        ct_slice = ct_img[:, :, i].rot90()
        fig, axes = plt.subplots(1, 2, figsize=(6, 3))
        axes[0].imshow(ct_slice, cmap="gray")
        axes[1].imshow(seg_mask, vmin=0, vmax=117, cmap="gist_stern")
        plt.show()
        break


## Fine-tune seg model

In [None]:
from torch.nn import Conv3d
import numpy as np
import cc3d
from utils.plot import transparent_cmap

In [None]:
# replace head to single channel conv
seg_model.up_layers[3].head = Conv3d(32, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1))

### Data processing

In [None]:
# select sample
from patching import sample_points, get_lesion_patch
import random
f = random.sample(uls_img, 1)[0]
# f = "AutoPET-Lymphoma-B_PETCT_0fa313309d_CT_0000.nii.gz"
f = "ULSDL3D_000441_02_01_187_lesion_01_0000.nii.gz"

# load CT and seg
ct_path = train_images / f
ct_img = sitk.ReadImage(ct_path)
ct_data = sitk.GetArrayFromImage(ct_img)

seg_path = train_labels / f.replace("_0000.nii.gz", ".nii.gz")
seg_img = sitk.ReadImage(seg_path)
seg_data = sitk.GetArrayFromImage(seg_img)
spacing = seg_img.GetSpacing()

# get connected components in seg
labels, n_components = cc3d.connected_components(seg_data, return_N=True)

# sample points within cc
for c in range(1, n_components + 1):
    coords = np.argwhere(labels == c)
    points = sample_points(coords)
    for point in points:
        # crop volume around the point
        seg_patch = get_lesion_patch(seg_data, point, spacing)
        ct_patch = get_lesion_patch(ct_data, point, spacing)
        print(f"{f[:10]} {c=} {point=}")