# Custom 3D Segmentation Model
Use ULS DeepLesion 3D (700+ samples)
* Split data into patches of 12.8cm x 12.8cm x 6.4cm, based on `Spacing_mm_px_` in DL_info.csv
* Encode using CT-FM
* Decode into segmentation mask of middle slice

In [None]:
import os
from pathlib import Path
import SimpleITK as sitk

data_folder = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset001_3dlesion")
train_images = data_folder / "imagesTr"
train_labels = data_folder / "labelsTr"

In [None]:
uls_img = [x for x in os.listdir(train_images) if x.startswith("ULS")]
ap_img = [x for x in os.listdir(train_images) if x.startswith("AutoPET")]

# filenames = random.sample(ap_img, 5) + random.sample(uls_img, 5)
filenames = uls_img + ap_img
f = filenames[0]
f

In [None]:
ct_path = train_images / f
seg_path = train_labels / f.replace("_0000.nii.gz", ".nii.gz")

seg_img = sitk.ReadImage(seg_path)
seg_data = sitk.GetArrayFromImage(seg_img)
# if seg_data.mean() > 5e-4:
#     print(seg_data.mean())
#     ct_img = sitk.ReadImage(ct_path)
#     ct_data = sitk.GetArrayFromImage(ct_img)
#     plot(f, ct_data, seg_data)

### CT-FM Feature Extractor

In [None]:
import torch
from lighter_zoo import SegResEncoder
from monai.transforms import (
    Compose, LoadImage, EnsureType, Orientation,
    ScaleIntensityRange, CropForeground
)

In [None]:
model = SegResEncoder.from_pretrained(
    "project-lighter/ct_fm_feature_extractor")
model.eval()

In [None]:
preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    EnsureType(),                         # Ensure correct data type
    Orientation(axcodes="SPL"),           # Standardize orientation
    # Scale intensity to [0,1] range, clipping outliers
    ScaleIntensityRange(
        a_min=-1024,    # Min HU value
        a_max=2048,     # Max HU value
        b_min=0,        # Target min
        b_max=1,        # Target max
        clip=True       # Clip values outside range
    ),
    CropForeground(allow_smaller=True)
])

In [None]:
LoadImage(ensure_channel_first=True)(ct_path).shape

In [None]:
# Preprocess input
input_tensor = preprocess(ct_path)

# Run inference
with torch.no_grad():
    output = model(input_tensor.unsqueeze(0))[-1]

    # Average pooling compressed the feature vector across all patches. If this is not desired, remove this line and
    # use the output tensor directly which will give you the feature maps in a low-dimensional space.
    avg_output = torch.nn.functional.adaptive_avg_pool3d(output, 1).squeeze()

print("✅ Feature extraction completed")
print(f"Output shape: {avg_output.shape}")

### Segmentation model

In [None]:
import matplotlib.pyplot as plt
import torch
from lighter_zoo import SegResNet
from monai.transforms import (
    Compose, LoadImage, EnsureType, Orientation,
    ScaleIntensityRange, CropForeground, Invert,
    Activations, AsDiscrete, KeepLargestConnectedComponent,
    SaveImage, Spacing
)
from monai.inferers import SlidingWindowInferer
device = "cuda"

In [None]:
seg_model = SegResNet.from_pretrained(
    "project-lighter/whole_body_segmentation",
).to(device)

In [None]:
ct_path = train_images / random.sample(ap_img, 1)[0]
ct_img = sitk.ReadImage(ct_path)
ct_data = sitk.GetArrayFromImage(ct_img)

In [None]:
inferer = SlidingWindowInferer(
    roi_size=[96, 160, 160],  # Size of patches to process
    sw_batch_size=1,          # Number of windows to process in parallel
    overlap=0.25,            # Overlap between windows (reduces boundary artifacts)
    mode="gaussian",           # Gaussian weighting for overlap regions
    sw_device=device,
    device='cpu',
)

# Preprocessing pipeline
preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    Spacing(pixdim=(2.0, 2.0, 2.0)),
    EnsureType(device=torch.device("cpu")),                         # Ensure correct data type
    Orientation(axcodes="SPL"),           # Standardize orientation
    ScaleIntensityRange(
        a_min=-1024,    # Min HU value
        a_max=2048,     # Max HU value
        b_min=0,        # Target min
        b_max=1,        # Target max
        clip=True,
    ),
    CropForeground(allow_smaller=True),    # Remove background to reduce computation
])

# Postprocessing pipeline
postprocess = Compose([
    Activations(softmax=True),
    AsDiscrete(argmax=True),  # threshold=0.1, dtype=torch.int16
    # KeepLargestConnectedComponent(),
    Invert(transform=preprocess),           # Restore original space
    # SaveImage(output_dir="./ct_fm_output")
])

In [None]:
ct_path = train_images / ap_img[0]
input_tensor = preprocess(ct_path)

In [None]:
# Run inference
with torch.no_grad():
    output = inferer(input_tensor.unsqueeze(dim=0), seg_model.to(device))[0]

# Copy metadata from input
output.applied_operations = input_tensor.applied_operations
output.affine = input_tensor.affine

# Postprocess and save result
result = postprocess(output[0])
print("✅ Segmentation completed and saved")

#### Visualize

In [None]:
ct_img = LoadImage()(ct_path)
ct_img.shape

