# Region Proposal Network (RPN)

*Notebook 3 of 6 in the Faster RCNN from-scratch series*

The RPN is the key innovation of Faster RCNN: a small network that slides over
FPN feature maps and proposes object-containing regions using anchor boxes.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from typing import List, Tuple

In [None]:
class AnchorGenerator(nn.Module):
    """Generate anchor boxes for each FPN level.

    Produces anchors at the following scales x aspect ratios per level:
        scales: (32, 64, 128, 256, 512) — one scale per FPN level P2-P6
        ratios: (0.5, 1.0, 2.0)
    k = 3 anchors per location.

    Output: (N_total, 4) tensor of [x1, y1, x2, y2] anchors in image pixel coords.
    """

    def __init__(self,
                 anchor_sizes:  Tuple[int, ...]   = (32, 64, 128, 256, 512),
                 aspect_ratios: Tuple[float, ...] = (0.5, 1.0, 2.0),
                 strides:       Tuple[int, ...]   = (4, 8, 16, 32, 64)):
        super().__init__()
        self.anchor_sizes  = anchor_sizes
        self.aspect_ratios = aspect_ratios
        self.strides       = strides

    def _base_anchors(self, size: int) -> torch.Tensor:
        """Create k=3 base anchors centred at origin for a given scale."""
        anchors = []
        for ratio in self.aspect_ratios:
            w = size * (ratio ** 0.5)
            h = size / (ratio ** 0.5)
            anchors.append([-w / 2, -h / 2, w / 2, h / 2])
        return torch.tensor(anchors, dtype=torch.float32)

    def forward(self, feature_maps: List[torch.Tensor],
                image_size: Tuple[int, int]) -> torch.Tensor:
        """
        Args:
            feature_maps: list of (B, C, H_i, W_i) tensors (P2-P6)
            image_size:   (H, W) of the original image
        Returns:
            all_anchors: (N_total, 4) [x1, y1, x2, y2] in image pixel coords
        """
        all_anchors = []
        for fm, size, stride in zip(feature_maps, self.anchor_sizes, self.strides):
            _, _, fh, fw = fm.shape
            base = self._base_anchors(size)          # (3, 4)

            shift_x = (torch.arange(fw) + 0.5) * stride
            shift_y = (torch.arange(fh) + 0.5) * stride
            sy, sx  = torch.meshgrid(shift_y, shift_x, indexing='ij')
            shifts  = torch.stack([sx, sy, sx, sy], dim=-1).reshape(-1, 4)  # (H*W, 4)

            anchors = shifts[:, None, :] + base[None, :, :]  # (H*W, 3, 4)
            all_anchors.append(anchors.reshape(-1, 4))

        return torch.cat(all_anchors, dim=0)


# Shape verification
gen = AnchorGenerator()
fps = [
    torch.zeros(1, 256, 200, 200),   # P2 stride 4
    torch.zeros(1, 256, 100, 100),   # P3 stride 8
    torch.zeros(1, 256,  50,  50),   # P4 stride 16
    torch.zeros(1, 256,  25,  25),   # P5 stride 32
    torch.zeros(1, 256,  13,  13),   # P6 stride 64
]
anchors = gen(fps, (800, 800))
print(f"Total anchors: {len(anchors):,}")
# P2: 200*200*3=120,000 | P3: 30,000 | P4: 7,500 | P5: 1,875 | P6: 507
# Expected: 159,882

In [None]:
class RPNHead(nn.Module):
    """Shared 3x3 conv followed by two sibling 1x1 outputs.

    Applied independently to each FPN level.
    Outputs: objectness logits and bbox deltas.
    """

    def __init__(self, in_channels: int = 256, num_anchors: int = 3):
        super().__init__()
        self.conv       = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.cls_logits = nn.Conv2d(in_channels, num_anchors,     1)
        self.bbox_pred  = nn.Conv2d(in_channels, num_anchors * 4, 1)

        for layer in [self.conv, self.cls_logits, self.bbox_pred]:
            nn.init.normal_(layer.weight, std=0.01)
            nn.init.zeros_(layer.bias)

    def forward(self, features: List[torch.Tensor]):
        """
        Args:
            features: list of (B, 256, H_i, W_i) tensors
        Returns:
            cls_logits: list of (B, k, H_i, W_i) per level
            bbox_preds: list of (B, k*4, H_i, W_i) per level
        """
        cls_logits, bbox_preds = [], []
        for feat in features:
            t = F.relu(self.conv(feat))
            cls_logits.append(self.cls_logits(t))
            bbox_preds.append(self.bbox_pred(t))
        return cls_logits, bbox_preds

