In [1]:
import sys
sys.path.append("../stage1/exp010")
from train import RSNA2024Stage1LightningModel, RSNA2024Stage1Dataset, RSNA2024Stage1DataModule, EXP_ID
import pandas as pd
from sklearn.model_selection import GroupKFold
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython import display
from glob import glob
import cv2

import pydicom

In [4]:
# series_description = "Sagittal_T2-STIR"
series_description = "Sagittal_T1"
fold = 0

ckpt_path = glob(f"../../logs/stage1/exp{EXP_ID}/resnetrs50_unet_scse_20x256x256_mixup/{series_description}/fold{fold}/**/best_loss.ckpt", recursive=True)[0]
model = RSNA2024Stage1LightningModel.load_from_checkpoint(ckpt_path).model.to("cpu").eval()


In [5]:
idx = 0

In [None]:
train_descriptions = pd.read_csv(
    "../../input/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv"
)
train_coord_df = pd.read_csv(
    "../../input/rsna-2024-lumbar-spine-degenerative-classification/train_label_coordinates.csv"
)
train_coord_df = train_coord_df.merge(
    train_descriptions, on=["series_id", "study_id"], how="left"
)
train_df = pd.read_csv(
    "../../input/rsna-2024-lumbar-spine-degenerative-classification/train.csv"
)
train_df["fold_id"] = -1
for i, (train_index, valid_index) in enumerate(
    GroupKFold(n_splits=5).split(
        train_df, np.arange(len(train_df)), train_df.study_id
    )
):
    train_df.loc[valid_index, "fold_id"] = i
train_coord_df = train_coord_df.merge(
    train_df.loc[:, ["study_id", "fold_id"]], on=["study_id"], how="left"
)
train_coord_df = train_coord_df.sort_values(
    by=["study_id", "series_id", "level", "condition"]
).reset_index(drop=True)
coord_label_num = (
    train_coord_df.groupby("series_id")
    .count()
    .sort_values("series_id")
    .study_id.to_numpy()
)
SCS = (
    train_coord_df.groupby("series_id")
    .head(1)
    .sort_values("series_id")
    .condition.to_numpy()
)
series_id = (
    train_coord_df.groupby("series_id").head(1).sort_values("series_id").series_id
)
new_series_id = series_id[
    ((SCS == "Spinal Canal Stenosis") & (coord_label_num == 5))
    | ((SCS != "Spinal Canal Stenosis") & (coord_label_num == 10))
]
new_train_coord_df = (
    train_coord_df[train_coord_df.series_id.isin(new_series_id)]
    .sort_values(by=["series_id", "level"])
    .reset_index(drop=True)
)
wrong_series_ids = [
    221289021,
    1735851779,
    880361156,
    1921917205,
    2231471633,
    737753815,
    1488857550,
    3736941525,
    1490272456,
    3086719329,
    1485193299,
    3521409198,
    816381378,
]

new_train_coord_df =  new_train_coord_df[~new_train_coord_df.series_id.isin(wrong_series_ids)]

new_train_coord_df = new_train_coord_df[
    new_train_coord_df.series_description
    == series_description.replace("_", " ").replace("-", "/")
].reset_index()

valid_df = new_train_coord_df[new_train_coord_df.fold_id == fold].reset_index(
    drop=True
)

dataset = RSNA2024Stage1Dataset(valid_df, mode="valid")
data = dataset[idx]
print(idx)
idx += 1

In [None]:
volume = data["volume"]
series_id = data["series_id"]
study_id = data["study_id"]

volume_vis = volume[..., None].repeat(1, 1, 1, 3).numpy()
anno_map = np.asarray([
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255),
    (0, 127, 127),
    (127, 0, 127),
])
label = data["label"]
label = label[..., None].repeat(1, 1, 1, 3).numpy()
label = label * anno_map[:, None, None] / 255
volume_vis_gt = volume_vis.copy()
for l in label:
    volume_vis_gt = volume_vis_gt + l[None]
