# Loss Functions, Label Assignment, and Training

*Notebook 4 of 5 in the YOLOv11 from-scratch series*

## Introduction

With our model architecture complete (backbone, neck, and head from Notebooks 2-3), we now face three critical challenges that determine whether the detector actually learns to find objects:

1. **IoU computation** - How do we measure the geometric overlap between predicted and ground-truth boxes? The choice of IoU variant directly affects gradient quality and convergence speed.

2. **Label assignment strategy** - Given thousands of anchor points but only a handful of ground-truth boxes per image, which anchors should be responsible for predicting each object? This is the assignment problem.

3. **Loss function design** - How do we combine classification, localization, and distribution regression objectives into a single scalar loss that balances all three tasks?

YOLOv11 addresses these with:
- **CIoU** (Complete IoU) for box regression, which captures overlap, center distance, and aspect ratio in a single differentiable metric
- **Task-Aligned Learning (TAL)** for label assignment, which selects anchors based on both classification confidence and localization quality
- A **composite loss** combining BCE classification, CIoU box regression, and Distribution Focal Loss (DFL)

We will build each component from scratch and run a small training loop on synthetic data to verify that the entire pipeline works end-to-end.

In [None]:
# --- Colab Environment Setup ---
import sys
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    %pip install -q matplotlib seaborn scikit-learn scipy tqdm datasets
    print("Colab dependencies installed")


## Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math
from typing import List, Tuple, Optional

from datasets import load_dataset
from PIL import Image

## Model components from Notebooks 2-3

To keep this notebook self-contained, we re-define all model components (backbone, neck, head) in a single compact cell. These are identical to the implementations in Notebooks 2 and 3. Refer to those notebooks for detailed explanations of each block.

In [None]:
class ConvBNSiLU(nn.Module):
    def __init__(self, c_in, c_out, k=1, s=1, p=None, g=1):
        super().__init__()
        if p is None: p = k // 2
        self.conv = nn.Conv2d(c_in, c_out, k, s, p, groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c_out)
        self.act = nn.SiLU(inplace=True)
    def forward(self, x): return self.act(self.bn(self.conv(x)))

class Bottleneck(nn.Module):
    def __init__(self, c_in, c_out, shortcut=True, k=(3,3), e=0.5):
        super().__init__()
        c_hid = int(c_out * e)
        self.cv1 = ConvBNSiLU(c_in, c_hid, k[0])
        self.cv2 = ConvBNSiLU(c_hid, c_out, k[1])
        self.add = shortcut and c_in == c_out
    def forward(self, x): return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

class C3k2(nn.Module):
    def __init__(self, c_in, c_out, n=1, shortcut=True, e=0.5):
        super().__init__()
        self.c = int(c_out * e)
        self.cv1 = ConvBNSiLU(c_in, 2 * self.c, 1)
        self.cv2 = ConvBNSiLU((2 + n) * self.c, c_out, 1)
        self.bottlenecks = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, (3,3), 1.0) for _ in range(n))
    def forward(self, x):
        y = list(self.cv1(x).chunk(2, dim=1))
        for bn in self.bottlenecks: y.append(bn(y[-1]))
        return self.cv2(torch.cat(y, dim=1))

