Training a 2D bbox to 2D segmentation model
---

In [None]:
%load_ext autoreload
%autoreload 2
from collections import Counter

import matplotlib.pyplot as plt
import numpy as np
import os
from evaluation.eval_utils import get_seg_bbox
from pathlib import Path
from tqdm import tqdm

import nrrd
import SimpleITK as sitk
from utils.plot import transparent_cmap

# Dataset
data_folder = Path("/media/liushifeng/KINGSTON/ULS Jan 2025/ULS23/novel_data/ULS23_DeepLesion3D")
seg_folder = data_folder / "labels"
ct_folder = data_folder / "images"

with open(data_folder / "train.txt", "r") as f:
    train_names = [x.strip() for x in f.readlines()]
with open(data_folder / "val.txt", "r") as f:
    val_names = [x.strip() for x in f.readlines()]

filenames = [x for x in os.listdir(seg_folder) if ".zip" not in x]
seg_paths = {x: seg_folder / x for x in filenames}
ct_paths = {x: ct_folder / x for x in filenames}

In [None]:
ignored_train_samples = [x.strip() for x in
    """
    000148_04_01_034_lesion_01
    003287_01_01_188_lesion_01
    003931_01_01_078_lesion_01
    000026_06_01_257_lesion_01
    000346_01_01_085_lesion_01
    000215_05_01_096_lesion_01
    001354_04_02_305_lesion_01
    001564_02_02_513_lesion_01
    """.split("\n") if x.strip()]

In [None]:
# for plotting histograms
bins = [int(x) for x in np.arange(-1000, 2001, 50)] + [2050] # Define HU bins
hu_counter = Counter()  # Initialize the counter

for every lesion, based on filename:
- input
    - crop (32X * 32Y)
    - rectangle patch or border of the segment (32X * 32Y)
    - full slice as context (fixed size e.g. 256 * 256)
- output
    - seg (32X * 32Y) **Blur for soft mask?**

In [None]:
seg, _ = nrrd.read("/media/liushifeng/KINGSTON/ULS DL3D 2D dataset/masks/002779_01_01_183_lesion_01_slice61_seg")

In [None]:
from train_utils import crop_from_img

output_folder = Path("/media/liushifeng/KINGSTON/ULS DL3D 2D dataset/")

plot = 0
save = 0

for filename in list(seg_paths.keys()):
    if filename.split(".")[0] in ignored_train_samples:
        print("skipped:", filename.split(".")[0])
        continue

    lesion_name = filename.split(".")[0]

    # load arrays
    ct = sitk.ReadImage(ct_paths[filename])
    ct_array = sitk.GetArrayFromImage(ct)
    seg = sitk.ReadImage(seg_paths[filename])
    seg_array = sitk.GetArrayFromImage(seg)

    # get slices where there are segmentations
    seg_slice_indices = np.where(seg_array.any(axis=(1,2)) > 0)[0]

    for i in seg_slice_indices:
        ct_slice = ct_array[i]
        seg_slice = seg_array[i]

        crop_bbox = get_seg_bbox(seg_slice)

        if save:
            print(f"saving {lesion_name}_slice{i}")
            # train_or_val = "train" if lesion_name in train_names else "val"
            # save_folder = output_folder / train_or_val
            # save_folder.mkdir(parents=True, exist_ok=True)
            np.save(str(output_folder / "images" / f"{lesion_name}_slice{i}.npy"), ct_slice)
            # nrrd.write(str(output_folder / "slices" / f"{lesion_name}_slice{i}_seg"), seg_slice)

        if plot:
            vmin, vmax = -200, 200
            plt.imshow(ct_slice, vmin=vmin, vmax=vmax, cmap='gray')
            plt.show()
            plt.imshow(seg_slice)
            plt.show()

            # crops areas for plotting
            margin = 2
            ct_crop = crop_from_img(ct_slice, crop_bbox, margin, AIR_VALUE)
            seg_crop = crop_from_img(seg_slice, crop_bbox, margin, 0)

        # hu_counter.update(np.digitize(ct_crop.flatten(), bins))
        # hu_counter.update(np.digitize(ct_crop[seg_crop == 1], bins))

In [None]:
# plt.figure(figsize=(6,2.5));
# plt.bar([bins[int(x)] for x in hu_counter.keys()], hu_counter.values(), width=50);
#
# plt.imshow(ct_slice, cmap='gray'); plt.axis("off")
# plt.imshow(seg_slice, cmap='jet', alpha=0.5); plt.axis("off");

## Load data

