# Segment Anything Keras Core port Benchmarks

This notebook benchmarks the segment anything model for TensorFlow, JAX, and PyTorch using Keras Core.

There are three types of benchmarks:

1. End-to-end model inference (`image_encoder + prompt_encoder + mask_decoder`)
2. End-to-end model inference with pre and post-processing
3. Prompt benchmarks (`prompt_encoder + mask_decoder` with image features set)

## Get all the dependencies and weight sets

In [1]:
# Get the dependencies
!pip install -Uq torch torchvision torchaudio torchtext >> /dev/null
!pip install -Uq tensorflow >> /dev/null
!pip install -Uq jax >> /dev/null
!pip install -Uq keras-nlp >> /dev/null
!pip install -Uq keras-cv >> /dev/null
!pip install -Uq keras >> /dev/null
!pip install -Uq git+https://github.com/tirthasheshpatel/segment_anything_keras.git >> /dev/null
# Get the image for the demo
!curl -sSL https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg -o truck.jpg
!curl -sSL https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/groceries.jpg -o groceries.jpg

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.0.0 which is incompatible.[0m[31m
[0m

## Set the backend

In [1]:
import os
os.environ['KERAS_BACKEND'] = "tensorflow"

## Choose the model

In [2]:
model_type = "huge"

## Import Dependencies

In [3]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras import ops
from keras_cv.models import SegmentAnythingModel
from sam_keras import SAMPredictor

## Define the model

In [4]:
sam = SegmentAnythingModel.from_preset(f"sam_{model_type}_sa1b")

Downloading data from https://storage.googleapis.com/keras-cv/models/segment_anything/sam_huge.h5
[1m2564774344/2564774344[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 0us/step


## End-to-End Model Inference with pre and post-processing

### Setup

In [5]:
# Define predictor
model = SAMPredictor(sam)
transform  = model.transform

# Load the image
image = cv2.imread('truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Define the inputs
input_point = np.array([[500, 375]])
input_label = np.array([1])

image_record = {}

image_record["image"] = ops.convert_to_tensor(
    transform.apply_image(image)[np.newaxis, ...],
    dtype="float32"
)

image_record["original_size"] = (image.shape[0], image.shape[1])

image_record["point_coords"] = ops.reshape(
    ops.convert_to_tensor(
        input_point, dtype="float32"
    ),
    (1, 1, 2)
)
image_record["point_coords"] = transform.apply_coords(
    image_record["point_coords"], image_record["original_size"]
)

image_record["point_labels"] = ops.convert_to_tensor(
    input_label[np.newaxis, ...],
    dtype="float32"
)

### Benchmark

In [6]:
# Dry run to build the model
out = model.predict(image_record)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 21s/step


In [7]:
# Predict also reports a time. Let's consider that too.
out = model.predict(image_record)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 39ms/step


In [8]:
# Benchmark the model
%timeit out = model.predict(image_record, verbose=0)

195 ms ± 513 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## End-to-End Model Inference

### Setup

In [9]:
# Run the pre and post-processing steps here itself
images = model.preprocess_images(image_record["image"])
points = ops.convert_to_tensor(
    image_record.get("point_coords", ops.ones((1, 0, 2))),
    dtype="float32"
)
labels = ops.convert_to_tensor(
    image_record.get("point_labels", ops.ones((1, 0))),
    dtype="float32"
)
box = ops.convert_to_tensor(
    image_record.get("boxes", ops.ones((1, 0, 2, 2))),
    dtype="float32"
)
mask = ops.convert_to_tensor(
    image_record.get("mask_inputs", ops.ones((1, 0, 256, 256, 1))),
    dtype="float32"
)

if ops.size(points) and not ops.size(box):
    pad_point = ops.zeros((points.shape[0], 1, 2), dtype="float32")
    pad_label = -ops.ones((labels.shape[0], 1), dtype="float32")
    points = ops.concatenate([points, pad_point], axis=1)
    labels = ops.concatenate([labels, pad_label], axis=1)

B = max([
    images.shape[0],
    points.shape[0],
    labels.shape[0],
    box.shape[0],
    mask.shape[0],
])

images, points, labels, box, mask = model._broadcast_batch(
    B, images, points, labels, box, mask
)

model_input = {
    "images": images,
    "points": points,
    "labels": labels,
    "boxes": box,
    "masks": mask
}

### Benchmark

In [10]:
model.model.predict(model_input);

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 16s/step


In [11]:
%timeit model.model.predict(model_input, verbose=0)

197 ms ± 2.12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Prompt Benchmarks

### Setup

In [12]:
# Set the features
features = ops.convert_to_tensor(
    model.model.backbone.predict(model_input["images"], verbose=0),
    dtype="float32"
)

In [13]:
class SAMPrompter(keras.Model):
    def __init__(self, prompt_encoder, mask_decoder, feature_shape=(64, 64, 256), **kwargs):
        # Define the prompt encoder inputs -- Prompts
        prompt_inputs = {
            "points": keras.Input(shape=[None, 2], name="points"),
            "labels": keras.Input(shape=[None], name="labels"),
            "boxes": keras.Input(shape=[None, 2, 2], name="boxes"),
            "masks": keras.Input(shape=[None, None, None, 1], name="masks"),
        }

        # All Inputs -- Features + Prompts
        all_inputs = {"features": keras.Input(feature_shape, name="features")}
        all_inputs.update(prompt_inputs)

        # Build the prompt encoder
        prompt_embeddings = prompt_encoder(prompt_inputs)

        # Define the mask decoder inputs
        mask_decoder_inputs = {
            "image_embeddings": all_inputs["features"],
            "image_pe": prompt_embeddings["dense_positional_embeddings"],
            "sparse_prompt_embeddings": prompt_embeddings["sparse_embeddings"],
            "dense_prompt_embeddings": prompt_embeddings["dense_embeddings"],
        }

        # Build the mask decoder
        outputs = mask_decoder(mask_decoder_inputs)

        super().__init__(inputs=all_inputs, outputs=outputs, **kwargs)

        self.prompt_encoder = prompt_encoder
        self.mask_decoder = mask_decoder

In [14]:
prompter_model = SAMPrompter(model.model.prompt_encoder, model.model.mask_decoder, feature_shape=features.shape[1:])

In [15]:
prompt_inputs = {
    "features": features,
    "points": model_input["points"],
    "labels": model_input["labels"],
    "boxes": model_input["boxes"],
    "masks": model_input["masks"]
}

### Benchmark

In [16]:
# Dry run to build the model
outs = prompter_model.predict(prompt_inputs)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step


In [17]:
# Predict also reports a time. Let's also consider that.
outs = prompter_model.predict(prompt_inputs)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step


In [18]:
%timeit outs = prompter_model.predict(prompt_inputs, verbose=0)

75.5 ms ± 598 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
