# Exercise 5: DETR Walkthrough

Explore query-based object detection with DETR


- a) Model inspection: Load DETR and identify components
- b) Forward pass: Inspect output tensor shapes
- c) No-object analysis: Count sparse predictions
- d) Visualization: Post-process and display detections

## Setup

In [None]:
import torch
import numpy as np
import requests
from io import BytesIO
from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.patches as patches

from transformers import DetrForObjectDetection, DetrImageProcessor

In [None]:
# Load model and processor
model_id = "facebook/detr-resnet-50"
model = DetrForObjectDetection.from_pretrained(model_id)
processor = DetrImageProcessor.from_pretrained(model_id)
model.eval()

device = torch.device("cpu")  # laptop-friendly
model.to(device)

print(f"Model loaded: {model_id}")

## Part a) Model Inspection

Identify the main components of DETR:
- CNN backbone (ResNet-50)
- Transformer encoder
- Transformer decoder
- Prediction heads (class and box)

In [None]:
# High-level model structure
print(model)

In [None]:
# Access internal components
print("DETR components")
print(f"Backbone:")
model.model.backbone


In [None]:
print(f"Encoder:")
model.model.encoder


In [None]:
print(f"Decoder:")
model.model.decoder

In [None]:
# Object queries - the learnable embeddings
query_embeds = model.model.query_position_embeddings.weight
print(f"Object Queries")
print(f"Shape: {query_embeds.shape}")
print(f"Number of queries (Q): {query_embeds.shape[0]}")
print(f"Query dimension: {query_embeds.shape[1]}")

In [None]:
# Prediction heads
print("Prediction Heads")
print(f"Class head: {model.class_labels_classifier}")
print(f"Box head: {model.bbox_predictor}")

## Part b) Forward Pass and Tensor Interpretation

Run inference and inspect output shapes.

In [None]:
# Load an image
url = "https://cdn.mos.cms.futurecdn.net/vhQreQN76LUVdycsEDUFTH-1024-80.jpg"
response = requests.get(url)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert("RGB")

print(f"Image size (W, H): {image.size}")
plt.figure(figsize=(10, 6))
plt.imshow(image)
plt.axis("off")
plt.title("Input image")
plt.show()

In [None]:
# Preprocess
inputs = processor(images=image, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

print(f"Input keys: {inputs.keys()}")
print(f"pixel_values shape: {inputs['pixel_values'].shape}")  # (B, 3, H, W)

In [None]:
# Forward pass
with torch.no_grad():
    outputs = model(**inputs)

print(f"Output shapes")
print(f"logits: {outputs.logits.shape}")          # (B, Q, K+1)
print(f"pred_boxes: {outputs.pred_boxes.shape}")  # (B, Q, 4)

In [None]:
# Why K+1 classes?
num_classes = model.config.num_labels
print(f"Number of object classes: {num_classes}")
print(f"Total classes in logits: {outputs.logits.shape[-1]}")
print(f"Extra class is: no-object (background)")

## Part c) No-Object Analysis

Most queries predict "no object" - examine this sparsity.

In [None]:
logits = outputs.logits[0].cpu().numpy()  # (Q, K+1)
pred_classes = np.argmax(logits, axis=-1)

# In DETR, the last class index is "no-object"
no_obj_id = model.config.num_labels  # 91 (index of no-object class)
num_noobj = np.sum(pred_classes == no_obj_id)

print(f"Total queries: {logits.shape[0]}")
print(f"Queries predicting 'no-object': {num_noobj}")
print(f"Queries predicting actual objects: {logits.shape[0] - num_noobj}")

In [None]:
# Examine confidence distribution
# Apply softmax to get probabilities
probs = torch.softmax(outputs.logits[0], dim=-1).cpu().numpy()
max_probs = np.max(probs, axis=-1)  # max probability per query

print(f"\nConfidence distribution:")
print(f"  Max confidence: {max_probs.max():.3f}")
print(f"  Min confidence: {max_probs.min():.3f}")
print(f"  Mean confidence: {max_probs.mean():.3f}")

# Queries with high confidence for actual objects
threshold = 0.7
confident_objects = np.sum((pred_classes != no_obj_id) & (max_probs > threshold))
print(f"\nQueries with object class and confidence > {threshold}: {confident_objects}")

## Part d) Post-processing and Visualization

Convert normalized boxes to pixel coordinates and visualize.

In [None]:
# Post-process with threshold
threshold = 0.7
target_sizes = torch.tensor([image.size[::-1]], device=device)  # (H, W)
results = processor.post_process_object_detection(
    outputs, target_sizes=target_sizes, threshold=threshold
)[0]

print(f"Detections (threshold={threshold})")
print(f"Number of detections: {len(results['boxes'])}")
print(f"Boxes shape: {results['boxes'].shape}")
print(f"Scores shape: {results['scores'].shape}")
print(f"Labels shape: {results['labels'].shape}")

In [None]:
# Print detections
print("Detected objects:")
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
    cls = model.config.id2label[label.item()]
    box_coords = [round(x, 1) for x in box.tolist()]
    print(f"  {cls:>12s}  score={score:.3f}  box(xyxy)={box_coords}")

In [None]:
# Visualize detections
fig, ax = plt.subplots(figsize=(12, 8))
ax.imshow(image)

colors = plt.cm.tab10.colors

for i, (score, label, box) in enumerate(zip(results["scores"], results["labels"], results["boxes"])):
    xmin, ymin, xmax, ymax = box.tolist()
    w, h = xmax - xmin, ymax - ymin

    color = colors[label.item() % len(colors)]
    rect = patches.Rectangle(
        (xmin, ymin), w, h,
        fill=False, linewidth=2, edgecolor=color
    )
    ax.add_patch(rect)

    cls = model.config.id2label[label.item()]
    ax.text(
        xmin, ymin - 5,
        f"{cls} {score:.2f}",
        bbox=dict(facecolor=color, alpha=0.7),
        fontsize=10, color="white"
    )

ax.axis("off")
ax.set_title(f"DETR detections (threshold={threshold})")
plt.tight_layout()
plt.show()

## Effect of Threshold

Lowering the threshold reveals more (potentially false) detections.

In [None]:
# Compare different thresholds
thresholds = [0.9, 0.7, 0.5, 0.3]

fig, axes = plt.subplots(1, len(thresholds), figsize=(16, 4))

for ax, thresh in zip(axes, thresholds):
    results = processor.post_process_object_detection(
        outputs, target_sizes=target_sizes, threshold=thresh
    )[0]

    ax.imshow(image)
    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        xmin, ymin, xmax, ymax = box.tolist()
        rect = patches.Rectangle(
            (xmin, ymin), xmax-xmin, ymax-ymin,
            fill=False, linewidth=2, edgecolor="red"
        )
        ax.add_patch(rect)

    ax.set_title(f"threshold={thresh}\n({len(results['boxes'])} detections)")
    ax.axis("off")

plt.tight_layout()
plt.show()