# ROI Head: ROI Align + Classification and Box Regression

*Notebook 4 of 6 in the Faster RCNN from-scratch series*

Given proposals from the RPN, we extract fixed-size features via ROI Align,
then classify each proposal and refine its bounding box.

**Mask RCNN extension point**: this notebook also demonstrates the 14×14
ROI Align variant used by the mask head (notebook 07).

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from typing import List, Tuple

In [None]:
class ROIAlign(nn.Module):
    """ROI Align using bilinear interpolation via F.grid_sample.

    Extracts a fixed (out_size x out_size) feature crop for each proposal,
    selecting the FPN level based on box area (Lin et al. FPN level assignment):

        k = clip(k0 + floor(log2(sqrt(wh) / 224)), k_min, k_max)
        k0=4, k_min=2 (P2), k_max=5 (P5)  ->  index 0..3
    """

    def __init__(self, out_size: int = 7,
                 k0: int = 4, k_min: int = 2, k_max: int = 5):
        super().__init__()
        self.out_size = out_size
        self.k0    = k0
        self.k_min = k_min
        self.k_max = k_max

    def _assign_level(self, boxes: torch.Tensor) -> torch.Tensor:
        """Return 0-indexed FPN level (0=P2, 1=P3, 2=P4, 3=P5) per box."""
        ws = boxes[:, 2] - boxes[:, 0]
        hs = boxes[:, 3] - boxes[:, 1]
        areas = (ws * hs).clamp(min=1e-6).sqrt()
        levels = torch.floor(self.k0 + torch.log2(areas / 224.0)).long()
        return levels.clamp(self.k_min, self.k_max) - self.k_min

    def forward(self, feature_maps: List[torch.Tensor],
                proposals: List[torch.Tensor],
                image_size: Tuple[int, int]) -> torch.Tensor:
        """
        Args:
            feature_maps: [P2, P3, P4, P5] — (B, 256, H_i, W_i) each
            proposals:    list of (N_i, 4) per image, pixel coords
            image_size:   (H, W)
        Returns:
            roi_features: (sum(N_i), 256, out_size, out_size)
        """
        H, W = image_size
        strides = [4, 8, 16, 32]
        all_features = []

        for batch_idx, props in enumerate(proposals):
            if len(props) == 0:
                continue
            levels = self._assign_level(props)
            feats  = torch.zeros(len(props), feature_maps[0].shape[1],
                                 self.out_size, self.out_size,
                                 device=props.device)

            for lvl, (fm, stride) in enumerate(zip(feature_maps, strides)):
                mask = levels == lvl
                if not mask.any():
                    continue
                lvl_props = props[mask]
                n = len(lvl_props)

                # Normalise box coords to [-1, 1] for grid_sample
                x1 = lvl_props[:, 0] / W * 2 - 1
                y1 = lvl_props[:, 1] / H * 2 - 1
                x2 = lvl_props[:, 2] / W * 2 - 1
                y2 = lvl_props[:, 3] / H * 2 - 1

                gx = torch.linspace(0, 1, self.out_size, device=props.device)
                gy = torch.linspace(0, 1, self.out_size, device=props.device)
                gy_g, gx_g = torch.meshgrid(gy, gx, indexing='ij')

                gx_g = x1[:, None, None] + (x2 - x1)[:, None, None] * gx_g[None]
                gy_g = y1[:, None, None] + (y2 - y1)[:, None, None] * gy_g[None]
                grid = torch.stack([gx_g, gy_g], dim=-1)

                fm_exp = fm[batch_idx:batch_idx + 1].expand(n, -1, -1, -1)
                crops  = F.grid_sample(fm_exp, grid, align_corners=True,
                                       mode='bilinear', padding_mode='border')
                feats[mask] = crops

            all_features.append(feats)

        return torch.cat(all_features, dim=0)

In [None]:
class TwoMLPHead(nn.Module):
    """Two fully-connected layers applied after ROI Align."""

    def __init__(self, in_channels: int = 256 * 7 * 7, fc_dim: int = 1024):
        super().__init__()
        self.fc1 = nn.Linear(in_channels, fc_dim)
        self.fc2 = nn.Linear(fc_dim, fc_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x


class FastRCNNPredictor(nn.Module):
    """Sibling FC heads: class scores and per-class box deltas."""

    def __init__(self, in_channels: int = 1024, num_classes: int = 81):
        super().__init__()
        self.cls_score = nn.Linear(in_channels, num_classes)
        self.bbox_pred = nn.Linear(in_channels, num_classes * 4)
        nn.init.normal_(self.cls_score.weight, std=0.01)
        nn.init.normal_(self.bbox_pred.weight, std=0.001)
        nn.init.zeros_(self.cls_score.bias)
        nn.init.zeros_(self.bbox_pred.bias)

    def forward(self, x: torch.Tensor):
        return self.cls_score(x), self.bbox_pred(x)

In [None]:
# Smoke test with dummy feature maps and proposals
roi_align = ROIAlign(out_size=7)
mlp_head  = TwoMLPHead()
predictor = FastRCNNPredictor()

feat_maps = [
    torch.randn(1, 256, 200, 200),   # P2
    torch.randn(1, 256, 100, 100),   # P3
    torch.randn(1, 256,  50,  50),   # P4
    torch.randn(1, 256,  25,  25),   # P5
]
proposals = [torch.tensor([
    [ 50.,  50., 300., 300.],
    [100., 100., 400., 400.],
    [200., 200., 600., 600.],
])]

roi_feats  = roi_align(feat_maps, proposals, (800, 800))
box_feats  = mlp_head(roi_feats)
cls_logits, bbox_preds = predictor(box_feats)

print(f"ROI features: {roi_feats.shape}")   # [3, 256, 7, 7]
print(f"Box features: {box_feats.shape}")   # [3, 1024]
print(f"Class logits: {cls_logits.shape}")  # [3, 81]
print(f"Box preds:    {bbox_preds.shape}")  # [3, 324]

In [None]:
# Inspection: mean-channel activation of 7x7 ROI crops
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
for i, ax in enumerate(axes):
    crop = roi_feats[i].mean(dim=0).detach().numpy()
    im = ax.imshow(crop, cmap='viridis')
    ax.set_title(f'ROI {i} — 7×7 (mean over 256 ch)')
    ax.axis('off')
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
plt.suptitle('ROI Align 7×7 Crops (dummy feature maps)')
plt.tight_layout()
plt.savefig('images/roi_crops.png', dpi=100, bbox_inches='tight')
plt.show()

In [None]:
# Mask RCNN extension point: 14x14 ROI Align
mask_roi_align = ROIAlign(out_size=14)
mask_roi_feats = mask_roi_align(feat_maps, proposals, (800, 800))
print(f"Mask ROI features (14x14): {mask_roi_feats.shape}")
# Expected: [3, 256, 14, 14]
print("Extension point ready for Mask RCNN mask head (notebook 07).")