<!-- Autogenerated by `scripts/make_examples.py` -->
<table align="left">
    <td>
        <a target="_blank" href="https://colab.research.google.com/github/voxel51/fiftyone-examples/blob/master/examples/zero_shot_instance_segmentation.ipynb">
            <img src="https://user-images.githubusercontent.com/25985824/104791629-6e618700-5769-11eb-857f-d176b37d2496.png" height="32" width="32">
            Try in Google Colab
        </a>
    </td>
    <td>
        <a target="_blank" href="https://nbviewer.jupyter.org/github/voxel51/fiftyone-examples/blob/master/examples/zero_shot_instance_segmentation.ipynb">
            <img src="https://user-images.githubusercontent.com/25985824/104791634-6efa1d80-5769-11eb-8a4c-71d6cb53ccf0.png" height="32" width="32">
            Share via nbviewer
        </a>
    </td>
    <td>
        <a target="_blank" href="https://github.com/voxel51/fiftyone-examples/blob/master/examples/zero_shot_instance_segmentation.ipynb">
            <img src="https://user-images.githubusercontent.com/25985824/104791633-6efa1d80-5769-11eb-8ee3-4b2123fe4b66.png" height="32" width="32">
            View on GitHub
        </a>
    </td>
    <td>
        <a href="https://github.com/voxel51/fiftyone-examples/raw/master/examples/zero_shot_instance_segmentation.ipynb" download>
            <img src="https://user-images.githubusercontent.com/25985824/104792428-60f9cc00-576c-11eb-95a4-5709d803023a.png" height="32" width="32">
            Download notebook
        </a>
    </td>
</table>


# Zero-Shot Segmentation with OWL-ViT, SAM, and FiftyOne

