In [1]:
from train import RSNA2024Stage1LightningModel, RSNA2024Stage1Dataset, RSNA2024Stage1DataModule, EXP_ID
import pandas as pd
from sklearn.model_selection import GroupKFold
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython import display
from glob import glob
import cv2
import torch

In [2]:
series_description = "Sagittal_T2-STIR"
# series_description = "Sagittal_T1"
fold = 0

ckpt_paths = [
    glob(f"../../../logs/stage1/exp{EXP_ID}/resnetrs50_unet_scse_20x256x256_mixup/{series_description}/fold{fold}/**/best_loss.ckpt", recursive=True)[0],
    glob(f"../../../logs/stage1/exp{EXP_ID}/caformer_s18_unet_scse_20x256x256_mixup/{series_description}/fold{fold}/**/best_loss.ckpt", recursive=True)[0],
    glob(f"../../../logs/stage1/exp{EXP_ID}/swinv2_tiny_unet_scse_20x256x256_mixup/{series_description}/fold{fold}/**/best_loss.ckpt", recursive=True)[0],
    glob(f"../../../logs/stage1/exp{EXP_ID}/convnext_tiny_unet_scse_20x256x256_mixup/{series_description}/fold{fold}/**/best_loss.ckpt", recursive=True)[0],
]
models = [RSNA2024Stage1LightningModel.load_from_checkpoint(ckpt_path).model_ema.module.to("cpu").eval() for ckpt_path in ckpt_paths]

In [3]:
idx = 14

In [None]:
train_descriptions = pd.read_csv(
    "../../../input/rsna-2024-lumbar-spine-degenerative-classification/train_series_descriptions.csv"
)
train_coord_df = pd.read_csv(
    "../../../input/rsna-2024-lumbar-spine-degenerative-classification/train_label_coordinates.csv"
)
train_coord_df = train_coord_df.merge(
    train_descriptions, on=["series_id", "study_id"], how="left"
)
train_df = pd.read_csv(
    "../../../input/rsna-2024-lumbar-spine-degenerative-classification/train.csv"
)
train_df["fold_id"] = -1
for i, (train_index, valid_index) in enumerate(
    GroupKFold(n_splits=5).split(
        train_df, np.arange(len(train_df)), train_df.study_id
    )
):
    train_df.loc[valid_index, "fold_id"] = i
train_coord_df = train_coord_df.merge(
    train_df.loc[:, ["study_id", "fold_id"]], on=["study_id"], how="left"
)
train_coord_df = train_coord_df.sort_values(
    by=["study_id", "series_id", "level", "condition"]
).reset_index(drop=True)
coord_label_num = (
    train_coord_df.groupby("series_id")
    .count()
    .sort_values("series_id")
    .study_id.to_numpy()
)
SCS = (
    train_coord_df.groupby("series_id")
    .head(1)
    .sort_values("series_id")
    .condition.to_numpy()
)
series_id = (
    train_coord_df.groupby("series_id").head(1).sort_values("series_id").series_id
)
new_series_id = series_id[
    ((SCS == "Spinal Canal Stenosis") & (coord_label_num == 5))
    | ((SCS != "Spinal Canal Stenosis") & (coord_label_num == 10))
]
new_train_coord_df = (
    train_coord_df[train_coord_df.series_id.isin(new_series_id)]
    .sort_values(by=["series_id", "level"])
    .reset_index(drop=True)
)
wrong_series_ids = [
    221289021,
    1735851779,
    880361156,
    1921917205,
    2231471633,
    737753815,
    1488857550,
    3736941525,
    1490272456,
    3086719329,
    1485193299,
    3521409198,
    816381378,
]

new_train_coord_df =  new_train_coord_df[~new_train_coord_df.series_id.isin(wrong_series_ids)]

clean_keypoints = pd.read_csv(
    "../../../input/lumbar-coordinate-pretraining-dataset/coords_rsna_improved.csv",
    index_col=0,
)
clean_keypoints = pd.concat(
    [
        clean_keypoints[
            (clean_keypoints.condition == "Spinal Canal Stenosis")
            & (clean_keypoints.side == "R")
        ],
        clean_keypoints[~(clean_keypoints.condition == "Spinal Canal Stenosis")],
    ]
)
new_train_coord_df = new_train_coord_df.merge(
    clean_keypoints.loc[
        :, ["series_id", "level", "condition", "relative_x", "relative_y"]
    ],
    on=["series_id", "level", "condition"],
)

