<p align="center">
    <a href="https://predict-idlab.github.io/landmarker">
        <img alt="landmarker" src="https://raw.githubusercontent.com/predict-idlab/landmarker/main/docs/_static/images/logo.svg" width="66%">
    </a>
</p>

# Training and Evaluating Static Heatmap Regression Model for Multi-Instance and Multi-Class Landmark Detetection (EndoVis 2015 Challenge)

In this tutorial, we will train and evaluate an direct static heatmap regression model for landmark 
detection with EndoVis 2015 Challenge. We will use part of the EndoVis 2015 challenge dataset to 
construct a multi-instance and multi-class landmark detection task. The dataset contains 4 training 
and 6 testing videos of robotic surgery. The goal is to predict the location of instruments in the video, 
more specifically the tip of the clasper. We only consider the clasper points and ignore the other points,
since they are way more ambiguous. One of the difficulties 

The videos are transformed into images and the annotations are 
given as 2D points. The dataset is split into a training and testing set. The training set contains 4 videos and 
the testing set contains 6 videos, such as specified in the challenge. 

We will go through the following steps:
* [Loading the dataset](#Loading-the-dataset)
* [Inspecting the dataset](#Inspecting-the-dataset)
* [Training and initializing the UNet model](#Training-the-model)
* [Evaluating the model](#Evaluating-the-model)

<a target="_blank" href="https://colab.research.google.com/github/predict-idlab/landmarker/examples/static_unet_endovis2015.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Setup environment

In [None]:
# !python -c "import landmarker" || pip install landmarker

import sys
import os

sys.path.append("../src/")

## Setup imports and variables

In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from monai.transforms import (Compose, RandAffined, RandGaussianNoised, ScaleIntensityd,
                              RandScaleIntensityd, RandAdjustContrastd, RandHistogramShiftd)
from tqdm.notebook import tqdm

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Loading the dataset

In [None]:
fn_keys = ('image', 'mask')
spatial_transformd = [RandAffined(fn_keys, prob=1,
                        rotate_range=(-np.pi/12, np.pi/12),
                        translate_range=(-10, 10),
                        scale_range=(-0.1, 0.1),
                        shear_range=(-0.1, 0.1)
                        )]

train_transformd = Compose([
                            RandGaussianNoised(('image', ), prob=0.2, mean=0, std=0.1),  # Add gaussian noise
                            RandScaleIntensityd(('image', ), factors=0.25, prob=0.2),  # Add random intensity scaling
                            RandAdjustContrastd(('image', ), prob=0.2, gamma=(0.5,4.5)),  # Randomly adjust contrast
                            RandHistogramShiftd(('image', ), prob=0.2),  # Randomly shift histogram
                            ScaleIntensityd(('image', )),  # Scale intensity
                        ] + spatial_transformd)

inference_transformd = Compose([
    ScaleIntensityd(('image', )),
])

In [None]:
from landmarker.datasets import get_endovis2015_heatmap_datasets

data_dir = "/Users/jefjonkers/Data/landmark-datasets"
ds_train, ds_test = get_endovis2015_heatmap_datasets(data_dir, train_transform = train_transformd,
                                                     inference_transform= inference_transformd,
                                                     dim_img = (512, 512), sigma=3)

## Inspecting the dataset

In [None]:
from landmarker.visualize import inspection_plot

# Plot the first 3 images from the training set
inspection_plot(ds_train, np.random.randint(0, len(ds_train), 3))

In [None]:
# Plot the first 3 images from the test1 set
inspection_plot(ds_test, range(3))

## Training and initializing the SpatialConfiguration model

### Initializing the model, optimizer and loss function

In [None]:
from torch import nn
from monai.networks.nets import UNet


model = UNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

lr = 1e-4
batch_size = 4
epochs = 5

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-3)

criterion = nn.BCEWithLogitsLoss()

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5,
                                                          patience=20, verbose=True, cooldown=10)

### Setting the data loaders and split training set

In [None]:
split_lengths = [0.8, 0.2]
ds_train_train, ds_train_val = torch.utils.data.random_split(ds_train, split_lengths)
train_loader = DataLoader(ds_train_train, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(ds_train_val, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=0)

### Training the model

In [None]:
from landmarker.heatmap.decoder import heatmap_to_coord, heatmap_to_multiple_coord
from landmarker.metrics import point_error

from torch.nn.functional import sigmoid

from landmarker.metrics.metrics import multi_instance_point_error

def train_epoch(model, train_loader, criterion, optimizer, device):
    running_loss = 0
    model.train()
    for i, batch in enumerate(tqdm(train_loader)):
        images = batch["image"].to(device)
        heatmaps = batch["mask"].to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, heatmaps)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(train_loader)

def val_epoch(model, val_loader, criterion, device):
    eval_loss = 0
    model.eval()
    with torch.no_grad():
        for _, batch in enumerate(tqdm(val_loader)):
            images = batch["image"].to(device)
            heatmaps = batch["mask"].to(device)
            landmarks = batch["landmark"].to(device)
            outputs = model(images)
            loss = criterion(outputs, heatmaps)
            eval_loss += loss.item()
    return eval_loss / len(val_loader)

def train(model, train_loader, val_loader, criterion, optimizer, device, epochs=1000):
    for epoch in tqdm(range(epochs)):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
        ds_train.transform = None
        val_loss = val_epoch(model, val_loader, criterion, device)
        print(f"Epoch {epoch+1}/{epochs} - Train loss: {train_loss:.4f} - Val loss: {val_loss:.4f}")
        lr_scheduler.step(val_loss)

In [None]:
train(model, train_loader, val_loader, criterion, optimizer, device,
      epochs=epochs)
# model.load_state_dict(torch.load("best_weights_unet_endovis_static.pt"))

## Evaluating the model

In [None]:
pred_landmarks = []
true_landmarks = []
dim_origs = []
pixel_spacings = []
paddings = []
tp = []
fp = []
fn = []
test_mpe = 0
test_tp = 0
test_fp = 0
test_fn = 0
model.eval()
with torch.no_grad():
    for i, batch in enumerate(tqdm(test_loader)):
        images = batch["image"]
        heatmaps = batch["mask"]
        landmarks = batch["landmark"]
        outputs = model(images.to(device)).detach().cpu()
        offset_coords = outputs.shape[1]-landmarks.shape[1]
        pred_landmarks_list, _ = heatmap_to_multiple_coord(sigmoid(outputs), window=5,
                                                           threshold=0.5,
                                                           method="argmax")
        pe_batch, tp_batch, fp_batch, fn_batch, pred_landmarks_torch = multi_instance_point_error(
            true_landmarks=landmarks, pred_landmarks=pred_landmarks_list, dim=(512, 512),
            dim_orig=batch["dim_original"], pixel_spacing=batch["spacing"],
            padding=batch["padding"], reduction="none")
        test_mpe += torch.nanmean(pe_batch).item()
        test_tp += torch.nansum(tp_batch).item()
        test_fp += torch.nansum(fp_batch).item()
        test_fn += torch.nansum(fn_batch).item()
        pred_landmarks.append(pred_landmarks_torch)
        true_landmarks.append(landmarks)
        dim_origs.append(batch["dim_original"])
        pixel_spacings.append(batch["spacing"])
        paddings.append(batch["padding"])
        tp.append(tp_batch)
        fp.append(fp_batch)
        fn.append(fn_batch)


test_mpe /= len(test_loader)

print(f"Test Mean PE: {test_mpe:.4f}")
print(f"Test TP: {test_tp:.4f}")
print(f"Test FP: {test_fp:.4f}")
print(f"Test FN: {test_fn:.4f}")

In [None]:
from landmarker.metrics import sdr

sdr_test = sdr([4, 5, 10, 20], true_landmarks=torch.cat(true_landmarks, axis=0), pred_landmarks=torch.cat(pred_landmarks, axis=0),
               dim=(512, 512), dim_orig=torch.cat(dim_origs, axis=0).int(), pixel_spacing=torch.cat(pixel_spacings, axis=0),
               padding=torch.cat(paddings, axis=0))
for key in sdr_test:
    print(f"SDR for {key}mm: {sdr_test[key]:.4f}")

In [None]:
from landmarker.visualize.utils import prediction_inspect_plot_multi_instance

model.to("cpu")
prediction_inspect_plot_multi_instance(ds_test, model, range(3))

In [None]:
from landmarker.visualize import plot_cpe

plot_cpe(torch.cat(true_landmarks, axis=0), torch.cat(pred_landmarks, axis=0), dim=(512, 512),
            dim_orig=torch.cat(dim_origs, axis=0).int(), pixel_spacing=torch.cat(pixel_spacings, axis=0),
            padding=torch.cat(paddings, axis=0), class_names=ds_test.class_names,
            group=False, title="CPE curve", save_path=None,
            stat='proportion', unit='pixels', kind='ecdf')

In [None]:
from landmarker.visualize.evaluation import multi_instance_detection_report

multi_instance_detection_report(torch.cat(true_landmarks, axis=0), torch.cat(pred_landmarks, axis=0),
                                torch.cat(tp, axis=0), torch.cat(fp, axis=0), torch.cat(fn, axis=0), dim=(512, 512),
                                dim_orig=torch.cat(dim_origs, axis=0).int(), pixel_spacing=torch.cat(pixel_spacings, axis=0),
                                padding=torch.cat(paddings, axis=0), class_names=ds_test.class_names)