In [None]:
import os
import sys
sys.path.append('/big_vision/')
from bokeh import io as bokeh_io
import jax
from google.colab import output as colab_output
import matplotlib as mpl
from matplotlib import pyplot as plt
import numpy as np
from scenic.projects.owl_vit import configs
from scenic.projects.owl_vit import models

from scenic.projects.owl_vit.notebooks import inference
from scenic.projects.owl_vit.notebooks import interactive
from scenic.projects.owl_vit.notebooks import plotting
from scipy.special import expit as sigmoid
import skimage
from skimage import io as skimage_io
from skimage import transform as skimage_transform
import tensorflow as tf

tf.config.experimental.set_visible_devices([], 'GPU')
bokeh_io.output_notebook(hide_banner=True)

# Set up the model
This takes a minute or two.

In [None]:
config = configs.clip_b16.get_config(init_mode='canonical_checkpoint')
module = models.TextZeroShotDetectionModule(
    body_configs=config.model.body,
    normalize=config.model.normalize,
    box_bias=config.model.box_bias)
variables = module.load_variables(config.init_from.checkpoint_path)
model = inference.Model(config, module, variables)
model.warm_up()

# Load example images

Please provide a path to a directory containing example images. Google Cloud Storage and local storage are supported.

In [None]:
IMAGE_DIR = 'gs://scenic-bucket/owl_vit/example_images'  # @param {"type": "string"}
%matplotlib inline

from skimage import data

images = {}

for i, filename in enumerate(tf.io.gfile.listdir(IMAGE_DIR)):
  with tf.io.gfile.GFile(os.path.join(IMAGE_DIR, filename), 'rb') as f:
    image = mpl.image.imread(
        f, format=os.path.splitext(filename)[-1])[..., :3]
  if np.max(image) <= 1.:
    image *= 255
  images[i] = image

images[3] = data.rocket()
images[4] = data.astronaut()

