# Object detection with Vision Transformer for Open-World Localization (OWL-ViT)
The **[OWL-ViT](https://arxiv.org/abs/2205.06230)** model is an **open-vocabulary object detection model** that uses the standard Vision Transformer to perform detection.  
![OWL-ViT](https://github.com/google-research/scenic/raw/main/scenic/projects/owl_vit/data/owl_vit_schematic.png)  
Given an image and a free-text query, it finds objects matching that query in the image. It can also do one-shot object detection, i.e. detect objects based on a single example image. This notebook is to evaluate both tasks through the pre-trained model available in the Hugging Face's *transformers* library.   
As a result of my evaluation of this model, it looks like it is robust with reference to the open-vocabulary object detection task, but not yet robust enough to perform accurate one-shot object detection.   
A free Colab VM without hardware acceleration is enough to execute the code in this notebook.

## Settings

Install the missing requirements (actually only *transformers* not present in the Colab VM).

In [None]:
!pip install transformers

Import the general packages/classes to use.

In [None]:
import torch
from PIL import Image
from transformers import OwlViTProcessor, OwlViTForObjectDetection

Load the pre-trained ViT processor and model.

In [None]:
processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")

Define a function to upload images.

In [None]:
from google.colab import files

def upload_files():
  uploaded = files.upload()
  for k, v in uploaded.items():
    open(k, 'wb').write(v)
  return list(uploaded.keys())

## Open-vocabulary Object Detection

Define a function to add bounding boxes to the input image.

In [None]:
import cv2
import numpy as np

def add_bounding_boxes(boxes, scores, labels, text_queries, img):
  img = np.array(img)

  font = cv2.FONT_HERSHEY_SIMPLEX

  for box, score, label in zip(boxes, scores, labels):
      box = [int(i) for i in box.tolist()]

      if score >= score_threshold:
          img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
          if box[3] + 25 > 768:
              y = box[3] - 10
          else:
              y = box[3] + 25
              
          img = cv2.putText(
              img, text_queries[label], (box[0], y), font, 1, (255,0,0), 2, cv2.LINE_AA
          )

  return Image.fromarray(img, 'RGB')

Upload an image and display it.

In [None]:
uploaded_image_list = upload_files()
image = Image.open(uploaded_image_list[0])
display(image)

Indicate a list of free texts for querying and the minimal score threshold for the object detection.

In [None]:
texts = [["tuxedo cat", "tabby cat"]]
score_threshold = 0.08

Do object detection on the uploaded image, querying by the list of choosen free texts.

In [None]:
inputs = processor(text=texts, images=image, return_tensors="pt")
outputs = model(**inputs)

Target image height and width to rescale box predictions and then convert the outputs (bounding boxes and class logits) to the COCO API format.

In [None]:
target_sizes = torch.Tensor([image.size[::-1]])

results = processor.post_process(outputs=outputs, target_sizes=target_sizes)

Retrieve the predictions for the input image by corresponding text queries.

In [None]:
i = 0
text = texts[i]
boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]

for box, score, label in zip(boxes, scores, labels):
    box = [round(i, 2) for i in box.tolist()]
    if score >= score_threshold:
        print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")

Generate the bounding boxes for each prediction, overlay them to the input image and finally display it to the next code cell output.

In [None]:
labelled_image = add_bounding_boxes(boxes, scores, labels, text, image)
display(labelled_image)

## Image-guided Object Detection

Upload the target and query images and display them.

In [None]:
uploaded_image_guided_list = upload_files()

target_image = Image.open(uploaded_image_guided_list[0])
target_sizes = torch.Tensor([target_image.size[::-1]])

query_image = Image.open(uploaded_image_guided_list[1])

display(target_image)
display(query_image)

Process the target and query images.

In [None]:
device = 'cpu'
inputs = processor(images=target_image, query_images=query_image, return_tensors="pt").to(device)

for key, val in inputs.items():
    print(f"{key}: {val.shape}")

Get the predictions.

In [None]:
with torch.no_grad():
  outputs = model.image_guided_detection(**inputs)

for k, val in outputs.items():
    if k not in {"text_model_output", "vision_model_output"}:
        print(f"{k}: shape of {val.shape}")

print("\nVision model outputs")
for k, val in outputs.vision_model_output.items():
    print(f"{k}: shape of {val.shape}") 

Generate the bounding boxes for each prediction, overlay them to the target image and finally display it to the next code cell output.

In [None]:
img = cv2.cvtColor(np.array(target_image), cv2.COLOR_BGR2RGB)
outputs.logits = outputs.logits.cpu()
outputs.target_pred_boxes = outputs.target_pred_boxes.cpu() 

results = processor.post_process_image_guided_detection(outputs=outputs, threshold=0.6, nms_threshold=0.3, target_sizes=target_sizes)
boxes, scores = results[0]["boxes"], results[0]["scores"]

for box, score in zip(boxes, scores):
    box = [int(i) for i in box.tolist()]

    img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5)
    if box[3] + 25 > 768:
        y = box[3] - 10
    else:
        y = box[3] + 25 
        
display(Image.fromarray(img[:,:,::-1]))