# Pruning + FineTuning

In [None]:
import sys; sys.path.append("../src")
sys.path.append("src")
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
import sys; sys.path.append("../src", "../../src")
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import json
from functools import partial
from pathlib import Path

import numpy

import torch
import torch.nn as nn
import torchvision
from torch.utils.data import ConcatDataset
from torch.utils.data import Dataset, random_split, DataLoader

from contextlib import redirect_stdout
from pathlib import Path
print(torch.cuda.is_available())

from torch.nn.utils import prune as nnprune
import torch.nn as nn
from engine import train_one_epoch, evaluate  # from cv-pipeline

In [None]:
from src.prune import should_skip_module, prune_convs_structured, prune_linears_unstructured
from src.utils import Tee

In [None]:
# MODELS - https://github.com/pytorch/vision/tree/main/torchvision/models/detection
from torchvision.models.detection import (
    fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights,
    fasterrcnn_mobilenet_v3_large_fpn, FasterRCNN_MobileNet_V3_Large_FPN_Weights,
    fasterrcnn_mobilenet_v3_large_320_fpn, FasterRCNN_MobileNet_V3_Large_320_FPN_Weights,
    ssd300_vgg16, SSD300_VGG16_Weights,
    ssdlite320_mobilenet_v3_large, SSDLite320_MobileNet_V3_Large_Weights,
    retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights,
    fcos_resnet50_fpn, FCOS_ResNet50_FPN_Weights
)
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from torchvision.models.detection.fcos import FCOSHead

### Modeling Parameters

In [None]:
data_dir = '../data/images_v1_v2'
coco_path = '../data/annotations_v1_v2/coco_v1_v2.json'

# 20% annotations
aug_perc = 0.2
sample_coco_path = f"../data/annotations_v1_v2/coco_v1_v2_{aug_perc}.json"

image_size= (256, 256)  # (128, 128) (256, 256) (512, 512)
batch_size = 2
val_percent = 0.1
num_classes =  2  # for just 1 object, classes will be 2 as background should be added as well

# ---------------- Optim/training hyperparams -----------------------------------
num_epochs = 12
learning_rate = 0.0005
momentum = 0.9
weight_decay = 1e-4
print_freq = 100

# --- pruning hyperparams ---
TARGET_SPARSITY_CONV = 0.10   # 30% channels per Conv (structured)
TARGET_SPARSITY_FC   = 0.10   # 30% weights per Linear (unstructured)
FINE_TUNE_EPOCHS     = 5
FINE_TUNE_LR         = learning_rate * 0.2   # lower LR for recovery
PRUNE_RPN_AND_ROI_PREDICTORS = False  # keep False for safety

# Apply Pruning
### ðŸ”¹ Typical flow
* Step	Purpose	Learning rate
* Train full model	Learn full capacity	normal (e.g., 1e-3)
* Prune weights	Remove least important connections	â€”
* Fine-tune	Restore accuracy on reduced model	small (e.g., 1e-4)

##### Without fine-tuning, your detectorâ€™s features become incoherent â†’ zero detections / zero AP (exactly what you saw).

### ðŸ”¹ Analogy

* Imagine removing 30 % of a musical instrumentâ€™s strings.
* It can still produce sound, but itâ€™s out of tune.
* Fine-tuning retrains the musician to play new chords with fewer strings â€” same music, less hardware.

In [None]:
trans = "resize_horz_clrjtr_rot"
model_path = "../models/frcnn_mobilenet_3/kd_resize_horz_clrjtr_rot_frcnn_mobilenet_epoch10.pth"

prune_model_dir = Path(f"../models/Pruning")
prune_model_dir.mkdir(parents=True, exist_ok=True)

log_path = prune_model_dir / f"prune_{trans}_log.txt"

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("device:", device)

print(f"Will prune from: {model_path}")

# reload model
prune_model = fasterrcnn_mobilenet_v3_large_fpn(weights=FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1)
in_feat = prune_model.roi_heads.box_predictor.cls_score.in_features
prune_model.roi_heads.box_predictor = FastRCNNPredictor(in_feat, num_classes)
prune_model.roi_heads.score_thresh = 0.3
prune_model.load_state_dict(torch.load(model_path, weights_only=False,  map_location=torch.device('cpu')))
prune_model.to(device).eval()

In [None]:
conv_targets, linear_targets = [], []
for name, m in prune_model.named_modules():
    if should_skip_module(name): continue
    if isinstance(m, nn.Conv2d):  conv_targets.append((m, name))
    elif isinstance(m, nn.Linear): linear_targets.append((m, name))

print(f"[PRUNE] Conv2d: {len(conv_targets)} | Linear: {len(linear_targets)}")

prune_convs_structured(conv_targets, TARGET_SPARSITY_CONV)
prune_linears_unstructured(linear_targets, TARGET_SPARSITY_FC)

In [None]:
# Make masks permanent
for mod in prune_model.modules():
    for attr in ["weight", "bias"]:
        if hasattr(mod, f"{attr}_mask"):
            try: nnprune.remove(mod, attr)
            except: pass

# Lower score threshold and evaluate
prune_model.eval()
prune_model.roi_heads.score_thresh = 0.3 #0.05
evaluate(prune_model, data_loader_val, device=device)

# Apply FineTuning

In [None]:
prune_model.train()
ft_optim = torch.optim.SGD([p for p in prune_model.parameters() if p.requires_grad],
                           lr=max(1e-5, learning_rate*0.2), momentum=momentum, weight_decay=weight_decay)
ft_sched = torch.optim.lr_scheduler.StepLR(ft_optim, step_size=2, gamma=0.1)

for epoch in range(num_epochs):
    train_one_epoch(prune_model, ft_optim, data_loader_train, device, epoch, print_freq=100)
    ft_sched.step()
    prune_model.eval()
    prune_model.roi_heads.score_thresh = 0.3
    evaluate(prune_model, data_loader_val, device=device)

    # save model
    model_name = f"prune_{trans}_epoch{epoch}"
    model_save_path = prune_model_dir / f"{model_name}.pth"
    print("model will be saved as:", model_save_path, "\n")
    torch.save(prune_model.state_dict(), model_save_path)

print("\n[PRUNE] Final sparsity (pruned+finetuned):")
report_sparsity(prune_model)