<a href="https://colab.research.google.com/github/sophia-zhang-qwq/animal-pose-est/blob/main/docs/notebooks/Training_and_inference_on_an_example_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Training and inference on an example dataset

In this notebook we'll install SLEAP, download a sample dataset, run training and inference on that dataset using the SLEAP command-line interface, and then download the predictions.

## Install SLEAP
Note: Before installing SLEAP check [SLEAP releases](https://github.com/talmolab/sleap/releases) page for the latest version.

In [None]:
# 📦 Step 1: Install Python 3.10 and set it as the default
!sudo apt-get update
!sudo apt-get install python3.10 python3.10-distutils python3.10-venv -y
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
!sudo update-alternatives --set python3 /usr/bin/python3.10

# 🛠 Step 2: Install pip manually for Python 3.10
!wget https://bootstrap.pypa.io/get-pip.py
!python3 get-pip.py

# ✅ Step 3: Upgrade pip, setuptools, and wheel
!python3 -m pip install --upgrade pip setuptools wheel

0% [Working]            Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Get:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:3 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:5 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:7 https://r2u.stat.illinois.edu/ubuntu jammy/main all Packages [8,832 kB]
Hit:8 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Get:9 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 Packages [1,243 kB]
Hit:10 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Get:11 http://archive.ubuntu.com/ubuntu jammy-updates/universe amd64 Packages [1,542 kB]
Hit:12 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:13 https://r2u.stat.illinois.edu/ubunt

In [None]:
# Upgrade pip tools
!pip install -U pip setuptools wheel build cython --quiet

# Uninstall OpenCV if it conflicts (ignore warning if not installed)
!pip uninstall -y opencv-python opencv-contrib-python --quiet

# Now install SLEAP (specific version to avoid issues)
#!pip install sleap==1.3.3 --quiet

!pip install -qqq "sleap[pypi]>=1.3.3"

[0m

In [None]:
!pip show sleap
import sleap
print(sleap.__version__)

Name: sleap
Version: 1.3.3
Summary: SLEAP (Social LEAP Estimates Animal Poses) is a deep learning framework for animal pose tracking.
Home-page: https://sleap.ai
Author: Talmo Pereira
Author-email: talmo@salk.edu
License: BSD 3-Clause License
Location: /usr/local/lib/python3.10/dist-packages
Requires: imgstore, ndx-pose, nixio, pynwb, qimage2ndarray, segmentation-models
Required-by: 


ModuleNotFoundError: No module named 'sleap'

In [None]:
import sleap
sleap.disable_preallocation()  # This initializes the GPU and prevents TensorFlow from filling the entire GPU memory
sleap.versions()
sleap.system_summary()

ModuleNotFoundError: No module named 'sleap'

In [None]:
!pip install jsmin --quiet
!pip install imgaug --quiet
!pip install pyzmq --quiet
!pip install tensorflow_hub --quiet
!pip install pykalman --quiet
!pip install seaborn --quiet

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for jsmin (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m948.0/948.0 kB[0m [31m34.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.6/8.6 MB[0m [31m145.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.0/63.0 MB[0m [31m68.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m107.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.6/4.6 MB[0m [31m116.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m82.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m862.5/862.5 kB[0m [31m22.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m

In [None]:
!pip install cattrs --quiet

## Download sample training data into Colab
Let's download a sample dataset from the SLEAP [sample datasets repository](https://github.com/talmolab/sleap-datasets) into Colab.

In [None]:
!apt-get install tree
!wget -O dataset.zip https://github.com/talmolab/sleap-datasets/releases/download/dm-courtship-v1/drosophila-melanogaster-courtship.zip
!mkdir dataset
!unzip dataset.zip -d dataset
!rm dataset.zip
!tree dataset

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
The following NEW packages will be installed:
  tree
0 upgraded, 1 newly installed, 0 to remove and 36 not upgraded.
Need to get 47.9 kB of archives.
After this operation, 116 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu jammy/universe amd64 tree amd64 2.0.2-1 [47.9 kB]
Fetched 47.9 kB in 1s (78.4 kB/s)
Selecting previously unselected package tree.
(Reading database ... 126347 files and directories currently installed.)
Preparing to unpack .../tree_2.0.2-1_amd64.deb ...
Unpacking tree (2.0.2-1) ...
Setting up tree (2.0.2-1) ...
Processing triggers for man-db (2.10.2-1) ...
--2025-04-17 18:11:03--  https://github.com/talmolab/sleap-datasets/releases/download/dm-courtship-v1/drosophila-melanogaster-courtship.zip
Resolving github.com (github.com)... 140.82.116.4
Connecting to github.com (github.com)|140.82.116.4|:443... connected.
HTTP request sent, awaiting

## Train models
For the top-down pipeline, we'll need train two models: a centroid model and a centered-instance model.

Using the command-line interface, we'll first train a model for centroids using the default **training profile**. The training profile determines the model architecture, the learning rate, and other parameters.

When you start training, you'll first see the training parameters and then the training and validation loss for each training epoch.

As soon as you're satisfied with the validation loss you see for an epoch during training, you're welcome to stop training by clicking the stop button. The version of the model with the lowest validation loss is saved during training, and that's what will be used for inference.

If you don't stop training, it will run for 200 epochs or until validation loss fails to improve for some number of epochs (controlled by the `early_stopping` fields in the training profile).

In [None]:
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
#!pip install tensorflow-gpu==2.11.0

Num GPUs Available:  1


In [None]:
import tensorflow as tf
tf.debugging.set_log_device_placement(True)

In [None]:
!nvcc --version
!nvidia-smi

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0
Thu Apr 17 18:32:31 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   37C    P8              9W /   70W |       2MiB /  15360MiB |      0%      Default |
|                       

In [None]:
!apt-get install cuda-11.0

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
Note, selecting 'libcuda-11.0-1' for regex 'cuda-11.0'
0 upgraded, 0 newly installed, 0 to remove and 36 not upgraded.


In [None]:
import os
os.environ['MPLBACKEND'] = 'Agg'  # or 'inline' for Jupyter/Colab
#!sleap-train baseline.centroid.json "dataset/drosophila-melanogaster-courtship/courtship_labels.slp" --run_name "courtship.centroid" --video-paths "dataset/drosophila-melanogaster-courtship/20190128_113421.mp4"
!sleap-train baseline.centroid.json "dataset/drosophila-melanogaster-courtship/courtship_labels.slp" --run_name "courtship.centroid" --video-paths "dataset/drosophila-melanogaster-courtship/20190128_113421.mp4" --gpu "0"

2025-04-17 18:43:23.405435: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/lib/python3.10/dist-packages/cv2/../../lib64:/usr/local/lib/python3.11/dist-packages/cv2/../../lib64:/usr/lib64-nvidia
INFO:sleap.nn.training:Versions:
SLEAP: 1.3.3
TensorFlow: 2.8.4
Numpy: 1.22.4
Python: 3.10.12
OS: Linux-6.1.123+-x86_64-with-glibc2.35
INFO:sleap.nn.training:Training labels file: dataset/drosophila-melanogaster-courtship/courtship_labels.slp
INFO:sleap.nn.training:Training profile: /usr/local/lib/python3.10/dist-packages/sleap/training_profiles/baseline.centroid.json
INFO:sleap.nn.training:
INFO:sleap.nn.training:Arguments:
INFO:sleap.nn.training:{
    "training_job_path": "baseline.centroid.json",
    "labels_path": "dataset/drosophila-melanogaster-courtship/courtship_labels.slp",
    "video_paths": [
        "da

In [None]:
import sleap
sleap.disable_preallocation()  # This initializes the GPU and prevents TensorFlow from filling the entire GPU memory
sleap.versions()
sleap.system_summary()

ModuleNotFoundError: No module named 'sleap'

In [None]:
import os
print(os.path.exists("dataset/drosophila-melanogaster-courtship/20190128_113421.mp4"))
import sleap
labels = sleap.load_file('dataset/drosophila-melanogaster-courtship/courtship_labels.slp')
print(labels)

True


ModuleNotFoundError: No module named 'sleap'

In [None]:
print(f"NumPy version: {np.__version__}")
print(f"SciPy version: {scipy.__version__}")
print(f"Scikit-learn version: {sklearn.__version__}")
print(f"Tensorflow version: {tf.__version__}")

NumPy version: 2.0.2
SciPy version: 1.14.1
Scikit-learn version: 1.6.1
Tensorflow version: 2.18.0


In [None]:
# Install Python 3.7
!sudo apt-get update -y
!sudo apt-get install python3.7 python3.7-dev python3.7-distutils python3.7-venv -y
!sudo apt-get install python3.7-venv python3.7-dev -y

# Update alternatives for Python 3.7
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.7 1
!sudo update-alternatives --config python3
# Choose Python 3.7 (usually the number 1 option)
!python --version  # Verify Python version


0% [Working]            Hit:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Get:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:4 http://security.ubuntu.com/ubuntu jammy-security InRelease
Hit:5 http://archive.ubuntu.com/ubuntu jammy-updates InRelease
Hit:6 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Hit:7 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:8 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:9 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Fetched 3,632 B in 1s (2,654 B/s)
Reading package lists... Done
W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)

In [None]:
!python --version

Python 3.10.12


In [None]:
# Install specific versions of NumPy, TensorFlow, and SLEAP
!pip install numpy==1.21.5
!pip install tensorflow==2.7.0
!pip install sleap==1.3.2

Collecting numpy==1.21.5
  Downloading numpy-1.21.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.1 kB)
Downloading numpy-1.21.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.9/15.9 MB[0m [31m143.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 2.2.4
    Uninstalling numpy-2.2.4:
      Successfully uninstalled numpy-2.2.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
contourpy 1.3.2 requires numpy>=1.23, but you have numpy 1.21.5 which is incompatible.
matplotlib 3.10.1 requires numpy>=1.23, but you have numpy 1.21.5 which is incompatible.
pandas 2.2.3 requires numpy>=1.22.4; python_version < "3.11", but you have numpy 1.21.5 which is incompatible.

In [None]:
!pip install tensorflow==2.11.0 keras==2.11.0 --force-reinstall

Collecting tensorflow==2.11.0
  Using cached tensorflow-2.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.1 kB)
Collecting keras==2.11.0
  Using cached keras-2.11.0-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting absl-py>=1.0.0 (from tensorflow==2.11.0)
  Using cached absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting astunparse>=1.6.0 (from tensorflow==2.11.0)
  Using cached astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=2.0 (from tensorflow==2.11.0)
  Using cached flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting gast<=0.4.0,>=0.2.1 (from tensorflow==2.11.0)
  Using cached gast-0.4.0-py3-none-any.whl.metadata (1.1 kB)
Collecting google-pasta>=0.1.1 (from tensorflow==2.11.0)
  Using cached google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting grpcio<2.0,>=1.24.3 (from tensorflow==2.11.0)
  Using cached grpcio-1.71.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.met

Let's now train a centered-instance model.

In [None]:
!sleap-train baseline_medium_rf.topdown.json "dataset/drosophila-melanogaster-courtship/courtship_labels.slp" --run_name "courtship.topdown_confmaps" --video-paths "dataset/drosophila-melanogaster-courtship/20190128_113421.mp4"

INFO:sleap.nn.training:Versions:
SLEAP: 1.3.2
TensorFlow: 2.7.0
Numpy: 1.21.5
Python: 3.7.12
OS: Linux-5.15.0-78-generic-x86_64-with-debian-bookworm-sid
INFO:sleap.nn.training:Training labels file: dataset/drosophila-melanogaster-courtship/courtship_labels.slp
INFO:sleap.nn.training:Training profile: /home/talmolab/sleap-estimates-animal-poses/pull-requests/sleap/sleap/training_profiles/baseline_medium_rf.topdown.json
INFO:sleap.nn.training:
INFO:sleap.nn.training:Arguments:
INFO:sleap.nn.training:{
    "training_job_path": "baseline_medium_rf.topdown.json",
    "labels_path": "dataset/drosophila-melanogaster-courtship/courtship_labels.slp",
    "video_paths": [
        "dataset/drosophila-melanogaster-courtship/20190128_113421.mp4"
    ],
    "val_labels": null,
    "test_labels": null,
    "base_checkpoint": null,
    "tensorboard": false,
    "save_viz": false,
    "zmq": false,
    "run_name": "courtship.topdown_confmaps",
    "prefix": "",
    "suffix": "",
    "cpu": false,
    "

The models (along with the profiles and ground truth data used to train and validate the model) are saved in the `models/` directory:

In [None]:
!tree models/

[01;34mmodels/[00m
├── [01;34mcourtship.centroid[00m
│   ├── best_model.h5
│   ├── initial_config.json
│   ├── labels_gt.train.slp
│   ├── labels_gt.val.slp
│   ├── labels_pr.train.slp
│   ├── labels_pr.val.slp
│   ├── metrics.train.npz
│   ├── metrics.val.npz
│   ├── training_config.json
│   └── training_log.csv
└── [01;34mcourtship.topdown_confmaps[00m
    ├── best_model.h5
    ├── initial_config.json
    ├── labels_gt.train.slp
    ├── labels_gt.val.slp
    ├── labels_pr.train.slp
    ├── labels_pr.val.slp
    ├── metrics.train.npz
    ├── metrics.val.npz
    ├── training_config.json
    └── training_log.csv

2 directories, 20 files


## Inference
Let's run inference with our trained models for centroids and centered instances.

In [None]:
!sleap-track "dataset/drosophila-melanogaster-courtship/20190128_113421.mp4" --frames 0-100 -m "models/courtship.centroid" -m "models/courtship.topdown_confmaps"

Started inference at: 2023-09-01 13:42:03.066840
Args:
[1m{[0m
[2;32m│   [0m[32m'data_path'[0m: [32m'dataset/drosophila-melanogaster-courtship/20190128_113421.mp4'[0m,
[2;32m│   [0m[32m'models'[0m: [1m[[0m
[2;32m│   │   [0m[32m'models/courtship.centroid'[0m,
[2;32m│   │   [0m[32m'models/courtship.topdown_confmaps'[0m
[2;32m│   [0m[1m][0m,
[2;32m│   [0m[32m'frames'[0m: [32m'0-100'[0m,
[2;32m│   [0m[32m'only_labeled_frames'[0m: [3;91mFalse[0m,
[2;32m│   [0m[32m'only_suggested_frames'[0m: [3;91mFalse[0m,
[2;32m│   [0m[32m'output'[0m: [3;35mNone[0m,
[2;32m│   [0m[32m'no_empty_frames'[0m: [3;91mFalse[0m,
[2;32m│   [0m[32m'verbosity'[0m: [32m'rich'[0m,
[2;32m│   [0m[32m'video.dataset'[0m: [3;35mNone[0m,
[2;32m│   [0m[32m'video.input_format'[0m: [32m'channels_last'[0m,
[2;32m│   [0m[32m'video.index'[0m: [32m''[0m,
[2;32m│   [0m[32m'cpu'[0m: [3;91mFalse[0m,
[2;32m│   [0m[32m'first_gpu'[0m: [3;91mFalse[0

When inference is finished, predictions are saved in a file. Since we didn't specify a path, it will be saved as `<video filename>.predictions.slp` in the same directory as the video:

In [None]:
!tree dataset/drosophila-melanogaster-courtship

[01;34mdataset/drosophila-melanogaster-courtship[00m
├── [01;32m20190128_113421.mp4[00m
├── 20190128_113421.mp4.predictions.slp
├── [01;32mcourtship_labels.slp[00m
└── [01;35mexample.jpg[00m

0 directories, 4 files


You can inspect your predictions file using `sleap-inspect`:

In [None]:
!sleap-inspect dataset/drosophila-melanogaster-courtship/20190128_113421.mp4.predictions.slp

Labeled frames: 101
Tracks: 0
Video files:
  dataset/drosophila-melanogaster-courtship/20190128_113421.mp4
    labeled frames: 101
    labeled frames from 0 to 100
    user labeled frames: 0
    tracks: 1
    max instances in frame: 2
Total user labeled frames: 0

Provenance:
  model_paths: ['models/courtship.centroid/training_config.json', 'models/courtship.topdown_confmaps/training_config.json']
  predictor: TopDownPredictor
  sleap_version: 1.3.2
  platform: Linux-5.15.0-78-generic-x86_64-with-debian-bookworm-sid
  command: /home/talmolab/micromamba/envs/s0/bin/sleap-track dataset/drosophila-melanogaster-courtship/20190128_113421.mp4 --frames 0-100 -m models/courtship.centroid -m models/courtship.topdown_confmaps
  data_path: dataset/drosophila-melanogaster-courtship/20190128_113421.mp4
  output_path: dataset/drosophila-melanogaster-courtship/20190128_113421.mp4.predictions.slp
  total_elapsed: 7.775644779205322
  start_timestamp: 2023-09-01 13:42:03.066840
  finish_timestamp: 2023-

If you're using Chrome you can download your trained models like so:

In [None]:
# Zip up the models directory
!zip -r trained_models.zip models/

# Download.
from google.colab import files
files.download("/content/trained_models.zip")

  adding: models/ (stored 0%)
  adding: models/courtship.topdown_confmaps/ (stored 0%)
  adding: models/courtship.topdown_confmaps/labels_pr.val.slp (deflated 74%)
  adding: models/courtship.topdown_confmaps/metrics.val.npz (deflated 0%)
  adding: models/courtship.topdown_confmaps/labels_pr.train.slp (deflated 67%)
  adding: models/courtship.topdown_confmaps/labels_gt.val.slp (deflated 72%)
  adding: models/courtship.topdown_confmaps/initial_config.json (deflated 73%)
  adding: models/courtship.topdown_confmaps/training_log.csv (deflated 55%)
  adding: models/courtship.topdown_confmaps/metrics.train.npz (deflated 0%)
  adding: models/courtship.topdown_confmaps/labels_gt.train.slp (deflated 61%)
  adding: models/courtship.topdown_confmaps/best_model.h5 (deflated 8%)
  adding: models/courtship.topdown_confmaps/training_config.json (deflated 88%)
  adding: models/courtship.centroid/ (stored 0%)
  adding: models/courtship.centroid/labels_pr.val.slp (deflated 82%)
  adding: models/courtship

And you can likewise download your predictions:

In [None]:
from google.colab import files
files.download('dataset/drosophila-melanogaster-courtship/20190128_113421.mp4.predictions.slp')

In some other browsers (Safari) you might get an error and you can instead download using the "Files" tab in the side panel (it has a folder icon). Select "Show table of contents" in the "View" menu if you don't see the side panel.