# Faster RCNN Inference and Evaluation

*Notebook 6 of 6 in the Faster RCNN from-scratch series*

We load the checkpoint saved in notebook 05, run inference on COCO validation
images streamed from Hugging Face, and visualise detections. Because the
checkpoint was trained for only 5 steps (demo), predictions are random — the
notebook focuses on the **inference pipeline** rather than accuracy.

Topics covered:
- Loading and verifying a checkpoint
- Running `model.eval()` forward pass (proposal generation + postprocessing)
- Visualising class-agnostic proposals and final detections
- Measuring per-image inference latency

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from datasets import load_dataset
from torch.utils.data import IterableDataset
from torch.utils.checkpoint import checkpoint as grad_ckpt
from typing import List, Tuple, Optional
import time, os

IMG_SIZE    = 400   # must match notebook 05
NUM_CLASSES = 81
DEVICE      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {DEVICE}")

IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

# COCO 80-class names (1-indexed; 0 = background)
COCO_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
    'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
    'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
    'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
    'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
    'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
    'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
    'hair drier', 'toothbrush',
]

In [None]:
# ─── Re-define model components (self-contained) ──────────────────────────────

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_ch, out_ch, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.conv3 = nn.Conv2d(out_ch, out_ch * 4, 1, bias=False)
        self.bn3   = nn.BatchNorm2d(out_ch * 4)
        self.downsample = downsample
    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        if self.downsample: identity = self.downsample(x)
        return F.relu(out + identity)

class ResNet50(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem   = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(3, stride=2, padding=1))
        self.layer1 = self._make(  64,  64, 3, 1)
        self.layer2 = self._make( 256, 128, 4, 2)
        self.layer3 = self._make( 512, 256, 6, 2)
        self.layer4 = self._make(1024, 512, 3, 2)
    def _make(self, in_ch, out_ch, blocks, stride):
        ds = None
        if stride != 1 or in_ch != out_ch*4:
            ds = nn.Sequential(nn.Conv2d(in_ch, out_ch*4, 1, stride=stride, bias=False),
                               nn.BatchNorm2d(out_ch*4))
        layers = [Bottleneck(in_ch, out_ch, stride, ds)]
        for _ in range(1, blocks): layers.append(Bottleneck(out_ch*4, out_ch))
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.stem(x); c2 = self.layer1(x); c3 = self.layer2(c2)
        c4 = grad_ckpt(self.layer3, c3, use_reentrant=False)
        c5 = grad_ckpt(self.layer4, c4, use_reentrant=False)
        return c2, c3, c4, c5

class FPN(nn.Module):
    def __init__(self, in_channels=(256,512,1024,2048), out_channels=256):
        super().__init__()
        self.lateral = nn.ModuleList([nn.Conv2d(c, out_channels, 1) for c in in_channels])
        self.output  = nn.ModuleList([nn.Conv2d(out_channels, out_channels, 3, padding=1)
                                      for _ in in_channels])
        self.p6 = nn.MaxPool2d(1, stride=2)
    def forward(self, features):
        c2, c3, c4, c5 = features
        p5 = self.lateral[3](c5)
        p4 = self.lateral[2](c4) + F.interpolate(p5, size=c4.shape[-2:], mode='nearest')
        p3 = self.lateral[1](c3) + F.interpolate(p4, size=c3.shape[-2:], mode='nearest')
        p2 = self.lateral[0](c2) + F.interpolate(p3, size=c2.shape[-2:], mode='nearest')
        outs = [self.output[i](p) for i, p in enumerate([p2, p3, p4, p5])]
        outs.append(self.p6(outs[-1]))
        return outs

