# Hands-on tutorial for DETR Object Detection with Transformers, run it in colab

This notebook, is a reference to:
* use the pre-trained models to make object detection 
# https://github.com/facebookresearch/detr?tab=readme-ov-file

## Preliminaries
This section contains the boilerplate necessary for the other sections. Run it first.

In [1]:
import math

from PIL import Image
import requests
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'

import ipywidgets as widgets
from IPython.display import display, clear_output

import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False);

In [2]:
# COCO classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

In [3]:
# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Function to convert center-width-height format to xmin-ymin-xmax-ymax format (for output bounding box post-processing)
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    """
    Rescale bounding boxes from normalized [0; 1] coordinates to image scales.

    Parameters:
    out_bbox (torch.Tensor): Predicted bounding boxes in center-width-height format, shape (N, 4).
    size (tuple): Tuple containing width and height of the image.

    Returns:
    torch.Tensor: Rescaled bounding boxes in xmin, ymin, xmax, ymax format, shape (N, 4).
    """
    img_w, img_h = size  # Image width and height

    # Convert center-width-height to xmin-ymin-xmax-ymax
    b = box_cxcywh_to_xyxy(out_bbox)

    # Scale bounding boxes to image dimensions
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)

    return b

In [4]:
def plot_results(pil_img, prob, boxes):
    """
    Plot the results of object detection on an image.

    Parameters:
    pil_img (PIL.Image.Image): The input image in PIL format.
    prob (torch.Tensor): The probabilities or confidence scores for each detected object.
    boxes (torch.Tensor): The bounding boxes for each detected object, in the format (xmin, ymin, xmax, ymax).

    Returns:
    None
    """
    names = []
    scores = []
    bbox_list = []

    # Create a new figure with a specified size and Display the image
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)

    # Get the current axes instance on the current figure
    ax = plt.gca()

    # Extend the colors list to ensure there are enough colors for all boxes
    colors = COLORS * 100

    # Iterate over each probability, bounding box, and color
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        # Add a rectangle patch to the axes for the bounding box
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))

        # Find the class with the highest probability
        cl = p.argmax()

        # Prepare the text for the label with class name and probability
        text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
        scores.append(float(p[cl]))
        names.append(CLASSES[cl])
        bbox_list.append((xmin, ymin, xmax, ymax))

        # Add the text label to the axes
        ax.text(xmin, ymin, text, fontsize=6,
                bbox=dict(facecolor='yellow', alpha=0.5))

    print("Identified objects **")
    for i in range(len(names)):
      print(f"label: {names[i]}, score: {scores[i]}, bbox: {bbox_list[i]}")

    # Remove the axes for better visualization
    plt.axis('off')

    # Show the plot
    plt.show()

# Detection - using a pre-trained model from TorchHub

To load the simplest model (DETR-R50) for fast inference from hub, run it on a custom image, and print the result. (any other model from the model zoo can be used).

In [5]:
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model.eval();

Using cache found in /root/.cache/torch/hub/facebookresearch_detr_main


We now retrieve the image as a PIL image and apply some pre-processing, run it through the model

In [6]:
def load_predict(url):
  """
  Load an image from URL, preprocess it, pass it through a model, and return outputs along with the original image.

  Parameters:
  url (str): URL of the image to load.

  Returns:
  torch.Tensor: Model outputs, containing pred_logits and pred_boxes.
  PIL.Image.Image: Original image loaded from URL.
  """
  # Open and load the image from URL
  im = Image.open(requests.get(url, stream=True).raw)

  # mean-std normalize the input image (batch-size: 1)
  img = transform(im).unsqueeze(0)

  # propagate through the model , that returns outputs that contain
  # pred_logits (torch.Tensor): Raw class scores for each predicted box.
  #pred_boxes (torch.Tensor): Coordinates of the predicted bounding boxes.
  outputs = model(img)

  return outputs, im

we filter the predictions. In particular, we keep only the objects for which the class confidence is higher than 0.4 (discounting the "non-object" predictions). You can lower this threshold if you want more predictions.

In [9]:
# keep only predictions with 0.4+ confidence
def review_prediction(outputs,im,threshold=0.4):
  """
  Review predictions by processing model outputs and plotting results.

  Parameters:
  outputs (dict): Model outputs containing 'pred_logits' and 'pred_boxes'.
  im (PIL.Image.Image): Original image in PIL format.

  Returns:
  None
  """
  probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
  keep = probas.max(-1).values > threshold

  # convert boxes from [0; 1] to image scales
  bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
  plot_results(im, probas[keep], bboxes_scaled)

That's it! Now try it on your own image and see what the self-attention of the Transformer Encoder learned by itself!


In [10]:
url = "https://images.data.gov.sg/api/traffic-images/2022/03/881b8734-cca2-49d2-844f-96f16e53a1ac.jpg"
outputs, im = load_predict(url)
review_prediction(outputs,im)

Output hidden; open in https://colab.research.google.com to view.


# Conclusion

In this notebook, we showed:
- how to use torchhub to compute predictions on your own image,