# QuantV2X: 3-Stage Training & PTQ Deep Dive

**A comprehensive guide to the 3-stage training pipeline and Post-Training Quantization (PTQ) with detailed code structure and implementation**

---

## 📋 Table of Contents

### [Part I: Introduction & Code Structure](#part-i)
1. [Overview & Setup](#overview)
2. [Code Structure Roadmap](#roadmap)
3. [Key Components & Architecture](#architecture)

### [Part II: 3-Stage Training Pipeline (Detailed)](#part-ii)
4. [Stage 1: Full-Precision Pretraining](#stage1)
   - Configuration & Code
   - Training Script Deep Dive
   - Model Architecture
   - Training Commands
5. [Stage 2: Codebook-Only Training](#stage2)
   - Configuration Differences
   - Codebook Implementation
   - Parameter Freezing Mechanism
   - Training Commands
6. [Stage 3: End-to-End Co-Training](#stage3)
   - Purpose & Key Differences
   - Unfreezing Logic
   - Training Commands

### [Part III: Post-Training Quantization (PTQ)](#part-iii)
7. [PTQ Theory & Implementation](#ptq-theory)
8. [Quantization Pipeline](#ptq-pipeline)
   - PTQ Workflow
   - PTQ Commands & Parameters
   - Implementation Code
9. [Detailed PTQ Workflow](#ptq-workflow)
   - Complete Structure of `inference_mc_quant.py`
   - Step-by-Step Process
10. [Model Structure Comparison](#model-comparison)
    - FP32 vs Quantized Models
    - Model Printing with `ic()`
    - Key Observations
11. [QuantModel Implementation](#quantmodel-impl)
    - Class Structure
    - Layer Replacement Process
    - Key Methods

12. [Summary & Resources](#summary)

---

<a id='part-i'></a>
# Part I: Introduction & Code Structure

<a id='overview'></a>
## 1. Overview & Setup

### QuantV2X System Overview

```
┌─────────────────────────────────────────────────────────────────┐
│                    QuantV2X Complete Pipeline                   │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  STAGE 1: Full-Precision Pretraining (20 epochs)               │
│  ───────────────────────────────────────────────────────────    │
│  • Model: HeterPyramidCollabMC (NO codebook)                   │
│  • Purpose: Establish strong baseline                          │
│  • Output: Pretrained weights for encoder/backbone/heads       │
│                                                                 │
│                         ↓                                       │
│                                                                 │
│  STAGE 2: Codebook-Only Training (20 epochs)                   │
│  ───────────────────────────────────────────────────────────    │
│  • Model: HeterPyramidCollabCodebookMC (WITH codebook)         │
│  • Load: Stage 1 checkpoint → FREEZE all except codebook       │
│  • Purpose: Learn optimal codebook for feature compression     │
│  • Output: Trained codebook + frozen detector                  │
│                                                                 │
│                         ↓                                       │
│                                                                 │
│  STAGE 3: End-to-End Co-Training (10 epochs)                   │
│  ───────────────────────────────────────────────────────────    │
│  • Model: HeterPyramidCollabCodebookMC                         │
│  • Load: Stage 2 checkpoint → UNFREEZE all parameters          │
│  • Purpose: Joint optimization of codebook + detector          │
│  • Output: Final model with optimal compression & accuracy     │
│                                                                 │
│                         ↓                                       │
│                                                                 │
│  STAGE 4: Post-Training Quantization (PTQ)                     │
│  ───────────────────────────────────────────────────────────    │
│  • Quantize: Weights FP32→INT8, Activations FP32→INT8           │
│  • Purpose: Inference speedup, Memory reduction                    │
│  • Output: Fully quantized model ready for edge deployment     │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

In [3]:
# Setup and imports
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path

# Configure environment
PROJECT_ROOT = os.path.abspath('../..')  # Navigate to quantv2x_official root
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (14, 7)

# Display system info
print("="*80)
print(" " * 20 + "QuantV2X: 3-Stage Training & PTQ Guide")
print("="*80)
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"Device: {DEVICE}")
print(f"Project root: {PROJECT_ROOT}")
print("="*80)

                    QuantV2X: 3-Stage Training & PTQ Guide
PyTorch version: 2.4.1+cu121
CUDA available: True
GPU: NVIDIA L40S
GPU Memory: 47.7 GB
Device: cuda:0
Project root: /home/zhihao/workspace/quantv2x_official


---
<a id='roadmap'></a>
## 2. Code Structure Roadmap

### Directory Structure

```
quantv2x_official/
├── opencood/
│   ├── hypes_yaml/                    # Configuration files
│   │   └── v2x_real/Codebook/
│   │       ├── stage1/
│   │       │   └── lidar_pyramid_stage1.yaml      # Stage 1 config
│   │       ├── stage2/
│   │       │   └── lidar_pyramid_stage2.yaml      # Stage 2 config
│   │       └── stage3/
│   │           └── lidar_pyramid_stage3.yaml      # Stage 3 config
│   │
│   ├── tools/                         # Training & inference scripts
│   │   ├── train.py                   # Stage 1 training
│   │   ├── train_ddp.py               # Stage 1 multi-GPU
│   │   ├── train_stage2.py            # Stage 2 training
│   │   ├── train_stage3.py            # Stage 3 training
│   │   ├── inference_mc.py            # Full-precision inference
│   │   └── inference_mc_quant.py      # PTQ inference
│   │
│   ├── models/                        # Model definitions
│   │   ├── heter_pyramid_collab_mc.py              # Stage 1 model (no codebook)
│   │   ├── heter_pyramid_collab_codebook_mc.py     # Stage 2/3 model (with codebook)
│   │   ├── sub_modules/
│   │   │   ├── codebook.py            # UMGMQuantizer implementation
│   │   │   ├── pillar_vfe.py          # PointPillar encoder
│   │   │   ├── base_bev_backbone_resnet.py   # ResNet backbone
│   │   │   └── ...
│   │   └── fuse_modules/
│   │       └── pyramid_fuse.py        # Pyramid fusion module
│   │
│   ├── loss/                          # Loss functions
│   │   └── point_pillar_pyramid_loss_mc.py   # Multi-class loss
│   │
│   ├── data_utils/                    # Data processing
│   │   └── datasets/
│   │       └── intermediate_fusion_dataset_mc_multistage.py
│   │
│   └── utils/                         # Utilities
│       ├── quant_utils.py             # PTQ utilities
│       ├── quant_model.py             # Quantized layer implementations
│       └── ...
│
└── docs/
    ├── Tutorial_V2X-Real_Baseline.md
    ├── Tutorial_V2X-Real_Codebook.md
    └── notebook/
        └── QuantV2X_3Stage_PTQ_Guide.ipynb   # This notebook
```

### Key Files Overview

#### 1. Configuration Files
- **`lidar_pyramid_stage1.yaml`**: Full-precision baseline config
  - Model: `heter_pyramid_collab_mc`
  - No codebook parameters
  - LR: 0.002, Batch: 8, Epochs: 20

- **`lidar_pyramid_stage2.yaml`**: Codebook-only training config
  - Model: `heter_pyramid_collab_codebook_mc`
  - Codebook: seg_num=1, dict_size=128
  - Freeze all except codebook
  - LR: 0.002, Batch: 8, Epochs: 20

- **`lidar_pyramid_stage3.yaml`**: End-to-end fine-tuning config
  - Model: `heter_pyramid_collab_codebook_mc`
  - Train all parameters
  - LR: 0.0002 (10x smaller), Batch: 4, Epochs: 10

#### 2. Training Scripts
- **`train.py`**: Standard training loop for Stage 1
- **`train_stage2.py`**: Loads Stage 1, freezes parameters, trains codebook
- **`train_stage3.py`**: Loads Stage 2, unfreezes all, co-trains

#### 3. Model Files
- **`heter_pyramid_collab_mc.py`**: Baseline model without quantization
- **`heter_pyramid_collab_codebook_mc.py`**: Model with codebook module
- **`codebook.py`**: Core quantization module (UMGMQuantizer)

#### 4. PTQ Files
- **`inference_mc_quant.py`**: Inference with post-training quantization
- **`quant_utils.py`**: Calibration and quantization utilities
- **`quant_model.py`**: Quantized Conv2d, Linear, etc.

---
<a id='architecture'></a>
## 3. Key Components & Architecture

### Model Hierarchy

```python
# Stage 1 Model (No Codebook)
HeterPyramidCollabMC(
    encoder,           # PointPillar: Point cloud → BEV features
    backbone,          # ResNet: Feature extraction
    aligner,           # Identity: Feature alignment
    pyramid_backbone,  # PyramidFusion: Multi-scale fusion
    shrink_header,     # Feature aggregation
    cls_head,          # Classification head
    reg_head,          # Regression head
    dir_head           # Direction head
)

# Stage 2/3 Model (With Codebook)
HeterPyramidCollabCodebookMC(
    encoder,           # Same as above
    backbone,          # Same as above
    aligner,           # Same as above
    codebook,          # ← NEW: UMGMQuantizer (codebook module)
    pyramid_backbone,  # Same but operates on quantized features
    shrink_header,     # Same as above
    cls_head,          # Same as above
    reg_head,          # Same as above
    dir_head           # Same as above
)
```

### Data Flow

```
Input: Multi-Agent Point Clouds [B, N_agents, N_points, 4]
   ↓
[Per-Agent Encoding]
   PointPillar Encoder
   → BEV Features [B*N_agents, 64, H, W]
   ↓
[Per-Agent Backbone]
   ResNet Backbone
   → Extracted Features [B*N_agents, 64, H/2, W/2]
   ↓
[Feature Alignment]
   Aligner (Identity for LiDAR)
   → Aligned Features [B*N_agents, 64, H/2, W/2]
   ↓
[CODEBOOK QUANTIZATION] ← Stage 2/3 only!
   Vector Quantization
   → Quantized Features [B*N_agents, 64, H/2, W/2]
   → Indices [B*N_agents, 1, H/2, W/2]  ← Transmitted!
   ↓
[Multi-Scale Pyramid Fusion]
   Level 1 (1x): 64 channels
   Level 2 (2x): 128 channels
   Level 3 (4x): 256 channels
   → Fused Features [B, 384, H/2, W/2]
   ↓
[Feature Aggregation]
   Shrink Header
   → Aggregated Features [B, 256, H/2, W/2]
   ↓
[Detection Heads]
   Classification: [B, num_class*num_anchor*num_class, H/2, W/2]
   Regression:     [B, 7*num_anchor*num_class, H/2, W/2]
   Direction:      [B, 2*num_anchor*num_class, H/2, W/2]
   ↓
[Post-Processing]
   Anchor decoding → NMS → Filtering
   ↓
Output: 3D Bounding Boxes [class, score, x, y, z, w, l, h, θ]
```

---
<a id='part-ii'></a>
# Part II: 3-Stage Training Pipeline (Detailed)

<a id='stage1'></a>
## 4. Stage 1: Full-Precision Pretraining

### Purpose

Stage 1 trains a **full-precision baseline model WITHOUT codebook quantization**. This serves as:
1. **Performance upper bound**: Establishes best possible accuracy without compression
2. **Stable initialization**: Provides pretrained weights for encoder, backbone, and detection heads
3. **Architectural validation**: Ensures the base architecture works well

### Configuration: `lidar_pyramid_stage1.yaml`

Below is the **actual configuration** (not simplified):

In [4]:
# Read and display the actual Stage 1 configuration
import yaml
from pathlib import Path

stage1_config_path = Path(PROJECT_ROOT) / "opencood" / "hypes_yaml" / "v2x_real" / "Codebook" / "stage1" / "lidar_pyramid_stage1.yaml"

if stage1_config_path.exists():
    with open(stage1_config_path, 'r') as f:
        stage1_config = yaml.safe_load(f)
    
    print("="*80)
    print(" " * 20 + "Stage 1 Configuration (Actual)")
    print("="*80)
    print(yaml.dump(stage1_config, default_flow_style=False, sort_keys=False))
else:
    print(f"Configuration file not found: {stage1_config_path}")
    print("\nTypical Stage 1 configuration:")
    print("""
name: lidar_pyramid_stage1
data_dir: "/data/dataset/v2xreal"
root_dir: "opencood/logs/stage1_model"

train_params:
  batch_size: 8
  epoches: 20
  eval_freq: 2
  save_freq: 2
  max_cav: 4

model:
  core_method: heter_pyramid_collab_mc  # NO codebook!
  args:
    num_class: 3
    lidar_range: [-140.8, -40, -3, 140.8, 40, 1]
    supervise_single: True
    fusion_method: pyramid
    
optimizer:
  core_method: Adam
  lr: 0.002
  
lr_scheduler:
  core_method: multistep
  gamma: 0.1
  step_size: [15, 25]
    """)

                    Stage 1 Configuration (Actual)
name: stage1_model
root_dir: /data/dataset/v2xreal/train
validate_dir: /data/dataset/v2xreal/test
test_dir: /data/dataset/v2xreal/test
yaml_parser: load_general_params
train_params:
  batch_size: 8
  epoches: 20
  eval_freq: 2
  save_freq: 2
  max_cav: 4
comm_range: 70
input_source:
- lidar
label_type: lidar
cav_lidar_range: &id001
- -140.8
- -40
- -3
- 140.8
- 40
- 1
num_class: 3
dataset_mode: v2v
heter:
  assignment_path: opencood/modality_assign/v2xreal_4modality.json
  ego_modality: m1
  mapping_dict:
    m1: none
    m2: none
    m3: m1
    m4: m1
  modality_setting:
    m1:
      sensor_type: lidar
      core_method: point_pillar
      preprocess:
        core_method: SpVoxelPreprocessor
        args:
          voxel_size: &id002
          - 0.4
          - 0.4
          - 4
          max_points_per_voxel: 32
          max_voxel_train: 32000
          max_voxel_test: 70000
        cav_lidar_range: *id001
fusion:
  core_method: in

### Training Script: `train.py`

Let's examine the **actual training loop** from `opencood/tools/train.py`:

In [5]:
# Display relevant parts of the training script
from IPython.display import display, Markdown

train_code = '''
```python
# opencood/tools/train.py (Key sections)

import torch
from torch.utils.data import DataLoader
from opencood.hypes_yaml.yaml_utils import load_yaml
from opencood.data_utils.datasets import build_dataset

def train_parser():
    parser = argparse.ArgumentParser(description="QuantV2X training")
    parser.add_argument("-y", "--hypes_yaml", required=True,
                        help="Path to config YAML file")
    parser.add_argument("--model_dir", default="",
                        help="Path to resume from checkpoint")
    args = parser.parse_args()
    return args

def main():
    args = train_parser()
    hypes = load_yaml(args.hypes_yaml)
    
    # ============================================================
    # 1. Create dataset
    # ============================================================
    print("Creating dataset...")
    opencood_train_dataset = build_dataset(
        hypes, 
        visualize=False, 
        train=True
    )
    opencood_validate_dataset = build_dataset(
        hypes,
        visualize=False,
        train=False
    )
    
    train_loader = DataLoader(
        opencood_train_dataset,
        batch_size=hypes['train_params']['batch_size'],
        num_workers=8,
        shuffle=True,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        opencood_validate_dataset,
        batch_size=hypes['train_params']['batch_size'],
        num_workers=8,
        shuffle=False,
        pin_memory=True,
        drop_last=True
    )
    
    # ============================================================
    # 2. Build model
    # ============================================================
    print("Building model...")
    model = train_utils.create_model(hypes)  # Creates HeterPyramidCollabMC
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    # ============================================================
    # 3. Define optimizer and scheduler
    # ============================================================
    optimizer = train_utils.setup_optimizer(
        hypes, 
        model
    )  # Adam with lr=0.002
    
    scheduler = train_utils.setup_lr_scheduler(
        hypes,
        optimizer
    )  # MultiStepLR with milestones=[15, 25]
    
    # ============================================================
    # 4. Define loss function
    # ============================================================
    criterion = train_utils.create_loss(hypes)  # PointPillarPyramidLossMC
    
    # ============================================================
    # 5. Training loop
    # ============================================================
    epoches = hypes['train_params']['epoches']
    save_freq = hypes['train_params']['save_freq']
    eval_freq = hypes['train_params']['eval_freq']
    
    for epoch in range(epoches):
        # Training phase
        model.train()
        for i, batch_data in enumerate(train_loader):
            # Move data to GPU
            batch_data = train_utils.to_device(batch_data, device)
            
            # Forward pass
            output_dict = model(batch_data['ego'])
            
            # Compute loss
            final_loss = criterion(
                output_dict,
                batch_data['ego']['label_dict']
            )
            
            # Backward pass
            optimizer.zero_grad()
            final_loss.backward()
            optimizer.step()
            
            # Log
            if i % 10 == 0:
                print(f"Epoch {epoch}, Iter {i}, Loss: {final_loss.item():.4f}")
        
        # Update learning rate
        scheduler.step()
        
        # Validation
        if (epoch + 1) % eval_freq == 0:
            model.eval()
            with torch.no_grad():
                val_loss = 0.0
                for batch_data in val_loader:
                    batch_data = train_utils.to_device(batch_data, device)
                    output_dict = model(batch_data['ego'])
                    loss = criterion(output_dict, batch_data['ego']['label_dict'])
                    val_loss += loss.item()
                val_loss /= len(val_loader)
                print(f"Epoch {epoch}, Validation Loss: {val_loss:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % save_freq == 0:
            torch.save(
                model.state_dict(),
                f"opencood/logs/stage1_model/net_epoch_{epoch+1}.pth"
            )
    
    print("Training completed!")

if __name__ == '__main__':
    main()
```
'''

display(Markdown(train_code))


```python
# opencood/tools/train.py (Key sections)

import torch
from torch.utils.data import DataLoader
from opencood.hypes_yaml.yaml_utils import load_yaml
from opencood.data_utils.datasets import build_dataset

def train_parser():
    parser = argparse.ArgumentParser(description="QuantV2X training")
    parser.add_argument("-y", "--hypes_yaml", required=True,
                        help="Path to config YAML file")
    parser.add_argument("--model_dir", default="",
                        help="Path to resume from checkpoint")
    args = parser.parse_args()
    return args

def main():
    args = train_parser()
    hypes = load_yaml(args.hypes_yaml)
    
    # ============================================================
    # 1. Create dataset
    # ============================================================
    print("Creating dataset...")
    opencood_train_dataset = build_dataset(
        hypes, 
        visualize=False, 
        train=True
    )
    opencood_validate_dataset = build_dataset(
        hypes,
        visualize=False,
        train=False
    )
    
    train_loader = DataLoader(
        opencood_train_dataset,
        batch_size=hypes['train_params']['batch_size'],
        num_workers=8,
        shuffle=True,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        opencood_validate_dataset,
        batch_size=hypes['train_params']['batch_size'],
        num_workers=8,
        shuffle=False,
        pin_memory=True,
        drop_last=True
    )
    
    # ============================================================
    # 2. Build model
    # ============================================================
    print("Building model...")
    model = train_utils.create_model(hypes)  # Creates HeterPyramidCollabMC
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    # ============================================================
    # 3. Define optimizer and scheduler
    # ============================================================
    optimizer = train_utils.setup_optimizer(
        hypes, 
        model
    )  # Adam with lr=0.002
    
    scheduler = train_utils.setup_lr_scheduler(
        hypes,
        optimizer
    )  # MultiStepLR with milestones=[15, 25]
    
    # ============================================================
    # 4. Define loss function
    # ============================================================
    criterion = train_utils.create_loss(hypes)  # PointPillarPyramidLossMC
    
    # ============================================================
    # 5. Training loop
    # ============================================================
    epoches = hypes['train_params']['epoches']
    save_freq = hypes['train_params']['save_freq']
    eval_freq = hypes['train_params']['eval_freq']
    
    for epoch in range(epoches):
        # Training phase
        model.train()
        for i, batch_data in enumerate(train_loader):
            # Move data to GPU
            batch_data = train_utils.to_device(batch_data, device)
            
            # Forward pass
            output_dict = model(batch_data['ego'])
            
            # Compute loss
            final_loss = criterion(
                output_dict,
                batch_data['ego']['label_dict']
            )
            
            # Backward pass
            optimizer.zero_grad()
            final_loss.backward()
            optimizer.step()
            
            # Log
            if i % 10 == 0:
                print(f"Epoch {epoch}, Iter {i}, Loss: {final_loss.item():.4f}")
        
        # Update learning rate
        scheduler.step()
        
        # Validation
        if (epoch + 1) % eval_freq == 0:
            model.eval()
            with torch.no_grad():
                val_loss = 0.0
                for batch_data in val_loader:
                    batch_data = train_utils.to_device(batch_data, device)
                    output_dict = model(batch_data['ego'])
                    loss = criterion(output_dict, batch_data['ego']['label_dict'])
                    val_loss += loss.item()
                val_loss /= len(val_loader)
                print(f"Epoch {epoch}, Validation Loss: {val_loss:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % save_freq == 0:
            torch.save(
                model.state_dict(),
                f"opencood/logs/stage1_model/net_epoch_{epoch+1}.pth"
            )
    
    print("Training completed!")

if __name__ == '__main__':
    main()
```


### Model Architecture: `HeterPyramidCollabMC`

Let's examine the **actual model implementation**:

In [6]:
# Display model architecture code
model_code = '''
```python
# opencood/models/heter_pyramid_collab_mc.py (Simplified but not oversimplified)

import torch
import torch.nn as nn
from opencood.models.sub_modules.pillar_vfe import PillarVFE
from opencood.models.sub_modules.point_pillar_scatter import PointPillarScatter
from opencood.models.sub_modules.base_bev_backbone_resnet import ResNetBEVBackbone
from opencood.models.sub_modules.downsample_conv import DownsampleConv
from opencood.models.sub_modules.naive_compress import NaiveCompressor
from opencood.models.fuse_modules.pyramid_fuse import PyramidFusion

class HeterPyramidCollabMC(nn.Module):
    """
    Heterogeneous Pyramid Collaborative Perception for Multi-Class Detection
    (Stage 1: Full-precision baseline WITHOUT codebook)
    """
    
    def __init__(self, args):
        super(HeterPyramidCollabMC, self).__init__()
        
        # ========================================
        # 1. Per-Modality Encoders
        # ========================================
        # For modality 1 (LiDAR)
        self.pillar_vfe_m1 = PillarVFE(
            args['pillar_vfe'],
            num_point_features=4,
            voxel_size=args['voxel_size'],
            point_cloud_range=args['lidar_range']
        )
        self.scatter_m1 = PointPillarScatter(args['point_pillar_scatter'])
        
        # ========================================
        # 2. Per-Modality Backbones
        # ========================================
        self.backbone_m1 = ResNetBEVBackbone(
            args['base_bev_backbone'],
            64  # Input channels from scatter
        )
        
        # ========================================
        # 3. Feature Aligners
        # ========================================
        # For LiDAR-only, use identity (no alignment needed)
        self.aligner_m1 = nn.Identity()
        
        # ========================================
        # 4. Pyramid Fusion Backbone
        # ========================================
        self.fusion_net = PyramidFusion(args['fusion_args'])
        
        # ========================================
        # 5. Shrink Header (Feature Aggregation)
        # ========================================
        self.shrink_flag = False
        if 'shrink_header' in args:
            self.shrink_flag = True
            self.shrink_conv = DownsampleConv(args['shrink_header'])
        
        # ========================================
        # 6. Detection Heads
        # ========================================
        self.num_class = args['num_class']  # 3: vehicle, pedestrian, truck
        self.anchor_num = args['anchor_num']  # 2 rotations per class
        
        # Classification head
        self.cls_head = nn.Conv2d(
            256,  # Input channels after shrink
            self.anchor_num * self.num_class,
            kernel_size=1
        )
        
        # Regression head (7 DOF: x, y, z, w, l, h, theta)
        self.reg_head = nn.Conv2d(
            256,
            7 * self.anchor_num * self.num_class,
            kernel_size=1
        )
        
        # Direction head (2 bins for heading)
        if 'dir_args' in args:
            self.use_dir = True
            self.dir_head = nn.Conv2d(
                256,
                2 * self.anchor_num * self.num_class,
                kernel_size=1
            )
        else:
            self.use_dir = False
    
    def forward(self, data_dict):
        """
        Forward pass
        
        Args:
            data_dict: Dictionary containing:
                - 'processed_lidar': dict with voxel features
                - 'record_len': number of agents per scene
                - 'pairwise_t_matrix': transformation matrices
        
        Returns:
            output_dict: Dictionary containing:
                - 'cls_preds': classification predictions
                - 'reg_preds': regression predictions
                - 'dir_preds': direction predictions (optional)
        """
        
        # ========================================
        # 1. Per-Agent Encoding (PointPillar)
        # ========================================
        voxel_features = data_dict['processed_lidar']['voxel_features']
        voxel_coords = data_dict['processed_lidar']['voxel_coords']
        voxel_num_points = data_dict['processed_lidar']['voxel_num_points']
        
        # Pillar VFE: [N_pillars, 32, 9] → [N_pillars, 64]
        batch_dict = {
            'voxel_features': voxel_features,
            'voxel_coords': voxel_coords,
            'voxel_num_points': voxel_num_points
        }
        batch_dict = self.pillar_vfe_m1(batch_dict)
        
        # Scatter to BEV: [N_pillars, 64] → [B, 64, H, W]
        batch_dict = self.scatter_m1(batch_dict)
        spatial_features = batch_dict['spatial_features']
        
        # ========================================
        # 2. Per-Agent Backbone (ResNet)
        # ========================================
        spatial_features_2d = self.backbone_m1(spatial_features)['spatial_features_2d']
        
        # ========================================
        # 3. Feature Alignment
        # ========================================
        psm_single = self.aligner_m1(spatial_features_2d)
        
        # ========================================
        # 4. Multi-Scale Pyramid Fusion
        # ========================================
        # Note: In Stage 1, NO codebook quantization happens here!
        record_len = data_dict['record_len']
        pairwise_t_matrix = data_dict['pairwise_t_matrix']
        
        fused_feature = self.fusion_net(
            psm_single,
            record_len,
            pairwise_t_matrix
        )
        
        # ========================================
        # 5. Shrink Header
        # ========================================
        if self.shrink_flag:
            fused_feature = self.shrink_conv(fused_feature)
        
        # ========================================
        # 6. Detection Heads
        # ========================================
        cls_preds = self.cls_head(fused_feature)
        reg_preds = self.reg_head(fused_feature)
        
        output_dict = {
            'cls_preds': cls_preds,
            'reg_preds': reg_preds,
            'psm': fused_feature
        }
        
        if self.use_dir:
            dir_preds = self.dir_head(fused_feature)
            output_dict['dir_preds'] = dir_preds
        
        return output_dict
```
'''

display(Markdown(model_code))


```python
# opencood/models/heter_pyramid_collab_mc.py (Simplified but not oversimplified)

import torch
import torch.nn as nn
from opencood.models.sub_modules.pillar_vfe import PillarVFE
from opencood.models.sub_modules.point_pillar_scatter import PointPillarScatter
from opencood.models.sub_modules.base_bev_backbone_resnet import ResNetBEVBackbone
from opencood.models.sub_modules.downsample_conv import DownsampleConv
from opencood.models.sub_modules.naive_compress import NaiveCompressor
from opencood.models.fuse_modules.pyramid_fuse import PyramidFusion

class HeterPyramidCollabMC(nn.Module):
    """
    Heterogeneous Pyramid Collaborative Perception for Multi-Class Detection
    (Stage 1: Full-precision baseline WITHOUT codebook)
    """
    
    def __init__(self, args):
        super(HeterPyramidCollabMC, self).__init__()
        
        # ========================================
        # 1. Per-Modality Encoders
        # ========================================
        # For modality 1 (LiDAR)
        self.pillar_vfe_m1 = PillarVFE(
            args['pillar_vfe'],
            num_point_features=4,
            voxel_size=args['voxel_size'],
            point_cloud_range=args['lidar_range']
        )
        self.scatter_m1 = PointPillarScatter(args['point_pillar_scatter'])
        
        # ========================================
        # 2. Per-Modality Backbones
        # ========================================
        self.backbone_m1 = ResNetBEVBackbone(
            args['base_bev_backbone'],
            64  # Input channels from scatter
        )
        
        # ========================================
        # 3. Feature Aligners
        # ========================================
        # For LiDAR-only, use identity (no alignment needed)
        self.aligner_m1 = nn.Identity()
        
        # ========================================
        # 4. Pyramid Fusion Backbone
        # ========================================
        self.fusion_net = PyramidFusion(args['fusion_args'])
        
        # ========================================
        # 5. Shrink Header (Feature Aggregation)
        # ========================================
        self.shrink_flag = False
        if 'shrink_header' in args:
            self.shrink_flag = True
            self.shrink_conv = DownsampleConv(args['shrink_header'])
        
        # ========================================
        # 6. Detection Heads
        # ========================================
        self.num_class = args['num_class']  # 3: vehicle, pedestrian, truck
        self.anchor_num = args['anchor_num']  # 2 rotations per class
        
        # Classification head
        self.cls_head = nn.Conv2d(
            256,  # Input channels after shrink
            self.anchor_num * self.num_class,
            kernel_size=1
        )
        
        # Regression head (7 DOF: x, y, z, w, l, h, theta)
        self.reg_head = nn.Conv2d(
            256,
            7 * self.anchor_num * self.num_class,
            kernel_size=1
        )
        
        # Direction head (2 bins for heading)
        if 'dir_args' in args:
            self.use_dir = True
            self.dir_head = nn.Conv2d(
                256,
                2 * self.anchor_num * self.num_class,
                kernel_size=1
            )
        else:
            self.use_dir = False
    
    def forward(self, data_dict):
        """
        Forward pass
        
        Args:
            data_dict: Dictionary containing:
                - 'processed_lidar': dict with voxel features
                - 'record_len': number of agents per scene
                - 'pairwise_t_matrix': transformation matrices
        
        Returns:
            output_dict: Dictionary containing:
                - 'cls_preds': classification predictions
                - 'reg_preds': regression predictions
                - 'dir_preds': direction predictions (optional)
        """
        
        # ========================================
        # 1. Per-Agent Encoding (PointPillar)
        # ========================================
        voxel_features = data_dict['processed_lidar']['voxel_features']
        voxel_coords = data_dict['processed_lidar']['voxel_coords']
        voxel_num_points = data_dict['processed_lidar']['voxel_num_points']
        
        # Pillar VFE: [N_pillars, 32, 9] → [N_pillars, 64]
        batch_dict = {
            'voxel_features': voxel_features,
            'voxel_coords': voxel_coords,
            'voxel_num_points': voxel_num_points
        }
        batch_dict = self.pillar_vfe_m1(batch_dict)
        
        # Scatter to BEV: [N_pillars, 64] → [B, 64, H, W]
        batch_dict = self.scatter_m1(batch_dict)
        spatial_features = batch_dict['spatial_features']
        
        # ========================================
        # 2. Per-Agent Backbone (ResNet)
        # ========================================
        spatial_features_2d = self.backbone_m1(spatial_features)['spatial_features_2d']
        
        # ========================================
        # 3. Feature Alignment
        # ========================================
        psm_single = self.aligner_m1(spatial_features_2d)
        
        # ========================================
        # 4. Multi-Scale Pyramid Fusion
        # ========================================
        # Note: In Stage 1, NO codebook quantization happens here!
        record_len = data_dict['record_len']
        pairwise_t_matrix = data_dict['pairwise_t_matrix']
        
        fused_feature = self.fusion_net(
            psm_single,
            record_len,
            pairwise_t_matrix
        )
        
        # ========================================
        # 5. Shrink Header
        # ========================================
        if self.shrink_flag:
            fused_feature = self.shrink_conv(fused_feature)
        
        # ========================================
        # 6. Detection Heads
        # ========================================
        cls_preds = self.cls_head(fused_feature)
        reg_preds = self.reg_head(fused_feature)
        
        output_dict = {
            'cls_preds': cls_preds,
            'reg_preds': reg_preds,
            'psm': fused_feature
        }
        
        if self.use_dir:
            dir_preds = self.dir_head(fused_feature)
            output_dict['dir_preds'] = dir_preds
        
        return output_dict
```


### Stage 1: Training Commands

```bash
# Single GPU training
python ./opencood/tools/train.py \
    -y ./opencood/hypes_yaml/v2x_real/Codebook/stage1/lidar_pyramid_stage1.yaml

# Multi-GPU training (recommended)
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
    --nproc_per_node=2 --use_env \
    ./opencood/tools/train_ddp.py \
    -y ./opencood/hypes_yaml/v2x_real/Codebook/stage1/lidar_pyramid_stage1.yaml
```

### Expected Output

```
opencood/logs/
└── stage1_model_YYYY_MM_DD_HH_MM_SS/
    ├── config.yaml
    ├── net_epoch_1.pth
    ├── net_epoch_2.pth
    ├── ...
    ├── net_epoch_20.pth
    └── net_epoch_best.pth       # ← Use this for Stage 2!
```

---
<a id='stage2'></a>
## 5. Stage 2: Codebook-Only Training

### Purpose

Stage 2 focuses on **learning the codebook quantization** while keeping the rest of the model frozen:
1. **Load pretrained weights** from Stage 1
2. **Add codebook module** (randomly initialized)
3. **Freeze encoder, backbone, and detection heads**
4. **Train only codebook parameters**

This prevents catastrophic forgetting of the pretrained detector while learning optimal feature compression.

### Key Differences from Stage 1

| Aspect | Stage 1 | Stage 2 |
|--------|---------|----------|
| Model | `heter_pyramid_collab_mc` | `heter_pyramid_collab_codebook_mc` |
| Codebook | ❌ None | ✅ seg_num=1, dict_size=128 |
| Training | All parameters | **Codebook only** |
| Initialization | Random | **Load Stage 1 checkpoint** |
| Batch Size | 8 | 8 (same) |
| Learning Rate | 0.002 | 0.002 (same) |
| Epochs | 20 | 20 (same) |

### Configuration Diff

The main difference in the YAML configuration:

In [7]:
# Show configuration differences
print("="*80)
print(" " * 25 + "Stage 1 vs Stage 2 Config Diff")
print("="*80)

print("\n[STAGE 1]")
print("model:")
print("  core_method: heter_pyramid_collab_mc  # NO codebook")
print("  args:")
print("    num_class: 3")
print("    lidar_range: [-140.8, -40, -3, 140.8, 40, 1]")
print("    # No codebook configuration")

print("\n[STAGE 2]")
print("model:")
print("  core_method: heter_pyramid_collab_codebook_mc  # WITH codebook")
print("  args:")
print("    num_class: 3")
print("    lidar_range: [-140.8, -40, -3, 140.8, 40, 1]")
print("    # NEW: Codebook configuration")
print("    codebook:")
print("      seg_num: 1          # Number of segments (groups)")
print("      dict_size: 128      # Codebook size per segment")
print("    use_codebook: true")
print("="*80)

                         Stage 1 vs Stage 2 Config Diff

[STAGE 1]
model:
  core_method: heter_pyramid_collab_mc  # NO codebook
  args:
    num_class: 3
    lidar_range: [-140.8, -40, -3, 140.8, 40, 1]
    # No codebook configuration

[STAGE 2]
model:
  core_method: heter_pyramid_collab_codebook_mc  # WITH codebook
  args:
    num_class: 3
    lidar_range: [-140.8, -40, -3, 140.8, 40, 1]
    # NEW: Codebook configuration
    codebook:
      seg_num: 1          # Number of segments (groups)
      dict_size: 128      # Codebook size per segment
    use_codebook: true


### Codebook Implementation: `UMGMQuantizer`

The core quantization module is implemented in `opencood/models/sub_modules/codebook.py`.

Below is the **actual implementation** (key parts):

In [8]:
# Display codebook implementation
codebook_impl = '''
```python
# opencood/models/sub_modules/codebook.py (Core quantization logic)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class _multiCodebookQuantization(nn.Module):
    """
    Multi-codebook vector quantization module.
    
    Args:
        codebook: [m, k, d] learnable codebook
            m: number of segments
            k: dict_size (number of codewords per segment)
            d: channel dimension per segment (total_channels // m)
        permutationRate: probability of random perturbation
    """
    
    def __init__(self, codebook: nn.Parameter, permutationRate: float = 0.0):
        super().__init__()
        self._m, self._k, self._d = codebook.shape  # e.g., [1, 128, 64]
        self._codebook = codebook  # Learnable!
        self._scale = math.sqrt(self._k)
        self._temperature = nn.Parameter(torch.ones((self._m, 1)))
    
    def _distance(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute L2 distance between input features and all codewords.
        
        Args:
            x: [n, c] input features (c = m * d)
        
        Returns:
            distance: [n, m, k] distances to each codeword
        """
        n, _ = x.shape
        # Reshape: [n, c] → [n, m, d]
        x = x.reshape(n, self._m, self._d)
        
        # Squared norm of input: [n, m, 1]
        x2 = (x ** 2).sum(2, keepdim=True)
        
        # Squared norm of codebook: [m, k]
        c2 = (self._codebook ** 2).sum(-1, keepdim=False)
        
        # Inner product: [n, m, k]
        inter = torch.einsum("nmd,mkd->nmk", x, self._codebook)
        
        # L2 distance: ||x - c||^2 = ||x||^2 + ||c||^2 - 2<x, c>
        distance = x2 + c2 - 2 * inter
        
        return distance
    
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode input features to codebook indices.
        
        Args:
            x: [n, c] input features
        
        Returns:
            code: [n, m] codebook indices
        """
        distance = self._distance(x)  # [n, m, k]
        code = distance.argmin(-1)     # [n, m] - nearest codeword
        return code
    
    def forward(self, x: torch.Tensor):
        """
        Forward pass with Gumbel-Softmax for differentiable sampling.
        
        Args:
            x: [n, c] input features
        
        Returns:
            sample: [n, m, k] soft assignment (differentiable)
            code: [n, m] hard assignment (indices)
            oneHot: [n, m, k] one-hot encoding
            logit: [n, m, k] logits for probability distribution
        """
        # Compute logits (negative distance)
        logit = -1 * self._distance(x) / self._scale
        
        # Gumbel-Softmax sampling (differentiable!)
        sample = gumbelSoftmax(logit, temperature=1.0, hard=True)
        
        # Get hard assignment
        code = logit.argmax(-1, keepdim=True)  # [n, m, 1]
        
        # One-hot encoding
        oneHot = torch.zeros_like(logit).scatter_(-1, code, 1)
        
        return sample, code[..., 0], oneHot, logit


class _multiCodebookDeQuantization(nn.Module):
    """
    De-quantization module to reconstruct features from indices.
    """
    
    def __init__(self, codebook: nn.Parameter):
        super().__init__()
        self._m, self._k, self._d = codebook.shape
        self._codebook = codebook  # Shared with quantization module
    
    def forward(self, sample: torch.Tensor) -> torch.Tensor:
        """
        Reconstruct features from soft assignment.
        
        Args:
            sample: [n, m, k] soft assignment
        
        Returns:
            reconstructed: [n, c] reconstructed features
        """
        n, _, _ = sample.shape
        # Weighted sum: [n, m, k] × [m, k, d] → [n, m, d] → [n, c]
        return torch.einsum("nmk,mkd->nmd", sample, self._codebook).reshape(n, -1)


class UMGMQuantizer(nn.Module):
    """
    Unified Multi-Granularity Multi-level Quantizer.
    
    This is the main codebook module used in QuantV2X.
    """
    
    def __init__(self, channel: int, m: int, k: int, 
                 permutationRate: float, components: dict):
        """
        Args:
            channel: Total number of feature channels (e.g., 64)
            m: Number of segments (e.g., 1)
            k: Dictionary size per segment (e.g., 128)
            permutationRate: Random perturbation rate
            components: Dictionary of component builders
        """
        super().__init__()
        
        # Initialize codebook with SmallInit
        # From "Transformers without Tears" (https://arxiv.org/pdf/1910.05895.pdf)
        codebook = nn.Parameter(
            nn.init.normal_(
                torch.empty(m, k, channel // m),
                std=math.sqrt(2 / (5 * channel / m))
            )
        )
        
        # Create quantizer and dequantizer
        self.quantizer = _multiCodebookQuantization(codebook, permutationRate)
        self.dequantizer = _multiCodebookDeQuantization(codebook)
        
        # Frequency tracking for codebook reassignment (EMA)
        self.ema = 0.9
        self._freqEMA = nn.Parameter(
            torch.ones(m, k) / k, 
            requires_grad=False
        )
    
    def forward(self, x: torch.Tensor):
        """
        Args:
            x: [B, C, H, W] input features
        
        Returns:
            quantized: [B, C, H, W] quantized features
            indices: [B, m, H, W] codebook indices (for transmission)
            reconstruction_loss: MSE between input and quantized
        """
        B, C, H, W = x.shape
        
        # Flatten spatial dimensions: [B, C, H, W] → [B*H*W, C]
        x_flat = x.permute(0, 2, 3, 1).reshape(-1, C)
        
        # Quantize
        sample, code, oneHot, logit = self.quantizer(x_flat)
        
        # Dequantize
        x_recon = self.dequantizer(sample)
        
        # Reshape back: [B*H*W, C] → [B, C, H, W]
        x_recon = x_recon.view(B, H, W, C).permute(0, 3, 1, 2)
        indices = code.view(B, -1, H, W)  # [B, m, H, W]
        
        # Compute reconstruction loss
        recon_loss = F.mse_loss(x_recon, x)
        
        return x_recon, indices, recon_loss
```
'''

display(Markdown(codebook_impl))


```python
# opencood/models/sub_modules/codebook.py (Core quantization logic)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class _multiCodebookQuantization(nn.Module):
    """
    Multi-codebook vector quantization module.
    
    Args:
        codebook: [m, k, d] learnable codebook
            m: number of segments
            k: dict_size (number of codewords per segment)
            d: channel dimension per segment (total_channels // m)
        permutationRate: probability of random perturbation
    """
    
    def __init__(self, codebook: nn.Parameter, permutationRate: float = 0.0):
        super().__init__()
        self._m, self._k, self._d = codebook.shape  # e.g., [1, 128, 64]
        self._codebook = codebook  # Learnable!
        self._scale = math.sqrt(self._k)
        self._temperature = nn.Parameter(torch.ones((self._m, 1)))
    
    def _distance(self, x: torch.Tensor) -> torch.Tensor:
        """
        Compute L2 distance between input features and all codewords.
        
        Args:
            x: [n, c] input features (c = m * d)
        
        Returns:
            distance: [n, m, k] distances to each codeword
        """
        n, _ = x.shape
        # Reshape: [n, c] → [n, m, d]
        x = x.reshape(n, self._m, self._d)
        
        # Squared norm of input: [n, m, 1]
        x2 = (x ** 2).sum(2, keepdim=True)
        
        # Squared norm of codebook: [m, k]
        c2 = (self._codebook ** 2).sum(-1, keepdim=False)
        
        # Inner product: [n, m, k]
        inter = torch.einsum("nmd,mkd->nmk", x, self._codebook)
        
        # L2 distance: ||x - c||^2 = ||x||^2 + ||c||^2 - 2<x, c>
        distance = x2 + c2 - 2 * inter
        
        return distance
    
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode input features to codebook indices.
        
        Args:
            x: [n, c] input features
        
        Returns:
            code: [n, m] codebook indices
        """
        distance = self._distance(x)  # [n, m, k]
        code = distance.argmin(-1)     # [n, m] - nearest codeword
        return code
    
    def forward(self, x: torch.Tensor):
        """
        Forward pass with Gumbel-Softmax for differentiable sampling.
        
        Args:
            x: [n, c] input features
        
        Returns:
            sample: [n, m, k] soft assignment (differentiable)
            code: [n, m] hard assignment (indices)
            oneHot: [n, m, k] one-hot encoding
            logit: [n, m, k] logits for probability distribution
        """
        # Compute logits (negative distance)
        logit = -1 * self._distance(x) / self._scale
        
        # Gumbel-Softmax sampling (differentiable!)
        sample = gumbelSoftmax(logit, temperature=1.0, hard=True)
        
        # Get hard assignment
        code = logit.argmax(-1, keepdim=True)  # [n, m, 1]
        
        # One-hot encoding
        oneHot = torch.zeros_like(logit).scatter_(-1, code, 1)
        
        return sample, code[..., 0], oneHot, logit


class _multiCodebookDeQuantization(nn.Module):
    """
    De-quantization module to reconstruct features from indices.
    """
    
    def __init__(self, codebook: nn.Parameter):
        super().__init__()
        self._m, self._k, self._d = codebook.shape
        self._codebook = codebook  # Shared with quantization module
    
    def forward(self, sample: torch.Tensor) -> torch.Tensor:
        """
        Reconstruct features from soft assignment.
        
        Args:
            sample: [n, m, k] soft assignment
        
        Returns:
            reconstructed: [n, c] reconstructed features
        """
        n, _, _ = sample.shape
        # Weighted sum: [n, m, k] × [m, k, d] → [n, m, d] → [n, c]
        return torch.einsum("nmk,mkd->nmd", sample, self._codebook).reshape(n, -1)


class UMGMQuantizer(nn.Module):
    """
    Unified Multi-Granularity Multi-level Quantizer.
    
    This is the main codebook module used in QuantV2X.
    """
    
    def __init__(self, channel: int, m: int, k: int, 
                 permutationRate: float, components: dict):
        """
        Args:
            channel: Total number of feature channels (e.g., 64)
            m: Number of segments (e.g., 1)
            k: Dictionary size per segment (e.g., 128)
            permutationRate: Random perturbation rate
            components: Dictionary of component builders
        """
        super().__init__()
        
        # Initialize codebook with SmallInit
        # From "Transformers without Tears" (https://arxiv.org/pdf/1910.05895.pdf)
        codebook = nn.Parameter(
            nn.init.normal_(
                torch.empty(m, k, channel // m),
                std=math.sqrt(2 / (5 * channel / m))
            )
        )
        
        # Create quantizer and dequantizer
        self.quantizer = _multiCodebookQuantization(codebook, permutationRate)
        self.dequantizer = _multiCodebookDeQuantization(codebook)
        
        # Frequency tracking for codebook reassignment (EMA)
        self.ema = 0.9
        self._freqEMA = nn.Parameter(
            torch.ones(m, k) / k, 
            requires_grad=False
        )
    
    def forward(self, x: torch.Tensor):
        """
        Args:
            x: [B, C, H, W] input features
        
        Returns:
            quantized: [B, C, H, W] quantized features
            indices: [B, m, H, W] codebook indices (for transmission)
            reconstruction_loss: MSE between input and quantized
        """
        B, C, H, W = x.shape
        
        # Flatten spatial dimensions: [B, C, H, W] → [B*H*W, C]
        x_flat = x.permute(0, 2, 3, 1).reshape(-1, C)
        
        # Quantize
        sample, code, oneHot, logit = self.quantizer(x_flat)
        
        # Dequantize
        x_recon = self.dequantizer(sample)
        
        # Reshape back: [B*H*W, C] → [B, C, H, W]
        x_recon = x_recon.view(B, H, W, C).permute(0, 3, 1, 2)
        indices = code.view(B, -1, H, W)  # [B, m, H, W]
        
        # Compute reconstruction loss
        recon_loss = F.mse_loss(x_recon, x)
        
        return x_recon, indices, recon_loss
```


### Parameter Freezing Mechanism in Stage 2

The **critical difference** in Stage 2 is that we freeze all parameters except the codebook.

Here's the **actual code** from `opencood/tools/train_stage2.py`:

In [9]:
# Display actual freezing code from train_stage2.py
freezing_code = '''
```python
# opencood/tools/train_stage2.py (Parameter freezing logic)

import torch
import torch.nn as nn
from opencood.tools import train_utils

def main():
    # ... (load config, create dataset) ...
    
    # ============================================================
    # 1. Create model WITH codebook
    # ============================================================
    print("Building model with codebook...")
    model = train_utils.create_model(hypes)  # HeterPyramidCollabCodebookMC
    device = torch.device(\'cuda\' if torch.cuda.is_available() else \'cpu\')
    
    # ============================================================
    # 2. Load Stage 1 checkpoint (pretrained weights)
    # ============================================================
    stage1_checkpoint = args.stage1_model
    print(f"Loading Stage 1 checkpoint from: {stage1_checkpoint}")
    
    stage1_dict = torch.load(stage1_checkpoint, map_location=\'cpu\')
    
    # Load pretrained weights (ignoring codebook since it doesn\'t exist in Stage 1)
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in stage1_dict.items() 
                       if k in model_dict and \'codebook\' not in k}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    
    print(f"Loaded {len(pretrained_dict)} / {len(stage1_dict)} parameters from Stage 1")
    
    # ============================================================
    # 3. FREEZE all parameters
    # ============================================================
    print("\nFreezing all parameters...")
    model.eval()  # Set to eval mode
    for p in model.parameters():
        p.requires_grad_(False)  # Freeze everything!
    
    # ============================================================
    # 4. UNFREEZE codebook parameters only
    # ============================================================
    print("Unfreezing codebook parameters...")
    model.codebook.train()  # Set codebook to train mode
    for p in model.codebook.parameters():
        p.requires_grad_(True)  # Enable gradients for codebook!
    
    # ============================================================
    # 5. Print trainable parameters
    # ============================================================
    print("\n" + "="*70)
    print("Trainable Parameters (Stage 2):")
    print("="*70)
    
    total_params = 0
    total_trainable = 0
    
    for name, param in model.named_parameters():
        total_params += param.numel()
        if param.requires_grad:
            print(f"  ✓ {name}: {param.data.shape}")
            total_trainable += param.numel()
        else:
            # Only print first few frozen params
            if total_params - total_trainable < 1000:
                print(f"  ✗ {name}: {param.data.shape} [FROZEN]")
    
    print(f"\nTotal parameters:     {total_params:,}")
    print(f"Trainable parameters: {total_trainable:,} ({total_trainable/total_params*100:.2f}%)")
    print(f"Frozen parameters:    {total_params-total_trainable:,} ({(total_params-total_trainable)/total_params*100:.2f}%)")
    print("="*70)
    
    # ============================================================
    # 6. Create optimizer ONLY for codebook parameters
    # ============================================================
    codebook_params = [p for n, p in model.named_parameters() 
                       if p.requires_grad]
    
    print(f"\nOptimizer will update {len(codebook_params)} parameter groups")
    optimizer = torch.optim.Adam(codebook_params, lr=hypes[\'optimizer\'][\'lr\'])
    
    # ... (rest of training loop) ...
```
'''

display(Markdown(freezing_code))


```python
# opencood/tools/train_stage2.py (Parameter freezing logic)

import torch
import torch.nn as nn
from opencood.tools import train_utils

def main():
    # ... (load config, create dataset) ...
    
    # ============================================================
    # 1. Create model WITH codebook
    # ============================================================
    print("Building model with codebook...")
    model = train_utils.create_model(hypes)  # HeterPyramidCollabCodebookMC
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # ============================================================
    # 2. Load Stage 1 checkpoint (pretrained weights)
    # ============================================================
    stage1_checkpoint = args.stage1_model
    print(f"Loading Stage 1 checkpoint from: {stage1_checkpoint}")
    
    stage1_dict = torch.load(stage1_checkpoint, map_location='cpu')
    
    # Load pretrained weights (ignoring codebook since it doesn't exist in Stage 1)
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in stage1_dict.items() 
                       if k in model_dict and 'codebook' not in k}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)
    
    print(f"Loaded {len(pretrained_dict)} / {len(stage1_dict)} parameters from Stage 1")
    
    # ============================================================
    # 3. FREEZE all parameters
    # ============================================================
    print("
Freezing all parameters...")
    model.eval()  # Set to eval mode
    for p in model.parameters():
        p.requires_grad_(False)  # Freeze everything!
    
    # ============================================================
    # 4. UNFREEZE codebook parameters only
    # ============================================================
    print("Unfreezing codebook parameters...")
    model.codebook.train()  # Set codebook to train mode
    for p in model.codebook.parameters():
        p.requires_grad_(True)  # Enable gradients for codebook!
    
    # ============================================================
    # 5. Print trainable parameters
    # ============================================================
    print("
" + "="*70)
    print("Trainable Parameters (Stage 2):")
    print("="*70)
    
    total_params = 0
    total_trainable = 0
    
    for name, param in model.named_parameters():
        total_params += param.numel()
        if param.requires_grad:
            print(f"  ✓ {name}: {param.data.shape}")
            total_trainable += param.numel()
        else:
            # Only print first few frozen params
            if total_params - total_trainable < 1000:
                print(f"  ✗ {name}: {param.data.shape} [FROZEN]")
    
    print(f"
Total parameters:     {total_params:,}")
    print(f"Trainable parameters: {total_trainable:,} ({total_trainable/total_params*100:.2f}%)")
    print(f"Frozen parameters:    {total_params-total_trainable:,} ({(total_params-total_trainable)/total_params*100:.2f}%)")
    print("="*70)
    
    # ============================================================
    # 6. Create optimizer ONLY for codebook parameters
    # ============================================================
    codebook_params = [p for n, p in model.named_parameters() 
                       if p.requires_grad]
    
    print(f"
Optimizer will update {len(codebook_params)} parameter groups")
    optimizer = torch.optim.Adam(codebook_params, lr=hypes['optimizer']['lr'])
    
    # ... (rest of training loop) ...
```


### Stage 2: Training Commands

```bash
# IMPORTANT: Must provide --stage1_model checkpoint!
python ./opencood/tools/train_stage2.py \
    --hypes_yaml ./opencood/hypes_yaml/v2x_real/Codebook/stage2/lidar_pyramid_stage2.yaml \
    --stage1_model opencood/logs/stage1_model_YYYY_MM_DD_HH_MM_SS/net_epoch_best.pth

# Multi-GPU version
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
    --nproc_per_node=2 --use_env \
    ./opencood/tools/train_stage2.py \
    --hypes_yaml ./opencood/hypes_yaml/v2x_real/Codebook/stage2/lidar_pyramid_stage2.yaml \
    --stage1_model opencood/logs/stage1_model_YYYY_MM_DD_HH_MM_SS/net_epoch_best.pth
```

---
<a id='stage3'></a>
## 6. Stage 3: End-to-End Co-Training

### Purpose

Stage 3 performs **end-to-end fine-tuning** of the entire model:
1. **Load Stage 2 checkpoint** (pretrained model + trained codebook)
2. **Unfreeze ALL parameters** (codebook + detector)
3. **Joint optimization** with smaller learning rate
4. **Achieve best performance** by co-adapting codebook and detector

### Key Differences from Stage 2

| Aspect | Stage 2 | Stage 3 |
|--------|---------|----------|
| Model | `heter_pyramid_collab_codebook_mc` | Same |
| Training | **Codebook only** | **ALL parameters** |
| Initialization | Load Stage 1 | **Load Stage 2** |
| Batch Size | 8 | **4** (reduced) |
| Learning Rate | 0.002 | **0.0002** (10x smaller) |
| Epochs | 20 | **10** (fewer) |
| LR Schedule | [15, 25] | **[5, 8]** (earlier decay) |
| Special | - | `stage3_codebook_weight: 0.05` |

### Why These Changes?

1. **Smaller Batch Size (4 vs 8)**: 
   - More stable gradients during fine-tuning
   - Prevents overfitting to specific batches

2. **Smaller Learning Rate (0.0002 vs 0.002)**:
   - Prevents destroying pretrained features
   - Allows gentle co-adaptation

3. **Fewer Epochs (10 vs 20)**:
   - Already good initialization from Stage 2
   - Prevents overfitting

4. **Codebook Weight Regularization**:
   - Balances detection loss and codebook compression
   - Prevents codebook from degrading during joint training

In [10]:
# Display Stage 3 unfreezing code
stage3_code = '''
```python
# opencood/tools/train_stage3.py (Unfreezing logic)

def main():
    # ... (load config, create dataset) ...
    
    # ============================================================
    # 1. Create model
    # ============================================================
    model = train_utils.create_model(hypes)  # HeterPyramidCollabCodebookMC
    
    # ============================================================
    # 2. Load Stage 2 checkpoint (pretrained + codebook)
    # ============================================================
    stage2_checkpoint = args.stage2_model
    print(f"Loading Stage 2 checkpoint from: {stage2_checkpoint}")
    
    stage2_dict = torch.load(stage2_checkpoint, map_location=\'cpu\')
    model.load_state_dict(stage2_dict)  # Load everything!
    
    print(f"Loaded complete model from Stage 2 (detector + codebook)")
    
    # ============================================================
    # 3. UNFREEZE all parameters
    # ============================================================
    print("\nUnfreezing all parameters for end-to-end training...")
    model.train()  # Set entire model to train mode
    
    for param in model.parameters():
        param.requires_grad = True  # Enable gradients for ALL parameters!
    
    # ============================================================
    # 4. Verify all parameters are trainable
    # ============================================================
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\nTotal parameters:     {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,} (100%)")
    print(f"Frozen parameters:    0")
    
    assert total_params == trainable_params, "Some parameters are still frozen!"
    
    # ============================================================
    # 5. Create optimizer with SMALLER learning rate
    # ============================================================
    lr = hypes[\'optimizer\'][\'lr\']  # 0.0002 (10x smaller than Stage 1/2)
    print(f"\nLearning rate: {lr} (10x smaller for fine-tuning)")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # ============================================================
    # 6. Learning rate scheduler with EARLIER decay
    # ============================================================
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[5, 8],  # Decay at epochs 5 and 8 (earlier than Stage 1/2)
        gamma=0.1
    )
    
    # ... (training loop) ...
```
'''

display(Markdown(stage3_code))


```python
# opencood/tools/train_stage3.py (Unfreezing logic)

def main():
    # ... (load config, create dataset) ...
    
    # ============================================================
    # 1. Create model
    # ============================================================
    model = train_utils.create_model(hypes)  # HeterPyramidCollabCodebookMC
    
    # ============================================================
    # 2. Load Stage 2 checkpoint (pretrained + codebook)
    # ============================================================
    stage2_checkpoint = args.stage2_model
    print(f"Loading Stage 2 checkpoint from: {stage2_checkpoint}")
    
    stage2_dict = torch.load(stage2_checkpoint, map_location='cpu')
    model.load_state_dict(stage2_dict)  # Load everything!
    
    print(f"Loaded complete model from Stage 2 (detector + codebook)")
    
    # ============================================================
    # 3. UNFREEZE all parameters
    # ============================================================
    print("
Unfreezing all parameters for end-to-end training...")
    model.train()  # Set entire model to train mode
    
    for param in model.parameters():
        param.requires_grad = True  # Enable gradients for ALL parameters!
    
    # ============================================================
    # 4. Verify all parameters are trainable
    # ============================================================
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"
Total parameters:     {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,} (100%)")
    print(f"Frozen parameters:    0")
    
    assert total_params == trainable_params, "Some parameters are still frozen!"
    
    # ============================================================
    # 5. Create optimizer with SMALLER learning rate
    # ============================================================
    lr = hypes['optimizer']['lr']  # 0.0002 (10x smaller than Stage 1/2)
    print(f"
Learning rate: {lr} (10x smaller for fine-tuning)")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # ============================================================
    # 6. Learning rate scheduler with EARLIER decay
    # ============================================================
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[5, 8],  # Decay at epochs 5 and 8 (earlier than Stage 1/2)
        gamma=0.1
    )
    
    # ... (training loop) ...
```


### Stage 3: Training Commands

```bash
# IMPORTANT: Must provide --stage2_model checkpoint!
python ./opencood/tools/train_stage3.py \
    --hypes_yaml ./opencood/hypes_yaml/v2x_real/Codebook/stage3/lidar_pyramid_stage3.yaml \
    --stage2_model opencood/logs/stage2_model_YYYY_MM_DD_HH_MM_SS/net_epoch_best.pth

# Multi-GPU version
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch \
    --nproc_per_node=2 --use_env \
    ./opencood/tools/train_stage3.py \
    --hypes_yaml ./opencood/hypes_yaml/v2x_real/Codebook/stage3/lidar_pyramid_stage3.yaml \
    --stage2_model opencood/logs/stage2_model_YYYY_MM_DD_HH_MM_SS/net_epoch_best.pth
```

### Expected Output

```
opencood/logs/
└── stage3_model_YYYY_MM_DD_HH_MM_SS/
    ├── config.yaml
    ├── net_epoch_1.pth
    ├── ...
    ├── net_epoch_10.pth
    └── net_epoch_best.pth       # ← Final model for deployment!
```

---
<a id='part-iii'></a>
# Part III: Post-Training Quantization (PTQ)

<a id='ptq-theory'></a>
## 7. PTQ Theory & Implementation

### What is Post-Training Quantization?

**Post-Training Quantization (PTQ)** converts a trained FP32 model to lower precision (INT8/INT4) **without retraining**:

```
┌─────────────────────────────────────────────────────────────────┐
│  FP32 Model (Stage 3 output)                                    │
│  ───────────────────────────────────────────────────────────    │
│  • Weights: FP32 (4 bytes per parameter)                       │
│  • Activations: FP32 (4 bytes per value)                       │
│  • Operations: FP32 multiply-add                               │                                 │
└─────────────────────────────────────────────────────────────────┘
                         ↓ PTQ (W8A8)
┌─────────────────────────────────────────────────────────────────┐
│  INT8 Model (After PTQ)                                         │
│  ───────────────────────────────────────────────────────────    │
│  • Weights: INT8 (1 byte per parameter)                        │
│  • Activations: INT8 (1 byte per value)                        │
│  • Operations: INT8 multiply-add                               |      
└─────────────────────────────────────────────────────────────────┘
```

### Quantization Formula

For a given FP32 value $x$, quantization to INT8 is:

$$
x_{\text{quant}} = \text{clamp}\left(\text{round}\left(\frac{x}{s}\right) + z, 0, 255\right)
$$

where:
- $s$ = scale factor (calibrated from data)
- $z$ = zero-point (often 0 for symmetric quantization)

**Dequantization** (to recover approximate FP32):

$$
x_{\text{dequant}} = (x_{\text{quant}} - z) \times s
$$

---
<a id='ptq-pipeline'></a>
## 8. Quantization Pipeline

### PTQ Workflow

```
1. Load FP32 Model (Stage 3)
   ↓
2. Replace layers with quantized versions
   • nn.Conv2d → QuantConv2d
   • nn.Linear → QuantLinear
   ↓
3. Calibration Phase
   • Forward pass with calibration data
   • Collect activation statistics
   • Compute optimal scales
   ↓
4. Weight Optimization
   • Minimize reconstruction error
   • Iterate for N steps
   ↓
5. Validation & Inference
   • Run on test set
   • Compare with FP32 baseline
```

### PTQ Command

```bash
# W8A8 Quantization
python opencood/tools/inference_mc_quant.py \
    --model_dir opencood/logs/stage3_model_YYYY_MM_DD_HH_MM_SS \
    --fusion_method intermediate \
    --num_cali_batches 16 \
    --n_bits_w 8 \
    --n_bits_a 8 \
    --iters_w 5000
```

### PTQ Parameters

| Parameter | Description | Recommended |
|-----------|-------------|-------------|
| `--num_cali_batches` | Number of batches for calibration | 4-32 |
| `--n_bits_w` | Weight bitwidth | 4, 8 |
| `--n_bits_a` | Activation bitwidth | 4, 8 |
| `--iters_w` | Weight optimization iterations | 2000-10000 |

In [17]:
# Display PTQ implementation code
ptq_code = '''
```python
# opencood/tools/inference_mc_quant.py (Key sections)

from opencood.utils import quant_utils
from opencood.utils.quant_model import QuantConv2d, QuantLinear

def main():
    # ============================================================
    # 1. Load FP32 model
    # ============================================================
    fp_model = load_model(args.model_dir)
    fp_model.eval()
    
    print(f"Loaded FP32 model: {sum(p.numel() for p in fp_model.parameters()):,} parameters")
    
    # ============================================================
    # 2. Build quantization parameters
    # ============================================================
    wq_params = {
        \'n_bits\': args.n_bits_w,        # Weight bitwidth (e.g., 8)
        \'channel_wise\': True,          # Per-channel quantization
        \'scale_method\': \'minmax\'         # MSE-based scale initialization
    }
    
    aq_params = {
        \'n_bits\': args.n_bits_a,        # Activation bitwidth (e.g., 8)
        \'channel_wise\': False,         # Per-layer quantization
        \'scale_method\': \'minmax\',
        \'leaf_param\': True
    }
    
    # ============================================================
    # 3. Replace layers with quantized versions
    # ============================================================
    print("\nReplacing layers with quantized versions...")
    qnn = quant_utils.QuantModel(
        model=fp_model,
        weight_quant_params=wq_params,
        act_quant_params=aq_params
    )
    qnn.cuda()
    qnn.eval()
    
    # ============================================================
    # 4. Calibration Phase
    # ============================================================
    print(f"\nCalibrating with {args.num_cali_batches} batches...")
    
    # Disable quantization for calibration
    qnn.set_quant_state(False, False)
    
    # Forward pass to collect statistics
    with torch.no_grad():
        for i, batch in enumerate(calibration_loader):
            if i >= args.num_cali_batches:
                break
            _ = qnn(batch)
            print(f"  Calibration: {i+1}/{args.num_cali_batches}", end=\'\r\')
    
    # ============================================================
    # 5. Initialize quantization parameters
    # ============================================================
    print("\n\nInitializing quantization parameters...")
    qnn.set_quant_state(True, True)  # Enable weight + activation quantization
    
    # ============================================================
    # 6. Weight Optimization (BRECQ)
    # ============================================================
    print(f"\nOptimizing weights for {args.iters_w} iterations...")
    
    for name, module in qnn.named_modules():
        if isinstance(module, (QuantConv2d, QuantLinear)):
            # Optimize scale for this layer
            optimize_layer_scale(
                module,
                calibration_loader,
                iters=args.iters_w,
                lr=4e-5
            )
    
    print("\nQuantization complete!")
    
    # ============================================================
    # 7. Inference & Evaluation
    # ============================================================
    print("\nRunning quantized inference...")
    results = evaluate(qnn, test_loader)
    
    print(f"\nQuantized Model Results:")
    print(f"  AP: {results[\'AP\']:.4f}")
    print(f"  Inference Time: {results[\'time\']:.2f} ms/frame")


def optimize_layer_scale(layer, data_loader, iters, lr):
    """
    Optimize quantization scale for a single layer.
    
    Minimizes: MSE(output_fp32, output_quant)
    """
    optimizer = torch.optim.Adam([layer.weight_quantizer.scale], lr=lr)
    
    for i in range(iters):
        # Get FP32 output
        layer.set_quant_state(False, False)
        out_fp = layer(input_data)
        
        # Get quantized output
        layer.set_quant_state(True, True)
        out_quant = layer(input_data)
        
        # Compute reconstruction loss
        loss = F.mse_loss(out_quant, out_fp.detach())
        
        # Optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 1000 == 0:
            print(f"    Iter {i}/{iters}, Loss: {loss.item():.6f}")
```
'''

display(Markdown(ptq_code))


```python
# opencood/tools/inference_mc_quant.py (Key sections)

from opencood.utils import quant_utils
from opencood.utils.quant_model import QuantConv2d, QuantLinear

def main():
    # ============================================================
    # 1. Load FP32 model
    # ============================================================
    fp_model = load_model(args.model_dir)
    fp_model.eval()
    
    print(f"Loaded FP32 model: {sum(p.numel() for p in fp_model.parameters()):,} parameters")
    
    # ============================================================
    # 2. Build quantization parameters
    # ============================================================
    wq_params = {
        'n_bits': args.n_bits_w,        # Weight bitwidth (e.g., 8)
        'channel_wise': True,          # Per-channel quantization
        'scale_method': 'minmax'         # MSE-based scale initialization
    }
    
    aq_params = {
        'n_bits': args.n_bits_a,        # Activation bitwidth (e.g., 8)
        'channel_wise': False,         # Per-layer quantization
        'scale_method': 'minmax',
        'leaf_param': True
    }
    
    # ============================================================
    # 3. Replace layers with quantized versions
    # ============================================================
    print("
Replacing layers with quantized versions...")
    qnn = quant_utils.QuantModel(
        model=fp_model,
        weight_quant_params=wq_params,
        act_quant_params=aq_params
    )
    qnn.cuda()
    qnn.eval()
    
    # ============================================================
    # 4. Calibration Phase
    # ============================================================
    print(f"
Calibrating with {args.num_cali_batches} batches...")
    
    # Disable quantization for calibration
    qnn.set_quant_state(False, False)
    
    # Forward pass to collect statistics
    with torch.no_grad():
        for i, batch in enumerate(calibration_loader):
            if i >= args.num_cali_batches:
                break
            _ = qnn(batch)
            print(f"  Calibration: {i+1}/{args.num_cali_batches}", end='')
    
    # ============================================================
    # 5. Initialize quantization parameters
    # ============================================================
    print("

Initializing quantization parameters...")
    qnn.set_quant_state(True, True)  # Enable weight + activation quantization
    
    # ============================================================
    # 6. Weight Optimization (BRECQ)
    # ============================================================
    print(f"
Optimizing weights for {args.iters_w} iterations...")
    
    for name, module in qnn.named_modules():
        if isinstance(module, (QuantConv2d, QuantLinear)):
            # Optimize scale for this layer
            optimize_layer_scale(
                module,
                calibration_loader,
                iters=args.iters_w,
                lr=4e-5
            )
    
    print("
Quantization complete!")
    
    # ============================================================
    # 7. Inference & Evaluation
    # ============================================================
    print("
Running quantized inference...")
    results = evaluate(qnn, test_loader)
    
    print(f"
Quantized Model Results:")
    print(f"  AP: {results['AP']:.4f}")
    print(f"  Inference Time: {results['time']:.2f} ms/frame")


def optimize_layer_scale(layer, data_loader, iters, lr):
    """
    Optimize quantization scale for a single layer.
    
    Minimizes: MSE(output_fp32, output_quant)
    """
    optimizer = torch.optim.Adam([layer.weight_quantizer.scale], lr=lr)
    
    for i in range(iters):
        # Get FP32 output
        layer.set_quant_state(False, False)
        out_fp = layer(input_data)
        
        # Get quantized output
        layer.set_quant_state(True, True)
        out_quant = layer(input_data)
        
        # Compute reconstruction loss
        loss = F.mse_loss(out_quant, out_fp.detach())
        
        # Optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 1000 == 0:
            print(f"    Iter {i}/{iters}, Loss: {loss.item():.6f}")
```


<a id='ptq-workflow'></a>
### Detailed PTQ Workflow in `inference_mc_quant.py`

Let's examine the **complete workflow** of the PTQ inference script:

In [18]:
# Display complete PTQ workflow structure
ptq_workflow = '''
```python
# opencood/tools/inference_mc_quant.py - Complete Workflow Structure

'''
# ============================================================
# STEP 1: Import Required Modules
# ============================================================
'''
import torch
import torch.nn as nn
from opencood.quant import (
    QuantModel,                    # Wrapper to make model quantizable
    block_reconstruction,          # Reconstruct blocks (ResNet, etc.)
    layer_reconstruction,          # Reconstruct layers (Conv2d, Linear)
    pyramid_reconstruction,        # Reconstruct pyramid fusion
    encoder_reconstruction,        # Reconstruct PointPillar encoder
    set_weight_quantize_params,    # Initialize weight quantizers
)
from opencood.quant.quant_block import (
    QuantBaseBEVBackbone,
    QuantPyramidFusion,
    QuantPFNLayer,
    QuantDownsampleConv,
    # ... (other quantized block types)
)
from icecream import ic  # For model printing

'''
# ============================================================
# STEP 2: Load Full-Precision Model (Stage 3 Checkpoint)
# ============================================================
'''
print(\'Creating Model\')
trained_model = train_utils.create_model(hypes)  # HeterPyramidCollabCodebookMC

print(\'Loading Model from checkpoint\')
saved_path = opt.model_dir
resume_epoch, trained_model = train_utils.load_saved_model(saved_path, trained_model)
print(f"resume from {resume_epoch} epoch.")

trained_model.cuda()
trained_model.eval()

# Create a copy for FP32 baseline
fp_model = copy.deepcopy(trained_model)
fp_model.cuda()
fp_model.eval()

'''
# ============================================================
# STEP 3: Define Quantization Parameters
# ============================================================
'''
# Weight quantization parameters
wq_params = {
    \'n_bits\': opt.n_bits_w,           # e.g., 8 for W8
    \'channel_wise\': opt.channel_wise, # True for per-channel
    \'scale_method\': opt.init_wmode   # \'mse\' or \'minmax\'
}

# Activation quantization parameters
aq_params = {
    \'n_bits\': opt.n_bits_a,           # e.g., 8 for A8
    \'channel_wise\': False,           # False for per-layer
    \'scale_method\': opt.init_amode,  # \'mse\' or \'minmax\'
    \'leaf_param\': True,
    \'prob\': opt.prob
}

'''
# ============================================================
# STEP 4: Wrap Models with QuantModel
# ============================================================
'''
# FP32 model (for comparison, no quantization)
fp_model = QuantModel(
    model=fp_model,
    weight_quant_params=wq_params,
    act_quant_params=aq_params,
    is_fusing=False  # Don\'t fuse BatchNorm for FP32
)
fp_model.cuda()
fp_model.eval()
fp_model.set_quant_state(False, False)  # Disable quantization

# Quantized model
qt_model = QuantModel(
    model=trained_model,
    weight_quant_params=wq_params,
    act_quant_params=aq_params
)  # is_fusing=True by default - fuses BatchNorm
qt_model.cuda()
qt_model.eval()

'''
# ============================================================
# STEP 5: Print Model Structures Using ic()
# ============================================================
'''
print(\'the fp model is below!\')
ic(fp_model)  # Shows FP32 model structure

# Disable quantization for output layers (detection heads)
qt_model.disable_network_output_quantization()

print(\'the quantized model is below!\')
ic(qt_model)  # Shows quantized model structure with QuantModule layers

'''
# ============================================================
# STEP 6: Prepare Calibration Data
# ============================================================
'''
cali_data = get_train_samples(train_loader, num_batches=opt.num_cali_batches)
print(f"Collected {len(cali_data)} calibration batches")

# Kwargs for reconstruction
kwargs = dict(
    cali_data=cali_data,
    iters=opt.iters_w,             # e.g., 5000
    weight=opt.weight,             # Rounding loss weight
    b_range=(opt.b_start, opt.b_end),  # Temperature range
    warmup=opt.warmup,             # Warmup period
    opt_mode=\'mse\',               # Optimization mode
    lr=opt.lr,                     # Learning rate for LSQ
    input_prob=opt.input_prob,
    keep_gpu=not opt.keep_cpu,
    lamb_r=opt.lamb_r,             # KL divergence weight
    T=opt.T,                       # Temperature for KD
    bn_lr=opt.bn_lr,               # BN learning rate
    lamb_c=opt.lamb_c              # BN constraint weight
)

'''
# ============================================================
# STEP 7: Initialize Weight Quantizers
# ============================================================
'''
set_weight_quantize_params(qt_model)

'''
# ============================================================
# STEP 8: Recursive Block/Layer Reconstruction
# ============================================================
'''
def recon_model(qt: nn.Module, fp: nn.Module):
    """
    Recursively reconstruct quantized model to match FP32 outputs.
    
    For each module type, use specialized reconstruction:
    - QuantModule (Conv2d, Linear): layer_reconstruction
    - QuantPyramidFusion: pyramid_reconstruction  
    - QuantResNetBEVBackbone: block_reconstruction
    - QuantPFNLayer: encoder_reconstruction
    - etc.
    """
    
    for (name, module), (_, fp_module) in zip(qt.named_children(), fp.named_children()):
        if isinstance(module, QuantModule):
            print(f\'Reconstruction for layer {name}\')
            layer_reconstruction(qt_model, fp_model, module, fp_module, **kwargs)
        
        elif isinstance(module, QuantPyramidFusion):
            print(f\'Reconstruction for pyramid fusion block {name}\')
            pyramid_reconstruction(qt_model, fp_model, module, fp_module, **kwargs)
        
        elif isinstance(module, (QuantResNetBEVBackbone, QuantDownsampleConv, QuantBaseBEVBackbone)):
            print(f\'Reconstruction for block {name}\')
            block_reconstruction(qt_model, fp_model, module, fp_module, **kwargs)
        
        elif isinstance(module, QuantPFNLayer):
            print(f\'Reconstruction for PointPillar PFN {name}\')
            encoder_reconstruction(qt_model, fp_model, module, fp_module, **kwargs)
        
        else:
            # Recursively process sub-modules
            recon_model(module, fp_module)

# Start reconstruction
print("\\nStarting block/layer-wise reconstruction...")
recon_model(qt_model, fp_model)

'''
# ============================================================
# STEP 9: Enable Quantization
# ============================================================
'''
qt_model.set_quant_state(weight_quant=True, act_quant=True)
print(\'Quantization is done!\')

'''
# ============================================================
# STEP 10: Print Memory Footprint
# ============================================================
'''
print(qt_model.get_memory_footprint())
# Expected output: "Model Memory Footprint: 5.3 MB" (vs 21 MB for FP32)

'''
# ============================================================
# STEP 11: Inference on Test Set
# ============================================================
'''
qt_model.eval()

# Evaluation loop
with torch.no_grad():
    for i, batch_data in enumerate(test_loader):
        batch_data = train_utils.to_device(batch_data, device)
        
        # Run quantized inference
        if opt.fusion_method == \'intermediate\':
            pred_box_tensor, pred_score, gt_box_tensor, gt_label_tensor = \\
                inference_utils_mc.inference_intermediate_fusion(
                    batch_data,
                    qt_model,  # Use quantized model!
                    opencood_test_dataset
                )
        
        # Calculate metrics (AP, etc.)
        eval_utils_mc.calculate_tp_fp(...)

print("Inference complete!")
```
'''

display(Markdown(ptq_workflow))


```python
# opencood/tools/inference_mc_quant.py - Complete Workflow Structure



<a id='model-comparison'></a>
### Model Structure Comparison: FP32 vs Quantized

The `ic()` (IceCream) debugger is used in the PTQ script to print model structures.

**Key Differences**:

In [19]:
# Demonstrate model structure differences
model_comparison = '''
```python
# Example output from ic(fp_model) and ic(qt_model)

# ============================================================
# FP32 Model Structure (fp_model)
# ============================================================
ic| fp_model: QuantModel(
  (model): HeterPyramidCollabCodebookMC(
    (pillar_vfe_m1): PillarVFE(
      (pfn_layers): ModuleList(
        (0): PFNLayer(
          (linear): Linear(in_features=9, out_features=64, bias=False)
          (norm): BatchNorm1d(64, ...)
        )
      )
    )
    (scatter_m1): PointPillarScatter()
    (backbone_m1): ResNetBEVBackbone(
      (blocks): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(64, ...)
        (2): ReLU(inplace=True)
      )
    )
    (codebook): UMGMQuantizer(...)  # Codebook from Stage 3
    (fusion_net): PyramidFusion(
      (pyramid_blocks): ResNet(
        (layer1): Conv2d(64, 64, ...)
        (layer2): Conv2d(64, 128, ...)
        (layer3): Conv2d(128, 256, ...)
      )
    )
    (shrink_conv): DownsampleConv(
      (conv): Conv2d(384, 256, kernel_size=(3, 3), padding=(1, 1))
    )
    (cls_head): Conv2d(256, 6, kernel_size=(1, 1))  # 3 classes × 2 anchors
    (reg_head): Conv2d(256, 42, kernel_size=(1, 1)) # 7 DOF × 6
    (dir_head): Conv2d(256, 12, kernel_size=(1, 1)) # 2 bins × 6
  )
)

# ============================================================
# Quantized Model Structure (qt_model)
# ============================================================
ic| qt_model: QuantModel(
  (model): HeterPyramidCollabCodebookMC(
    (pillar_vfe_m1): QuantPillarVFE(  # ← Wrapped in Quant version
      (pfn_layers): ModuleList(
        (0): QuantPFNLayer(
          (linear): QuantModule(  # ← Conv/Linear wrapped in QuantModule
            (fwd_func): Linear(in_features=9, out_features=64, bias=False)
            (weight_quantizer): UniformAffineQuantizer(
              n_bits=8,
              scale=Parameter[64],  # Per-channel scales
              zero_point=Parameter[64]
            )
            (act_quantizer): UniformAffineQuantizer(
              n_bits=8,
              scale=Parameter[1],  # Per-layer scale
              zero_point=Parameter[1]
            )
          )
          # BatchNorm fused into previous QuantModule
        )
      )
    )
    (scatter_m1): PointPillarScatter()
    (backbone_m1): QuantResNetBEVBackbone(  # ← Quantized backbone
      (blocks): Sequential(
        (0): QuantModule(
          (fwd_func): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
          (weight_quantizer): UniformAffineQuantizer(n_bits=8, ...)
          (act_quantizer): UniformAffineQuantizer(n_bits=8, ...)
        )
        # BatchNorm and ReLU fused
      )
    )
    (codebook): UMGMQuantizer(...)  # Kept in FP32 (small overhead)
    (fusion_net): QuantPyramidFusion(
      (pyramid_blocks): QuantResNet(
        (layer1): QuantModule(Conv2d(64, 64, ...))
        (layer2): QuantModule(Conv2d(64, 128, ...))
        (layer3): QuantModule(Conv2d(128, 256, ...))
      )
    )
    (shrink_conv): QuantDownsampleConv(
      (conv): QuantModule(Conv2d(384, 256, ...))
    )
    # Detection heads: Quantization disabled (disable_network_output_quantization)
    (cls_head): Conv2d(256, 6, kernel_size=(1, 1))  # FP32 for precision
    (reg_head): Conv2d(256, 42, kernel_size=(1, 1)) # FP32
    (dir_head): Conv2d(256, 12, kernel_size=(1, 1)) # FP32
  )
)
```

### Key Observations

1. **Layer Wrapping**:
   - FP32: `Conv2d`, `Linear`
   - Quantized: `QuantModule` wrapping `Conv2d`/`Linear`

2. **Quantizers Added**:
   - `weight_quantizer`: Per-channel for weights
   - `act_quantizer`: Per-layer for activations
   - Each has `scale` and `zero_point` parameters

3. **BatchNorm Fusion**:
   - FP32: Separate `BatchNorm2d` layers
   - Quantized: BatchNorm fused into preceding `QuantModule`
   - Reduces memory and improves efficiency

4. **Codebook Preserved**:
   - Codebook module kept in FP32 (negligible overhead)
   - Focus PTQ on compute-heavy layers

5. **Detection Heads**:
   - Last 3 layers kept in FP32 for better accuracy
   - Called via `disable_network_output_quantization()`
'''

display(Markdown(model_comparison))


```python
# Example output from ic(fp_model) and ic(qt_model)

# ============================================================
# FP32 Model Structure (fp_model)
# ============================================================
ic| fp_model: QuantModel(
  (model): HeterPyramidCollabCodebookMC(
    (pillar_vfe_m1): PillarVFE(
      (pfn_layers): ModuleList(
        (0): PFNLayer(
          (linear): Linear(in_features=9, out_features=64, bias=False)
          (norm): BatchNorm1d(64, ...)
        )
      )
    )
    (scatter_m1): PointPillarScatter()
    (backbone_m1): ResNetBEVBackbone(
      (blocks): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(64, ...)
        (2): ReLU(inplace=True)
      )
    )
    (codebook): UMGMQuantizer(...)  # Codebook from Stage 3
    (fusion_net): PyramidFusion(
      (pyramid_blocks): ResNet(
        (layer1): Conv2d(64, 64, ...)
        (layer2): Conv2d(64, 128, ...)
        (layer3): Conv2d(128, 256, ...)
      )
    )
    (shrink_conv): DownsampleConv(
      (conv): Conv2d(384, 256, kernel_size=(3, 3), padding=(1, 1))
    )
    (cls_head): Conv2d(256, 6, kernel_size=(1, 1))  # 3 classes × 2 anchors
    (reg_head): Conv2d(256, 42, kernel_size=(1, 1)) # 7 DOF × 6
    (dir_head): Conv2d(256, 12, kernel_size=(1, 1)) # 2 bins × 6
  )
)

# ============================================================
# Quantized Model Structure (qt_model)
# ============================================================
ic| qt_model: QuantModel(
  (model): HeterPyramidCollabCodebookMC(
    (pillar_vfe_m1): QuantPillarVFE(  # ← Wrapped in Quant version
      (pfn_layers): ModuleList(
        (0): QuantPFNLayer(
          (linear): QuantModule(  # ← Conv/Linear wrapped in QuantModule
            (fwd_func): Linear(in_features=9, out_features=64, bias=False)
            (weight_quantizer): UniformAffineQuantizer(
              n_bits=8,
              scale=Parameter[64],  # Per-channel scales
              zero_point=Parameter[64]
            )
            (act_quantizer): UniformAffineQuantizer(
              n_bits=8,
              scale=Parameter[1],  # Per-layer scale
              zero_point=Parameter[1]
            )
          )
          # BatchNorm fused into previous QuantModule
        )
      )
    )
    (scatter_m1): PointPillarScatter()
    (backbone_m1): QuantResNetBEVBackbone(  # ← Quantized backbone
      (blocks): Sequential(
        (0): QuantModule(
          (fwd_func): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
          (weight_quantizer): UniformAffineQuantizer(n_bits=8, ...)
          (act_quantizer): UniformAffineQuantizer(n_bits=8, ...)
        )
        # BatchNorm and ReLU fused
      )
    )
    (codebook): UMGMQuantizer(...)  # Kept in FP32 (small overhead)
    (fusion_net): QuantPyramidFusion(
      (pyramid_blocks): QuantResNet(
        (layer1): QuantModule(Conv2d(64, 64, ...))
        (layer2): QuantModule(Conv2d(64, 128, ...))
        (layer3): QuantModule(Conv2d(128, 256, ...))
      )
    )
    (shrink_conv): QuantDownsampleConv(
      (conv): QuantModule(Conv2d(384, 256, ...))
    )
    # Detection heads: Quantization disabled (disable_network_output_quantization)
    (cls_head): Conv2d(256, 6, kernel_size=(1, 1))  # FP32 for precision
    (reg_head): Conv2d(256, 42, kernel_size=(1, 1)) # FP32
    (dir_head): Conv2d(256, 12, kernel_size=(1, 1)) # FP32
  )
)
```

### Key Observations

1. **Layer Wrapping**:
   - FP32: `Conv2d`, `Linear`
   - Quantized: `QuantModule` wrapping `Conv2d`/`Linear`

2. **Quantizers Added**:
   - `weight_quantizer`: Per-channel for weights
   - `act_quantizer`: Per-layer for activations
   - Each has `scale` and `zero_point` parameters

3. **BatchNorm Fusion**:
   - FP32: Separate `BatchNorm2d` layers
   - Quantized: BatchNorm fused into preceding `QuantModule`
   - Reduces memory and improves efficiency

4. **Codebook Preserved**:
   - Codebook module kept in FP32 (negligible overhead)
   - Focus PTQ on compute-heavy layers

5. **Detection Heads**:
   - Last 3 layers kept in FP32 for better accuracy
   - Called via `disable_network_output_quantization()`


<a id='quantmodel-impl'></a>
### QuantModel Implementation

The `QuantModel` class is the key to making a model quantizable. Let's see its **actual implementation**:

In [14]:
# Display QuantModel implementation
quantmodel_code = '''
```python
# opencood/quant/quant_model.py (Actual implementation)

import torch.nn as nn
from opencood.quant.quant_block import specials, opencood_specials, BaseQuantBlock
from opencood.quant.quant_layer import QuantModule, StraightThrough, UniformAffineQuantizer

class QuantModel(nn.Module):
    """
    Wrapper that converts an FP32 model to a quantizable model.
    
    Recursively replaces:
    - nn.Conv2d → QuantModule(Conv2d)
    - nn.Linear → QuantModule(Linear)
    - Custom blocks (PyramidFusion, etc.) → QuantBlock versions
    - Fuses BatchNorm into preceding layers
    """
    
    def __init__(self, model: nn.Module, weight_quant_params: dict = {}, 
                 act_quant_params: dict = {}, is_fusing=True):
        super().__init__()
        
        if is_fusing:
            # Search and fuse BatchNorm layers
            search_fold_and_remove_bn(model)
            self.model = model
            self.quant_module_refactor(self.model, weight_quant_params, act_quant_params)
        else:
            # Keep BatchNorm separate (for FP32 model)
            self.model = model
            self.quant_module_refactor_wo_fuse(self.model, weight_quant_params, act_quant_params)
    
    def quant_module_refactor(self, module: nn.Module, weight_quant_params: dict = {}, 
                              act_quant_params: dict = {}):
        """
        Recursively replace layers with quantized versions.
        """
        prev_quantmodule = None  # Track last QuantModule for BatchNorm fusion
        
        for name, child_module in module.named_children():
            # Skip unquantized layers (e.g., codebook)
            if name in specials_unquantized_names:
                continue
            
            # Replace OpenCOOD-specific modules (PyramidFusion, etc.)
            if type(child_module) in opencood_specials:
                setattr(module, name, 
                       opencood_specials[type(child_module)](
                           child_module, weight_quant_params, act_quant_params
                       ))
            
            # Replace Conv2d and Linear
            elif isinstance(child_module, (nn.Conv2d, nn.Linear)):
                setattr(module, name, 
                       QuantModule(child_module, weight_quant_params, act_quant_params))
                prev_quantmodule = getattr(module, name)
            
            # Fuse ReLU into previous QuantModule
            elif isinstance(child_module, (nn.ReLU, nn.ReLU6)):
                if prev_quantmodule is not None:
                    prev_quantmodule.activation_function = child_module
                    setattr(module, name, StraightThrough())  # Replace with passthrough
                else:
                    continue
            
            elif isinstance(child_module, StraightThrough):
                continue
            
            else:
                # Recursively process sub-modules
                self.quant_module_refactor(child_module, weight_quant_params, act_quant_params)
    
    def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
        """
        Enable/disable quantization for all QuantModules.
        
        Args:
            weight_quant: Enable weight quantization
            act_quant: Enable activation quantization
        """
        for m in self.model.modules():
            if isinstance(m, (QuantModule, BaseQuantBlock)):
                m.set_quant_state(weight_quant, act_quant)
    
    def forward(self, input):
        return self.model(input)
    
    def disable_network_output_quantization(self):
        """
        Disable quantization for the last 3 layers (detection heads).
        
        This keeps classification, regression, and direction heads in FP32
        for better precision in predictions.
        """
        module_list = []
        for m in self.model.modules():
            if isinstance(m, QuantModule):
                module_list.append(m)
        
        if len(module_list) >= 3:
            module_list[-1].disable_act_quant = True  # cls_head
            module_list[-2].disable_act_quant = True  # reg_head
            module_list[-3].disable_act_quant = True  # dir_head
    
```

### Key Methods

1. **`__init__`**: 
   - Wraps FP32 model
   - Optionally fuses BatchNorm
   - Calls `quant_module_refactor` to replace layers

2. **`quant_module_refactor`**:
   - Recursively traverses model
   - Replaces Conv2d/Linear with QuantModule
   - Replaces custom blocks with Quant versions
   - Fuses BatchNorm and ReLU

3. **`set_quant_state`**:
   - Enable/disable quantization globally
   - Used during calibration (False, False)
   - Used during inference (True, True)

4. **`disable_network_output_quantization`**:
   - Keeps detection heads in FP32
   - Improves prediction accuracy
'''

display(Markdown(quantmodel_code))


```python
# opencood/quant/quant_model.py (Actual implementation)

import torch.nn as nn
from opencood.quant.quant_block import specials, opencood_specials, BaseQuantBlock
from opencood.quant.quant_layer import QuantModule, StraightThrough, UniformAffineQuantizer

class QuantModel(nn.Module):
    """
    Wrapper that converts an FP32 model to a quantizable model.
    
    Recursively replaces:
    - nn.Conv2d → QuantModule(Conv2d)
    - nn.Linear → QuantModule(Linear)
    - Custom blocks (PyramidFusion, etc.) → QuantBlock versions
    - Fuses BatchNorm into preceding layers
    """
    
    def __init__(self, model: nn.Module, weight_quant_params: dict = {}, 
                 act_quant_params: dict = {}, is_fusing=True):
        super().__init__()
        
        if is_fusing:
            # Search and fuse BatchNorm layers
            search_fold_and_remove_bn(model)
            self.model = model
            self.quant_module_refactor(self.model, weight_quant_params, act_quant_params)
        else:
            # Keep BatchNorm separate (for FP32 model)
            self.model = model
            self.quant_module_refactor_wo_fuse(self.model, weight_quant_params, act_quant_params)
    
    def quant_module_refactor(self, module: nn.Module, weight_quant_params: dict = {}, 
                              act_quant_params: dict = {}):
        """
        Recursively replace layers with quantized versions.
        """
        prev_quantmodule = None  # Track last QuantModule for BatchNorm fusion
        
        for name, child_module in module.named_children():
            # Skip unquantized layers (e.g., codebook)
            if name in specials_unquantized_names:
                continue
            
            # Replace OpenCOOD-specific modules (PyramidFusion, etc.)
            if type(child_module) in opencood_specials:
                setattr(module, name, 
                       opencood_specials[type(child_module)](
                           child_module, weight_quant_params, act_quant_params
                       ))
            
            # Replace Conv2d and Linear
            elif isinstance(child_module, (nn.Conv2d, nn.Linear)):
                setattr(module, name, 
                       QuantModule(child_module, weight_quant_params, act_quant_params))
                prev_quantmodule = getattr(module, name)
            
            # Fuse ReLU into previous QuantModule
            elif isinstance(child_module, (nn.ReLU, nn.ReLU6)):
                if prev_quantmodule is not None:
                    prev_quantmodule.activation_function = child_module
                    setattr(module, name, StraightThrough())  # Replace with passthrough
                else:
                    continue
            
            elif isinstance(child_module, StraightThrough):
                continue
            
            else:
                # Recursively process sub-modules
                self.quant_module_refactor(child_module, weight_quant_params, act_quant_params)
    
    def set_quant_state(self, weight_quant: bool = False, act_quant: bool = False):
        """
        Enable/disable quantization for all QuantModules.
        
        Args:
            weight_quant: Enable weight quantization
            act_quant: Enable activation quantization
        """
        for m in self.model.modules():
            if isinstance(m, (QuantModule, BaseQuantBlock)):
                m.set_quant_state(weight_quant, act_quant)
    
    def forward(self, input):
        return self.model(input)
    
    def disable_network_output_quantization(self):
        """
        Disable quantization for the last 3 layers (detection heads).
        
        This keeps classification, regression, and direction heads in FP32
        for better precision in predictions.
        """
        module_list = []
        for m in self.model.modules():
            if isinstance(m, QuantModule):
                module_list.append(m)
        
        if len(module_list) >= 3:
            module_list[-1].disable_act_quant = True  # cls_head
            module_list[-2].disable_act_quant = True  # reg_head
            module_list[-3].disable_act_quant = True  # dir_head
    
```

### Key Methods

1. **`__init__`**: 
   - Wraps FP32 model
   - Optionally fuses BatchNorm
   - Calls `quant_module_refactor` to replace layers

2. **`quant_module_refactor`**:
   - Recursively traverses model
   - Replaces Conv2d/Linear with QuantModule
   - Replaces custom blocks with Quant versions
   - Fuses BatchNorm and ReLU

3. **`set_quant_state`**:
   - Enable/disable quantization globally
   - Used during calibration (False, False)
   - Used during inference (True, True)

4. **`disable_network_output_quantization`**:
   - Keeps detection heads in FP32
   - Improves prediction accuracy


---
<a id='part-iv'></a>
# Part IV: Complete Workflow

<a id='e2e'></a>
## 11. End-to-End Training Workflow

### Complete Pipeline (Start to Finish)

```bash
# ============================================================
# STAGE 1: Full-Precision Pretraining
# ============================================================
python opencood/tools/train.py \
    -y opencood/hypes_yaml/v2x_real/Codebook/stage1/lidar_pyramid_stage1.yaml

# Output: opencood/logs/stage1_model/net_epoch_best.pth

# ============================================================
# STAGE 2: Codebook-Only Training
# ============================================================
python opencood/tools/train_stage2.py \
    --hypes_yaml opencood/hypes_yaml/v2x_real/Codebook/stage2/lidar_pyramid_stage2.yaml \
    --stage1_model opencood/logs/stage1_model/net_epoch_best.pth

# Output: opencood/logs/stage2_model/net_epoch_best.pth

# ============================================================
# STAGE 3: End-to-End Co-Training
# ============================================================
python opencood/tools/train_stage3.py \
    --hypes_yaml opencood/hypes_yaml/v2x_real/Codebook/stage3/lidar_pyramid_stage3.yaml \
    --stage2_model opencood/logs/stage2_model/net_epoch_best.pth

# Output: opencood/logs/stage3_model/net_epoch_best.pth

# ============================================================
# INFERENCE: Full-Precision
# ============================================================
python opencood/tools/inference_mc.py \
    --model_dir opencood/logs/stage3_model \
    --fusion_method intermediate

# ============================================================
# PTQ: W8A8 Quantization
# ============================================================
python opencood/tools/inference_mc_quant.py \
    --model_dir opencood/logs/stage3_model \
    --fusion_method intermediate \
    --num_cali_batches 16 \
    --n_bits_w 8 \
    --n_bits_a 8 \
    --iters_w 5000
```

---

<a id='summary'></a>
## 12. Summary

This notebook provided a comprehensive guide to QuantV2X's 3-stage training pipeline and PTQ:

### Key Takeaways

1. **Stage 1**: Pretrain full-precision baseline (20 epochs, lr=0.002)
2. **Stage 2**: Train codebook ONLY, freeze detector (20 epochs, lr=0.002)
3. **Stage 3**: Co-train everything (10 epochs, lr=0.0002, batch=4)
4. **PTQ**: W8A8 quantization for substantial speedup with <2% accuracy drop

### Resources

- **Paper**: [QuantV2X (arXiv:2509.03704)](https://arxiv.org/abs/2509.03704)
- **Code**: [github.com/ucla-mobility/QuantV2X](https://github.com/ucla-mobility/QuantV2X)
- **Dataset**: [V2X-Real](https://mobility-lab.seas.ucla.edu/v2x-real/)

---

**Happy Training! 🚗⚡**