# BEVFormer++ Project
This notebook implements and verifies the enhanced BEVFormer model with Memory Bank and ConvRNN.

## Data Setup
To use real data, please download the **nuScenes mini** dataset:
1. Register/Login at [nuscenes.org](https://www.nuscenes.org/nuscenes#download)
2. Download the "mini" split.
3. Extract it to `./data/nuscenes`.
   Structure should be:
   ```
   data/
     nuscenes/
       maps/
       samples/
       sweeps/
       v1.0-mini/
   ```

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import sys
import importlib

# Add current directory to path
if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

from modules.dataset import CarlaDataset, NuScenesDataset
import modules.bevformer
# Force reload to ensure latest changes are picked up
importlib.reload(modules.bevformer)
from modules.bevformer import EnhancedBEVFormer

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Load Data (Dummy or Real)

In [None]:
USE_REAL_DATA = False # Set to True if you want to visualize Real Data

if USE_REAL_DATA and os.path.exists('data/v1.0-mini'):
    print("Loading NuScenes Mini dataset...")
    dataset = NuScenesDataset(version='v1.0-mini', dataroot='data')
else:
    print("Using Dummy Data (NuScenes not found or disabled)...")
    dataset = CarlaDataset(root_dir='data', dummy_mode=True)

data = dataset[0]

imgs = data['img'].to(device) # (Seq, 6, 3, H, W)
intrinsics = data['intrinsics'].to(device)
extrinsics = data['extrinsics'].to(device)
ego_pose = data['ego_pose'].to(device)

print(f"Images shape: {imgs.shape}")

## 2. Initialize Model

In [None]:
model = EnhancedBEVFormer(bev_h=200, bev_w=200, embed_dim=256).to(device)

# Load Checkpoint if available
checkpoint_path = 'checkpoints/latest.pth'
if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['backbone_state_dict'], strict=False)
    print("Checkpoint loaded successfully.")
else:
    print("No checkpoint found. Using random initialization.")

print("Model initialized.")

## 3. Forward Pass Verification (Sequence)

In [None]:
# Run forward pass on the sequence
with torch.no_grad():
    # imgs: (Seq, 6, 3, H, W) -> (1, Seq, 6, 3, H, W)
    seq_imgs = imgs.unsqueeze(0)
    seq_intrinsics = intrinsics.unsqueeze(0)
    seq_extrinsics = extrinsics.unsqueeze(0)
    seq_ego_pose = ego_pose.unsqueeze(0)
    
    # Use forward_sequence
    if hasattr(model, 'forward_sequence'):
        bev_seq_output = model.forward_sequence(seq_imgs, seq_intrinsics, seq_extrinsics, seq_ego_pose)
        # Take the last frame output
        bev_output = bev_seq_output[:, -1]
        print(f"BEV Sequence Output shape: {bev_seq_output.shape}")
        print(f"Last Frame BEV Output shape: {bev_output.shape}")
    else:
        print("ERROR: forward_sequence method missing despite reload!")

## 4. Visualization

In [None]:
# Visualize the feature map (average across channels)
if 'bev_output' in locals():
    bev_map = bev_output[0].mean(dim=0).cpu().numpy()

    plt.figure(figsize=(10, 10))
    plt.imshow(bev_map, cmap='viridis')
    plt.title("BEV Feature Map (Last Frame)")
    plt.colorbar()
    plt.show()

## 5. Interpretability: Spatial Correspondence

In [None]:
def project_bev_to_cam(bev_pos, intrinsic, extrinsic, img_shape):
    # bev_pos: (H, W, 3)
    # intrinsic: (3, 3)
    # extrinsic: (4, 4)
    H, W, _ = bev_pos.shape
    pts_3d = bev_pos.reshape(-1, 3)
    ones = torch.ones(pts_3d.shape[0], 1, device=pts_3d.device)
    pts_hom = torch.cat([pts_3d, ones], dim=1) # (N, 4)
    
    # Transform to camera
    cam_coords = (pts_hom @ extrinsic.T)[:, :3] # (N, 3)
    
    # Project
    img_coords = (cam_coords @ intrinsic.T) # (N, 3)
    u = img_coords[:, 0] / (img_coords[:, 2] + 1e-5)
    v = img_coords[:, 1] / (img_coords[:, 2] + 1e-5)
    z = img_coords[:, 2]
    
    mask = (z > 0) & (u >= 0) & (u < img_shape[1]) & (v >= 0) & (v < img_shape[0])
    return u[mask], v[mask]

# Get BEV grid from model
bev_pos = model.bev_pos[0] # (H, W, 3)

# Visualize on the first camera
cam_idx = 0
img_tensor = imgs[-1, cam_idx].cpu().permute(1, 2, 0).numpy() # Use last frame image
img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min()) # Normalize for display

intrinsic = intrinsics[-1, cam_idx]
extrinsic = extrinsics[-1, cam_idx]

u, v = project_bev_to_cam(bev_pos, intrinsic, extrinsic, (256, 704))

plt.figure(figsize=(12, 6))
plt.imshow(img_tensor)
plt.scatter(u.cpu().numpy(), v.cpu().numpy(), s=1, c='red', alpha=0.5)
plt.title(f"BEV Grid Projection on Camera {cam_idx}")
plt.show()