[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/talmolab/dreem/blob/main/examples/quickstart.ipynb)

# DREEM Quickstart: Fly Tracking Demo

Welcome to DREEM! This notebook will guide you through tracking a social interaction between two flies using a pretrained model. By the end, you'll have hands-on experience with the core DREEM workflow.

### What you'll learn:
- How to download sample data and pretrained models
- Running tracking inference with `dreem-track`
- Evaluating tracking accuracy with `dreem-eval`
- Visualizing tracking results

### Requirements:
- **Runtime**: ~5-10 minutes total
- **Hardware**: CPU is sufficient (GPU optional but faster)
- **Data**: We provide sample fly videos from the [SLEAP fly32 dataset](https://sleap.ai/datasets.html#fly32)


---
## Step 1: Install Dependencies

First, we'll install DREEM and the required packages. This may take a few minutes.


In [None]:
# Install DREEM and dependencies
%pip install dreem-tracker huggingface_hub opencv-python-headless

# Install ffmpeg for video visualization (Colab-specific)
!apt-get install -y ffmpeg 2>/dev/null || echo "ffmpeg already installed or not on Linux"


### Check available hardware

Let's detect what compute device is available. DREEM works on CPU, CUDA (NVIDIA GPUs), or MPS (Apple Silicon).


In [None]:
import torch

cuda_available = torch.cuda.is_available()
print(f"CUDA available: {cuda_available}")

if cuda_available:
    accelerator = "cuda"
    print(f"GPU: {torch.cuda.get_device_name(0)}")
elif torch.backends.mps.is_available():
    accelerator = "mps"
    print("Using Apple Silicon (MPS)")
else:
    accelerator = "cpu"
    print("Using CPU (tracking will be slower but still works!)")

print(f"\nâœ“ Device selected: {accelerator}")
torch.set_float32_matmul_precision("medium")


---
## Step 2: Download Sample Data

We'll download a sample fly tracking dataset from Hugging Face. This includes:
- Video files (`.mp4`) of fly interactions
- Detection files (`.slp`) with pose keypoints from SLEAP
- Configuration files for running inference


In [None]:
!huggingface-cli download talmolab/sample-flies --repo-type dataset --local-dir ./data


### Expected directory structure

After downloading, your data folder should look like this:

```
./data
    /test
        190719_090330_wt_18159206_rig1.2@15000-17560.mp4
        GT_190719_090330_wt_18159206_rig1.2@15000-17560.slp
    /train
        190612_110405_wt_18159111_rig2.2@4427.mp4
        GT_190612_110405_wt_18159111_rig2.2@4427.slp
    /val
        two_flies.mp4
        GT_two_flies.slp
    /inference
        190719_090330_wt_18159206_rig1.2@15000-17560.mp4
        190719_090330_wt_18159206_rig1.2@15000-17560.slp
    /configs
        inference.yaml
        base.yaml
        eval.yaml
```


In [None]:
# Verify the data downloaded correctly
import os

expected_dirs = ['test', 'train', 'val', 'inference', 'configs']
missing = [d for d in expected_dirs if not os.path.exists(f'./data/{d}')]

if missing:
    print(f"âš  Missing directories: {missing}")
    print("Please re-run the download cell above.")
else:
    print("âœ“ Data downloaded successfully!")
    print("\nContents:")
    for d in expected_dirs:
        files = os.listdir(f'./data/{d}')
        print(f"  ./data/{d}/: {len(files)} files")


---
## Step 3: Download Pretrained Model

Now we'll download a pretrained DREEM model. This model was trained on various animal data (mice, flies, zebrafish) and generalizes well to new animal tracking tasks.


In [None]:
!huggingface-cli download talmolab/animals-pretrained animals-pretrained.ckpt --local-dir=./models


In [None]:
# Verify the model downloaded correctly
model_path = "./models/animals-pretrained.ckpt"

if os.path.exists(model_path):
    size_mb = os.path.getsize(model_path) / (1024 * 1024)
    print(f"âœ“ Model downloaded successfully!")
    print(f"  Path: {model_path}")
    print(f"  Size: {size_mb:.1f} MB")
else:
    print("âš  Model not found. Please re-run the download cell above.")


---
## Step 4: Run Tracking

Now for the main event! We'll use `dreem-track` to run tracking inference on our fly video.

**What happens during tracking:**
1. DREEM loads the video and detections (pose keypoints)
2. For each frame, it extracts visual features around each detection
3. The transformer model associates detections across frames to form tracks
4. Results are saved as a new `.slp` file with track IDs assigned


In [None]:
!dreem-track --config-dir=./data/configs --config-name=inference ckpt_path=./models/animals-pretrained.ckpt


In [None]:
# Check that results were generated
results_dir = "./results"
if os.path.exists(results_dir):
    result_files = [f for f in os.listdir(results_dir) if f.endswith('.slp')]
    if result_files:
        print("âœ“ Tracking complete! Output files:")
        for f in result_files:
            print(f"  {results_dir}/{f}")
        # Store the latest result file path for later use
        result_file = sorted(result_files)[-1]
        result_path = os.path.join(results_dir, result_file)
    else:
        print("âš  No .slp files found in results directory")
else:
    print("âš  Results directory not found. Check the tracking output above for errors.")


---
## Step 5: Evaluate Tracking Accuracy (Optional)

If you have ground truth labels, you can use `dreem-eval` instead of `dreem-track`. This computes tracking metrics like:
- **MOTA** (Multiple Object Tracking Accuracy)
- **IDF1** (ID F1 Score)  
- **Number of ID switches**

The eval config points to test data with ground truth labels (files prefixed with `GT_`).


In [None]:
!dreem-eval --config-dir=./data/configs --config-name=eval ckpt_path=./models/animals-pretrained.ckpt


---
## Step 6: Visualize Results

Let's visualize the tracking results! We'll create an animation showing the tracked flies with their assigned IDs.

> **Note**: For the best visualization experience with full pose keypoints, you can open the `.slp` file in the SLEAP GUI locally:
> ```bash
> sleap-label results/<your_output_file>.slp
> ```
> The SLEAP GUI won't render in Colab, but the animation below gives you a quick preview.


In [None]:
import numpy as np
import pandas as pd
import sleap_io as sio
import cv2
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.patches import Circle
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display


### Load tracking results


In [None]:
# Find the latest tracking result
results_dir = "./results"
result_files = sorted([f for f in os.listdir(results_dir) if f.endswith('.slp') and 'dreem_inference' in f])
result_path = os.path.join(results_dir, result_files[-1])
print(f"Loading results from: {result_path}")

# Load the predictions
pred_slp = sio.load_slp(result_path)
print(f"Loaded {len(pred_slp)} frames with tracking results")


In [None]:
# Convert predictions to a DataFrame for visualization
list_frames = []
for lf in pred_slp:
    for instance in lf.instances:
        centroid = np.nanmean(instance.numpy(), axis=0)
        track_id = int(instance.track.name) if instance.track else -1
        list_frames.append({
            "frame_id": lf.frame_idx,
            "track_id": track_id,
            "centroid": centroid
        })
df = pd.DataFrame(list_frames)
print(f"Found {df['track_id'].nunique()} unique tracks across {df['frame_id'].nunique()} frames")


### Create tracking animation


def create_tracking_animation(video_path, metadata_df, fps=15, marker_size=15, max_frames=200, display_width=600):
    """Create and display a tracking animation in the notebook."""
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError(f"Could not open video file: {video_path}")
    
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    # Create colormap for track IDs
    unique_ids = metadata_df["track_id"].unique()
    cmap = cm.get_cmap("tab10", len(unique_ids))
    id_to_color = {id_val: cmap(i) for i, id_val in enumerate(unique_ids)}
    
    # Setup figure
    fig_width = display_width / 100
    fig_height = fig_width * (height / width)
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))
    
    frame_img = ax.imshow(np.zeros((height, width, 3), dtype=np.uint8))
    markers, texts = [], []
    
    frame_ids = sorted(metadata_df["frame_id"].unique())
    if max_frames and max_frames < len(frame_ids):
        frame_ids = frame_ids[:max_frames]
        print(f"Showing first {max_frames} frames")
    
    def update(frame_num):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
        ret, frame = cap.read()
        if not ret:
            return []
        
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame_img.set_array(frame_rgb)
        
        for m in markers: m.remove()
        for t in texts: t.remove()
        markers.clear()
        texts.clear()
        
        frame_data = metadata_df[metadata_df["frame_id"] == frame_num]
        for _, row in frame_data.iterrows():
            x, y = row["centroid"]
            color = id_to_color[row["track_id"]]
            circle = Circle((x, y), marker_size, color=color, alpha=0.7)
            markers.append(ax.add_patch(circle))
            text = ax.text(x, y, str(row["track_id"]), color="white", fontsize=8, 
                          ha="center", va="center", fontweight="bold")
            texts.append(text)
        
        frame_text = ax.text(10, 20, f"Frame: {frame_num}", color="white", 
                            fontsize=8, backgroundcolor="black")
        texts.append(frame_text)
        return [frame_img] + markers + texts
    
    ax.set_xlim(0, width)
    ax.set_ylim(height, 0)
    ax.axis("off")
    
    print(f"Creating animation with {len(frame_ids)} frames...")
    anim = FuncAnimation(fig, update, frames=frame_ids, blit=True)
    plt.close(fig)
    
    display(HTML(anim.to_html5_video()))
    cap.release()
    return anim


