# Splitting volumes into patches
* create patches around lesions
* create sliding windows patches
  * keep the ones with lesion voxels
* get xyz ranges of patches
* use ranges to crop other channels
* save patches
* save ranges in filenames or in pickles to crop more channels in the future

In [8]:
import math
from pathlib import Path

import SimpleITK as sitk
import cc3d
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F


data_folder = Path("/media/liushifeng/KINGSTON/nnUNet_raw/Dataset001_3dlesion")
train_images = data_folder / "imagesTr"
train_labels = data_folder / "labelsTr"

In [None]:
# constants
PATCH_SIZE_MM = (64, 128, 128)  # zyx
PATCH_DIMS = (24, 128, 128)  # zyx

In [None]:
def get_centroid(coords):
    return np.mean(coords, axis=0).round().astype(int)


def sample_points(coords):
    centroid = get_centroid(coords)
    coords = np.vstack([c for c in coords if np.any(c != centroid)])
    n_points = 2
    point_indices = np.random.choice(len(coords), n_points, replace=False)
    sampled_points = np.vstack([coords[point_indices], centroid]).astype(int)
    return np.unique(sampled_points, axis=0)


def get_xyz_range(point, spacing, patch_size_mm):
    x_size = round(patch_size_mm[2] / spacing[0])
    y_size = round(patch_size_mm[1] / spacing[1])
    z_size = round(patch_size_mm[0] / spacing[2])

    x_start = point[2] - x_size // 2
    x_end = x_start + x_size
    y_start = point[1] - y_size // 2
    y_end = y_start + y_size
    z_start = point[0] - z_size // 2
    z_end = z_start + z_size
    return (x_start, x_end), (y_start, y_end), (z_start, z_end)


def calculate_padding(array, x_range, y_range, z_range):
    # array is zyx
    pad_x = (max(-x_range[0], 0), max(x_range[1] - array.shape[2], 0))
    pad_y = (max(-y_range[0], 0), max(y_range[1] - array.shape[1], 0))
    pad_z = (max(-z_range[0], 0), max(z_range[1] - array.shape[0], 0))
    return pad_x, pad_y, pad_z


def resize_volume(volume, new_shape):
    tensor = torch.from_numpy(volume).unsqueeze(0).unsqueeze(0).float()
    resized_tensor = F.interpolate(tensor, size=new_shape, mode='trilinear')
    return resized_tensor.squeeze().numpy()


def get_patch(array, point, spacing, patch_dims=PATCH_DIMS, patch_size_mm=PATCH_SIZE_MM):
    # get ranges to crop
    x_range, y_range, z_range = get_xyz_range(point, spacing, patch_size_mm)

    # pad array so it fits within the range
    pad_x, pad_y, pad_z = calculate_padding(array, x_range, y_range, z_range)
    array_padded = np.pad(array, (pad_z, pad_y, pad_x), mode='reflect')  # mode='constant', constant_values=pad_value

    # adjust range after padding
    z_range = [z + pad_z[0] for z in z_range]
    y_range = [y + pad_y[0] for y in y_range]
    x_range = [x + pad_x[0] for x in x_range]

    # crop and resize
    patch = array_padded[z_range[0]:z_range[1], y_range[0]:y_range[1], x_range[0]:x_range[1]]
    return resize_volume(patch, patch_dims)

In [None]:
# f = random.sample(uls_img, 1)[0]
f = "AutoPET-Lymphoma-B_PETCT_0fa313309d_CT_0000.nii.gz"
# f = "ULSDL3D_000441_02_01_187_lesion_01_0000.nii.gz"

# load CT and seg
ct_path = train_images / f
ct_img = sitk.ReadImage(ct_path)
ct_data = sitk.GetArrayFromImage(ct_img)

seg_path = train_labels / f.replace("_0000.nii.gz", ".nii.gz")
seg_img = sitk.ReadImage(seg_path)
seg_data = sitk.GetArrayFromImage(seg_img)
spacing = seg_img.GetSpacing()

# get connected components in seg
labels, n_components = cc3d.connected_components(seg_data, return_N=True)
n_components

In [None]:
# get centroid patches
seg_patches = []
ct_patches = []
for c in range(1, n_components + 1):
    coords = np.argwhere(labels == c)
    point = get_centroid(coords)

    seg_patches.append(get_patch(seg_data, point, spacing))
    ct_patches.append(get_patch(ct_data, point, spacing))

## Seg