def decode_boxes(anchors, deltas):
    aw=anchors[:,2]-anchors[:,0]; ah=anchors[:,3]-anchors[:,1]
    ax=anchors[:,0]+0.5*aw; ay=anchors[:,1]+0.5*ah
    dx,dy=deltas[:,0],deltas[:,1]; dw=deltas[:,2].clamp(max=4.); dh=deltas[:,3].clamp(max=4.)
    px=dx*aw+ax; py=dy*ah+ay; pw=torch.exp(dw)*aw; ph=torch.exp(dh)*ah
    return torch.stack([px-0.5*pw, py-0.5*ph, px+0.5*pw, py+0.5*ph], dim=1)

class AnchorGenerator(nn.Module):
    def __init__(self,sizes=(32,64,128,256,512),ratios=(0.5,1.,2.),strides=(4,8,16,32,64)):
        super().__init__()
        self.sizes=sizes; self.ratios=ratios; self.strides=strides
    def _base(self,sz):
        return torch.tensor([[-sz*(r**.5)/2,-sz/(r**.5)/2,sz*(r**.5)/2,sz/(r**.5)/2]
                              for r in self.ratios],dtype=torch.float32)
    def forward(self,fmaps,img_sz):
        out=[]
        for fm,sz,st in zip(fmaps,self.sizes,self.strides):
            _,_,fh,fw=fm.shape; base=self._base(sz)
            sx=(torch.arange(fw,device=fm.device)+0.5)*st
            sy=(torch.arange(fh,device=fm.device)+0.5)*st
            sy,sx=torch.meshgrid(sy,sx,indexing='ij')
            shifts=torch.stack([sx,sy,sx,sy],dim=-1).reshape(-1,4)
            out.append((shifts[:,None,:]+base.to(fm.device)[None,:,:]).reshape(-1,4))
        return torch.cat(out,0)

class RPNHead(nn.Module):
    def __init__(self,in_ch=256,k=3):
        super().__init__()
        self.conv=nn.Conv2d(in_ch,in_ch,3,padding=1)
        self.cls=nn.Conv2d(in_ch,k,1); self.box=nn.Conv2d(in_ch,k*4,1)
        for l in [self.conv,self.cls,self.box]:
            nn.init.normal_(l.weight,std=0.01); nn.init.zeros_(l.bias)
    def forward(self,features):
        cls_o,box_o=[],[]
        for f in features:
            t=F.relu(self.conv(f)); cls_o.append(self.cls(t)); box_o.append(self.box(t))
        return cls_o,box_o

class RegionProposalNetwork(nn.Module):
    def __init__(self,head,anchor_gen,pre_nms=2000,post_nms=1000,nms_thr=0.7,min_sz=16):
        super().__init__()
        self.head=head; self.anchor_gen=anchor_gen
        self.pre_nms=pre_nms; self.post_nms=post_nms; self.nms_thr=nms_thr; self.min_sz=min_sz
    def _filter(self,props,scores,img_size):
        H,W=img_size
        props[:,[0,2]]=props[:,[0,2]].clamp(0,W); props[:,[1,3]]=props[:,[1,3]].clamp(0,H)
        keep=(props[:,2]-props[:,0]>=self.min_sz)&(props[:,3]-props[:,1]>=self.min_sz)
        props,scores=props[keep],scores[keep]
        scores,order=scores.topk(min(self.pre_nms,len(scores)))
        props=props[order]
        keep=self._nms(props,scores,self.nms_thr)[:self.post_nms]
        return props[keep],scores[keep]
    @staticmethod
    def _nms(boxes,scores,thr):
        x1,y1,x2,y2=boxes.unbind(1); areas=(x2-x1)*(y2-y1)
        order=scores.argsort(descending=True); keep=[]
        while order.numel()>0:
            i=order[0].item(); keep.append(i)
            if order.numel()==1: break
            xx1=x1[order[1:]].clamp(min=x1[i]); yy1=y1[order[1:]].clamp(min=y1[i])
            xx2=x2[order[1:]].clamp(max=x2[i]); yy2=y2[order[1:]].clamp(max=y2[i])
            inter=(xx2-xx1).clamp(0)*(yy2-yy1).clamp(0)
            iou=inter/(areas[i]+areas[order[1:]]-inter).clamp(1e-6)
            order=order[1:][iou<=thr]
        return torch.tensor(keep,dtype=torch.long)
    def forward(self,features,image_size,targets=None):
        cls_o,box_o=self.head(features); anchors=self.anchor_gen(features,image_size)
        all_scores=torch.cat([c.permute(0,2,3,1).reshape(c.shape[0],-1) for c in cls_o],1)
        all_deltas=torch.cat([b.permute(0,2,3,1).reshape(b.shape[0],-1,4) for b in box_o],1)
        props=[]
        for i in range(all_scores.shape[0]):
            sc=all_scores[i].sigmoid(); pr=decode_boxes(anchors,all_deltas[i])
            pr,_=self._filter(pr.detach(),sc.detach(),image_size); props.append(pr)
        return props,{}

