# Segmenting *C. elegans* in videos with SAM 2

This notebook shows how to use SAM 2 for interactive segmentation of worms in videos (adjusted from video_predictor_example from SAM2 repository). It will cover the following:

- adding clicks on a frame to get and refine _masklets_ (spatio-temporal masks)
- propagating clicks to get _masklets_ throughout the video
- segmenting and tracking multiple objects at the same time

We use the terms _segment_ or _mask_ to refer to the model prediction for an object on a single frame, and _masklet_ to refer to the spatio-temporal masks across the entire video.

## Environment Set-up

If running locally using jupyter, first install `sam2` in your environment using the [installation instructions](https://github.com/facebookresearch/sam2#installation) in the repository.

If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'. Note that it's recommended to use **A100 or L4 GPUs when running in Colab** (T4 GPUs might also work, but could be slow and might run out of memory in some cases).

It is recommended to use Google Colab, unless you have powerful hardware loacally!

<a href="https://colab.research.google.com/github/pwetterauer/Notebooks/blob/main/segment_worm_video_with_sam2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
using_colab = True

In [None]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/sam2.git'

    !mkdir -p ../checkpoints/
    !wget -P ../checkpoints/ https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

## Set-up


In [3]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

### Loading the SAM 2 video predictor


In [5]:
from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

In [6]:
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

#### Select an example video


We assume that the video is stored as a list of JPEG frames with filenames like `<frame_index>.jpg`.

You can extract their JPEG frames using ffmpeg (https://ffmpeg.org/) as follows:
```
ffmpeg -i <your_video>.mp4 -q:v 2 -start_number 0 <output_dir>/'%05d.jpg'
```
where `-q:v` generates high-quality JPEG frames (higher number -> lower quality) and `-start_number 0` asks ffmpeg to start the JPEG file from `00000.jpg`.

Upload the videos to your Google Drive, into a directory called 'video' or adjust the path to the frames in code. Mount your drive to Google Colab by running the cell below.


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = "../content/drive/MyDrive/video"

# scan all the JPEG frame names in this directory
frame_names = [
    p for p in os.listdir(video_dir)
    if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))

# take a look the first video frame
frame_idx = 0
plt.figure(figsize=(9, 6))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))

#### Initialize the inference state


SAM 2 requires stateful inference for interactive video segmentation, so we need to initialize an **inference state** on this video.

During initialization, it loads all the JPEG frames in `video_path` and stores their pixels in `inference_state` (as shown in the progress bar below).


In [None]:
inference_state = predictor.init_state(video_path=video_dir)

## Track several worms at once

Add position data for all worms. Use e.g. ImageJ to find coordinates of points inside each worm you want to track.

`points` should contain one array for each worm with coordinates for points. There should be at least one point per worm, but you can add more. You can also add points outside the worm ('negative points') to better define the boundarys. Each point should be given as `[x-coordinate, y-coordinate]`.

`labels` should contain one array per worm with 1 for 'positive points' and 0 for 'negative points'.

Adjust `nr_of_worms` accordingly!

In [None]:
# this example contains data for 4 worms in frame 0. For each worm two 'positive points' and one 'negative point' arte given.
ann_frame_idx = 0  # the frame index we interact with
nr_of_worms = 4    # number of worms
worm_points = np.array([[[318, 154], [312, 170], [318, 167]],
                        [[276, 294], [286, 322], [279, 316]],
                        [[415, 300], [419, 325], [426, 324]],
                        [[578, 237], [596, 266], [593, 252]]
                        ],dtype=np.float32)
worm_labels = np.array([[1,1,0],
                        [1,1,0],
                        [1,1,0],
                        [1,1,0]
                        ],np.int32)

if worm_points.shape[0]==nr_of_worms:
    print("Data ok!")
else:
    print("'points' should contain one array for each worm! Check the number of worms and added points.")

In [None]:
prompts = {}  # hold all the data we add for visualization

for i in range(nr_of_worms):
    ann_obj_id = i+1  # give a unique id to each object we interact with

    # Adds the points for one worm
    points = worm_points[i]
    # Adds the corresponding labels
    labels = worm_labels[i]
    prompts[ann_obj_id] = points, labels
    _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
        inference_state=inference_state,
        frame_idx=ann_frame_idx,
        obj_id=ann_obj_id,
        points=points,
        labels=labels,
    )

# show the results on the current (interacted) frame
plt.figure(figsize=(9, 6))
plt.title(f"frame {ann_frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[ann_frame_idx])))
show_points(points, labels, plt.gca())
for i, out_obj_id in enumerate(out_obj_ids):
    show_points(*prompts[out_obj_id], plt.gca())
    show_mask((out_mask_logits[i] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_id)

## Propagate predicted masks to whole video


If the results look good for the one frame, let's propagate them the whole vodeo. This will take some time.

In [None]:
# run propagation throughout the video and collect the results in a dict
video_segments = {}  # video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# render the segmentation results every few frames
vis_frame_stride = 30
plt.close("all")
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

## Save the results

The segmentation masks can be stored in a `results.pk` file, that can be further analysed using python scripts, or as PNG-images. Additionally, the figures visualising the result in single frames can be saved. All files will be saved in the folder specified in `save_dir`.

In [None]:
save_dir = '../content/drive/MyDrive"

In [None]:
import pickle

In [None]:
# save results for further analysis in python
with open(os.path.join(save_dir,'result.pk'), 'wb') as f:
    pickle.dump(video_segments, f)

In [None]:
# save each mask as a PNG-image
for frame_nr in range(len(video_segments)):
    for out_obj_id, out_mask in video_segments[frame_nr].items():
        mask=Image.fromarray(np.invert(out_mask).squeeze())
        if not os.path.exists(os.path.join(save_dir,f"id{out_obj_id}")):
            os.makedirs(os.path.join(save_dir,f"id{out_obj_id}"))
        mask.save(os.path.join(save_dir,f"id{out_obj_id}",f"{frame_nr}.png"))

In [None]:
# shows and saves a frame with overlayed masks as png file
frame_nr = 450    # number of the frame that you want to save

img = Image.open(os.path.join(video_dir, frame_names[frame_nr]))
ratio = img.width/img.height
fig = plt.figure(figsize=(12, 12/ratio+0.3))
ax = plt.Axes(fig, [0., -0.02, 1., 1.])
ax.set_title(f"frame {frame_nr}")
ax.set_axis_off()
fig.add_axes(ax)
ax.imshow(img)
for out_obj_id, out_mask in video_segments[frame_nr].items():
    show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

fig.savefig(os.path.join(save_dir,"frame_with_masks.png"))

## Load the saved results

Use the following cell to load pickled results back into `video_segments` variable.

In [None]:
with open('/content/drive/MyDrive/result.pk', 'rb') as f:
  video_segments=pickle.load(f)