fig = plt.figure(figsize=(3, 3))
im = plt.imshow(volume_vis_gt[0])
def draw(i):
    im.set_array(volume_vis_gt[i])
    return [im]
anim = animation.FuncAnimation(
    fig, draw, frames=volume_vis_gt.shape[0], interval=200, blit=True
)
plt.close()
display.HTML(anim.to_jshtml())

In [None]:
threshold = 0.3

anno_map = np.asarray([
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255),
    (0, 127, 127),
    (127, 0, 127),
])
output = model(volume[None, ...])
# output = model(volume[None, :, :128])
logit = output["logit"].sigmoid()[0]
heatmap = logit.detach()
heatmap = heatmap * (heatmap > threshold)
keypoints = {}
for c, h in enumerate(heatmap):
    if h.max() < threshold:
        keypoints[c] = []
    else:
        y = h.sum(1).argmax()
        x = h.sum(0).argmax()
        keypoints[c] = [x / h.shape[1], y / h.shape[0]]
print(keypoints)
heatmap = heatmap[..., None].repeat(1, 1, 1, 3).numpy()

heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6)
heatmap = heatmap * anno_map[:, None, None] / 255
volume_vis2 = volume_vis.copy()
# volume_vis2 = volume_vis[:, :128].copy()

for h in heatmap:
    volume_vis2 = volume_vis2 + h[None]
fig = plt.figure(figsize=(3, 3))
im = plt.imshow(volume_vis2[0])
def draw(i):
    im.set_array(volume_vis2[i])
    return [im]
anim = animation.FuncAnimation(
    fig, draw, frames=volume_vis2.shape[0], interval=200, blit=True
)
plt.close()
display.HTML(anim.to_jshtml())

In [None]:
glob(f"../../input/rsna-2024-lumbar-spine-degenerative-classification/train_images/**/{study_id}/*", recursive=True) 

In [80]:
# planes = {"sagittal": 0, "coronal": 1, "axial": 2}
def get_dicom_geometry(series_id):
    series_description = train_descriptions[train_descriptions.series_id == int(series_id)].series_description.iloc[0]
    if series_description == "Axial T2":
        planes = 2
    else:
        planes = 0
    dcm_paths = glob(f"../../input/rsna-2024-lumbar-spine-degenerative-classification/train_images/**/{series_id}/*.dcm", recursive=True) 
    dicoms = [pydicom.dcmread(f) for f in dcm_paths]
    positions = np.asarray([float(d.ImagePositionPatient[planes]) for d in dicoms])
    idx = np.argsort(-positions)
    positions = np.asarray([d.ImagePositionPatient for d in dicoms])[idx]
    orientations = np.asarray([d.ImageOrientationPatient for d in dicoms])[idx]
    pixel_spacing = np.asarray(dicoms[0].PixelSpacing).astype("float") # (xy)
    img_sizes = np.asarray([[d.Rows, d.Columns] for d in dicoms])[idx] # (yx)
    return {"positions": positions, "orientations": orientations, "pixel_spacing": pixel_spacing, "img_sizes": img_sizes}

