# Object detection in videos with DETR

Now that we've seen how we can get some nice results for images, let's see if we can generalise this to video. 

In [1]:
from transformers import DetrImageProcessor, DetrForObjectDetection 
import torch
import cv2

from tqdm import tqdm
import gc
import numpy as np
import time

Again, we load in the DETR model using the Transformer package.

In [2]:
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm") 
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50", revision="no_timm")

We use the same convenience functions to create the label font size and colour of the bounding boxes.

In [3]:
def calculate_label_size(box_width, box_height, frame_width, frame_height, min_scale=0.4, max_scale=1.2):
    """
    Calculate appropriate font scale based on bounding box size and frame dimensions.
    """
    box_size_ratio = (box_width * box_height) / (frame_width * frame_height)
    font_scale = min_scale + (np.log(1 + box_size_ratio * 100) / 5) * (max_scale - min_scale)
    return np.clip(font_scale, min_scale, max_scale)

def generate_color_palette(num_classes):
    """
    Generate a separate colour for every label predicted for that image. 
    """
    np.random.seed(42)  # Ensure the same colors are generated every time
    return {label: tuple(np.random.randint(0, 255, 3).tolist()) for label in range(num_classes)}

# Create a color palette for each class (COCO has 91 classes by default)
colors = generate_color_palette(num_classes=91)

# Function to get the color for a specific label
def get_label_color(label):
    return colors.get(label, (0, 255, 0))  # Default to green if no label found

We can keep track of the frames-per-second using an FPS counter. The function below contains all the formatting option we need to make a nicely presented counter in the top corner of the video.  

In [4]:
def add_fps_counter(frame, fps, frame_width):
    """
    Add FPS counter to the frame.

    Args:
        frame: The video frame
        fps: Current FPS
        frame_width: Width of the frame
    """
    # Format FPS text
    fps_text = f"FPS: {fps:.1f}"

    # Set font properties
    font_scale = 0.7
    font_thickness = 2
    font = cv2.FONT_HERSHEY_SIMPLEX

    # Get text size
    (text_width, text_height), baseline = cv2.getTextSize(
        fps_text, font, font_scale, font_thickness
    )

    # Calculate position (top right corner with padding)
    padding = 10
    x = frame_width - text_width - padding
    y = text_height + padding + baseline

    # Draw background rectangle
    cv2.rectangle(
        frame,
        (int(x - 5), int(y - text_height - baseline - 5)),
        (int(x + text_width + 5), int(y + 5)),
        (0, 0, 0),
        -1
    )

    # Draw FPS text
    cv2.putText(
        frame,
        fps_text,
        (int(x), int(y)),
        font,
        font_scale,
        (255, 255, 255),  # White text
        font_thickness,
        cv2.LINE_AA
    )

Finally, we're ready to run inference over the video. The function below does the following:
* Detects whether a GPU accelerator is available to apply inference to each frame. In my case, I have MPS (MacBook M3), so we'll enable this.
* OpenCVs `VideoCapture` method is used to open a connection to the video.
* The number of frames in the video is calculated, as well as the size of each one.
* The number of frames is used to calculate the FPS and create the progress bar.
* For each frame:
    * The image is converted from BGR to RGB, and converted to tensors
    * The image is input into the DETR model for inference
    * The result is post-processed, extracting the most likely label, as well as its probability and bounding box coordinates
    * And then for each detected bounding box ...
        * The results are plotted, with their associated label and class colour
* At the last frame, the connection to the video is closed.

