Training a 2D bbox to 2D segmentation model
---

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
import os
from evaluation.eval_utils import get_seg_bbox
from pathlib import Path
from tqdm import tqdm

import nrrd
import SimpleITK as sitk
from utils.plot import transparent_cmap

In [None]:
# Dataset
data_folder = Path("/media/liushifeng/KINGSTON/ULS Jan 2025/ULS23/novel_data/ULS23_DeepLesion3D")
seg_folder = data_folder / "labels"
ct_folder = data_folder / "images"

with open(data_folder / "train.txt", "r") as f:
    train_names = [x.strip() for x in f.readlines()]
with open(data_folder / "val.txt", "r") as f:
    val_names = [x.strip() for x in f.readlines()]

filenames = [x for x in os.listdir(seg_folder) if ".zip" not in x]
seg_paths = {x: seg_folder / x for x in filenames}
ct_paths = {x: ct_folder / x for x in filenames}

In [None]:
ignored_train_samples = [x.strip() for x in
    """
    000148_04_01_034_lesion_01
    003287_01_01_188_lesion_01
    003931_01_01_078_lesion_01
    000026_06_01_257_lesion_01
    000346_01_01_085_lesion_01
    000215_05_01_096_lesion_01
    001354_04_02_305_lesion_01
    001564_02_02_513_lesion_01
    """.split("\n") if x.strip()]

In [None]:
# for plotting histograms
bins = [int(x) for x in np.arange(-1000, 2001, 50)] + [2050] # Define HU bins
hu_counter = Counter()  # Initialize the counter

for every lesion, based on filename:
- input
    - crop (32X * 32Y)
    - rectangle patch or border of the segment (32X * 32Y)
    - full slice as context (fixed size e.g. 256 * 256)
- output
    - seg (32X * 32Y) **Blur for soft mask?**

In [None]:
from train_utils import crop_from_img

output_folder = Path("/media/liushifeng/KINGSTON/ULS DL3D 2D dataset/")

plot = 0
save = 0

for filename in list(seg_paths.keys()):
    if filename.split(".")[0] in ignored_train_samples:
        print("skipped:", filename.split(".")[0])
        continue

    lesion_name = filename.split(".")[0]

    # load arrays
    ct = sitk.ReadImage(ct_paths[filename])
    ct_array = sitk.GetArrayFromImage(ct)
    seg = sitk.ReadImage(seg_paths[filename])
    seg_array = sitk.GetArrayFromImage(seg)

    # get slices where there are segmentations
    seg_slice_indices = np.where(seg_array.any(axis=(1,2)) > 0)[0]

    for i in seg_slice_indices:
        ct_slice = ct_array[i]
        seg_slice = seg_array[i]

        crop_bbox = get_seg_bbox(seg_slice)

        if save:
            print(f"saving {lesion_name}_slice{i}")
            # train_or_val = "train" if lesion_name in train_names else "val"
            # save_folder = output_folder / train_or_val
            # save_folder.mkdir(parents=True, exist_ok=True)
            np.save(str(output_folder / "images" / f"{lesion_name}_slice{i}.npy"), ct_slice)
            # nrrd.write(str(output_folder / "slices" / f"{lesion_name}_slice{i}_seg"), seg_slice)

        if plot:
            vmin, vmax = -200, 200
            plt.imshow(ct_slice, vmin=vmin, vmax=vmax, cmap='gray')
            plt.show()
            plt.imshow(seg_slice)
            plt.show()

            # crops areas for plotting
            margin = 2
            ct_crop = crop_from_img(ct_slice, crop_bbox, margin, AIR_VALUE)
            seg_crop = crop_from_img(seg_slice, crop_bbox, margin, 0)

        # hu_counter.update(np.digitize(ct_crop.flatten(), bins))
        # hu_counter.update(np.digitize(ct_crop[seg_crop == 1], bins))

In [None]:
# if save:
#     # save training data set
#     train_or_val = "train" if lesion_name in train_names else "val"
#     save_folder = output_folder / train_or_val
#     save_folder.mkdir(parents=True, exist_ok=True)
#
#     np.save(save_folder / f"{lesion_name}_slice{i}_crop", np.array(ct_crop, dtype=np.int16))
#     np.save(save_folder / f"{lesion_name}_slice{i}_cropseg", np.array(seg_crop, dtype=np.uint8))
#     np.save(save_folder / f"{lesion_name}_slice{i}_slice", np.array(ct_slice, dtype=np.int16))
#
# if plot:
#     _, axes = plt.subplots(1, 2, figsize=(4, 2))
#
#     # plot the crop and segment
#     vmin, vmax = -200, 200
#     axes[0].imshow(ct_crop, vmin=vmin, vmax=vmax, cmap='gray')
#     axes[1].imshow(ct_crop, vmin=vmin, vmax=vmax, cmap='gray')
#     axes[1].imshow(seg_crop, cmap=transparent_cmap("blue"), alpha=0.4)
#     for ax in axes:
#         ax.axis("off")
#     plt.show()
#
#     # plot the entire slice
#     _, axes = plt.subplots(1, 2, figsize=(10, 5))
#     axes[0].imshow(ct_slice, cmap='gray')
#     axes[1].imshow(ct_slice, cmap='gray')
#     axes[1].imshow(seg_slice, cmap=transparent_cmap("blue"), alpha=0.4)
#     for ax in axes:
#         ax.axis("off")
#     plt.show()