In [None]:
import os
from torch.utils.data import DataLoader
from dataset import SegmentationDataset
from transforms import train_transform, val_transform, context_transform

# data
data_dir = Path("/media/liushifeng/KINGSTON/ULS DL3D 2D dataset")
image_dir = data_dir / "images"
label_dir = data_dir / "masks"
all_images = sorted(os.listdir(image_dir))

val_images = [x for x in all_images if x.split("_slice")[0] in val_names]
train_images = [x for x in all_images if x.split("_slice")[0] in train_names]

# dataloaders
image_names = None
batch_size = 32
workers = 24
train_dataset = SegmentationDataset(
    image_dir, label_dir, train_images, train_transform, context_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)

val_dataset = SegmentationDataset(image_dir, label_dir, val_images, val_transform, context_transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)

In [None]:
# Visualizing data loader
data = next(iter(val_loader))
plt.figure(figsize=(1.5,1.5))
plt.imshow(data['image'][0].squeeze(), cmap='gray')
plt.imshow(data['mask'][0].squeeze(), cmap=transparent_cmap("b"), alpha=0.5)
plt.imshow(data['box_mask'][0].squeeze(), vmin=0, vmax=1, cmap=transparent_cmap("g"), alpha=0.3)
# plt.axis('off');
print(data['image'][0].shape)
plt.show()

## Train Model

In [None]:
from glob import glob
import warnings
warnings.filterwarnings("ignore", category=SyntaxWarning)
import torch
import lightning as L
from lightning.pytorch.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
from model import UNet

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=workers)

# torch.set_float32_matmul_precision('medium')
unet = UNet()
unet = torch.compile(unet)
early_stopping = EarlyStopping('val_loss', patience=5, verbose=True)
logger = TensorBoardLogger("lightning_logs", name="dl1")  # name="my_model", version="GAT"
trainer = L.Trainer(
    max_epochs=70,
    callbacks=[early_stopping],
    logger=logger,
)
trainer.fit(unet, train_loader, val_loader)

In [None]:
ckpt_path = glob("lightning_logs/dl3/version_1/*/*.ckpt")[0]; print(ckpt_path)
unet = UNet.load_from_checkpoint(ckpt_path)

In [None]:
large_lesion_names = ['000083_07_01_114_lesion_01_slice56.npy',
 '000083_07_01_114_lesion_01_slice57.npy',
 '000083_07_01_114_lesion_01_slice58.npy',
 '000083_07_01_114_lesion_01_slice59.npy',
 '000083_07_01_114_lesion_01_slice60.npy',
 '000083_07_01_114_lesion_01_slice61.npy',
 '000083_07_01_114_lesion_01_slice62.npy',
 '000083_07_01_114_lesion_01_slice63.npy',
 '000083_07_01_114_lesion_01_slice64.npy',
 '000083_07_01_114_lesion_01_slice65.npy',
 '000083_07_01_114_lesion_01_slice66.npy',
 '000083_07_01_114_lesion_01_slice67.npy',
 '000083_07_01_114_lesion_01_slice68.npy',
 '000425_01_01_058_lesion_01_slice61.npy',
 '000425_01_01_058_lesion_01_slice62.npy',
 '000425_01_01_058_lesion_01_slice63.npy',
 '000425_01_01_058_lesion_01_slice64.npy',
 '000425_01_01_058_lesion_01_slice65.npy',
 '000425_01_01_058_lesion_01_slice66.npy',
 '000425_01_01_058_lesion_01_slice67.npy',
 '000482_02_01_099_lesion_01_slice62.npy',
 '000482_02_01_099_lesion_01_slice63.npy',
 '000482_02_01_099_lesion_01_slice64.npy',
 '000482_02_01_099_lesion_01_slice65.npy',
 '000482_02_01_099_lesion_01_slice66.npy',
 '000482_02_01_099_lesion_01_slice67.npy',
 '000482_02_01_099_lesion_01_slice68.npy',
 '000482_02_01_099_lesion_01_slice69.npy',
 '000482_02_01_099_lesion_01_slice70.npy',
 '000482_02_01_099_lesion_01_slice71.npy',
 '000482_02_01_099_lesion_01_slice72.npy',
 '000482_02_01_099_lesion_01_slice73.npy',
 '000482_02_01_099_lesion_01_slice74.npy',
 '000879_03_01_054_lesion_01_slice63.npy',
 '000879_03_01_054_lesion_01_slice64.npy',
 '000879_03_01_054_lesion_01_slice65.npy',
 '000879_03_01_054_lesion_01_slice66.npy',
 '000879_03_01_054_lesion_01_slice67.npy',
 '000879_03_01_054_lesion_01_slice68.npy',
 '000879_03_01_054_lesion_01_slice69.npy',
 '000879_03_01_054_lesion_01_slice70.npy',
 '000879_03_01_054_lesion_01_slice71.npy',
 '000879_03_01_054_lesion_01_slice72.npy']