class SPPF(nn.Module):
    def __init__(self, c_in, c_out, k=5):
        super().__init__()
        c_hid = c_in // 2
        self.cv1 = ConvBNSiLU(c_in, c_hid, 1)
        self.cv2 = ConvBNSiLU(c_hid * 4, c_out, 1)
        self.pool = nn.MaxPool2d(k, stride=1, padding=k // 2)
    def forward(self, x):
        x = self.cv1(x); y1 = self.pool(x); y2 = self.pool(y1); y3 = self.pool(y2)
        return self.cv2(torch.cat([x, y1, y2, y3], dim=1))

class YOLOv11Backbone(nn.Module):
    def __init__(self, c_in=3, base=64):
        super().__init__()
        c1, c2, c3, c4, c5 = base, base*2, base*4, base*8, base*16
        self.stem = ConvBNSiLU(c_in, c1, 3, s=2)
        self.s1_down = ConvBNSiLU(c1, c2, 3, s=2); self.s1_c3k2 = C3k2(c2, c2, n=2)
        self.s2_down = ConvBNSiLU(c2, c3, 3, s=2); self.s2_c3k2 = C3k2(c3, c3, n=2)
        self.s3_down = ConvBNSiLU(c3, c4, 3, s=2); self.s3_c3k2 = C3k2(c4, c4, n=2)
        self.s4_down = ConvBNSiLU(c4, c5, 3, s=2); self.s4_c3k2 = C3k2(c5, c5, n=2)
        self.sppf = SPPF(c5, c5)
    def forward(self, x):
        x = self.stem(x)
        x = self.s1_c3k2(self.s1_down(x))
        p3 = self.s2_c3k2(self.s2_down(x))
        p4 = self.s3_c3k2(self.s3_down(p3))
        p5 = self.sppf(self.s4_c3k2(self.s4_down(p4)))
        return p3, p4, p5

class FPN(nn.Module):
    def __init__(self, c3=256, c4=512, c5=1024):
        super().__init__()
        self.lateral5 = ConvBNSiLU(c5, c4, 1)
        self.lateral4 = ConvBNSiLU(c4, c3, 1)
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.fuse4 = C3k2(c4 * 2, c4, n=2)
        self.fuse3 = C3k2(c3 * 2, c3, n=2)
    def forward(self, p3, p4, p5):
        p5_up = self.up(self.lateral5(p5))
        p4 = self.fuse4(torch.cat([p4, p5_up], dim=1))
        p4_up = self.up(self.lateral4(p4))
        p3 = self.fuse3(torch.cat([p3, p4_up], dim=1))
        return p3, p4, p5

class PAN(nn.Module):
    def __init__(self, c3=256, c4=512, c5=1024):
        super().__init__()
        self.down3 = ConvBNSiLU(c3, c3, 3, s=2)
        self.fuse4 = C3k2(c3 + c4, c4, n=2)
        self.down4 = ConvBNSiLU(c4, c4, 3, s=2)
        self.fuse5 = C3k2(c4 + c5, c5, n=2)
    def forward(self, p3, p4, p5):
        p4 = self.fuse4(torch.cat([self.down3(p3), p4], dim=1))
        p5 = self.fuse5(torch.cat([self.down4(p4), p5], dim=1))
        return p3, p4, p5

class C2PSA(nn.Module):
    def __init__(self, c, n_heads=8):
        super().__init__()
        self.cv1 = ConvBNSiLU(c, c, 1)
        self.attn = nn.MultiheadAttention(c, n_heads, batch_first=True)
        self.ffn = nn.Sequential(ConvBNSiLU(c, c * 2, 1), ConvBNSiLU(c * 2, c, 1))
        self.cv2 = ConvBNSiLU(c, c, 1)
    def forward(self, x):
        y = self.cv1(x)
        B, C, H, W = y.shape
        flat = y.flatten(2).permute(0, 2, 1)
        flat = flat + self.attn(flat, flat, flat, need_weights=False)[0]
        y = flat.permute(0, 2, 1).view(B, C, H, W)
        y = y + self.ffn(y)
        return self.cv2(y)

class DFLHead(nn.Module):
    def __init__(self, c_in, num_classes=80, reg_max=16):
        super().__init__()
        self.reg_max = reg_max
        self.num_classes = num_classes
        self.cls_convs = nn.Sequential(ConvBNSiLU(c_in, c_in, 3), ConvBNSiLU(c_in, c_in, 3))
        self.reg_convs = nn.Sequential(ConvBNSiLU(c_in, c_in, 3), ConvBNSiLU(c_in, c_in, 3))
        self.cls_pred = nn.Conv2d(c_in, num_classes, 1)
        self.reg_pred = nn.Conv2d(c_in, 4 * reg_max, 1)
        self.proj = nn.Parameter(torch.arange(reg_max, dtype=torch.float32), requires_grad=False)
    def forward(self, x):
        cls_out = self.cls_pred(self.cls_convs(x))
        reg_raw = self.reg_pred(self.reg_convs(x))
        B, _, H, W = reg_raw.shape
        reg_dist = reg_raw.view(B, 4, self.reg_max, H, W)
        reg_box = F.softmax(reg_dist, dim=2)
        reg_box = (reg_box * self.proj.view(1, 1, -1, 1, 1)).sum(dim=2)
        return cls_out, reg_box, reg_raw

class DetectionHead(nn.Module):
    def __init__(self, channels=[256, 512, 1024], num_classes=80, reg_max=16):
        super().__init__()
        self.heads = nn.ModuleList([DFLHead(c, num_classes, reg_max) for c in channels])
    def forward(self, features):
        return [head(f) for head, f in zip(self.heads, features)]

class YOLOv11(nn.Module):
    def __init__(self, num_classes=80, reg_max=16):
        super().__init__()
        self.backbone = YOLOv11Backbone()
        self.fpn = FPN()
        self.pan = PAN()
        self.c2psa = C2PSA(1024)
        self.head = DetectionHead(num_classes=num_classes, reg_max=reg_max)
    def forward(self, x):
        p3, p4, p5 = self.backbone(x)
        p5 = self.c2psa(p5)
        p3, p4, p5 = self.fpn(p3, p4, p5)
        p3, p4, p5 = self.pan(p3, p4, p5)
        return self.head([p3, p4, p5])

print("Model components loaded successfully.")
print(f"YOLOv11 parameters: {sum(p.numel() for p in YOLOv11(num_classes=80).parameters()):,}")

## The evolution of IoU metrics

Intersection over Union (IoU) is the foundational metric for measuring bounding box quality. However, the basic IoU has significant limitations that led to a series of improvements:

### IoU (Intersection over Union)
$$\text{IoU} = \frac{|B_p \cap B_{gt}|}{|B_p \cup B_{gt}|}$$

Simple and intuitive, but has a critical flaw: when two boxes do not overlap, IoU is zero regardless of how far apart they are. This means **zero gradient** for non-overlapping predictions, making it useless as a standalone loss for poorly initialized detectors.

### GIoU (Generalized IoU)
$$\text{GIoU} = \text{IoU} - \frac{|C \setminus (B_p \cup B_{gt})|}{|C|}$$

where $C$ is the smallest enclosing box. GIoU adds a penalty for the gap between the predicted and ground-truth boxes. It provides gradients even when boxes do not overlap, but converges slowly because it only penalizes the empty area ratio, not the distance directly.

### DIoU (Distance IoU)
$$\text{DIoU} = \text{IoU} - \frac{d^2(\mathbf{b}_p, \mathbf{b}_{gt})}{c^2}$$

where $d$ is the Euclidean distance between box centers and $c$ is the diagonal of the enclosing box. By directly penalizing center-point distance, DIoU converges much faster than GIoU.

### CIoU (Complete IoU)
$$\text{CIoU} = \text{IoU} - \frac{d^2(\mathbf{b}_p, \mathbf{b}_{gt})}{c^2} - \alpha v$$

where $v = \frac{4}{\pi^2}\left(\arctan\frac{w_{gt}}{h_{gt}} - \arctan\frac{w_p}{h_p}\right)^2$ measures aspect ratio consistency, and $\alpha = \frac{v}{(1 - \text{IoU}) + v}$ is an adaptive weight. CIoU provides **complete geometric alignment** by considering overlap, center distance, and aspect ratio simultaneously. This is what YOLOv11 uses for box regression.

## IoU implementations

In [None]:
def compute_iou(box1, box2, mode='ciou', eps=1e-7):
    """Compute IoU variants between two sets of boxes.
    
    Args:
        box1: (N, 4) in [x1, y1, x2, y2] format
        box2: (M, 4) in [x1, y1, x2, y2] format
        mode: 'iou', 'giou', 'diou', or 'ciou'
    Returns:
        iou: (N, M) pairwise IoU values
    """
    # Intersection
    inter_x1 = torch.max(box1[:, None, 0], box2[None, :, 0])
    inter_y1 = torch.max(box1[:, None, 1], box2[None, :, 1])
    inter_x2 = torch.min(box1[:, None, 2], box2[None, :, 2])
    inter_y2 = torch.min(box1[:, None, 3], box2[None, :, 3])
    inter = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0)
    
    # Union
    area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
    area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])
    union = area1[:, None] + area2[None, :] - inter
    
    iou = inter / (union + eps)
    
    if mode == 'iou':
        return iou
    
    # Enclosing box
    enc_x1 = torch.min(box1[:, None, 0], box2[None, :, 0])
    enc_y1 = torch.min(box1[:, None, 1], box2[None, :, 1])
    enc_x2 = torch.max(box1[:, None, 2], box2[None, :, 2])
    enc_y2 = torch.max(box1[:, None, 3], box2[None, :, 3])
    enc_area = (enc_x2 - enc_x1) * (enc_y2 - enc_y1)
    
    if mode == 'giou':
        return iou - (enc_area - union) / (enc_area + eps)
    
    # Center distance
    cx1 = (box1[:, 0] + box1[:, 2]) / 2
    cy1 = (box1[:, 1] + box1[:, 3]) / 2
    cx2 = (box2[:, 0] + box2[:, 2]) / 2
    cy2 = (box2[:, 1] + box2[:, 3]) / 2
    
    center_dist = (cx1[:, None] - cx2[None, :]) ** 2 + (cy1[:, None] - cy2[None, :]) ** 2
    diag_dist = (enc_x2 - enc_x1) ** 2 + (enc_y2 - enc_y1) ** 2
    
    if mode == 'diou':
        return iou - center_dist / (diag_dist + eps)
    
    # CIoU: aspect ratio penalty
    w1 = box1[:, 2] - box1[:, 0]
    h1 = box1[:, 3] - box1[:, 1]
    w2 = box2[:, 2] - box2[:, 0]
    h2 = box2[:, 3] - box2[:, 1]
    
    v = (4 / math.pi ** 2) * (
        torch.atan(w2[None, :] / (h2[None, :] + eps)) - 
        torch.atan(w1[:, None] / (h1[:, None] + eps))
    ) ** 2
    
    with torch.no_grad():
        alpha = v / (1 - iou + v + eps)
    
    return iou - center_dist / (diag_dist + eps) - alpha * v

