In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import glob
import os
from pathlib import Path

import matplotlib.pyplot as plt
import torch

from experiments.model_utils import load_medsam
from experiments.utils import (
    load_process_csv, process_slice, process_bbox_str, segment, split_seg, get_slices_filenames,
)
from experiments.viz_utils import plot_results

device = 'cuda' if torch.cuda.is_available() else 'cpu'
medsam = load_medsam("../work_dir/MedSAM/medsam_vit_b.pth", device)

In [None]:
data_folder = Path("/home/liushifeng/Desktop/DeepLesion Dataset/")
df = load_process_csv(data_folder / "DL_info.csv")

batch_folders = [data_folder / f for f in os.listdir(data_folder) if f.startswith("Images_png_") and ".zip" not in f]

In [None]:
{i+1:x for i, x in enumerate("bone, abdomen, mediastinum, liver, lung, kidney, soft tissue, pelvis".split(", "))}

In [None]:
lesion_types = {
    -1: None,
    1: 'bone',
    2: 'abdomen',
    3: 'mediastinum',
    4: 'liver',
    5: 'lung',
    6: 'kidney',
    7: 'soft tissue',
8: 'pelvis'}

## Segment 2D slices

In [None]:
# %matplotlib inline
scan_folders = [Path(f) for f in glob.glob(str(data_folder / "Images_png_*/Images_png/*"))]

In [None]:
# %matplotlib inline
# every scan_folder contains N png files
for scan_folder in scan_folders:
    scan_name = scan_folder.stem
    print("Scan name:", scan_name)
    df_scan = df[df['scan_name'].eq(scan_name)]

    for i in range(len(df_scan)):
        df_slice = df_scan.iloc[i]
        k = df_slice['Key_slice_index']

        key_slice_path = scan_folder / df_slice['file_name']
        lung, abdomen = process_slice(key_slice_path, rgb=True)

        bbox = [round(x) for x in process_bbox_str(df_slice['Bounding_boxes'])]
        seg = segment(abdomen, bbox, medsam)
        segs = split_seg(seg)

        plot_results(abdomen, [bbox], segs, plot=True,
                     save_path=f"outputs/2d segs/{scan_name}_{k}")
        break
    break

## Extend to 3D

In [None]:
from experiments.utils import slice_num, segment_slices
from PIL import Image

# every scan_folder contains N png slices
ids = [0, 2, 3, 6, 20, 29, 36, 40, 45]
lesion_type = 5
for si, scan_folder in enumerate(scan_folders):
    print(si)
    if si not in ids:
        continue
    scan_name = scan_folder.stem
    df_scan = df[df['scan_name'].eq(scan_name)]

    # if lesion_type not in set(df_scan['Coarse_lesion_type']):
    #     continue

    # looping through all the png slices
    for i in range(len(df_scan)):
        df_slice = df_scan.iloc[i]
        # if df_slice['Coarse_lesion_type'] != lesion_type:
        #     continue

        scan = df_slice['scan_name']
        k = df_slice['Key_slice_index']
        slice_files = get_slices_filenames(df_slice['Slice_range'].split(", "))

        up_paths = [scan_folder / f for f in slice_files if slice_num(f) >= k]
        down_paths = [scan_folder / f for f in slice_files if slice_num(f) <= k][::-1]

        up_slices = [Image.open(x) for x in up_paths]
        down_slices = [Image.open(x) for x in down_paths]
        up_indices = [x.stem for x in up_paths]
        down_indices = [x.stem for x in down_paths]

        plot = True
        save = False
        window = True
        up_slice_segs = segment_slices(
            medsam, up_slices, df_slice['bbox'], scan, up_indices, window, plot=plot, save=save
        )
        # down_slice_segs = segment_slices(
        #     medsam, down_slices, df_slice['bbox'], scan, down_indices, window, plot=plot, save=save
        # )
    #     break
    # break

In [None]:
ct_path = "/media/liushifeng/KINGSTON/ULS Jan 2025/part1/ULS23/novel_data/ULS23_DeepLesion3D/images/003717_02_01_056_lesion_01.nii.gz"

In [None]:
import numpy as np
import pandas as pd

In [None]:
def get_slice_stats(slices, slice_segs):
    median_conf = []
    size = []
    intensity = []

    for i in range(len(slices)):
        img = np.array(slices[i])
        seg = slice_segs[i]
        low, mid, high = split_seg(seg)

        median_conf.append(np.median(seg[seg > 0.1]))
        size.append((mid.sum() + high.sum()))
        intensity.append(np.mean(img[high.astype(bool)]))

    # get stats of slices
    res = pd.DataFrame([median_conf, size, intensity]).T
    res.columns = ['median_conf', 'size', 'intensity']
    res = res / res.iloc[0]
    return res

In [None]:
res_up = get_slice_stats(up_slices, up_slice_segs)
res_down = get_slice_stats(down_slices, down_slice_segs)

res = pd.concat([res_up.iloc[1:].iloc[::-1], res_down], axis=0).reset_index(drop=True)
res.plot(figsize=(5,4));

In [None]:
import numpy as np

# Calculate the centroid of the segs[2] 2D mask
mask = segs[2]
y, x = np.where(mask == 1)
centroid = (np.mean(x), np.mean(y)) if len(x) > 0 and len(y) > 0 else (None, None)
print("Centroid:", centroid)

In [None]:
plt.figure(figsize=(6, 5));
plt.imshow(abdomen);
plt.imshow(segs[2], cmap=transparent_cmap("red"));
plt.colorbar();
plt.axis('off');