In [None]:
# visualize data and model output
from torchmetrics.segmentation import DiceScore

dice = DiceScore(1)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=workers)

res = []
n = 0
for i, data in enumerate(val_loader):
    # if data['name'][0] != "000738_01_01_075_lesion_01_slice62.npy": #not in large_lesion_names:
    #     continue

    box_area_ratio = data['box_mask'].mean().item()

    img = data['input'].squeeze()[0]
    box_mask = data['input'].squeeze()[1]
    out_raw = unet.model.to(device)(data['input'].to(device)).detach().cpu().squeeze()
    out = out_raw.clone() > 0.5

    gt = data['mask'].bool().squeeze()
    pred = out.bool().squeeze()

    # get metrics
    tp = gt & pred
    fp = ~gt & pred
    fn = gt & ~pred

    dice_score = dice(pred.unsqueeze(0), gt.unsqueeze(0))
    vol_sim = 1 - abs(pred.sum() - gt.sum()) / gt.sum()

    res.append({"dice": dice_score.item(), "vs": vol_sim.item(), "area_ratio": box_area_ratio})
    if fp.sum() + fn.sum() > max(30, (0.3 * gt.sum())):
    # if True:
        # continue
        print(f"dice: {dice_score:.3f}, volume similarity: {vol_sim:.3f}")
        print(data['name'])
        fig, ax = plt.subplots(1, 3, figsize=(6,2))
        ax[0].imshow(img, vmin=0, vmax=1, cmap='gray')

        ax[1].imshow(img, vmin=0, vmax=1, cmap='gray')
        ax[1].imshow(out * tp, cmap=transparent_cmap("green"), alpha=0.6)
        ax[1].imshow(out * fp, cmap=transparent_cmap("red"), alpha=0.6)
        ax[1].imshow(fn, cmap=transparent_cmap("blueviolet"), alpha=0.6)

        ax[2].imshow(out_raw, vmin=0, vmax=1, cmap="summer_r", alpha=1)
        ax[2].imshow(box_mask, cmap=transparent_cmap("cyan"), alpha=0.2)

        [a.axis('off') for a in ax]
        plt.show()
        n += 1
    #
    # if n > 10:
    #     break

In [None]:
import pandas as pd

df = pd.DataFrame(res)
df.describe().round(3).loc[['50%', 'mean']]

In [None]:
import pandas as pd
import seaborn as sns

In [None]:
# get lesion stats
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=workers)

sizes = []
boxes = []
h_ratios = []
w_ratios = []
area_ratios = []
for i, data in enumerate(val_loader):
    size = data['original_size']
    box = [x.item() for x in data['original_bbox']]

    area_ratios.append(data['mask'][0].mean().item())
    x0, y0, x1, y1 = get_seg_bbox(data['box_mask'][0])
    w_ratios.append((x1 - x0) / 128)
    h_ratios.append((y1 - y0) / 128)

    sizes.append(size[0].item())
    boxes.append(box)
    # print(box)
    # break

df = pd.DataFrame(boxes, columns=["xmin", "ymin", "xmax", "ymax"])

df['width'] = df['xmax'] - df['xmin']
df['height'] = df['ymax'] - df['ymin']

df['width_ratio'] = w_ratios
df['height_ratio'] = h_ratios
df['box_area_ratio'] = df['width_ratio'] * df['height_ratio']
df['area_ratio'] = area_ratios

In [None]:
sns.jointplot(
    data=df, x="width_ratio", y="height_ratio", height=5,
    size=df['box_area_ratio'].tolist(),
    hue='box_area_ratio',
    legend=None,
);

In [None]:
sns.jointplot(
    data=df, x="width_ratio", y="height_ratio", height=5,
    size=df['box_area_ratio'].tolist(),
    hue='box_area_ratio',
    legend=None,
);

In [None]:
sns.jointplot(
    data=df, x="width_ratio", y="height_ratio", height=5,
    size=df['box_area_ratio'].tolist(),
    hue='box_area_ratio',
    legend=None,
);

In [None]:
sns.jointplot(data=df, x="width", y="height", height=5,
              size=df['area_ratio'].tolist(),
              # hue='box_area_ratio'
              );