In [9]:
from lighter_zoo import SegResNet
from monai.inferers import SlidingWindowInferer
from monai.transforms import (
    Compose, LoadImage, Orientation,
    ScaleIntensityRange, Activations, AsDiscrete, KeepLargestConnectedComponent
)
device = "cuda"
seg_model = SegResNet.from_pretrained(
    "project-lighter/whole_body_segmentation",
).to(device)

seg_preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    Orientation(axcodes="SPL"),  # Standardize orientation
])

ct_preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    Orientation(axcodes="SPL"),  # Standardize orientation
    ScaleIntensityRange(
        a_min=-1024,  # Min HU value
        a_max=2048,  # Max HU value
        b_min=0,  # Target min
        b_max=1,  # Target max
        clip=True  # Clip values outside range
    ),
])

postprocess = Compose([
    Activations(softmax=True),
    AsDiscrete(argmax=True, to_onehot=118),  # threshold=0.1, dtype=torch.int16
    KeepLargestConnectedComponent(),
    # Invert(transform=ct_preprocess),           # Restore original space
    # SaveImage(output_dir="./ct_fm_output")
])

# create sliding window patches
inferer = SlidingWindowInferer(
    roi_size=[96, 160, 160],
    sw_batch_size=1,
    overlap=0.5,
    mode="gaussian",
    device='cpu',
    sw_device=device,
)

In [10]:
seg = seg_preprocess(seg_path)
ct = ct_preprocess(ct_path)

In [None]:
def get_n_windows(x, width, overlap):
    """Get number of windows that can fit, and remainder as a ratio of window"""
    windows = (x - overlap) / (width - overlap)
    remainder = windows % 1 * (width - overlap) / width
    return math.floor(windows), remainder


def get_step_size(x, width, n) -> int:
    """Calculate step size needed to fit n windows in x"""
    if n == 1:
        return 0
    overlap = (n * width - x) / (n - 1)
    return width - overlap


def get_range(dim, x, width, step) -> tuple:
    """get start and end range of a dimension given window width and step"""
    start = round(x * step)
    end = start + width
    if end > dim:  # if it exceeds the dimension
        end = dim
        start = end - width
    return start, end

def generate_patches(volume, patch_size, overlap_ratio=0.5):
    """Generate sliding windows"""
    shape = list(volume.shape)  # zyx

    prelim_overlaps = [overlap_ratio * x for x in patch_size]
    n_windows = [get_n_windows(x, w, s)[0] for x, w, s in zip(shape, patch_size, prelim_overlaps)]
    remainders = [get_n_windows(x, w, s)[1] for x, w, s in zip(shape, patch_size, prelim_overlaps)]

    # add an extra window if there's remainder >10% of window
    n_windows = [(n + 1 if r > 0.1 else n) for n, r in zip(n_windows, remainders)]

    # calculate new step size to evenly distribute the windows
    steps = [get_step_size(x, w, n) for x, w, n in zip(shape, patch_size, n_windows)]

    print("winds", n_windows)
    print("steps", steps)

    patches = []
    for z in range(n_windows[0]):
        z_start, z_end = get_range(shape[0], z, patch_size[0], steps[0])
        for y in range(n_windows[1]):
            y_start, y_end = get_range(shape[1], y, patch_size[1], steps[1])
            for x in range(n_windows[2]):
                x_start, x_end = get_range(shape[2], x, patch_size[2], steps[2])
                patches.append(volume[z_start:z_end, y_start:y_end, x_start:x_end])
    return patches

In [None]:
patch_size = [64, 128, 128]
overlap_ratio = 0.5

ct_patches = generate_patches(ct[0], patch_size, overlap_ratio)
# seg_patches = generate_patches(seg[0])

In [None]:
from utils.plot import transparent_cmap

n = 373

for i in range(len(ct_patches)):
    print(i)
    fig, axes = plt.subplots(1, 2)
    axes[0].imshow(ct_patches[n][i], cmap="gray")
    axes[1].imshow(ct_patches[n][i], cmap="gray")
    axes[1].imshow(seg_patches[n][i], cmap=transparent_cmap("r"), alpha=0.5)
    plt.show()

In [None]:
out_raw = inferer(ct_patches[373].unsqueeze(0).unsqueeze(0), seg_model)  # seg_model)  # lambda x: x)

In [None]:
out = postprocess(out_raw)[0]

In [None]:
out.shape

In [None]:
plt.imshow(out[], vmin=0, vmax=117, cmap="gist_stern")

In [None]:
plt.imshow(out.detach().numpy())