<a href="https://colab.research.google.com/github/talmolab/sleap-nn/blob/main/docs/colab_notebooks/Training_with_sleap_nn_on_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This tutorial notebook walks through running training, inference, and evaluation worlflows in sleap-nn using higher-level APIs. (See [docs](https://nn.sleap.ai/latest/) for details on how to use our CLI).

**Note**:
Ensure you enable GPU runtime before you start tranining! Go to Runtime -> Change Runtime type -> Select "T4 GPU"

## Install sleap-nn

In [None]:
# !pip install sleap-nn[torch] --index-url https://pypi.org/simple --extra-index-url https://download.pytorch.org/whl/cu128

# install from git until 0.0.6 is out!
!pip install "sleap-nn[torch] @ git+https://github.com/talmolab/sleap-nn.git" --index-url https://pypi.org/simple --extra-index-url https://download.pytorch.org/whl/cu128


If the previous cell returns `False`, check if enabled GPU runtime in the Runtime settings

## Imports

In [None]:
from pathlib import Path
from omegaconf import OmegaConf
import torch

import sleap_io as sio
from sleap_nn.train import run_training
from sleap_nn.config.training_job_config import TrainingJobConfig
from sleap_nn.predict import run_inference
from sleap_nn.evaluation import Evaluator



In [None]:
# Verify installation

sleap_nn.__version__

In [None]:
# Check if cuda is available

torch.cuda.is_available()

## Setting up

In [None]:
train_labels_paths = ["/path/to/train/slp/file"]
val_labels_paths = ["/path/to/val/slp/file"] # set this to `None` if you don't have a validation dataset
test_file_path = "path/to/test/file" # set this to `None` if you dont have a test file

If you have an `yaml` config file, load it using the below command (check the [docs](https://nn.sleap.ai/latest/config/) for config format / to download sample configs):

In [None]:
config = OmegaConf.load("/path/to/config/file.yaml")

# if you have a `json` config file from SLEAP <= v1.4, then use the below code to get the `sleap-nn` config
# config = TrainingJobConfig.load_sleap_config("path/to/config.json")

In [None]:
# Set the train, val and test file paths

config.data_config.train_labels_path = train_labels_paths
config.data_config.val_labels_path = val_labels_path # set this to `None` if you don't have a validation dataset
config.data_config.test_file_path = test_file_path # set this to `None` if you dont have a test file

In [None]:
# To speed up training:

# config.data_config.data_pipeline_fw = "torch_dataset_cache_img_memory"

In [None]:
# set-up the ckpt dir and run names

config.trainer_config.ckpt_dir = "path/to/ckpt/dir"
config.trainer_config.run_name = None # if None, a name with the timestamp and model type would be assigned

config.trainer_config.max_epochs = 100

config.trainer_config.train_data_loader.num_workers = 2
config.trainer_config.val_data_loader.num_workers = 2

In [None]:
# for finetuning (initializing model with prv trained ckpts)

# If previous model ckpts from SLEAP >= v1.5
# model_ckpt_file_path = (Path(prv_model_ckpt_dir_path) / "best.ckpt").as_posix()

# If previous model ckpts from SLEAP < v1.5
# model_ckpt_file_path = (Path(prv_model_ckpt_dir_path) / "best_model.h5").as_posix()

# config.model_config.pretrained_backbone_weights = model_ckpt_file_path
# config.model_config.pretrained_head_weights = model_ckpt_file_path

In [None]:
# to setup wandb

# config.trainer_config.use_wandb = True
# config.trainer_config.wandb.entity = "<wandb entity name>"
# config.trainer_config.wandb.project = "<wandb project name>"
# config.trainer_config.wandb.name =  "<wandb run name>"
# config.trainer_config.wandb.save_viz_imgs_wandb = False
# config.trainer_config.wandb.api_key = "<wandb API key>" # this is required to login to your account
# config.trainer_config.wandb.group = "<wandb run group name>"

Check the config and update any parameters (if needed)!




In [None]:
print(OmegaConf.to_yaml(config, resolve=True, sort_keys=False))

## Training

In [None]:
# if you have custom train and val labels object, then you could pass them to the `run_training()` function
# Note that these labels will override the labels path provided in the config

# run_training(config, train_labels=[train_labels], val_labels=[val_labels])

In [None]:
run_training(config)

## Running inference

Once the training is completed and we have the ckpts, we can run inference on either a `.slp` file or a `.mp4` with the trained model.

In [None]:
pred_labels = run_inference(
    data_path="path/to/inference/file",
    model_paths=["path/to/ckpt/dir/run_name"],
    output_path=f"predictions_{model_trainer.model_type}.slp",
)

Evaluate the model against ground truth and compute metrics. (Make sure gt_labels contains ground-truth annotations.)

In [None]:
#### NOTE: only if you have Ground-truth data

gt_labels = sio.load_slp(path_to_val_slp_file)

evaluator = Evaluator(
    ground_truth_instances=gt_labels,
    predicted_instances=pred_labels,
)

metrics = evaluator.evaluate()

print(f"Evaluation metrics:")
print(f"OKS mAP: {metrics['voc_metrics']['oks_voc.mAP']}")
print(f"Dist p90: {metrics['distance_metrics']['p90']}")

Let's visualize the predictions!

In [None]:
import sleap_io as sio
import matplotlib.pyplot as plt

def plot_preds(gt_labels, pred_labels, lf_index):
    _fig, _ax = plt.subplots(1, 1, figsize=(5 * 1, 5 * 1))

    # Plot each frame
    if gt_labels is not None:
      gt_lf = gt_labels[lf_index]
    pred_lf = pred_labels[lf_index]

    # Ensure we're plotting keypoints for the same frame
    if gt_labels is not None:
      assert (
          gt_lf.frame_idx == pred_lf.frame_idx
      ), f"Frame mismatch at {lf_index}: GT={gt_lf.frame_idx}, Pred={pred_lf.frame_idx}"

    _ax.imshow(gt_lf.image, cmap="gray")
    _ax.set_title(
        f"Frame {gt_lf.frame_idx} (lf idx: {lf_index})",
        fontsize=12,
        fontweight="bold",
    )


    if gt_labels is not None:
      # Plot ground truth instances
      for idx, instance in enumerate(gt_lf.instances):
          if not instance.is_empty:
              gt_pts = instance.numpy()
              _ax.plot(
                  gt_pts[:, 0],
                  gt_pts[:, 1],
                  "go",
                  markersize=6,
                  alpha=0.8,
                  label="GT" if idx == 0 else "",
              )

    # Plot predicted instances
    for idx, instance in enumerate(pred_lf.instances):
        if not instance.is_empty:
            pred_pts = instance.numpy()
            _ax.plot(
                pred_pts[:, 0],
                pred_pts[:, 1],
                "rx",
                markersize=6,
                alpha=0.8,
                label="Pred" if idx == 0 else "",
            )

    # Add legend
    _ax.legend(loc="upper right", fontsize=8)

    _ax.axis("off")

    plt.suptitle(f"Predictions", fontsize=16, fontweight="bold", y=0.98)

    plt.tight_layout()
    plt.show()
    return

In [None]:
frame_index_to_view = 9
gt_labels = sio.load_slp("path/to/gt.slp") # set to None if there are no ground-truth labels
plot_preds(gt_labels, pred_labels, lf_index=frame_index_to_view)