In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.


References:
1. https://github.com/facebookresearch/segment-anything
2. https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb
3. https://github.com/openai/CLIP
4.https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb

# Automatically generating object masks with SAM and Classify masks with CLIP

Since SAM can efficiently process prompts, masks for the entire image can be generated by sampling a large number of prompts over an image. This method was used to generate the dataset SA-1B.

The class `SamAutomaticMaskGenerator` implements this capability. It works by sampling single-point input prompts in a grid over the image, from each of which SAM can predict multiple masks. Then, masks are filtered for quality and deduplicated using non-maximal suppression. Additional options allow for further improvement of mask quality and quantity, such as running prediction on multiple crops of the image or postprocessing masks to remove small disconnected regions and holes.

In [None]:
from IPython.display import display, HTML
display(HTML(
"""
<a target="_blank" href="https://colab.research.google.com/github/healthonrails/annolid/blob/main/docs/tutorials/automatic_mask_generator_example.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
"""
))

## Environment Set-up

If running locally using jupyter, first install `segment_anything` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything#installation) in the repository. If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'.

In [None]:
using_colab = True

In [None]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

    !mkdir images
    !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg

    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

In [None]:
!pip install decord

## Install MobileSAM

In [None]:
!git clone https://github.com/ChaoningZhang/MobileSAM.git


In [None]:
%cd MobileSAM/
!pip install -e .

In [None]:
!pip install timm

## Set-up

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from pkg_resources import packaging

In [None]:
from mobile_encoder.setup_mobile_sam import setup_model
checkpoint = torch.load('./weights/mobile_sam.pt')
mobile_sam = setup_model()
mobile_sam.load_state_dict(checkpoint,strict=True)

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

## Example image

In [None]:
image = cv2.imread('/content/R2202_02-10-2023_000100275.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

## Automatic mask generation

In [None]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
device = "cuda"

In [None]:
mobile_sam.to(device=device)
mobile_sam.eval()
predictor = SamPredictor(mobile_sam)

In [None]:
mask_generator = SamAutomaticMaskGenerator(mobile_sam)


To run automatic mask generation, provide a SAM model to the `SamAutomaticMaskGenerator` class. Set the path below to the SAM checkpoint. Running on CUDA and with the default model is recommended.

In [None]:

# sam_checkpoint = "sam_vit_h_4b8939.pth"
# model_type = "vit_h"
# sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)

# mask_generator = SamAutomaticMaskGenerator(sam)

To generate masks, just run `generate` on an image.

In [None]:
masks = mask_generator.generate(image)

Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask. These keys are:
* `segmentation` : the mask
* `area` : the area of the mask in pixels
* `bbox` : the boundary box of the mask in XYWH format
* `predicted_iou` : the model's own prediction for the quality of the mask
* `point_coords` : the sampled input point that generated this mask
* `stability_score` : an additional measure of mask quality
* `crop_box` : the crop of the image used to generate this mask in XYWH format

In [None]:
print(len(masks))
print(masks[0].keys())

Show all the masks overlayed on the image.

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()

## Automatic mask generation options

There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Additionally, generation can be automatically run on crops of the image to get improved performance on smaller objects, and post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:

In [None]:
mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,  # Requires open-cv to run post-processing
)

In [None]:
masks2 = mask_generator_2.generate(image)

In [None]:
len(masks2)

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show()

In [None]:
import numpy as np
import cv2
from google.colab.patches import cv2_imshow

def crop_image_with_masks(image,
                          masks,
                          max_area=8000,
                          min_area=500,
                          width_height_ratio=0.9):
    """
    Crop the image based on provided masks and apply the masks to each cropped region.

    Args:
        image (numpy.ndarray): The input image.
        masks (list): A list of dictionaries containing mask data.
        max_area (int): Max area of the mask
        min_area (int): Min area of the mask
        width_height_ratio(float): Min width / height

    Returns:
        list: A list of cropped images with applied masks.
    """
    cropped_images = []

    for mask_data in masks:
        # Extract mask and bounding box data
        bbox = mask_data['bbox']
        seg = mask_data['segmentation']
        x, y, w, h = bbox

        # Crop the image based on the bounding box
        cropped_image = image[y:y+h, x:x+w]

        # Create an 8-bit mask from the segmentation data
        mask = np.asarray(seg[y:y+h, x:x+w], dtype=np.uint8) * 255
        # Apply the mask to the cropped image
        cropped_image = cv2.bitwise_and(cropped_image, cropped_image, mask=mask)
        cropped_image = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)
        if mask_data['area'] >= min_area and mask_data['area'] <= max_area and w/h >= width_height_ratio:
            cropped_images.append(cropped_image)

    return cropped_images

In [None]:
cropped_images = crop_image_with_masks(image,masks)

In [None]:
for cimg in cropped_images:
    cv2_imshow(cimg)