In [5]:
def process_video(video_path, output_path, threshold = 0.9):
    """
    Process a video file using GPU acceleration (if available) for object detection.
    """
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"Using device: {device}")

    model.to(device)

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise ValueError("Error opening video file")

    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))

    pbar = tqdm(total=total_frames, desc="Processing frames",
                unit="frames", dynamic_ncols=True)

    frame_count = 0 # To count total frames.
    total_fps = 0 # To get the final frames per second.

    try:
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # Start timing for this frame
            start_time = time.time()

            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            inputs = processor(images=frame_rgb, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}

            with torch.no_grad():
                outputs = model(**inputs)

            target_sizes = torch.tensor([frame_rgb.shape[:2]]).to(device)
            results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]

            scores = results["scores"].cpu().numpy()
            keep = scores > threshold

            boxes = results["boxes"].cpu().numpy()[keep]
            labels = results["labels"].cpu().numpy()[keep]
            scores = scores[keep]

            # Draw bounding boxes
            for box, label, score in zip(boxes, labels, scores):
                xmin, ymin, xmax, ymax = box
                box_width = xmax - xmin
                box_height = ymax - ymin
                font_scale = calculate_label_size(box_width, box_height, frame_width, frame_height)

                label_text = f"{model.config.id2label[label]}: {score:.2f}"
                (text_width, text_height), baseline = cv2.getTextSize(
                    label_text,
                    cv2.FONT_HERSHEY_SIMPLEX,
                    font_scale,
                    1
                )

                # Get the color for this label
                color = get_label_color(label)

                # Draw rectangle and label with the same color for the same class
                cv2.rectangle(frame, (int(xmin), int(ymin)), (int(xmax), int(ymax)), color, max(2, int(font_scale * 3)))
                cv2.rectangle(frame, (int(xmin), int(ymin) - text_height - baseline - 5),
                              (int(xmin) + text_width, int(ymin)), color, -1)
                cv2.putText(frame, label_text, (int(xmin), int(ymin) - baseline - 2),
                            cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), max(1, int(font_scale * 1.5)), cv2.LINE_AA)

            end_time = time.time()
            # Get the current fps.
            fps = 1 / (end_time - start_time)
            # Add `fps` to `total_fps`.
            total_fps += fps
            # Increment frame count.
            frame_count += 1

            # Add FPS counter to frame
            add_fps_counter(frame, fps, frame_width)

            out.write(frame)
            pbar.update(1)

            if pbar.n % 100 == 0:
                torch.cuda.empty_cache()
                gc.collect()

    except Exception as e:
        print(f"An error occurred: {e}")
        raise
    finally:
        pbar.close()
        cap.release()
        out.release()
        cv2.destroyAllWindows()
        torch.cuda.empty_cache()
        gc.collect()
        model.to('cpu')

    print("\nVideo processing completed!")

## Let's apply this to a video!

We'll use another video from one of my holidays: a little clip of driving along Tower Bridge in London. This scene has a lot going on so it will give us a good idea of what DETR is capable of.

In [6]:
def display_mp4(video_path):

    # Create a VideoCapture object and read from input file
    cap = cv2.VideoCapture(video_path)
    
    # Check if camera opened successfully
    if (cap.isOpened() == False):
        print("Error opening video file")
    
    # Read until video is completed
    while (cap.isOpened()):
    
        # Capture frame-by-frame
        ret, frame = cap.read()
        if ret == True:
            # Display the resulting frame
            cv2.imshow('Frame', frame)
    
            # Press Q on keyboard to exit
            if cv2.waitKey(25) & 0xFF == ord('q'):
                break
    
        # Break the loop
        else:
            break
    
    # When everything done, release
    # the video capture object
    cap.release()
    
    # Closes all the frames
    cv2.destroyAllWindows()

In [8]:
display_mp4("inference_data/own_videos/tower-bridge-london-long.mp4")



In [10]:
process_video("inference_data/own_videos/tower-bridge-london.mp4", "tower-bridge-labeled.mp4")

Using device: mps


Processing frames: 100%|██████████| 664/664 [00:59<00:00, 11.12frames/s]



Video processing completed!


In [9]:
display_mp4("inference_data/own_videos/tower-bridge-labeled.mp4")

Error opening video file


OpenCV: Couldn't read video stream from file "inference_data/own_videos/tower-bridge-labeled.mp4"