## Visualizing IoU variants

To build intuition for how these metrics differ, we slide a prediction box horizontally away from a fixed ground-truth box and plot each IoU variant. Notice how:

- **IoU** drops to zero once boxes separate and stays there (no gradient signal)
- **GIoU** continues to decrease below zero but slowly
- **DIoU** decreases more steeply due to the direct distance penalty
- **CIoU** behaves like DIoU here (same aspect ratio), but would differ for shape changes

In [None]:
def visualize_iou_variants():
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    # Fixed reference box
    ref = torch.tensor([[2.0, 2.0, 5.0, 5.0]])
    
    # Move a box horizontally
    offsets = torch.linspace(0, 6, 50)
    modes = ['iou', 'giou', 'diou', 'ciou']
    
    for ax, mode in zip(axes, modes):
        values = []
        for dx in offsets:
            pred = torch.tensor([[2.0 + dx.item(), 2.0, 5.0 + dx.item(), 5.0]])
            val = compute_iou(ref, pred, mode=mode)
            values.append(val.item())
        ax.plot(offsets.numpy(), values, linewidth=2)
        ax.set_title(mode.upper(), fontsize=14, fontweight='bold')
        ax.set_xlabel('Horizontal offset')
        ax.set_ylabel(f'{mode} value')
        ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
        ax.grid(True, alpha=0.3)
    
    plt.suptitle('IoU Variants: Response to Horizontal Box Translation', fontsize=14)
    plt.tight_layout()
    plt.show()