new_train_coord_df = new_train_coord_df[
    new_train_coord_df.series_description
    == series_description.replace("_", " ").replace("-", "/")
].reset_index()

train_df = new_train_coord_df[new_train_coord_df.fold_id != fold].reset_index(
    drop=True
)
valid_df = new_train_coord_df[new_train_coord_df.fold_id == fold].reset_index(
    drop=True
)

dataset = RSNA2024Stage1Dataset(valid_df, mode="valid")
data = dataset[idx]
print(idx)
idx += 1

In [None]:
threshold = 0.3
volume = data["volume"]
volume_vis = volume[..., None].repeat(1, 1, 1, 3).numpy()
fig = plt.figure(figsize=(3, 3))
im = plt.imshow(volume_vis[0])
def draw(i):
    im.set_array(volume_vis[i])
    return [im]
anim = animation.FuncAnimation(
    fig, draw, frames=volume_vis.shape[0], interval=200, blit=True
)
plt.close()
display.HTML(anim.to_jshtml())

In [None]:
anno_map = np.asarray([
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255),
    (0, 127, 127),
    (127, 0, 127),
])
label = data["label"]
label = label[..., None].repeat(1, 1, 1, 3).numpy()
label = label * anno_map[:, None, None] / 255
volume_vis_gt = volume_vis.copy()
for l in label:
    volume_vis_gt = volume_vis_gt + l[None]
fig = plt.figure(figsize=(3, 3))
im = plt.imshow(volume_vis_gt[0])
def draw(i):
    im.set_array(volume_vis_gt[i])
    return [im]
anim = animation.FuncAnimation(
    fig, draw, frames=volume_vis_gt.shape[0], interval=200, blit=True
)
plt.close()
display.HTML(anim.to_jshtml())

In [None]:
anno_map = np.asarray([
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255),
    (0, 127, 127),
    (127, 0, 127),
])
# output = model(volume[None, ...])
outputs = [model(volume[None, ...]) for model in models]
# output = model(volume[None, :, :200])
# logit = output["logit"].sigmoid()[0]
logit = torch.stack([output["logit"].sigmoid()[0] for output in outputs]).mean(0)

heatmap = logit.detach()
heatmap = heatmap[..., None].repeat(1, 1, 1, 3).numpy()
heatmap = heatmap * (heatmap > threshold)
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6)
heatmap = heatmap * anno_map[:, None, None] / 255
volume_vis2 = volume_vis.copy()
# volume_vis2 = volume_vis[:, :200].copy()
for h in heatmap:
    volume_vis2 = volume_vis2 + h[None]
fig = plt.figure(figsize=(3, 3))
im = plt.imshow(volume_vis2[0])
def draw(i):
    im.set_array(volume_vis2[i])
    return [im]
anim = animation.FuncAnimation(
    fig, draw, frames=volume_vis2.shape[0], interval=200, blit=True
)
plt.close()
display.HTML(anim.to_jshtml())

In [None]:
anno_map = [
    (255, 0, 0),
    (0, 255, 0),
    (0, 0, 255),
    (0, 127, 127),
    (127, 0, 127),
]
# output = model(volume[None, ...])
outputs = [model(volume[None, ...]) for model in models]
# output = model(volume[None, :, :200])µ
# logit = output["logit"].sigmoid()[0]
logit = torch.stack([output["logit"].sigmoid()[0] for output in outputs]).mean(0)

heatmap = logit.detach()
heatmap = heatmap * (heatmap > threshold)
volume_vis3 = (volume_vis * 255).astype(np.uint8)
# volume_vis3 = (volume_vis[:, :128] * 255).astype(np.uint8)
fig = plt.figure(figsize=(3, 3))
im = plt.imshow(volume_vis3[0])
def draw(i):
    for c, h in enumerate(heatmap):
        if h.max() < threshold:
            continue
        y = h.sum(1).argmax()
        x = h.sum(0).argmax()
        # y, x = np.unravel_index(h.argmax(), h.shape)
        volume_vis3[i] = cv2.circle(volume_vis3[i], (int(x), int(y)), 5, anno_map[c], -1)
    im.set_array(volume_vis3[i])
    return [im]
anim = animation.FuncAnimation(
    fig, draw, frames=volume_vis3.shape[0], interval=200, blit=True
)
plt.close()
display.HTML(anim.to_jshtml())

In [9]:
# series_description = "Sagittal_T2-STIR"
# fold = 0
# 14, 72