In [None]:
res = result.squeeze()
for i in range(0, res.shape[-1], 50):
    slice = res[..., i].rot90()

    if (slice > 0).sum() > 0:
        ct_slice = ct_img[:, :, i].rot90()
        fig, axes = plt.subplots(1, 2, figsize=(6, 3))
        axes[0].imshow(ct_slice, cmap="gray")
        axes[1].imshow(slice, vmin=0, vmax=117, cmap="gist_stern")
        plt.show()


## Fine-tune seg model

In [None]:
from torch.nn import Conv3d
import numpy as np
import cc3d
from utils.plot import transparent_cmap

In [None]:
# replace head to single channel conv
seg_model.up_layers[3].head = Conv3d(32, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1))

### Data processing
* Load CT, other channels, and seg
* Get all connected components in seg
* For every cc:
    * randomly sample max(1, 1% of all points) points in cc
    * get (64,128,128) mm patch around it + pad?
* crop the same patch in other channels (ctfm-seg, boxes, seg target)
* resize all to (C,24,128,128) pixels (c, z, x, y)

In [None]:
# constants
PATCH_SIZE_MM = (64, 128, 128)  # zyx
PATCH_DIMS = (24, 128, 128)  # zyx

In [None]:
import torch
import torch.nn.functional as F


def get_centroid(coords):
    return np.mean(coords, axis=0).round()


def sample_points(coords):
    centroid = get_centroid(coords)
    coords = np.vstack([c for c in coords if np.any(c != centroid)])
    n_points = 2
    point_indices = np.random.choice(len(coords), n_points, replace=False)
    sampled_points = np.vstack([coords[point_indices], centroid]).astype(int)
    return np.unique(sampled_points, axis=0)


def get_xyz_range(point, spacing, patch_size_mm):
    x_size = round(patch_size_mm[2] / spacing[0])
    y_size = round(patch_size_mm[1] / spacing[1])
    z_size = round(patch_size_mm[0] / spacing[2])

    x_start = point[2] - x_size // 2
    x_end = x_start + x_size
    y_start = point[1] - y_size // 2
    y_end = y_start + y_size
    z_start = point[0] - z_size // 2
    z_end = z_start + z_size
    return (x_start, x_end), (y_start, y_end), (z_start, z_end)


def calculate_padding(array, x_range, y_range, z_range):
    # array is zyx
    pad_x = (max(-x_range[0], 0), max(x_range[1] - array.shape[2], 0))
    pad_y = (max(-y_range[0], 0), max(y_range[1] - array.shape[1], 0))
    pad_z = (max(-z_range[0], 0), max(z_range[1] - array.shape[0], 0))
    return pad_x, pad_y, pad_z


def resize_volume(volume, new_shape):
    tensor = torch.from_numpy(volume).unsqueeze(0).unsqueeze(0).float()
    resized_tensor = F.interpolate(tensor, size=new_shape, mode='trilinear')
    return resized_tensor.squeeze().numpy()


def get_patch(array, point, spacing, patch_dims=PATCH_DIMS, patch_size_mm=PATCH_SIZE_MM):
    # get ranges to crop
    x_range, y_range, z_range = get_xyz_range(point, spacing, patch_size_mm)

    # pad array so it fits within the range
    pad_x, pad_y, pad_z = calculate_padding(array, x_range, y_range, z_range)
    array_padded = np.pad(array, (pad_z, pad_y, pad_x), mode='reflect')  # mode='constant', constant_values=pad_value

    # adjust range after padding
    z_range = [z + pad_z[0] for z in z_range]
    y_range = [y + pad_y[0] for y in y_range]
    x_range = [x + pad_x[0] for x in x_range]

    # crop and resize
    patch = array_padded[z_range[0]:z_range[1], y_range[0]:y_range[1], x_range[0]:x_range[1]]
    return resize_volume(patch, patch_dims)

In [None]:
# select sample
import random
# f = random.sample(uls_img, 1)[0]
# f = "AutoPET-Lymphoma-B_PETCT_0fa313309d_CT_0000.nii.gz"
f = "ULSDL3D_000441_02_01_187_lesion_01_0000.nii.gz"

# load CT and seg
ct_path = train_images / f
ct_img = sitk.ReadImage(ct_path)
ct_data = sitk.GetArrayFromImage(ct_img)

seg_path = train_labels / f.replace("_0000.nii.gz", ".nii.gz")
seg_img = sitk.ReadImage(seg_path)
seg_data = sitk.GetArrayFromImage(seg_img)
spacing = seg_img.GetSpacing()

# get connected components in seg
labels, n_components = cc3d.connected_components(seg_data, return_N=True)

# sample points within cc
for c in range(1, n_components + 1):
    coords = np.argwhere(labels == c)
    points = sample_points(coords)
    for point in points:
        # crop volume around the point
        seg_patch = get_patch(seg_data, point, spacing)
        ct_patch = get_patch(ct_data, point, spacing)
        print(f"{f[:10]} {c=} {point=}")

In [None]:
# visualize slices
fig, axes = plt.subplots(1, 3, figsize=(6, 2))
for i in range(3):
    axes[i].imshow(ct_patch.mean(axis=i))
    axes[i].imshow(seg_patch.mean(axis=i), cmap=transparent_cmap("r"))

for ct, seg in zip(ct_patch, seg_patch):
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    axes[0].imshow(ct, cmap="gray")
    axes[1].imshow(ct, cmap="gray")
    axes[1].imshow(seg, cmap=transparent_cmap("r"))
    plt.show()