# Classify masks with openai/CLIP

# Loading the model

`clip.available_models()` will list the names of available CLIP models.

In [None]:
import clip

clip.available_models()

In [None]:
model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

# Setting up input images and texts

In [None]:
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# images to use and their textual descriptions
descriptions = {
    "rat": "a small mammal with fur, long tail, and a white stripe on its back",
    "hand": "a human hand with fingers and palm",
    "arm": "a human arm extending from the shoulder to the hand",
    "cup": "a petri dish used for holding odor treated white or yellow sand",
    "book": "a bound collection of paper sheets used for writing or reading",
}


# Zero-Shot Image Classification

You can classify images using the cosine similarity (times 100) as the logits to the softmax operation.


In [None]:
class_names = list(descriptions.keys())
class_names

In [None]:
text_descriptions = [f"This is a photo of a {label}, {descriptions[label]}" for label in class_names]
text_tokens = clip.tokenize(text_descriptions).cuda()

In [None]:
with torch.no_grad():
    text_features = model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)

In [None]:
def get_top_probs_and_labels(image, masks, model, text_features):
    """
    Computes the top probabilities and labels for matching text and images.

    Args:
        image (ndarray): The input image.
        masks (ndarray): The masks for cropping the image.
        model: The CLIP model used for encoding image features.
        text_features: The text features used for comparison.

    Returns:
        top_probs (Tensor): Top probabilities of text matching for the cropped images.
        top_labels (Tensor): Top labels corresponding to the top probabilities.
    """
    # Crop images using masks
    cropped_images = crop_image_with_masks(image, masks)

    # Preprocess cropped images
    images = [preprocess(Image.fromarray(cimg)) for cimg in cropped_images]

    # Convert images to tensor and move to GPU
    image_input = torch.tensor(np.stack(images)).cuda()

    with torch.no_grad():
        # Encode image features
        image_features = model.encode_image(image_input).float()

        # Calculate text probabilities
        text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

        # Get top probabilities and labels
        top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

    return top_probs, top_labels


In [None]:
def get_mask_features(image, mask, model):
    """
    Computes the features of the mask portion of an image.

    Args:
        image (ndarray): The input image.
        mask (ndarray): The mask for cropping the image.
        model: The CLIP model used for encoding image features.

    Returns:
        mask_features (Tensor): The features of the mask portion of the image.
    """
    # Apply the mask to the image
    masked_image = image.copy()
    masked_image[~mask] = 0

    # Preprocess the masked image
    masked_image = preprocess(Image.fromarray(masked_image))

    # Convert the image to tensor and move to GPU
    #image_input = torch.tensor(masked_image).unsqueeze(0).cuda()
    image_input = masked_image.unsqueeze(0).cuda().float()


    with torch.no_grad():
        # Encode image features
        image_features = model.encode_image(image_input).float()

    return image_features.detach().cpu().numpy()

In [None]:
top_probs, top_labels = get_top_probs_and_labels(image, masks, model, text_features)

In [None]:
num_rows = int(np.ceil(len(cropped_images)/ 2))

In [None]:
plt.figure(figsize=(16, 16))

for i, _image in enumerate(cropped_images):
    _image = cv2.cvtColor(_image,cv2.COLOR_BGR2RGB)
    plt.subplot(num_rows, 4, 2 * i + 1)
    plt.imshow(_image)
    plt.axis("off")

    plt.subplot(num_rows, 4, 2 * i + 2)
    y = np.arange(top_probs.shape[-1])
    plt.grid()
    plt.barh(y, top_probs[i])
    plt.gca().invert_yaxis()
    plt.gca().set_axisbelow(True)
    plt.yticks(y, [class_names[index] for index in top_labels[i].numpy()])
    plt.xlabel("probability")

plt.subplots_adjust(wspace=0.5)
plt.show()

In [None]:
from scipy.spatial.distance import euclidean, cosine

def generate_mask_id(mask_features, existing_masks, threshold=6.0, distance_metric="euclidean"):
    """
    Generates an ID for the mask based on its features and compares it with existing masks.

    Args:
        mask_features (ndarray): The features of the mask.
        existing_masks (list): List of existing masks and their features.
        threshold (float): Similarity threshold for considering a match (default: 0.9).
        distance_metric (str): Distance metric to be used (default: "euclidean").
                               Options: "euclidean", "cosine".

    Returns:
        mask_id (int): The generated ID for the mask.
    """
    mask_id = -1  # Initialize the mask ID

    if distance_metric == "euclidean":
        distance_function = euclidean
    elif distance_metric == "cosine":
        distance_function = cosine
    else:
        raise ValueError("Invalid distance metric. Choose either 'euclidean' or 'cosine'.")

    for idx, (existing_id, existing_features) in enumerate(existing_masks):
        similarity = distance_function(mask_features.flatten(), existing_features.flatten())

        if similarity < threshold:
            mask_id = existing_id
            break

    if mask_id == -1:
        mask_id = len(existing_masks) + 1  # Assign a new ID if no match is found
        existing_masks.append((mask_id, mask_features.flatten()))

    return mask_id