In [None]:
def decode_boxes(anchors: torch.Tensor, deltas: torch.Tensor) -> torch.Tensor:
    """Apply predicted deltas to anchors to get proposal boxes (RCNN encoding)."""
    aw = anchors[:, 2] - anchors[:, 0]
    ah = anchors[:, 3] - anchors[:, 1]
    ax = anchors[:, 0] + 0.5 * aw
    ay = anchors[:, 1] + 0.5 * ah

    dx, dy, dw, dh = deltas[:, 0], deltas[:, 1], deltas[:, 2], deltas[:, 3]
    dw = dw.clamp(max=4.0)
    dh = dh.clamp(max=4.0)

    px = dx * aw + ax
    py = dy * ah + ay
    pw = torch.exp(dw) * aw
    ph = torch.exp(dh) * ah

    return torch.stack([px - 0.5 * pw, py - 0.5 * ph,
                        px + 0.5 * pw, py + 0.5 * ph], dim=1)


class RegionProposalNetwork(nn.Module):
    """Full RPN: generates, scores, and filters proposals."""

    def __init__(self, rpn_head: RPNHead, anchor_gen: AnchorGenerator,
                 pre_nms_top_n: int = 2000, post_nms_top_n: int = 1000,
                 nms_thresh: float = 0.7, min_size: int = 16):
        super().__init__()
        self.head          = rpn_head
        self.anchor_gen    = anchor_gen
        self.pre_nms_top_n = pre_nms_top_n
        self.post_nms_top_n = post_nms_top_n
        self.nms_thresh    = nms_thresh
        self.min_size      = min_size

    def _filter_proposals(self, proposals: torch.Tensor, scores: torch.Tensor,
                          img_size: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]:
        H, W = img_size
        proposals[:, [0, 2]] = proposals[:, [0, 2]].clamp(0, W)
        proposals[:, [1, 3]] = proposals[:, [1, 3]].clamp(0, H)

        ws = proposals[:, 2] - proposals[:, 0]
        hs = proposals[:, 3] - proposals[:, 1]
        keep = (ws >= self.min_size) & (hs >= self.min_size)
        proposals, scores = proposals[keep], scores[keep]

        k = min(self.pre_nms_top_n, len(scores))
        scores, order = scores.topk(k)
        proposals = proposals[order]

        keep = self._nms(proposals, scores, self.nms_thresh)
        keep = keep[:self.post_nms_top_n]
        return proposals[keep], scores[keep]

    @staticmethod
    def _nms(boxes: torch.Tensor, scores: torch.Tensor,
             thresh: float) -> torch.Tensor:
        """Greedy NMS — pure PyTorch, no torchvision dependency."""
        x1, y1, x2, y2 = boxes.unbind(1)
        areas = (x2 - x1) * (y2 - y1)
        order = scores.argsort(descending=True)
        keep  = []
        while order.numel() > 0:
            i = order[0].item()
            keep.append(i)
            if order.numel() == 1:
                break
            xx1 = x1[order[1:]].clamp(min=x1[i])
            yy1 = y1[order[1:]].clamp(min=y1[i])
            xx2 = x2[order[1:]].clamp(max=x2[i])
            yy2 = y2[order[1:]].clamp(max=y2[i])
            inter = (xx2 - xx1).clamp(min=0) * (yy2 - yy1).clamp(min=0)
            iou   = inter / (areas[i] + areas[order[1:]] - inter).clamp(min=1e-6)
            order = order[1:][iou <= thresh]
        return torch.tensor(keep, dtype=torch.long)

    def forward(self, features: List[torch.Tensor],
                image_size: Tuple[int, int],
                targets=None) -> List[torch.Tensor]:
        cls_logits, bbox_preds = self.head(features)
        anchors = self.anchor_gen(features, image_size)

        # Flatten across levels: (B, N_total)
        all_scores = torch.cat([
            c.permute(0, 2, 3, 1).reshape(c.shape[0], -1)
            for c in cls_logits
        ], dim=1)
        all_deltas = torch.cat([
            b.permute(0, 2, 3, 1).reshape(b.shape[0], -1, 4)
            for b in bbox_preds
        ], dim=1)

        proposals_list = []
        for i in range(all_scores.shape[0]):
            scores_i = all_scores[i].sigmoid()
            props_i  = decode_boxes(anchors, all_deltas[i])
            props_i, _ = self._filter_proposals(props_i, scores_i, image_size)
            proposals_list.append(props_i)

        return proposals_list