visualize_iou_variants()

## Task-Aligned Learning (TAL)

Label assignment is the bridge between ground-truth annotations and the thousands of predictions a detector makes. For each ground-truth box, we need to decide which anchor points are "responsible" for predicting it.

### The problem with simpler strategies

- **IoU-based assignment** (used in earlier YOLO versions): assigns anchors based purely on spatial overlap with GT. This ignores whether the model is actually confident about the prediction, leading to misalignment between assignment and model capacity.
- **Center-based assignment** (e.g., FCOS): assigns all anchors whose centers fall inside the GT box. Simple but does not consider prediction quality.

### Task-Aligned Learning

TAL resolves this by computing an **alignment metric** that combines both classification and localization quality:

$$t = s^\alpha \cdot u^\beta$$

where:
- $s$ is the predicted classification score for the GT class
- $u$ is the IoU between the predicted box and the GT box
- $\alpha = 1.0$ and $\beta = 6.0$ control the relative importance (localization is weighted much more heavily)

The assignment procedure:
1. Filter to anchors whose centers lie inside each GT box
2. Compute the alignment metric $t$ for all valid anchor-GT pairs
3. Select the **top-k** anchors (default $k=13$) per GT based on $t$
4. Resolve conflicts (anchor assigned to multiple GTs) by keeping the highest-alignment GT
5. Generate **soft label targets** by normalizing the alignment scores

This approach is more effective because it assigns labels to anchors that the model is already doing well on, creating a positive feedback loop that accelerates training.