In [None]:
sag = get_dicom_geometry(series_id)
top_left_hand_corner_sag = sag["positions"][len(sag["positions"]) // 2]
img_size = sag["img_sizes"][len(sag["positions"]) // 2]
top_left_hand_corner_sag, img_size

In [None]:
sag["positions"]

In [83]:
# sag_y_axis_to_pixel_space = [top_left_hand_corner_sag[2]]
sag_y_axis_to_pixel_space = np.arange(top_left_hand_corner_sag[2], top_left_hand_corner_sag[2] - sag["pixel_spacing"][1] * img_size[0], step=-sag["pixel_spacing"][1])
sag_x_axis_to_pixel_space = np.arange(top_left_hand_corner_sag[1], top_left_hand_corner_sag[1] + sag["pixel_spacing"][0] * img_size[1], step=sag["pixel_spacing"][0])

In [84]:
axial_series_id = train_descriptions[(train_descriptions.study_id == int(study_id)) & (train_descriptions.series_description == "Axial T2")].series_id.iloc[0]
axial = get_dicom_geometry(axial_series_id)

In [85]:
y_start_points = np.argmin(np.abs((sag_y_axis_to_pixel_space[:, None] - axial["positions"][:, 2][None, :])), 0) / img_size[0]
x_start_points = np.argmin(np.abs((sag_x_axis_to_pixel_space[:, None] - axial["positions"][:, 1][None, :])), 0) / img_size[1]

In [None]:
img = volume_vis[len(volume_vis) // 2].copy()
for i, (x, y) in enumerate(zip(x_start_points, y_start_points)):
    start = (int(x * img.shape[1]), int(y * img.shape[0]))    
    img = cv2.circle(img, start, 1, (255, 0, 0), -1)
    cos = axial["orientations"][i][4]
    sin = np.sin(np.arccos(cos))
    x_end = int(start[0] + axial["img_sizes"][i, 1] * cos * axial["pixel_spacing"][0] / sag["pixel_spacing"][0] / axial["img_sizes"][i, 1] * img.shape[1])
    y_end = int(start[1] - np.sign(axial["orientations"][i][5]) * axial["img_sizes"][i, 1] * sin * axial["pixel_spacing"][1] / sag["pixel_spacing"][1] / axial["img_sizes"][i, 0] * img.shape[0])
    end = (x_end, y_end)
    # print(start, end)
    img = cv2.circle(img, end, 1, (0, 255, 0), -1)
    img = cv2.line(img, start, end, (0, 0, 255), thickness=1, lineType=cv2.LINE_8, shift=0)
plt.imshow(img)

In [None]:
anno_map = [
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255),
    (0, 127, 127),
    (127, 0, 127),
]
output = model(volume[None, ...])
# output = model(volume[None, :, :128])
logit = output["logit"].sigmoid()[0]
heatmap = logit.detach()
heatmap = heatmap * (heatmap > threshold)
keypoints = []
for c, h in enumerate(heatmap):
    if h.max() < threshold:
        keypoints.append([])
    else:
        y = h.sum(1).argmax()
        x = h.sum(0).argmax()
        keypoints.append([x / h.shape[1], y / h.shape[0]])
print(keypoints)
volume_vis2 = volume_vis.copy()
fig = plt.figure(figsize=(3, 3))
im = plt.imshow(volume_vis2[0])
def draw(i):
    img = volume_vis2[i].copy()
    for i, (x, y) in enumerate(zip(x_start_points, y_start_points)):
        start = (int(x * img.shape[1]), int(y * img.shape[0]))    
        img = cv2.circle(img, start, 1, (255, 0, 0), -1)
        cos = axial["orientations"][i][4]
        sin = np.sin(np.arccos(cos))
        x_end = int(start[0] + axial["img_sizes"][i, 1] * cos * axial["pixel_spacing"][0] / sag["pixel_spacing"][0] / axial["img_sizes"][i, 1] * img.shape[1])
        y_end = int(start[1] - np.sign(axial["orientations"][i][5]) * axial["img_sizes"][i, 1] * sin * axial["pixel_spacing"][1] / sag["pixel_spacing"][1] / axial["img_sizes"][i, 0] * img.shape[0])
        end = (x_end, y_end)
        # print(start, end)
        img = cv2.circle(img, end, 1, (0, 255, 0), -1)
        img = cv2.line(img, start, end, (0, 0, 255), thickness=1, lineType=cv2.LINE_8, shift=0)

    for i, k in enumerate(keypoints):
        if len(k) == 0:
            continue
        img = cv2.circle(img, (int(k[0] * img.shape[1]), int(k[1] * img.shape[0])), 5, anno_map[i], -1)
    im.set_array(img)
    return [im]
anim = animation.FuncAnimation(
    fig, draw, frames=volume_vis2.shape[0], interval=200, blit=True
)
plt.close()
display.HTML(anim.to_jshtml())