cols = 5
rows = max(len(images) // 5, 1)
fig, axs = plt.subplots(rows, cols, figsize=(16, 8 * rows))

for ax in axs.ravel():
  ax.set_visible(False)

for ax, (ind, image) in zip(axs.ravel(), images.items()):
  ax.set_visible(True)
  ax.imshow(image)
  ax.set_xticks([])
  ax.set_yticks([])
  ax.set_title(f'Image ID: {ind}')

fig.tight_layout()

# Text-conditioned detection
Enter comma-separated queries int the text box above the image to detect stuff. If nothing happens, try running the cell first (<kbd>Ctrl</kbd>+<kbd>Enter</kbd>).

In [None]:
#@title { run: "auto" }
IMAGE_ID =   4# @param {"type": "number"}
image = images[IMAGE_ID]
_, _, boxes = model.embed_image(image)
plotting.create_text_conditional_figure(
    image=model.preprocess_image(image), boxes=boxes, fig_size=900)
interactive.register_text_input_callback(model, image, colab_output)

In [None]:
#@title { run: "auto" }
IMAGE_ID =   2# @param {"type": "number"}
image = images[IMAGE_ID]
_, _, boxes = model.embed_image(image)
plotting.create_text_conditional_figure(
    image=model.preprocess_image(image), boxes=boxes, fig_size=900)
interactive.register_text_input_callback(model, image, colab_output)

# Image-conditioned detection

In image-conditioned detection, the model is tasked to detect objects that match a given example image. In the cell below, the example image is chosen by drawing a bounding box around an object in the left image. The model will then detect similar objects in the right image.

In [None]:
#@title { run: "auto" }

#@markdown The *query image* is used to select example objects:
QUERY_IMAGE_ID = 3  # @param {"type": "number"}

#@markdown Objects will be detected in the *target image* :
TARGET_IMAGE_ID = 4  # @param {"type": "number"}

#@markdown Threshold for the minimum confidence that a detection must have to
#@markdown be displayed (higher values mean fewer boxes will be shown):
MIN_CONFIDENCE = 0.9994 #@param { type: "slider", min: 0.9, max: 1.0, step: 0.0001}


#@markdown Threshold for non-maximum suppression of overlapping boxes (higher
#@markdown values mean more boxes will be shown):
NMS_THRESHOLD = 0.15 #@param { type: "slider", min: 0.05, max: 1.0, step: 0.01}

interactive.IMAGE_COND_MIN_CONF = MIN_CONFIDENCE
interactive.IMAGE_COND_NMS_IOU_THRESHOLD = NMS_THRESHOLD

query_image = images[QUERY_IMAGE_ID]
target_image = images[TARGET_IMAGE_ID]
_, _, boxes = model.embed_image(target_image)
plotting.create_image_conditional_figure(
    query_image=model.preprocess_image(query_image),
    target_image=model.preprocess_image(target_image),
    target_boxes=boxes, fig_size=600)
interactive.register_box_selection_callback(model, query_image, target_image, colab_output)

In [None]:
#@title { run: "auto" }

#@markdown The *query image* is used to select example objects:#@title { run: "auto" }

#@markdown The *query image* is used to select example objects:
QUERY_IMAGE_ID = 1  # @param {"type": "number"}

#@markdown Objects will be detected in the *target image* :
TARGET_IMAGE_ID = 0  # @param {"type": "number"}

#@markdown Threshold for the minimum confidence that a detection must have to
#@markdown be displayed (higher values mean fewer boxes will be shown):
MIN_CONFIDENCE = 0.9994 #@param { type: "slider", min: 0.9, max: 1.0, step: 0.0001}


#@markdown Threshold for non-maximum suppression of overlapping boxes (higher
#@markdown values mean more boxes will be shown):
NMS_THRESHOLD = 0.15 #@param { type: "slider", min: 0.05, max: 1.0, step: 0.01}

interactive.IMAGE_COND_MIN_CONF = MIN_CONFIDENCE
interactive.IMAGE_COND_NMS_IOU_THRESHOLD = NMS_THRESHOLD

query_image = images[QUERY_IMAGE_ID]
target_image = images[TARGET_IMAGE_ID]
_, _, boxes = model.embed_image(target_image)
plotting.create_image_conditional_figure(
    query_image=model.preprocess_image(query_image),
    target_image=model.preprocess_image(target_image),
    target_boxes=boxes, fig_size=600)
interactive.register_box_selection_callback(model, query_image, target_image, colab_output)

# Benchmark inference speed
- This section shows how to benchmark the inference speed of OWL-ViT. 
- Speed and accuracy can be traded off by reducing the input resolution. 
- This is done by truncating the position embeddings, and it works if the model was trained with heavy size augmentation and padding at the bottom and/or right of the image.

In [None]:
config = configs.clip_b16.get_config(init_mode='canonical_checkpoint')

# To use variable inference resolution, patch size and native (=training) grid
# size need to be added to the config:
config.model.body.patch_size = int(config.model.body.variant[-2:])
config.model.body.native_image_grid_size = (
    config.dataset_configs.input_size // config.model.body.patch_size
)

In [None]:
class PredictWithTextEmbeddings(models.TextZeroShotDetectionModule):
  """Module that performs box prediction with precomputed query embeddings."""

  def __call__(self, image, query_embeddings):
    feature_map = self.image_embedder(image[None, ...], False)  # Add batch dim.
    b, h, w, d = feature_map.shape
    image_features = feature_map.reshape(b, h * w, d)
    boxes = self.box_predictor(
        image_features=image_features, feature_map=feature_map
    )['pred_boxes']
    logits = self.class_predictor(image_features, query_embeddings[None, ...])[
        'pred_logits'
    ]
    return boxes, logits


module = PredictWithTextEmbeddings(
    body_configs=config.model.body,
    normalize=config.model.normalize,
    box_bias=config.model.box_bias,
)

variables = module.load_variables(config.init_from.checkpoint_path)


@jax.jit
def predict(image, query_embeddings):
  return module.apply(variables, image, query_embeddings)

In [None]:
import time

# Get fake query embeddings for benchmarking (1203 classes):
embed_dim = models.clip_model.CONFIGS[config.model.body.variant]['embed_dim']
query_embeddings = jax.random.normal(jax.random.PRNGKey(0), (1203, embed_dim))

# Resolutions at which to benchmark the model:
if config.model.body.patch_size == 16:
  sizes = [100, 200, 368, 400, 448, 480, 528, 576, 624, 672, 736]
else:
  raise ValueError(
      'Please define image sizes for patch size:'
      f' {config.model.body.patch_size}'
  )
num_trials = 5
all_timings = {}
for image_size in sizes:
  print(f'Benchmarking image size: {image_size}')

  # Get fake image for benchmarking:
  image = jax.random.uniform(jax.random.PRNGKey(0), (image_size, image_size, 3))
  timings = []
  for i in range(num_trials + 1):  # Add 1 trial to account for compilation.
    start_time = time.time()
    boxes, logits = predict(image, query_embeddings)
    _ = jax.block_until_ready((boxes, logits))
    timings.append(time.time() - start_time)

  # Store the median. Note that the first trial will always be very slow due to
  # model commpilation:
  all_timings[image_size] = np.median(timings)
  print(f'FPS at resolution={image_size}: {1/all_timings[image_size]:.2f}\n')