In [None]:
class TaskAlignedAssigner:
    """Task-Aligned Label Assignment for anchor-free detection.
    
    Assigns ground truth to predictions using alignment metric
    that considers both classification confidence and box IoU.
    """
    
    def __init__(self, topk: int = 13, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
        self.topk = topk
        self.alpha = alpha
        self.beta = beta
        self.eps = eps
    
    @torch.no_grad()
    def assign(self, pred_scores, pred_bboxes, gt_labels, gt_bboxes, anchor_points, stride):
        """
        Args:
            pred_scores: (num_anchors, num_classes) predicted class scores (sigmoid)
            pred_bboxes: (num_anchors, 4) predicted boxes [x1, y1, x2, y2]
            gt_labels: (num_gt,) ground truth class indices
            gt_bboxes: (num_gt, 4) ground truth boxes [x1, y1, x2, y2]
            anchor_points: (num_anchors, 2) anchor center positions
            stride: feature stride for this level
        Returns:
            assigned_labels: (num_anchors,) -1 for background
            assigned_bboxes: (num_anchors, 4) 
            assigned_scores: (num_anchors, num_classes) soft labels
        """
        device = pred_scores.device
        num_anchors = pred_scores.shape[0]
        num_gt = gt_bboxes.shape[0]
        num_classes = pred_scores.shape[1]
        
        if num_gt == 0:
            return (
                torch.full((num_anchors,), -1, dtype=torch.long, device=device),
                torch.zeros((num_anchors, 4), device=device),
                torch.zeros((num_anchors, num_classes), device=device)
            )
        
        # Check if anchor centers fall inside GT boxes
        # anchor_points: (num_anchors, 2) [cx, cy]
        lt = anchor_points[:, None, :] - gt_bboxes[None, :, :2]  # (na, ng, 2)
        rb = gt_bboxes[None, :, 2:] - anchor_points[:, None, :]  # (na, ng, 2)
        in_gt = torch.cat([lt, rb], dim=-1).min(dim=-1).values > 0  # (na, ng)
        
        # Compute alignment metric
        # Get predicted class scores for GT classes
        gt_cls_scores = pred_scores[:, gt_labels]  # (na, ng)
        
        # Compute pairwise IoU
        pair_iou = compute_iou(pred_bboxes, gt_bboxes, mode='iou')  # (na, ng)
        pair_iou = pair_iou.clamp(0, 1)
        
        # Alignment metric: score^alpha * iou^beta
        alignment = gt_cls_scores.pow(self.alpha) * pair_iou.pow(self.beta)
        alignment[~in_gt] = 0  # mask out anchors not inside GT
        
        # Select top-k anchors per GT
        topk_mask = torch.zeros_like(alignment, dtype=torch.bool)
        for j in range(num_gt):
            vals = alignment[:, j]
            k = min(self.topk, (vals > 0).sum().item())
            if k > 0:
                _, topk_idx = vals.topk(k)
                topk_mask[topk_idx, j] = True
        
        alignment[~topk_mask] = 0
        
        # Resolve conflicts: each anchor -> highest alignment GT
        assigned_gt = alignment.argmax(dim=1)  # (na,)
        max_alignment = alignment.max(dim=1).values  # (na,)
        
        # Background mask
        bg_mask = max_alignment < self.eps
        assigned_gt[bg_mask] = -1
        
        # Build outputs
        assigned_labels = torch.where(
            bg_mask,
            torch.tensor(-1, device=device),
            gt_labels[assigned_gt.clamp(min=0)]
        )
        assigned_labels[bg_mask] = -1
        assigned_bboxes = torch.zeros((num_anchors, 4), device=device)
        assigned_bboxes[~bg_mask] = gt_bboxes[assigned_gt[~bg_mask]]
        
        # Soft label targets (normalized alignment score)
        assigned_scores = torch.zeros((num_anchors, num_classes), device=device)
        fg_mask = ~bg_mask
        if fg_mask.any():
            norm_align = max_alignment[fg_mask] / (max_alignment[fg_mask].max() + self.eps)
            assigned_scores[fg_mask, assigned_labels[fg_mask]] = norm_align
        
        return assigned_labels, assigned_bboxes, assigned_scores

print("TaskAlignedAssigner ready.")

## Composite loss function

YOLOv11's loss function combines three complementary objectives:

1. **Classification loss** (BCE with soft labels): Binary cross-entropy between predicted class logits and the soft label targets produced by TAL. Soft labels (values between 0 and 1 based on alignment quality) provide richer supervisory signal than hard 0/1 labels.

2. **Box regression loss** (CIoU): $\mathcal{L}_{box} = 1 - \text{CIoU}(\hat{b}, b^*)$, applied only to foreground (assigned) anchors. CIoU captures overlap, center distance, and aspect ratio in a single loss term.

3. **Distribution Focal Loss** (DFL): Instead of directly regressing LTRB offsets, the DFL head predicts a discrete probability distribution over integer bins $\{0, 1, \ldots, \text{reg\_max}-1\}$. The DFL loss is a weighted cross-entropy between adjacent bins:

$$\mathcal{L}_{DFL}(S_i, S_{i+1}, y) = -(1 - (y - i)) \log(S_i) - (y - i) \log(S_{i+1})$$

where $y$ is the continuous target offset, $i = \lfloor y \rfloor$, and $S_i$ is the softmax probability for bin $i$.

The total loss is a weighted sum:

$$\mathcal{L} = \lambda_{cls} \cdot \mathcal{L}_{cls} + \lambda_{box} \cdot \mathcal{L}_{box} + \lambda_{dfl} \cdot \mathcal{L}_{dfl}$$

with default weights $\lambda_{cls} = 0.5$, $\lambda_{box} = 7.5$, $\lambda_{dfl} = 1.5$. The high box weight reflects the importance of precise localization in object detection.

In [None]:
class YOLOv11Loss(nn.Module):
    """Composite loss for YOLOv11 training.
    
    Components:
        1. Classification: BCE with soft labels from TAL
        2. Box regression: CIoU loss
        3. Distribution Focal Loss: cross-entropy on DFL bins
    """
    
    def __init__(self, num_classes: int = 80, reg_max: int = 16, strides: List[int] = [8, 16, 32],
                 cls_weight: float = 0.5, box_weight: float = 7.5, dfl_weight: float = 1.5):
        super().__init__()
        self.num_classes = num_classes
        self.reg_max = reg_max
        self.strides = strides
        self.cls_weight = cls_weight
        self.box_weight = box_weight
        self.dfl_weight = dfl_weight
        self.assigner = TaskAlignedAssigner()
        self.bce = nn.BCEWithLogitsLoss(reduction='none')
    
    def _make_anchor_points(self, feat_sizes, device):
        """Generate anchor points for all feature levels."""
        all_points = []
        all_strides = []
        for (h, w), stride in zip(feat_sizes, self.strides):
            sy, sx = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
            points = torch.stack([sx.flatten(), sy.flatten()], dim=-1).float()
            points = (points + 0.5) * stride  # center of each cell in image coords
            all_points.append(points)
            all_strides.append(torch.full((h * w,), stride, dtype=torch.float32))
        return torch.cat(all_points).to(device), torch.cat(all_strides).to(device)
    
    def _decode_boxes(self, box_pred, anchor_points, strides):
        """Decode LTRB offsets to x1y1x2y2 boxes."""
        lt = box_pred[:, :2] * strides.unsqueeze(-1)
        rb = box_pred[:, 2:] * strides.unsqueeze(-1)
        x1y1 = anchor_points - lt
        x2y2 = anchor_points + rb
        return torch.cat([x1y1, x2y2], dim=-1)
    
    def forward(self, predictions, gt_boxes_list, gt_labels_list):
        """
        Args:
            predictions: list of (cls_pred, box_pred, box_raw) per scale
            gt_boxes_list: list of (num_gt, 4) per image, normalized [cx, cy, w, h]
            gt_labels_list: list of (num_gt,) per image
        """
        device = predictions[0][0].device
        batch_size = predictions[0][0].shape[0]
        
        feat_sizes = [(p[0].shape[2], p[0].shape[3]) for p in predictions]
        anchor_points, anchor_strides = self._make_anchor_points(feat_sizes, device)
        
        # Concatenate predictions across scales
        all_cls = torch.cat([p[0].flatten(2).permute(0, 2, 1) for p in predictions], dim=1)
        all_box = torch.cat([p[1].flatten(2).permute(0, 2, 1) for p in predictions], dim=1)
        all_raw = torch.cat([p[2].flatten(2).permute(0, 2, 1) for p in predictions], dim=1)
        
        total_cls_loss = torch.tensor(0.0, device=device)
        total_box_loss = torch.tensor(0.0, device=device)
        total_dfl_loss = torch.tensor(0.0, device=device)
        num_pos = 0
        
        for b in range(batch_size):
            cls_pred = all_cls[b].sigmoid()  # (num_anchors, num_classes)
            box_pred = all_box[b]            # (num_anchors, 4) LTRB
            raw_pred = all_raw[b]            # (num_anchors, 4*reg_max)
            
            # Decode predicted boxes
            pred_bboxes = self._decode_boxes(box_pred, anchor_points, anchor_strides)
            
            gt_boxes = gt_boxes_list[b]
            gt_labels = gt_labels_list[b]
            
            if len(gt_boxes) == 0:
                total_cls_loss += self.bce(all_cls[b], torch.zeros_like(all_cls[b])).sum()
                continue
            
            # Convert GT from [cx, cy, w, h] normalized to [x1, y1, x2, y2] pixel
            gt_xyxy = torch.zeros_like(gt_boxes)
            gt_xyxy[:, 0] = (gt_boxes[:, 0] - gt_boxes[:, 2] / 2) * 640
            gt_xyxy[:, 1] = (gt_boxes[:, 1] - gt_boxes[:, 3] / 2) * 640
            gt_xyxy[:, 2] = (gt_boxes[:, 0] + gt_boxes[:, 2] / 2) * 640
            gt_xyxy[:, 3] = (gt_boxes[:, 1] + gt_boxes[:, 3] / 2) * 640
            
            # Task-aligned assignment
            assigned_labels, assigned_bboxes, assigned_scores = self.assigner.assign(
                cls_pred, pred_bboxes, gt_labels.long(), gt_xyxy,
                anchor_points, anchor_strides
            )
            
            fg_mask = assigned_labels >= 0
            num_fg = fg_mask.sum().item()
            num_pos += num_fg
            
            # Classification loss (BCE with soft labels)
            cls_targets = assigned_scores.to(device)
            total_cls_loss += self.bce(all_cls[b], cls_targets).sum()
            
            if num_fg > 0:
                # Box loss (CIoU)
                fg_pred_boxes = pred_bboxes[fg_mask]
                fg_gt_boxes = assigned_bboxes[fg_mask].to(device)
                ciou = compute_iou(fg_pred_boxes, fg_gt_boxes, mode='ciou')
                ciou_diag = torch.diag(ciou)
                total_box_loss += (1.0 - ciou_diag).sum()
                
                # DFL loss
                fg_raw = raw_pred[fg_mask]  # (num_fg, 4*reg_max)
                fg_raw = fg_raw.view(-1, self.reg_max)  # (num_fg*4, reg_max)
                # Target: continuous LTRB offsets
                fg_target_ltrb = torch.zeros((num_fg, 4), device=device)
                fg_target_ltrb[:, :2] = (anchor_points[fg_mask] - fg_gt_boxes[:, :2]) / anchor_strides[fg_mask].unsqueeze(-1)
                fg_target_ltrb[:, 2:] = (fg_gt_boxes[:, 2:] - anchor_points[fg_mask]) / anchor_strides[fg_mask].unsqueeze(-1)
                fg_target_ltrb = fg_target_ltrb.clamp(0, self.reg_max - 1 - 0.01)
                target_flat = fg_target_ltrb.view(-1)
                # DFL: cross-entropy between adjacent integer bins
                target_left = target_flat.long()
                target_right = target_left + 1
                weight_right = target_flat - target_left.float()
                weight_left = 1.0 - weight_right
                dfl_loss = (
                    F.cross_entropy(fg_raw, target_left, reduction='none') * weight_left +
                    F.cross_entropy(fg_raw, target_right.clamp(max=self.reg_max - 1), reduction='none') * weight_right
                )
                total_dfl_loss += dfl_loss.sum()
        
        num_pos = max(num_pos, 1)
        loss_cls = self.cls_weight * total_cls_loss / num_pos
        loss_box = self.box_weight * total_box_loss / num_pos
        loss_dfl = self.dfl_weight * total_dfl_loss / num_pos
        total_loss = loss_cls + loss_box + loss_dfl
        
        return total_loss, {
            'cls_loss': loss_cls.item(),
            'box_loss': loss_box.item(),
            'dfl_loss': loss_dfl.item(),
            'total_loss': total_loss.item(),
            'num_pos': num_pos
        }

print("YOLOv11Loss ready.")

## Real COCO data for training

Instead of training on synthetic colored rectangles, we stream real COCO images from [detection-datasets/coco](https://huggingface.co/datasets/detection-datasets/coco) on the Hugging Face Hub. We buffer 32 images in memory for this demo to keep training fast while using real-world data.

> **Data source**: Images streamed from [detection-datasets/coco](https://huggingface.co/datasets/detection-datasets/coco). See our [HF COCO streaming tutorial](/blog/tutorials/hf-coco-streaming) for details.

In [None]:
class COCOStreamDetectionDataset(torch.utils.data.Dataset):
    """Buffer real COCO images from HF streaming for detection training.

    Pre-fetches max_samples images via streaming and stores them in memory,
    providing random access and len() support for the DataLoader.
    Annotations are converted to YOLO format [cx, cy, w, h] normalized.
    """

    def __init__(self, split='train', max_samples=32, img_size=640, num_classes=80):
        self.img_size = img_size
        self.num_classes = num_classes
        self.data = []

        print(f"Streaming {max_samples} COCO images from Hugging Face...")
        ds = load_dataset('detection-datasets/coco', split=split, streaming=True)

        for example in ds:
            if len(self.data) >= max_samples:
                break

            img_pil = example['image'].convert('RGB')
            img_np = np.array(img_pil)
            h, w = img_np.shape[:2]

            bboxes = example['objects']['bbox']
            cats = example['objects']['category']

            boxes = []
            labels = []
            for bbox, cat_id in zip(bboxes, cats):
                bx, by, bw, bh = bbox
                if bw <= 0 or bh <= 0:
                    continue
                cx = (bx + bw / 2) / w
                cy = (by + bh / 2) / h
                boxes.append([cx, cy, bw / w, bh / h])
                labels.append(int(cat_id))

            if len(boxes) == 0:
                continue

            # Resize to model input size
            img_resized = np.array(img_pil.resize((self.img_size, self.img_size)))
            img_tensor = torch.from_numpy(img_resized).permute(2, 0, 1).float() / 255.0

            boxes_t = torch.tensor(boxes, dtype=torch.float32)
            labels_t = torch.tensor(labels, dtype=torch.long)

            self.data.append((img_tensor, boxes_t, labels_t))

        print(f"Buffered {len(self.data)} COCO images")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def collate_fn(batch):
    """Custom collate: images are stacked, boxes/labels stay as lists."""
    imgs, boxes, labels = zip(*batch)
    imgs = torch.stack(imgs, dim=0)
    targets = torch.zeros(len(imgs))
    return imgs, targets, list(boxes), list(labels)


# Create dataset and dataloader with real COCO images
dataset = COCOStreamDetectionDataset(max_samples=32, num_classes=80)
loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

# Visualize samples
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
for i, ax in enumerate(axes):
    img, boxes, labels = dataset[i]
    ax.imshow(img.permute(1, 2, 0).numpy())
    ax.set_title(f'Image {i}: {len(boxes)} objects')
    ax.axis('off')
    for box in boxes:
        cx, cy, w, h = box.numpy() * 640
        rect = plt.Rectangle((cx - w/2, cy - h/2), w, h,
                             linewidth=2, edgecolor='white', facecolor='none')
        ax.add_patch(rect)
plt.suptitle('Real COCO Training Data', fontsize=14)
plt.tight_layout()
plt.show()

print(f"Dataset: {len(dataset)} images, DataLoader: {len(loader)} batches")

## Training loop

The training loop follows a standard PyTorch pattern with a few detection-specific details:

- **Gradient clipping** (`max_norm=10.0`) prevents exploding gradients, which can occur when the loss components have very different magnitudes early in training.
- We track individual loss components (classification, box, DFL) to diagnose training behavior.
- The number of positive (foreground) assignments per batch is logged to ensure the assigner is working correctly.

In [None]:
def train_one_epoch(model, dataloader, optimizer, loss_fn, device, epoch):
    model.train()
    epoch_losses = {'cls_loss': 0, 'box_loss': 0, 'dfl_loss': 0, 'total_loss': 0}
    
    for batch_idx, (imgs, targets, boxes_list, labels_list) in enumerate(dataloader):
        imgs = imgs.to(device)
        
        # Move GT to device
        gt_boxes = [b.to(device) for b in boxes_list]
        gt_labels = [l.to(device) for l in labels_list]
        
        # Forward
        predictions = model(imgs)
        loss, loss_dict = loss_fn(predictions, gt_boxes, gt_labels)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        optimizer.step()
        
        for k in epoch_losses:
            epoch_losses[k] += loss_dict[k]
        
        if batch_idx % 2 == 0:
            print(f"  Batch {batch_idx}: loss={loss_dict['total_loss']:.4f} "
                  f"(cls={loss_dict['cls_loss']:.4f}, box={loss_dict['box_loss']:.4f}, "
                  f"dfl={loss_dict['dfl_loss']:.4f}, pos={loss_dict['num_pos']})")
    
    n = len(dataloader)
    return {k: v / n for k, v in epoch_losses.items()}

## Running the training demo

We train for 5 epochs on our tiny synthetic dataset. The goal is not to achieve good detection performance (that requires real data and many more epochs), but to verify that:

1. The forward pass produces valid predictions at all three scales
2. The TAL assigner finds positive anchors for the ground-truth boxes
3. All three loss components produce valid gradients
4. The total loss decreases over training

We use AdamW with cosine annealing, which is standard for YOLO training.

In [None]:
# Small-scale training demo with real COCO data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = YOLOv11(num_classes=80).to(device)
loss_fn = YOLOv11Loss(num_classes=80)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)

# Cosine LR scheduler
num_epochs = 5
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Training for {num_epochs} epochs on {len(dataset)} real COCO images\n")

history = []
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs} (lr={scheduler.get_last_lr()[0]:.6f})")
    epoch_loss = train_one_epoch(model, loader, optimizer, loss_fn, device, epoch)
    scheduler.step()
    history.append(epoch_loss)
    print(f"  -> Avg loss: {epoch_loss['total_loss']:.4f}")

print("\nTraining complete!")

## Loss curves visualization

Plotting the individual loss components over training helps diagnose issues:

- **Classification loss** should decrease as the model learns to distinguish object classes from background
- **Box (CIoU) loss** should decrease as predicted boxes align better with ground truth
- **DFL loss** should decrease as the distribution predictions sharpen around the correct offsets

If one component plateaus while others decrease, it may indicate an imbalance in loss weights.

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(20, 4))
keys = ['total_loss', 'cls_loss', 'box_loss', 'dfl_loss']
titles = ['Total Loss', 'Classification Loss', 'Box (CIoU) Loss', 'DFL Loss']