In [None]:
import pycocotools.mask as mask_util
def convert_to_annolid_format(frame_number,
                              masks,
                              frame=None,
                              model=None,
                              min_mask_area=float('-inf'),
                              max_mask_area=float('inf'),
                              existing_masks=None
                              ):
    """Converts predicted SAM masks information to annolid format.

    Args:
        frame_number (int): The frame number associated with the masks.
        masks (list): List of dictionaries representing the predicted masks.
            Each dictionary should contain the following keys:
                -segmentation : the mask
                -area : the area of the mask in pixels
                -bbox : the boundary box of the mask in XYWH format
                -predicted_iou : the model's own prediction for the quality of the mask
                -point_coords : the sampled input point that generated this mask
                -stability_score : an additional measure of mask quality
                -crop_box : the crop of the image used to generate this mask in XYWH format

    Returns:
        list: List of dictionaries representing the masks in annolid format.
            Each dictionary contains the following keys:
                - "frame_number": The frame number associated with the masks.
                - "x1", "y1", "x2", "y2": The coordinates of the bounding box in XYXY format.
                - "instance_name": The name of the instance/object.
                - "class_score": The predicted IoU (Intersection over Union) for the mask.
                - "segmentation": The segmentation mask.
                - "tracking_id": The tracking ID associated with the mask.

    """
    pred_rows = []
    for mask in masks:
        mask_area = mask.get("area",0)
        if min_mask_area <= mask_area <= max_mask_area:
            x1 = mask.get("bbox")[0]
            y1 = mask.get("bbox")[1]
            x2 = mask.get("bbox")[0] + mask.get("bbox")[2]
            y2 = mask.get("bbox")[1] + mask.get("bbox")[3]
            score = mask.get("predicted_iou", '')
            segmentation = mask.get("segmentation", '')
            mask_features = get_mask_features(frame,segmentation,model)
            mask_id = generate_mask_id(mask_features,existing_masks)
            instance_name = mask.get("instance_name", f'instance_{mask_id}')
            segmentation = mask_util.encode(segmentation)
            tracking_id = mask.get("tracking_id", "")

            pred_rows.append({
                "frame_number": frame_number,
                "x1": x1,
                "y1": y1,
                "x2": x2,
                "y2": y2,
                "instance_name": instance_name,
                "class_score": score,
                "segmentation": segmentation,
                "tracking_id": tracking_id
            })

    return pred_rows

In [None]:
predict_rows = convert_to_annolid_format(100275,masks,image,model,existing_masks=[])

In [None]:
import pandas as pd
df = pd.DataFrame(predict_rows)
df.head()

In [None]:
df.to_csv("rats_v1_coco_dataset_R2202_02-10-2023_mask_rcnn_tracking_results_with_segmentation.csv")

# Example video

In [None]:
!wget https://storage.googleapis.com/sleap-data/datasets/eleni_mice/clips/20200111_USVpairs_court1_M1_F1_top-01112020145828-0000%400-2560.mp4

In [None]:
original_video_file = "/content/MobileSAM/20200111_USVpairs_court1_M1_F1_top-01112020145828-0000@0-2560.mp4"
video_file = '/content/usvpairs_court1.mp4'

## Cut 2s of the video starting from second 30

In [None]:
!ffmpeg -i {original_video_file} -ss 00:00:30 -t 00:00:02 -c:v copy -c:a copy {video_file}

In [None]:
import decord as de
import pandas as pd

def process_video_and_save_tracking_results(video_file, mask_generator):
    """
    Process a video file, generate tracking results with segmentation masks,
    and save the results to a CSV file.

    Args:
        video_file (str): Path to the video file.
        mask_generator: An instance of the mask generator class.

    Returns:
        None
    """
    video_reader = de.VideoReader(video_file)
    tracking_results = []
    existing_masks = []

    for key_index in video_reader.get_key_indices():
        frame = video_reader[key_index].asnumpy()
        masks = mask_generator.generate(frame)
        tracking_results += convert_to_annolid_format(key_index, masks,frame,model,existing_masks=existing_masks)
        print(key_index)

    dataframe = pd.DataFrame(tracking_results)
    output_file = f"{video_file.split('.')[0]}_mask_tracking_results_with_segmentation.csv"
    dataframe.to_csv(output_file)
    return output_file

In [None]:
tracking_results_file = process_video_and_save_tracking_results(video_file, mask_generator)

In [None]:
from google.colab.files import download

In [None]:
download(tracking_results_file)

In [None]:
download(video_file)