# Smoke test
head = RPNHead()
rpn  = RegionProposalNetwork(head, gen)
feat_maps = [
    torch.randn(1, 256, 200, 200),
    torch.randn(1, 256, 100, 100),
    torch.randn(1, 256,  50,  50),
    torch.randn(1, 256,  25,  25),
    torch.randn(1, 256,  13,  13),
]
proposals = rpn(feat_maps, (800, 800))
print(f"Proposals per image: {[len(p) for p in proposals]}")
# Expected: [~1000]

In [None]:
# Inspection: anchor grid at P3 (stride 8) — sample every 8th cell
stride, fh, fw = 8, 100, 100
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xlim(0, 800); ax.set_ylim(800, 0)
ax.set_facecolor('#1a1a2e')
ax.set_title('Anchor centres at P3 (stride 8), every 8th cell')
for r in range(0, fh, 8):
    for c in range(0, fw, 8):
        ax.plot((c + 0.5) * stride, (r + 0.5) * stride, 'c.', markersize=2)
plt.tight_layout()
plt.savefig('images/anchor_grid.png', dpi=100, bbox_inches='tight')
plt.show()
print(f"P3 total locations: {fh * fw:,}  |  total anchors: {fh * fw * 3:,}")

In [None]:
# Inspection: objectness score distribution (random weights)
# Re-run head on feat_maps to get scores
cls_logits_list, bbox_preds_list = head(feat_maps)
all_scores_flat = torch.cat([
    c.permute(0, 2, 3, 1).reshape(-1) for c in cls_logits_list
]).sigmoid().detach().numpy()

fig, ax = plt.subplots(figsize=(8, 4))
ax.hist(all_scores_flat, bins=80, color='steelblue', alpha=0.8, edgecolor='none')
ax.set_xlabel('Objectness score'); ax.set_ylabel('Anchor count')
ax.set_title('Objectness score distribution (random-weight RPN)')
plt.tight_layout()
plt.savefig('images/objectness_dist.png', dpi=100, bbox_inches='tight')
plt.show()

In [None]:
# Inspection: top-50 proposals before NMS on a blank canvas
anchors_all = gen(feat_maps, (800, 800))
all_deltas_flat = torch.cat([
    b.permute(0, 2, 3, 1).reshape(-1, 4) for b in bbox_preds_list
]).detach()
all_scores_1d = torch.tensor(all_scores_flat)

top50_idx  = all_scores_1d.argsort(descending=True)[:50]
top50_props = decode_boxes(anchors_all[top50_idx], all_deltas_flat[top50_idx])
top50_props = top50_props.clamp(0, 800)

fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xlim(0, 800); ax.set_ylim(800, 0)
ax.set_facecolor('#1a1a2e')
ax.set_title('Top-50 proposals (before NMS, random weights)')
for box in top50_props.tolist():
    x1, y1, x2, y2 = box
    rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
                               linewidth=1, edgecolor='cyan', facecolor='none', alpha=0.5)
    ax.add_patch(rect)
plt.tight_layout()
plt.savefig('images/top50_proposals.png', dpi=100, bbox_inches='tight')
plt.show()