In [None]:
plt.figure(figsize=(6,2.5));
plt.bar([bins[int(x)] for x in hu_counter.keys()], hu_counter.values(), width=50);

In [None]:
plt.figure(figsize=(6,2.5));
plt.bar([bins[int(x)] for x in hu_counter.keys()], hu_counter.values(), width=50);

In [None]:
plt.imshow(ct_slice, cmap='gray'); plt.axis("off");
plt.imshow(seg_slice, cmap='jet', alpha=0.5); plt.axis("off");

## Load data

In [None]:
data_dir = Path("/media/liushifeng/KINGSTON/ULS DL3D 2D dataset")

In [None]:
import os
import numpy as np
from torch.utils.data import DataLoader
from dataset import SegmentationDataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from functools import partial

image_dir = data_dir / "images"
label_dir = data_dir / "masks"
all_images = sorted(os.listdir(image_dir))

# random.shuffle(all_images)  # Shuffle before splitting
# val_size = int(0.25 * len(all_images))
# val_images = all_images[:val_size]
# train_images = all_images[val_size:]

val_images = [x for x in all_images if x.split("_slice")[0] in val_names]
train_images = [x for x in all_images if x.split("_slice")[0] in train_names]

In [None]:
# transform functions
def clip_image(image, min_val, max_val, **kwargs):
    return np.clip(image, min_val, max_val)

def rescale_image(image, min_val, max_val, **kwargs):
    return (image - min_val) / (max_val - min_val)

def pass_through(mask, **kwargs):
    return mask

v = -500, 1000
clip = partial(clip_image, min_val=v[0], max_val=v[1])
rescale = partial(rescale_image, min_val=v[0], max_val=v[1])

val_transform = A.Compose([
    A.Lambda(image=clip, mask=pass_through),
    A.Lambda(image=rescale, mask=pass_through),
    ToTensorV2(),
])

train_transform = A.Compose([
    A.Lambda(image=clip, mask=pass_through),
    A.Lambda(image=rescale, mask=pass_through),

    # pixel
    A.GaussianBlur(blur_limit=5, p=0.5),
    A.GaussNoise(std_range=(0.01,0.05), per_channel=False, p=0.5),

    # # spatial
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Transpose(p=0.5),
    A.OpticalDistortion(distort_limit=(-0.5, -0.1), p=0.5),
    A.GridDistortion(p=0.5),  # adds some black edges?
    ToTensorV2(),
])

image_names = None
batch_size = 1
workers = 1
train_dataset = SegmentationDataset(image_dir, label_dir, train_images, train_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)

val_dataset = SegmentationDataset(image_dir, label_dir, val_images, val_transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)

# Visualizing data loader
data = next(iter(val_loader))
plt.figure(figsize=(3,3))
plt.imshow(data['image'].squeeze(), cmap='gray')
plt.imshow(data['mask'].squeeze(), cmap=transparent_cmap("blue"), alpha=0.5)
plt.imshow(data['box_mask'].squeeze(), cmap='Greens', alpha=0.3)
# plt.axis('off');
plt.show()

In [None]:
import warnings
warnings.filterwarnings("ignore", category=SyntaxWarning)
import torch
import lightning as L
from lightning.pytorch.callbacks import EarlyStopping
import segmentation_models_pytorch as smp

torch.set_float32_matmul_precision('medium')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
len(train_names) + len(val_names)

In [None]:
from model import UNet

unet = UNet()
early_stopping = EarlyStopping('val_loss', verbose=True)
trainer = L.Trainer(
    accumulate_grad_batches=32,
    max_epochs=100,
    limit_val_batches=1,
    callbacks=[early_stopping]
)
trainer.fit(unet, train_loader, val_loader)

In [None]:
data['mask'].mean()

In [None]:
fp

In [None]:
# visualize data and model output
data = next(iter(val_loader))
out = unet.model.to(device)(data['input'].to(device)).detach().cpu().squeeze()
out = out > 0.5

tp = data['mask'].bool().squeeze() & out.bool().squeeze()
fp = ~data['mask'].bool().squeeze() & out.bool().squeeze()
fn = data['mask'].bool().squeeze() & ~out.bool().squeeze()

plt.figure(figsize=(3,3))
plt.imshow(data['input'].squeeze()[0], cmap='gray')
plt.imshow(tp, cmap=transparent_cmap("limegreen"), alpha=0.5)
plt.imshow(fp, cmap=transparent_cmap("red"), alpha=0.5)
plt.imshow(fn, cmap=transparent_cmap("blue"), alpha=0.5)
# plt.imshow(out, cmap=transparent_cmap("green"), alpha=0.3);