# Custom 3D Segmentation Model
Use ULS DeepLesion 3D (700+ samples)
* Split data into patches of 12.8cm x 12.8cm x 6.4cm, based on `Spacing_mm_px_` in DL_info.csv
* Encode using CT-FM
* Decode into segmentation mask of middle slice

In [None]:
import os
from pathlib import Path
import SimpleITK as sitk

data_folder = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset001_3dlesion")
train_images = data_folder / "imagesTr"
train_labels = data_folder / "labelsTr"

In [None]:
uls_img = [x for x in os.listdir(train_images) if x.startswith("ULS")]
ap_img = [x for x in os.listdir(train_images) if x.startswith("AutoPET")]

# filenames = random.sample(ap_img, 5) + random.sample(uls_img, 5)
filenames = uls_img + ap_img
f = filenames[0]
f

In [None]:
ct_path = train_images / f
seg_path = train_labels / f.replace("_0000.nii.gz", ".nii.gz")

seg_img = sitk.ReadImage(seg_path)
seg_data = sitk.GetArrayFromImage(seg_img)
# if seg_data.mean() > 5e-4:
#     print(seg_data.mean())
#     ct_img = sitk.ReadImage(ct_path)
#     ct_data = sitk.GetArrayFromImage(ct_img)
#     plot(f, ct_data, seg_data)

### CT-FM Feature Extractor

In [None]:
import torch
from lighter_zoo import SegResEncoder
from monai.transforms import (
    Compose, LoadImage, EnsureType, Orientation,
    ScaleIntensityRange, CropForeground
)

In [None]:
model = SegResEncoder.from_pretrained(
    "project-lighter/ct_fm_feature_extractor")
model.eval()

In [None]:
preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    EnsureType(),                         # Ensure correct data type
    Orientation(axcodes="SPL"),           # Standardize orientation
    # Scale intensity to [0,1] range, clipping outliers
    ScaleIntensityRange(
        a_min=-1024,    # Min HU value
        a_max=2048,     # Max HU value
        b_min=0,        # Target min
        b_max=1,        # Target max
        clip=True       # Clip values outside range
    ),
    CropForeground(allow_smaller=True)
])

In [None]:
LoadImage(ensure_channel_first=True)(ct_path).shape

In [None]:
# Preprocess input
input_tensor = preprocess(ct_path)

# Run inference
with torch.no_grad():
    output = model(input_tensor.unsqueeze(0))[-1]

    # Average pooling compressed the feature vector across all patches. If this is not desired, remove this line and
    # use the output tensor directly which will give you the feature maps in a low-dimensional space.
    avg_output = torch.nn.functional.adaptive_avg_pool3d(output, 1).squeeze()

print("✅ Feature extraction completed")
print(f"Output shape: {avg_output.shape}")

### Segmentation model

In [None]:
import matplotlib.pyplot as plt
import torch
from lighter_zoo import SegResNet
from monai.transforms import (
    Compose, LoadImage, EnsureType, Orientation,
    ScaleIntensityRange, CropForeground, Invert,
    Activations, AsDiscrete, KeepLargestConnectedComponent,
    SaveImage, Spacing
)
from monai.inferers import SlidingWindowInferer
device = "cuda"

In [None]:
seg_model = SegResNet.from_pretrained(
    "project-lighter/whole_body_segmentation",
).to(device)

In [None]:
import random

In [None]:
ct_path = train_images / random.sample(ap_img, 1)[0]
ct_img = sitk.ReadImage(ct_path)
ct_data = sitk.GetArrayFromImage(ct_img)

In [None]:
ct_data.shape

In [None]:
import joblib


In [None]:
joblib.load(r"/media/liushifeng/KINGSTON/nnUNet_preprocessed/Dataset001_3dlesion/nnUNetPlans_3d_fullres/AutoPET-Lymphoma-B_PETCT_0f4ee9e078_CT.pkl")

In [None]:
plt.imshow(ct_data.mean(axis=1))

In [None]:
676 * 2.5

In [None]:
ct_img.GetSize(), ct_img.GetSpacing()

In [None]:
inferer = SlidingWindowInferer(
    roi_size=[96, 160, 160],  # Size of patches to process
    sw_batch_size=1,          # Number of windows to process in parallel
    overlap=0.25,            # Overlap between windows (reduces boundary artifacts)
    mode="gaussian",           # Gaussian weighting for overlap regions
    sw_device=device,
    device='cpu',
)

# Preprocessing pipeline
preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    Spacing(pixdim=(2.0, 2.0, 2.0)),
    EnsureType(device=torch.device("cpu")),                         # Ensure correct data type
    Orientation(axcodes="SPL"),           # Standardize orientation
    ScaleIntensityRange(
        a_min=-1024,    # Min HU value
        a_max=2048,     # Max HU value
        b_min=0,        # Target min
        b_max=1,        # Target max
        clip=True,
    ),
    CropForeground(allow_smaller=True),    # Remove background to reduce computation
])

# Postprocessing pipeline
postprocess = Compose([
    Activations(softmax=True),
    AsDiscrete(argmax=True),  # threshold=0.1, dtype=torch.int16
    # KeepLargestConnectedComponent(),
    Invert(transform=preprocess),           # Restore original space
    # SaveImage(output_dir="./ct_fm_output")
])

In [None]:
ct_path = train_images / ap_img[0]
input_tensor = preprocess(ct_path)

In [None]:
input_tensor.shape

In [None]:
input_tensor.shape

In [None]:
# Run inference
with torch.no_grad():
    output = inferer(input_tensor.unsqueeze(dim=0), seg_model.to(device))[0]

# Copy metadata from input
output.applied_operations = input_tensor.applied_operations
output.affine = input_tensor.affine

# Postprocess and save result
result = postprocess(output[0])
print("✅ Segmentation completed and saved")

#### Visualize

In [None]:
ct_img = LoadImage()(ct_path)
ct_img.shape

In [None]:
res = result.squeeze()
for i in range(0, res.shape[-1], 50):
    slice = res[..., i].rot90()

    if (slice > 0).sum() > 0:
        ct_slice = ct_img[:, :, i].rot90()
        fig, axes = plt.subplots(1, 2, figsize=(6, 3))
        axes[0].imshow(ct_slice, cmap="gray")
        axes[1].imshow(slice, vmin=0, vmax=117, cmap="gist_stern")
        plt.show()


## Fine-tune seg model

In [None]:
from torch.nn import Conv3d

In [None]:
# replace head to single channel conv
seg_model.up_layers[3].head = Conv3d(32, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1))

In [None]:
seg_path = train_labels / random.sample(ap_img, 1)[0].replace("_0000.nii.gz", ".nii.gz")
seg_img = sitk.ReadImage(seg_path)
seg_data = sitk.GetArrayFromImage(seg_img)

In [None]:
plt.imshow(seg_data.any(axis=1))

In [None]:
import numpy as np
import cc3d


# Label connected components (26-connected by default)
labels, num_components = cc3d.connected_components(seg_data, return_N=True)

# Calculate centroids
centroids = []
for component_id in range(1, num_components + 1):  # Skip background (label 0)
    voxel_coords = np.argwhere(labels == component_id)
    centroid = np.mean(voxel_coords, axis=0)  # [z, y, x] order
    centroids.append(centroid)

In [None]:
centroids