<p align="center">
    <a href="https://predict-idlab.github.io/landmarker">
        <img alt="landmarker" src="https://raw.githubusercontent.com/predict-idlab/landmarker/main/docs/_static/images/logo.svg" width="66%">
    </a>
</p>

# Training and Evaluating One-hot Encoded (Mask) Regression Model for Landmark Localizatioin on UWSpineCT (3D)

In this tutorial, we will train and evaluate an one-hot encoded (mask) regression model for landmark 
localization on UWSpineCT. The UWSpineCT dataset, which consists of XX
annotated CT volumes of the spine. The CT volumes are transformed to a uniform scale of 128 × 128 × 64.

We will go through the following steps:
* [Loading the dataset](#Loading-the-dataset)
* [Inspecting the dataset](#Inspecting-the-dataset)
* [Training and initializing the UNet model](#Training-the-model)
* [Evaluating the model](#Evaluating-the-model)

<a target="_blank" href="https://colab.research.google.com/github/predict-idlab/landmarker/examples/3D-example-UWSpineCT-maskdataset.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Setup environment

In [1]:
# !python -c "import landmarker" || pip install landmarker

import sys
import os

sys.path.append("../src/")
import landmarker

## Setup imports and variables

In [2]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'

## Loading the dataset

### Short description of the data and dataset module
The [landmarker](https://github.com/predict-idlab/landmarker) package has several built-in
datasets in the `landmarker.datasets` module, as well as utility classes for building your own
datasets in the `landmarker.data` module. There are three types of datasets: 'LandmarkDataset',
'HeatmapDataset', and 'MaskDataset'. The 'LandmarkDataset' is a dataset of images with landmarks,
the 'HeatmapDataset' is a dataset of images with heatmaps, and the 'MaskDataset' is a dataset of
images with masks (i.e., binary segmentation masks indiciating the location of the landmarks). The 
'HeatmapDataset' and 'MaskDataset' both inherit from the 'LandmarkDataset' class, and thus also 
contain information about the landmarks. The 'MaskDataset' can be constructed from specified image 
and landmarks pairs, or from images and masks pairs, because often that is how the data is
distributed. The 'HeatmapDataset' can be constructed from images and landmarks pairs.

Images can be provided as a list of paths to stored images, or as a a numpy arary, torch tensor, 
list of numpy  arrays or list of torch tensors. Landmarks can be as numpy arrays or torch tensors.
These landmarks can be provided in three different shapes: (1) (N, D) where N is the number of
samples and D is the number of dimensions, (2) (N, C, D) where C is the number of landmark
classes, (3) (N, C, I, D) where I is the number of instances per landmark class, if less than I
instances are provided, the remaining instances are filled with NaNs.

In [4]:
from monai.transforms import (Compose, ScaleIntensityd)

fn_keys = ('image', 'mask')
spatial_transformd = []

train_transformd = Compose([
                            ScaleIntensityd(('image', )),  # Scale intensity
                        ] + spatial_transformd)

inference_transformd = Compose([
    ScaleIntensityd(('image', )),
])

In [None]:
from glob import glob

import pandas as pd

path_data = "/Users/jefjonkers/Data/landmark-datasets/UWSpineCT"

# Warning ensure that lists of paths for volumes, landmarks, and pixel_spacings

# volume paths
volume_paths_train_1 = sorted(glob(f"{path_data}/spine-1/*/*/*.nii.gz"))
volume_paths_train_2 = sorted(glob(f"{path_data}/spine-2/*/*/*.nii.gz"))
volume_paths_train_3 = sorted(glob(f"{path_data}/spine-3/*/*/*.nii.gz"))
volume_paths_train_4 = sorted(glob(f"{path_data}/spine-4/*/*/*.nii.gz"))
volume_paths_train_5 = sorted(glob(f"{path_data}/spine-5/*/*/*.nii.gz"))
volume_paths_test = sorted(glob(f"{path_data}/spine-test-data/*.nii.gz"))
volume_paths_train = volume_paths_train_1 + volume_paths_train_2 + volume_paths_train_3 + volume_paths_train_4
volume_paths_val = volume_paths_train_5

# landmark paths and transform to single numpy arrays for each set
landmark_paths_train_1 = sorted(glob(f"{path_data}/spine-1/*/*/*.lml"))
landmark_paths_train_2 = sorted(glob(f"{path_data}/spine-2/*/*/*.lml"))
landmark_paths_train_3 = sorted(glob(f"{path_data}/spine-3/*/*/*.lml"))
landmark_paths_train_4 = sorted(glob(f"{path_data}/spine-4/*/*/*.lml"))
landmark_paths_train_5 = sorted(glob(f"{path_data}/spine-5/*/*/*.lml"))
landmark_paths_test = sorted(glob(f"{path_data}/spine-test-data/*.lml"))
landmark_paths_train = landmark_paths_train_1 + landmark_paths_train_2 + landmark_paths_train_3 + landmark_paths_train_4
landmark_paths_val = landmark_paths_train_5

df_landmarks_train = []
for i, path in enumerate(landmark_paths_train):
    df = pd.read_csv(path,
                     sep='\s+',
                     skiprows=1,
                     header=None)
    df.columns = ["ID", "Label", "Y", "X", "Z"]
    df["file_id"] = f"train_{i:03d}"
    df["source_file"] = path   # keep track of origin file
    df_landmarks_train.append(df)

df_landmarks_train = pd.concat(df_landmarks_train)

df_landmarks_val = []
for i, path in enumerate(landmark_paths_val):
    df = pd.read_csv(path,
                     sep='\s+',
                     skiprows=1,
                     header=None)
    df.columns = ["ID", "Label", "Y", "X", "Z"]
    df["file_id"] = f"val_{i:03d}"
    df["source_file"] = path   # keep track of origin file
    df_landmarks_val.append(df)

df_landmarks_val = pd.concat(df_landmarks_val)

df_landmarks_test = []
for i, path in enumerate(landmark_paths_test):
    df = pd.read_csv(path,
                     sep='\s+',
                     skiprows=1,
                     header=None)
    df.columns = ["ID", "Label", "Y", "X", "Z"]
    df["file_id"] = f"test_{i:03d}"
    df["source_file"] = path   # keep track of origin file
    df_landmarks_test.append(df)

df_landmarks_test = pd.concat(df_landmarks_test)


In [22]:
import nibabel as nib

class_names = df_landmarks_train["Label"].unique().tolist()

landmarks_train = np.zeros((len(df_landmarks_train["file_id"].unique()), len(class_names), 3), dtype=np.float32) * np.nan
spacing_train = np.zeros((len(df_landmarks_train["file_id"].unique()), 3), dtype=np.float32)
for i, file_id in enumerate(df_landmarks_train["file_id"].unique()):
    img = nib.load(volume_paths_train[i])
    spacing_train[i] = img.header.get_zooms()[:3]
    affine = img.affine
    affine_inv = np.linalg.inv(affine)
    df = df_landmarks_train[df_landmarks_train["file_id"] == file_id]
    for j, class_name in enumerate(class_names):
        coords = df[df["Label"] == class_name][["Y", "X", "Z"]].values
        if len(coords) > 0:
            coords = coords[0]
            landmarks_train[i, j] = coords / spacing_train[i]

landmarks_val = np.zeros((len(df_landmarks_val["file_id"].unique()), len(class_names), 3), dtype=np.float32) * np.nan
spacing_val = np.zeros((len(df_landmarks_val["file_id"].unique()), 3), dtype=np.float32)
for i, file_id in enumerate(df_landmarks_val["file_id"].unique()):
    img = nib.load(volume_paths_val[i])
    spacing_val[i] = img.header.get_zooms()[:3]
    affine = img.affine
    affine_inv = np.linalg.inv(affine)
    df = df_landmarks_val[df_landmarks_val["file_id"] == file_id]
    for j, class_name in enumerate(class_names):
        coords = df[df["Label"] == class_name][["Y", "X", "Z"]].values
        if len(coords) > 0:
            coords = coords[0]
            landmarks_val[i, j] = coords / spacing_val[i]

landmarks_test = np.zeros((len(df_landmarks_test["file_id"].unique()), len(class_names), 3), dtype=np.float32) * np.nan
spacing_test = np.zeros((len(df_landmarks_test["file_id"].unique()), 3), dtype=np.float32)
for i, file_id in enumerate(df_landmarks_test["file_id"].unique()):
    img = nib.load(volume_paths_test[i])
    spacing_test[i] = img.header.get_zooms()[:3]
    affine = img.affine
    affine_inv = np.linalg.inv(affine)
    df = df_landmarks_test[df_landmarks_test["file_id"] == file_id]
    for j, class_name in enumerate(class_names):
        coords = df[df["Label"] == class_name][["Y", "X", "Z"]].values
        if len(coords) > 0:
            coords = coords[0]
            landmarks_test[i, j] = coords / spacing_test[i]

In [23]:
dim_img = None

In [24]:
from landmarker.data import MaskDataset

ds_train = MaskDataset(
    imgs=volume_paths_train,
    landmarks=landmarks_train,
    spatial_dims=3,
    pixel_spacing= spacing_train,
    transform=train_transformd,
    store_imgs=False,
    dim_img=dim_img,
    resize_pad=False
)

ds_val = MaskDataset(
    imgs=volume_paths_val,
    landmarks=landmarks_val,
    spatial_dims=3,
    pixel_spacing=pixel_spacing_val,
    transform=inference_transformd,
    store_imgs=False,
    dim_img=dim_img,
    resize_pad=False
)

ds_test = MaskDataset(
    imgs=volume_paths_test,
    landmarks=landmarks_test,
    spatial_dims=3,
    pixel_spacing=pixel_spacing_test,
    transform=inference_transformd,
    store_imgs=False,
    dim_img=dim_img,
    resize_pad=False
)

## Inspecting the dataset

In [17]:
batch = ds_train[0]

In [22]:
# %pip install 'napari[all]'
# %pip install 'ipyvolume'

### Visualize 3D Image + Landmarks

In [19]:
batch["landmark"]

tensor([[174.9549, 233.4400,  96.0000],
        [164.6944, 236.5181,  84.0000],
        [143.1472, 254.9869,  71.0000],
        [115.4438, 264.2214,  58.0000],
        [ 98.0010, 259.0912,  44.0000],
        [ 99.0272, 219.0752,  28.0000],
        [172.9027, 173.9290,  12.0000],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],
        [     nan,      nan,      nan],


In [21]:
import napari
import numpy as np


viewer = napari.Viewer(ndisplay=3)

voxel_spacing = tuple(batch['spacing'].T.tolist())

viewer.add_image(batch["image"][0],
                 scale=voxel_spacing,
                 name="image")

# Add landmarks (as points layer)
viewer.add_points(
    np.stack((batch["landmark"][:, 1], batch["landmark"][:, 0], batch["landmark"][:, 2]), axis=-1),
    size=5,
    face_color="red",
    name="landmarks",
    scale=voxel_spacing
)

napari.run()


  ).astype(int)
  zoom_factor = np.divide(
  np.ceil(zoom_factor * np.array(shape[:2])).astype(int),
  np.ceil(zoom_factor * np.array(shape[:2])).astype(int),
  ).astype(int)
  ).astype(int)
  zoom_factor = np.divide(
  np.ceil(zoom_factor * np.array(shape[:2])).astype(int),
  np.ceil(zoom_factor * np.array(shape[:2])).astype(int),
  ).astype(int)


In [None]:
import nibabel as nib

img = nib.load(ds_train.img_paths[2])

In [None]:
img.affine

In [None]:
import ipyvolume as ipv
import numpy as np

# Show volume
ipv.quickvolshow(batch["image"][0], level=[0.3, 0.6], opacity=0.03)

# Extract xyz coords
z, y, x = batch["landmark"].T

# Plot landmarks
ipv.scatter(x, y, z, color="red", size=5, marker="sphere")

ipv.show()

### Visualize 3D Image + Masks + Landmarks

In [None]:
import napari

# Launch viewer
viewer = napari.Viewer(ndisplay=3)

# Add base volume
viewer.add_image(batch["image"][0], name="volume")

# Add landmarks (as points layer)
viewer.add_points(
    batch["landmark"],
    size=2,
    face_color="blue",
    name="landmarks"
)

# Add each channel as a label layer (colored mask)
for c in range(14):
    viewer.add_labels(batch["mask"][c].int(), name=f"landmark_{c}")

napari.run()

## Training and initializing the Unet model

### Initializing the model, optimizer and loss function

In [None]:
from monai.networks.nets import FlexibleUNet
from landmarker.losses import NLLLoss
from landmarker.models.utils import SoftmaxND

model = FlexibleUNet(
                    in_channels=1,
                    out_channels=14, # nb of landmarks
                    backbone="efficientnet-b0",
                    pretrained=True,
                    decoder_channels=[128, 128, 128, 128, 128],
                    spatial_dims=3,
                    norm="batch",
                    act="relu",
                    dropout=0.5,
                    decoder_bias=False,
                    upsample="nontrainable",
                    pre_conv="default",
                    interp_mode="nearest",
                    is_pad=True,
                    ).to(device)
print("Number of learnable parameters: {}".format(
    sum(p.numel() for p in model.parameters() if p.requires_grad)))
lr = 1e-5
batch_size = 1
epochs = 60

optimizer = torch.optim.Adam(
            [
                {"params": model.parameters()},
            ],
            lr=lr,
)


criterion = NLLLoss(spatial_dims=3)

decoder_method = "argmax"

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
                                                          patience=10, cooldown=10)

### Setting the data loaders

In [None]:
train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(ds_val, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=4)

### Training the model

In [None]:
from landmarker.heatmap.decoder import heatmap_to_coord
from landmarker.metrics import point_error

def train_epoch(model, train_loader, criterion, optimizer, device):
    running_loss = 0
    model.train()
    for i, batch in enumerate(tqdm(train_loader)):
        images = batch["image"].to(device)
        masks = batch["mask"].to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10000.0)
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(train_loader)

def val_epoch(model, val_loader, criterion, device, method="argmax"):
    eval_loss = 0
    eval_mpe = 0
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(tqdm(val_loader)):
            images = batch["image"].to(device)
            outputs = model(images)
            dim_orig = batch["dim_original"].to(device)
            pixel_spacing = batch["spacing"].to(device)
            padding = batch["padding"].to(device)
            masks = batch["mask"].to(device)
            landmarks = batch["landmark"].to(device)
            loss = criterion(outputs, masks)
            pred_landmarks = heatmap_to_coord(outputs, method=method, spatial_dims=3)
            eval_loss += loss.item()
            eval_mpe += point_error(landmarks, pred_landmarks, images.shape[-3:], dim_orig,
                                    pixel_spacing, padding, reduction="mean")
    return eval_loss / len(val_loader), eval_mpe / len(val_loader)

def train(model, train_loader, val_loader, criterion, optimizer, device, epochs=1000):
    for epoch in tqdm(range(epochs)):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_mpe = val_epoch(model, val_loader, criterion, device)
        print(f"Epoch {epoch+1}/{epochs} - Train loss: {train_loss:.4f} - Val loss: {val_loss:.4f} - Val mpe: {val_mpe:.4f}")
        lr_scheduler.step(val_loss)

In [None]:
train(model, train_loader, val_loader, criterion, optimizer, device,
      epochs=epochs)

In [None]:
# torch.save(model.state_dict(), "3D-mml-one-hot-unet.pt")

## Evaluating the model

In [None]:
# model.load_state_dict(torch.load("3D-mml-one-hot-unet.pt", weights_only=True))

In [None]:
pred_landmarks_test = []
true_landmarks_test = []
dim_origs_test = []
pixel_spacings_test = []
paddings_test = []
test_mpe = 0
model.eval()
model.to(device)
with torch.no_grad():
    for i, batch in enumerate(tqdm(test_loader)):
        images = batch["image"].to(device)
        outputs = model(images)
        dim_orig = batch["dim_original"].to(device)
        pixel_spacing = batch["spacing"].to(device)
        padding = batch["padding"].to(device)
        landmarks = batch["landmark"].to(device)
        pred_landmark = heatmap_to_coord(outputs, method="argmax", spatial_dims=3)
        test_mpe += point_error(landmarks, pred_landmark, images.shape[-3:], dim_orig,
                                pixel_spacing, padding, reduction="mean")
        pred_landmarks_test.append(pred_landmark.cpu())
        true_landmarks_test.append(landmarks.cpu())
        dim_origs_test.append(dim_orig.cpu())
        pixel_spacings_test.append(pixel_spacing.cpu())
        paddings_test.append(padding.cpu())

pred_landmarks_test = torch.cat(pred_landmarks_test)
true_landmarks_test = torch.cat(true_landmarks_test)
dim_origs_test = torch.cat(dim_origs_test)
pixel_spacings_test = torch.cat(pixel_spacings_test)
paddings_test = torch.cat(paddings_test)

test_mpe /= len(test_loader)

print(f"Test Mean PE: {test_mpe:.4f}")

In [None]:
from landmarker.metrics import sdr

sdr_test = sdr([2.0, 2.5, 3.0, 4.0], true_landmarks=true_landmarks_test, pred_landmarks=pred_landmarks_test,
               dim=dim_img, dim_orig=dim_origs_test.int(), pixel_spacing=pixel_spacings_test, padding=paddings_test)

print("Results on Test Set:")
for key in sdr_test:
    print(f"SDR for {key}mm: {sdr_test[key]:.4f}")

In [None]:
from landmarker.visualize import detection_report

print("Test Set")
detection_report(true_landmarks_test, pred_landmarks_test, dim=dim_img, dim_orig=dim_origs_test.int(),
                    pixel_spacing=pixel_spacings_test, padding=paddings_test, class_names=ds_test.class_names,
                    radius=[2.0, 2.5, 3.0, 4.0], digits=2)

In [None]:
from landmarker.visualize import plot_cpe

plot_cpe(true_landmarks_test, pred_landmarks_test, dim=dim_img, dim_orig=dim_origs_test.int(),
                    pixel_spacing=pixel_spacings_test, padding=paddings_test, class_names=ds_test.class_names,
                    group=False, title="CPE curve", save_path=None,
                    stat='proportion', unit='mm', kind='ecdf')