class ROIAlign(nn.Module):
    def __init__(self,out_size=7,k0=4,k_min=2,k_max=5):
        super().__init__()
        self.out_size=out_size; self.k0=k0; self.k_min=k_min; self.k_max=k_max
    def _level(self,boxes):
        areas=((boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1])).clamp(1e-6).sqrt()
        return torch.floor(self.k0+torch.log2(areas/224.)).long().clamp(self.k_min,self.k_max)-self.k_min
    def forward(self,fmaps,proposals,image_size):
        H,W=image_size; all_feats=[]
        for bi,props in enumerate(proposals):
            if len(props)==0: continue
            levels=self._level(props)
            feats=torch.zeros(len(props),fmaps[0].shape[1],self.out_size,self.out_size,device=props.device)
            for lvl,fm in enumerate(fmaps[:4]):
                mask=levels==lvl
                if not mask.any(): continue
                lp=props[mask]; n=len(lp)
                x1=lp[:,0]/W*2-1; y1=lp[:,1]/H*2-1; x2=lp[:,2]/W*2-1; y2=lp[:,3]/H*2-1
                gx=torch.linspace(0,1,self.out_size,device=props.device)
                gy=torch.linspace(0,1,self.out_size,device=props.device)
                gy_g,gx_g=torch.meshgrid(gy,gx,indexing='ij')
                gx_g=x1[:,None,None]+(x2-x1)[:,None,None]*gx_g[None]
                gy_g=y1[:,None,None]+(y2-y1)[:,None,None]*gy_g[None]
                grid=torch.stack([gx_g,gy_g],dim=-1)
                crops=F.grid_sample(fm[bi:bi+1].expand(n,-1,-1,-1),grid,
                                    align_corners=True,mode='bilinear',padding_mode='border')
                feats[mask]=crops
            all_feats.append(feats)
        if not all_feats:
            return torch.zeros(0,fmaps[0].shape[1],self.out_size,self.out_size,device=fmaps[0].device)
        return torch.cat(all_feats,0)

class TwoMLPHead(nn.Module):
    def __init__(self,in_channels=256*7*7,fc_dim=1024):
        super().__init__()
        self.fc1=nn.Linear(in_channels,fc_dim); self.fc2=nn.Linear(fc_dim,fc_dim)
    def forward(self,x): return F.relu(self.fc2(F.relu(self.fc1(x.flatten(1)))))

class FastRCNNPredictor(nn.Module):
    def __init__(self,in_channels=1024,num_classes=81):
        super().__init__()
        self.cls=nn.Linear(in_channels,num_classes); self.box=nn.Linear(in_channels,num_classes*4)
        nn.init.normal_(self.cls.weight,std=0.01); nn.init.zeros_(self.cls.bias)
        nn.init.normal_(self.box.weight,std=0.001); nn.init.zeros_(self.box.bias)
    def forward(self,x): return self.cls(x),self.box(x)