In [None]:
# Find the video file used for inference
video_dir = "./data/inference"
video_files = [f for f in os.listdir(video_dir) if f.endswith('.mp4')]
video_path = os.path.join(video_dir, video_files[0])
print(f"Video: {video_path}")

# Create the animation
anim = create_tracking_animation(
    video_path=video_path,
    metadata_df=df,
    fps=15,
    marker_size=15,
    max_frames=200
)


---
## Next Steps

Congratulations! You've successfully run DREEM tracking on fly data. Here's where to go next:

### Dive Deeper
- **[End-to-End Demo](https://colab.research.google.com/github/talmolab/dreem/blob/main/examples/dreem-demo.ipynb)**: Train your own model, run inference, and evaluate results
- **[Microscopy Demo](https://colab.research.google.com/github/talmolab/dreem/blob/main/examples/microscopy-demo-simple.ipynb)**: Track cells in microscopy data

### Documentation
- **[Usage Guide](https://dreem.sleap.ai/usage/)**: Full CLI reference and configuration options
- **[Configuration Reference](https://dreem.sleap.ai/configs/)**: Customize training and inference parameters
- **[API Reference](https://dreem.sleap.ai/reference/dreem/)**: Python API documentation

### Use Your Own Data
The pretrained animals model works with any SLEAP detections! Just:
1. Generate pose predictions with [SLEAP](https://sleap.ai/)
2. Update the config to point to your video and `.slp` files
3. Run `dreem-track` as shown above

Happy tracking! ðŸª°