This notebook walks you through how to add zero-shot instance segmentation masks to your dataset using [FiftyOne](https://voxel51.com/docs/fiftyone/). You will also see how to turn this into tracking data when applied to videos!

In particular, you will learn how to:
- Extract PNGs from the frames of a video
- Run zero-shot object detection on the images
- Add segmentation masks to these detections
- Create tracks from those detections and masks

![wildlife-watcher1](https://user-images.githubusercontent.com/12500356/260869179-c8b191a2-f729-4f22-87ac-545545a39c27.gif)

For the purposes of illustration, we will use the `wildlife-watcher` dataset from Wildlife AI, which contains 7868 short video clips of a variety of animals.

**Note**: You can also browse this dataset (with instance segmentation masks and tracking info) for free at [try.fiftyone.ai](https://try.fiftyone.ai/datasets/wildlife-watcher/samples)!

We will use the following libraries:
- [FiftyOne](https://github.com/voxel51/fiftyone) to organize our dataset and visualize the results
- [transformers](https://huggingface.co/docs/transformers/index) from Hugging Face to load and run inference with [OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit) for zero-shot object detection
- [Ultralytics](https://docs.ultralytics.com/) to run instance segmentation inference with bounding box prompts using Facebook's [Segment Anything Model](https://docs.ultralytics.com/models/sam/#sam-prediction-example)

There are many ways to get these models. Additionally, there are other models that can be used for zero-shot object detection and instance segmentation. For example, [grounding DINO](https://github.com/IDEA-Research/GroundingDINO) is currently a state-of-the-art zero-shot object detection model.

We will also use the headless version of [OpenCV](https://opencv.org/) to convert the videos to frame images.

**Note**: For image-only datasets, FiftyOne [natively supports SAM](https://docs.voxel51.com/user_guide/model_zoo/models.html#segment-anything-vith-torch) as part of the [FiftyOne Model Zoo](https://docs.voxel51.com/user_guide/model_zoo/index.html#)!

## Setup

Let's install the neccessary libraries:

In [None]:
!pip install fiftyone transformers ultralytics opencv-python-headless

And then import all of the necessary packages:

In [None]:
from glob import glob
import os

import cv2
import fiftyone as fo
import numpy as np
from PIL import Image
from transformers import pipeline
from ultralytics import SAM

Download the dataset from [this zip file](https://drive.google.com/file/d/1UfB3klvMUs9R7wlqUZ8Dlbdm1nF3Pu1a/view?usp=sharing) to a folder `taranaki` and unzip it.

Then load it into FiftyOne:

In [None]:
dataset = fo.Dataset.from_dir(
    "taranaki", 
    dataset_type=fo.types.FiftyOneVideoLabelsDataset
    )

Then we can give it a name and make it persistent:

In [5]:
dataset.name = "wildlife-watcher"
dataset.persistent = True

Then we convert the videos into sequences of images and save them in a folder. We will also use FiftyOne's `ensure_frames()` method to ensure that the frames are accessible on the videos, so we can add predictions to them.

In [None]:
dataset.ensure_frames()

mp4_files = glob("taranaki/data/*")

### Create PNGs for each frame
for mf in mp4_files:
    subdir = os.path.basename(mf).split(".")[0]
    frames_dir = f'taranaki/frames/{subdir}'
    os.makedirs(frames_dir, exist_ok=True)
    frame_number = 0
    video = cv2.VideoCapture(mf)
    while True:
        success, frame = video.read()
        if not success:
            break
        frame_path = os.path.join(frames_dir, f'frame_{frame_number}.png')
        cv2.imwrite(frame_path, frame)
        frame_number += 1
    video.release() 

## Add predictions to dataset

Now that we have our frame PNGs, we will iterate through the samples in our dataset, adding zero-shot predictions.

First, we define our detector. In this case, we know the *type* of animal we can expect to see in each image (from the `ground_truth` label on the frames), so we can use this as a text prompt for zero-shot detection with OWL-ViT:

In [None]:
detector = pipeline(model=checkpoint, task="zero-shot-object-detection")

We only want the bounding box of the highest confidence prediction, if there is one:

In [None]:
def get_bounding_box(image, label):
    predictions = detector(
        image,
        candidate_labels=[label],
    )
    
    if len(predictions) == 0:
        return None, None

    prediction = max(predictions, key=lambda x: x['score'])
    score, box = prediction['score'], prediction['box']

    bounding_box = [box['xmin'], box['ymin'], box['xmax'], box['ymax']]
    return bounding_box, score

We can then create an instance segmentation function which takes this bounding box as input, for a given input:

In [None]:
sam_model = SAM('sam_l.pt')

def run_box_sam_segmentation(image, bbox):
    res = sam_model(image, bboxes=bbox)
    mask = np.array(res[0].masks.data.cpu())[:, :][0]
    return mask

def generate_mask(image, label, width, height):
    abs_bbox, score = get_bounding_box(Image.open(image), label)
    
    if not abs_bbox:
        return None, None
    (cmin, rmin, cmax, rmax) = abs_bbox
    mask = run_box_sam_segmentation(image, abs_bbox)
    bounding_box_mask = mask[rmin:rmax+1, cmin:cmax+1]

    rel_bbox = [cmin/width, rmin/height, (cmax-cmin)/width, (rmax-rmin)/height]
    return bounding_box_mask, rel_bbox

The output bounding box from OWL-ViT is in absolute `xyxy` coordinates, and we use these absolute coordinates to truncate the full-image mask generated by SAM into an instance segmentation mask. At the end of the day, we convert the bounding box to relative `xywh` coordinates, which is the format accepted by FiftyOne Detection labels.

To add mask to a sample, we loop over the frames in the video and add the generated mask, if there is one, to the frame:

In [None]:
def add_masks_to_sample(sample, label):
    subdir = sample.filename.split(".")[0]
    frames_dir = f'taranaki/frames/{subdir}'
    n_frames = len(sample.frames)
    for i in range(n_frames-1):
        frame = sample.frames[i+1]
        frame_img = f"{frames_dir}/frame_{i}.png"
        if not os.path.exists(frame_img):
            continue
        mask, bounding_box = generate_mask(
            frame_img, 
            label, 
            sample.metadata.frame_width, 
            sample.metadata.frame_height
        )
        if mask is None:
            continue
        frame["sam_track"] = fo.Detections(
            detections = [
                fo.Detection(
                    label = label,
                    bounding_box = bounding_box,
                    mask = mask,
                    index = 1
                    )
                ]
        )
        sample.save()

Note that here we are setting `index=1` for each detection. This is because we only expect to see a single animal in each video, so we can say with decent certainty that all of the detections correspond to the same animal. This index will be used to associate detections with tracks, which you can extract with `to_trajectories()`!

Now all that is left is to loop over samples in the dataset:

Then unzip these zip files:

In [2]:
def add_sam_tracks(dataset):
    for sample in dataset[100:].iter_samples(autosave=True, progress=True):
        if "ground_truth" not in sample.frames[1]:
            continue
        if sample.frames[1].ground_truth is None:
            continue
        label = sample.frames[1].ground_truth.label
        add_masks_to_sample(sample, label)

In [None]:
session = fo.launch_app(dataset)

![wildlife-watcher2](https://user-images.githubusercontent.com/12500356/260869182-061034a3-729d-485e-af34-13ec2ee074cd.gif)

One of the main takeaways from this experiment is that zero-shot computer vision pipelines are inherently limited. The model is not able to generalize to unseen classes, and the performance on seen classes is not as good as a model trained on those classes. In this case, there were very few detections for small objects, as well as objects not seen from a frontal view, or objects which were occluded or truncated!