class FasterRCNN(nn.Module):
    SCORE_THR=0.05; NMS_THR=0.5; MAX_DETS=100
    def __init__(self,num_classes=81):
        super().__init__()
        self.num_classes=num_classes
        self.backbone=ResNet50(); self.fpn=FPN()
        self.rpn=RegionProposalNetwork(RPNHead(),AnchorGenerator())
        self.roi_align=ROIAlign(out_size=7)
        self.box_head=TwoMLPHead(); self.predictor=FastRCNNPredictor(num_classes=num_classes)
    def _postprocess(self,cls_logits,bbox_preds,proposals_list,image_size):
        H,W=image_size; C=self.num_classes; results=[]; offset=0
        for props in proposals_list:
            n=len(props)
            if n==0:
                results.append({'boxes':torch.zeros(0,4),'scores':torch.zeros(0),
                                'labels':torch.zeros(0,dtype=torch.long)}); continue
            logits=cls_logits[offset:offset+n]; deltas=bbox_preds[offset:offset+n]; offset+=n
            scores=F.softmax(logits,-1)
            all_b,all_s,all_l=[],[],[]
            for ci in range(1,C):
                boxes=decode_boxes(props,deltas.view(n,C,4)[:,ci,:])
                boxes[:,[0,2]]=boxes[:,[0,2]].clamp(0,W); boxes[:,[1,3]]=boxes[:,[1,3]].clamp(0,H)
                sc=scores[:,ci]; mask=sc>self.SCORE_THR
                if not mask.any(): continue
                keep=RegionProposalNetwork._nms(boxes[mask],sc[mask],self.NMS_THR)
                all_b.append(boxes[mask][keep]); all_s.append(sc[mask][keep])
                all_l.append(torch.full((len(keep),),ci,dtype=torch.long,device=props.device))
            if all_b:
                b=torch.cat(all_b); s=torch.cat(all_s); l=torch.cat(all_l)
                top=s.argsort(descending=True)[:self.MAX_DETS]
                results.append({'boxes':b[top],'scores':s[top],'labels':l[top]})
            else:
                results.append({'boxes':torch.zeros(0,4),'scores':torch.zeros(0),
                                'labels':torch.zeros(0,dtype=torch.long)})
        return results
    def forward(self,images,targets=None):
        img_sz=(images.shape[2],images.shape[3])
        feats=self.backbone(images); fpn_fs=self.fpn(feats)
        props,_=self.rpn(fpn_fs,img_sz)
        roi_feats=self.roi_align(fpn_fs[:4],props,img_sz)
        box_feats=self.box_head(roi_feats)
        cls_logits,bbox_preds=self.predictor(box_feats)
        return self._postprocess(cls_logits,bbox_preds,props,img_sz), props

print("Model architecture defined.")

In [None]:
# ─── Load checkpoint ──────────────────────────────────────────────────────────

CKPT_PATH = 'checkpoints/faster_rcnn_demo.pth'
assert os.path.exists(CKPT_PATH), f"Checkpoint not found at {CKPT_PATH}. Run notebook 05 first."

model = FasterRCNN(num_classes=NUM_CLASSES).to(DEVICE)
ckpt  = torch.load(CKPT_PATH, map_location=DEVICE)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

print(f"Checkpoint loaded: {CKPT_PATH}")
print(f"  Trained for {ckpt['steps_trained']} steps")
print(f"  Final losses: { {k: f'{v:.4f}' for k,v in ckpt['final_losses'].items()} }")
total = sum(p.numel() for p in model.parameters())
print(f"  Parameters: {total/1e6:.1f}M")

In [None]:
# ─── Inference on 4 validation images ─────────────────────────────────────────

val_ds = load_dataset('detection-datasets/coco', split='val', streaming=True)

images_pil, results_list, latencies = [], [], []
NUM_IMAGES = 4