for ax, key, title in zip(axes, keys, titles):
    values = [h[key] for h in history]
    ax.plot(range(1, len(values)+1), values, 'b-o', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title(title)
    ax.grid(True, alpha=0.3)

plt.suptitle('Training Loss Curves', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

In this notebook, we built the complete training pipeline for YOLOv11 from scratch. Here is a recap of the key components:

1. **CIoU** provides complete geometric alignment by combining overlap area, center-point distance, and aspect ratio consistency into a single differentiable metric. This gives the optimizer rich gradient information for box regression, unlike basic IoU which produces zero gradients for non-overlapping boxes.

2. **Task-Aligned Learning (TAL)** assigns ground-truth boxes to anchor points based on both classification confidence and localization quality. The alignment metric $t = s^\alpha \cdot u^\beta$ ensures that labels are assigned to anchors where the model is already performing well, creating a virtuous cycle during training.

3. **Distribution Focal Loss (DFL)** enables precise box regression by predicting a probability distribution over discrete offset bins rather than a single scalar. The weighted cross-entropy between adjacent bins preserves the continuous nature of the target.

4. The **composite loss** balances classification ($\lambda = 0.5$), localization ($\lambda = 7.5$), and distribution quality ($\lambda = 1.5$). The heavy weight on box regression reflects the critical importance of precise localization in object detection.

### Next steps

In **Notebook 5**, we will build the inference pipeline: decoding predictions into bounding boxes, applying Non-Maximum Suppression (NMS), and evaluating detection quality using COCO metrics (mAP, AP50, AP75).