# Modeling


For best results, these techniques are often applied sequentially:
* Distillation: Train a lightweight Student model using a powerful Teacher model.18
* Pruning: Further prune the Student model to reduce computation.
* Quantization: Apply $\text{INT}8$ quantization to the final pruned Student model for maximum deployment efficiency.

* * * * *

### ‚úÖ **Lightweight Models**

Designed for mobile or resource-constrained environments:

-   **`fasterrcnn_mobilenet_v3_large_fpn`**
-   **`fasterrcnn_mobilenet_v3_large_320_fpn`** (lower resolution variant)
-   **`ssdlite320_mobilenet_v3_large`** (SSD Lite optimized for mobile)
-   **`ssd300_vgg16`** (still heavier than MobileNet but lighter than ResNet-based Faster R-CNN)\
    These models trade off some accuracy for speed and low memory footprint. Ideal for real-time inference on edge devices. [[deepwiki.com]](https://deepwiki.com/pytorch/vision/3.2-detection-models)

* * * * *

### ‚öñÔ∏è **Moderate Models**

Balanced between accuracy and efficiency:

-   **`retinanet_resnet50_fpn`**
-   **`fcos_resnet50_fpn`**\
    These one-stage detectors are faster than two-stage models like Faster R-CNN but heavier than MobileNet-based SSD. Good for scenarios needing decent accuracy with reasonable latency. [[deepwiki.com]](https://deepwiki.com/pytorch/vision/3.2-detection-models)

* * * * *

### üîç **Heavy Models**

High accuracy, large backbone, and higher compute requirements:

-   **`fasterrcnn_resnet50_fpn`**
-   **`fasterrcnn_resnet50_fpn_v2`** (improved accuracy)
-   **`maskrcnn_resnet50_fpn`** (adds segmentation head ‚Üí even heavier)
-   **`keypointrcnn_resnet50_fpn`** (adds keypoint detection head)\
    These two-stage detectors are best for high-accuracy tasks but require GPUs for real-time performance. [[deepwiki.com]](https://deepwiki.com/pytorch/vision/3.2-detection-models)

* * * * *

### **Rule of Thumb**

-   **MobileNet-based ‚Üí Lightweight**
-   **ResNet-50-based ‚Üí Moderate to Heavy**
-   **Extra heads (Mask/Keypoint) ‚Üí Heavy**

In [None]:
import sys; sys.path.append("../src")
sys.path.append("src")
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
import sys; sys.path.append("../src", "../../src")
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import json
from functools import partial
from pathlib import Path

import numpy

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import ConcatDataset
from torch.utils.data import Dataset, random_split, DataLoader

from contextlib import redirect_stdout
from pathlib import Path
print(torch.cuda.is_available())

In [None]:
from src.fpn_feature_hooks import FeatureHook, RPNHeadHook, FPNAdapter
from src.student_model_builder import build_student
from src.train import train_one_epoch_kd
from src.utils import Tee

In [None]:
# MODELS - https://github.com/pytorch/vision/tree/main/torchvision/models/detection
from torchvision.models.detection import (
    fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights,
    fasterrcnn_mobilenet_v3_large_fpn, FasterRCNN_MobileNet_V3_Large_FPN_Weights,
    fasterrcnn_mobilenet_v3_large_320_fpn, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights,
    ssd300_vgg16, SSD300_VGG16_Weights,
    ssdlite320_mobilenet_v3_large, SSDLite320_MobileNet_V3_Large_Weights,
    retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights,
    fcos_resnet50_fpn, FCOS_ResNet50_FPN_Weights
)
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from torchvision.models.detection.fcos import FCOSHead

### Data Parameters

In [None]:
data_dir = '../data/images_v1_v2'
coco_path = '../data/annotations_v1_v2/coco_v1_v2.json'

# 20% annotations
aug_perc = 0.2
sample_coco_path = f"../data/annotations_v1_v2/coco_v1_v2_{aug_perc}.json"

image_size= (256, 256)  # (128, 128) (256, 256) (512, 512)
batch_size = 2
val_percent = 0.1
num_classes =  2  # for just 1 object, classes will be 2 as background should be added as well

# ---------------- Optim/training hyperparams -----------------------------------
num_epochs = 12
learning_rate = 0.001
momentum = 0.9
weight_decay = 1e-4
print_freq = 100

# ---------------- KD hyperparams ----------------------------------------------
KD_T = 3.0                 # temperature for distillation
LAMBDA_FPN = 0.5           # weight of FPN feature KD
LAMBDA_RPN = 0.5           # weight of RPN objectness KD
FEATURE_NORM = True        # L2-normalize along channels before MSE

In [None]:
trans = "resize_horz_clrjtr_rot"
n_iter = 1
student_variant = "frcnn_mobilenet_320"  # "frcnn_mobilenet_320"  # "frcnn_mobilenet" 

teacher_model_path = "../models/fasterrcnn_resnet50_fpn_v2/20_resize_horz_clrjtr_rot_epoch7.pth"


KD_model_dir = Path(f"../models/{student_variant}_{n_iter}")
KD_model_dir.mkdir(parents=True, exist_ok=True)

log_path = KD_model_dir / f"kd_{trans}_log.txt"

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("device:", device)

data_loader_train, data_loader_val are generated using `CV-pipeline`

In [None]:
# applied_transforms = ResizeTransform(size=image_size)
# dataset = CustomDataset(
#     root=data_dir,
#     annotation=coco_path,
#     transforms=applied_transforms
# )

# # split
# generator1 = torch.Generator().manual_seed(42)
# val_size = int(val_percent * len(dataset))
# train_size = len(dataset) - val_size
# train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size], generator=generator1)

# if "clr" in trans:
#     aug_transforms = ResizeColorTransform(size=image_size)
#     aug_dataset = CustomDataset(
#         root=data_dir,
#         annotation=sample_coco_path,  # augmentation subset (20%)
#         transforms=aug_transforms
#     )
#     # concat augmented into train
#     train_ds = ConcatDataset([train_ds, aug_dataset])
#     print(f"length of train dataset after applying ResizeColorTransform: {len(train_ds)}")

# if "rot" in trans:
#     aug_transforms = ResizeRotateTransform(size=image_size)
#     aug_dataset = CustomDataset(root=data_dir,
#                                      annotation=sample_coco_path,
#                                      transforms=aug_transforms)
#     # Concat train and augmenented data
#     train_ds = torch.utils.data.ConcatDataset([train_ds, aug_dataset])
#     print(f"length of train dataset after applying ResizeRotateTransform: {len(train_ds)}")

# if "horz" in trans:
#     aug_transforms = ResizeHorzTransform(size=image_size)
#     aug_dataset = CustomDataset(root=data_dir,
#                                      annotation=sample_coco_path,  # augmentation subset (20%)
#                                      transforms=aug_transforms)
#     # Concat train and augmenented data
#     train_ds = torch.utils.data.ConcatDataset([train_ds, aug_dataset])
#     print(f"length of train dataset after applying ResizeHorzTransform: {len(train_ds)}")

# data_loader_train = DataLoader(
#     train_ds, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=utils.collate_fn
# )
# data_loader_val = DataLoader(
#     val_ds, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=utils.collate_fn
# )
# print(len(train_ds), len(val_ds))

# Training

In [None]:
# ==============================================================================
# Teacher / Student builders
# ==============================================================================
def build_teacher(ckpt_path: Path, num_classes: int):
    teacher = fasterrcnn_resnet50_fpn_v2(weights=FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1)
    in_feat = teacher.roi_heads.box_predictor.cls_score.in_features
    teacher.roi_heads.box_predictor = FastRCNNPredictor(in_feat, num_classes)
    teacher.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
    for p in teacher.parameters():
        p.requires_grad_(False)
    teacher.eval()
    return teacher

# Teacher
teacher = build_teacher(teacher_model_path, num_classes=num_classes)

# Student
student = build_student(num_classes=num_classes, variant=student_variant)
if hasattr(student, "roi_heads"):
    student.roi_heads.score_thresh = 0.3  # optional

In [None]:
# Optim & per-iteration StepLR (decay every 3 epochs worth of iters)
params = [p for p in student.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
iters_per_epoch = max(1, len(data_loader_train))
step_size_iters = 3 * iters_per_epoch
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size_iters, gamma=0.1)

# ---------------- Register hooks ----------------
teacher_backbone_hook = FeatureHook()
student_backbone_hook = FeatureHook()
teacher_backbone_handle = teacher.backbone.register_forward_hook(teacher_backbone_hook)
student_backbone_handle = student.backbone.register_forward_hook(student_backbone_hook)

teacher_rpn_hook = RPNHeadHook()
student_rpn_hook = RPNHeadHook()
teacher_rpn_handle = teacher.rpn.head.register_forward_hook(teacher_rpn_hook)
student_rpn_handle = student.rpn.head.register_forward_hook(student_rpn_hook)

# FPN channel adapter
adapter = FPNAdapter().to(device)

In [None]:
# Training loop
with open(log_path, "w") as f:
    tee = Tee(sys.stdout, f)
    with redirect_stdout(tee):
        try:
            for epoch in range(num_epochs):
                stats = train_one_epoch_kd(
                    teacher, student, optimizer, lr_scheduler,
                    data_loader_train, device, epoch, KD_T, LAMBDA_FPN, LAMBDA_RPN, FEATURE_NORM,
                    adapter, teacher_backbone_hook, student_backbone_hook,
                    teacher_rpn_hook, student_rpn_hook,
                    print_freq=print_freq
                )
                
                # Evaluate
                evaluate(student, data_loader_val, device=device)
                
                # Save
                model_name = f"kd_{trans}_epoch{epoch}"
                save_path = KD_model_dir / f"{model_name}.pth"
                print("Saving:", save_path)
                torch.save(student.state_dict(), save_path)
        finally:
            # Clean up hooks
            teacher_backbone_handle.remove()
            student_backbone_handle.remove()
            teacher_rpn_handle.remove()
            student_rpn_handle.remove()