with torch.no_grad():
    for i, sample in enumerate(val_ds):
        if i >= NUM_IMAGES:
            break
        img_pil = sample['image'].convert('RGB')
        images_pil.append(img_pil)
        t = ((TF.to_tensor(img_pil.resize((IMG_SIZE, IMG_SIZE))) - IMAGENET_MEAN) / IMAGENET_STD)
        t = t.unsqueeze(0).to(DEVICE)

        t0 = time.perf_counter()
        dets, proposals = model(t)
        if DEVICE.type == 'cuda': torch.cuda.synchronize()
        latencies.append((time.perf_counter() - t0) * 1000)
        results_list.append(dets[0])

print(f"Mean latency: {sum(latencies)/len(latencies):.1f} ms  ({IMG_SIZE}x{IMG_SIZE} input)")
print("Detections per image:", [len(r['boxes']) for r in results_list])
print("(Detections are random — model trained only 5 steps)")

In [None]:
# ─── Visualise proposals + detections ─────────────────────────────────────────

fig, axes = plt.subplots(2, NUM_IMAGES, figsize=(5*NUM_IMAGES, 10))
TOP_K_PROPS = 30

with torch.no_grad():
    for col, sample in enumerate(load_dataset('detection-datasets/coco',
                                              split='val', streaming=True)):
        if col >= NUM_IMAGES: break
        img_pil = sample['image'].convert('RGB')
        img_res = img_pil.resize((IMG_SIZE, IMG_SIZE))
        t = ((TF.to_tensor(img_res) - IMAGENET_MEAN) / IMAGENET_STD).unsqueeze(0).to(DEVICE)

        _, proposals = model(t)
        props = proposals[0].cpu()[:TOP_K_PROPS]
        dets  = results_list[col]

        # Row 0: RPN proposals
        ax = axes[0][col]
        ax.imshow(img_res); ax.axis('off')
        ax.set_title(f'Image {col+1}: top-{TOP_K_PROPS} proposals', fontsize=9)
        for box in props.tolist():
            x1,y1,x2,y2=box
            ax.add_patch(patches.Rectangle((x1,y1),x2-x1,y2-y1,
                                            linewidth=1,edgecolor='cyan',facecolor='none'))

        # Row 1: final detections
        ax = axes[1][col]
        ax.imshow(img_res); ax.axis('off')
        n_det = len(dets['boxes'])
        ax.set_title(f'Image {col+1}: {n_det} detections', fontsize=9)
        for box,score,label in zip(dets['boxes'].tolist(),
                                   dets['scores'].tolist(),
                                   dets['labels'].tolist()):
            x1,y1,x2,y2=box
            cls_name = COCO_NAMES[label] if label < len(COCO_NAMES) else str(label)
            ax.add_patch(patches.Rectangle((x1,y1),x2-x1,y2-y1,
                                            linewidth=1.5,edgecolor='red',facecolor='none'))
            ax.text(x1,y1-2,f'{cls_name} {score:.2f}',
                    color='white',fontsize=6,backgroundcolor='red')

plt.suptitle('Row 1: RPN proposals  |  Row 2: Final detections (5-step model)', y=1.01)
plt.tight_layout()
os.makedirs('images', exist_ok=True)
plt.savefig('images/inference_results.png', dpi=100, bbox_inches='tight')
plt.show()

In [None]:
# ─── Latency bar chart ─────────────────────────────────────────────────────────

fig, ax = plt.subplots(figsize=(7, 4))
ax.bar(range(1, NUM_IMAGES+1), latencies, color='steelblue', edgecolor='white')
ax.axhline(sum(latencies)/len(latencies), color='red', linestyle='--', label='mean')
ax.set_xlabel('Image index'); ax.set_ylabel('Latency (ms)')
ax.set_title(f'Per-image inference latency on {str(DEVICE).upper()} ({IMG_SIZE}x{IMG_SIZE})')
ax.legend()
plt.tight_layout()
plt.savefig('images/latency.png', dpi=100, bbox_inches='tight')
plt.show()
print("\nSeries complete. Faster RCNN from scratch — all 6 notebooks executed.")