In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install -q pyquaternion

In [None]:
import json
import torch
from torch.utils.data import Dataset, DataLoader
from pyquaternion import Quaternion
import math
from tqdm.auto import tqdm
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# Device setup
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

# ========== DATASET PATH SELECTION ==========
# Choose which dataset to use: Kaggle (cloud) or Local (your machine)
USE_LOCAL_DATASET = False  # Set to True to use local dataset, False for Kaggle

if USE_LOCAL_DATASET:
    # Local dataset paths (your machine)
    ROOT = '/media/taz/One Touch/nuscenes/v1.0-trainval01_blobs/'
    METAROOT = '/media/taz/One Touch/nuscenes/v1.0-trainval_meta/v1.0-trainval/'
    print("üîÑ Using LOCAL dataset")
else:
    # Kaggle dataset paths (cloud)
    ROOT = '/kaggle/input/nuscences-front-sensors-only/nuscenes_lead_vehicle_distance_data'
    METAROOT = os.path.join(ROOT, 'v1.0-trainval_meta/v1.0-trainval')
    print("‚òÅÔ∏è  Using KAGGLE dataset")

print(f"‚úÖ ROOT: {ROOT}")
print(f"‚úÖ METAROOT: {METAROOT}")

# ========== BEV PARAMETERS ==========
XRANGE = (-100.0, 100.0)  # x-axis range (forward direction) in meters
YRANGE = (-50.0, 50.0)    # y-axis range (lateral direction) in meters
RES = 0.5                 # Resolution: meters per voxel

# Compute grid dimensions
NX = int((XRANGE[1] - XRANGE[0]) / RES)
NY = int((YRANGE[1] - YRANGE[0]) / RES)

print(f"BEV Grid: {NX} x {NY} ({NX*NY} voxels, {RES}m resolution)")

# ========== BEV CHANNEL DEFINITIONS ==========
BEV_CHANNELS_LIDAR = 4    # count, avg_z, max_z, avg_intensity
BEV_CHANNELS_RADAR = 4    # count, avg_z, avg_doppler, avg_rcs
BEV_CHANNELS_CAM = 8      # placeholder (R, G, B, Mean, Max, Std, Occupancy, Valid)

print(f"BEV Channels: LIDAR={BEV_CHANNELS_LIDAR}, RADAR={BEV_CHANNELS_RADAR}, CAM={BEV_CHANNELS_CAM}")


In [None]:
# ========== KAGGLE PATH DIAGNOSTICS (REMOVE AFTER FIXING) ==========
# Run this cell on Kaggle to find where the JSON files actually are

if not USE_LOCAL_DATASET:
    print("üîç Exploring Kaggle dataset structure...")
    print("="*70)
    
    # Check what's in /kaggle/input
    print("\nüìÇ Contents of /kaggle/input:")
    try:
        for item in os.listdir('/kaggle/input'):
            print(f"   üìÅ {item}/")
    except Exception as e:
        print(f"   Error: {e}")
    
    # Check ROOT
    print(f"\nüìÇ Contents of ROOT ({ROOT}):")
    print(f"   Exists: {os.path.exists(ROOT)}")
    try:
        items = os.listdir(ROOT)
        for item in items[:20]:
            full_path = os.path.join(ROOT, item)
            if os.path.isdir(full_path):
                print(f"   üìÅ {item}/")
            else:
                print(f"   üìÑ {item}")
        if len(items) > 20:
            print(f"   ... and {len(items) - 20} more")
    except Exception as e:
        print(f"   Error: {e}")
    
    # Try checking some common paths
    test_paths = [
        ROOT,
        os.path.join(ROOT, 'v1.0-trainval_meta'),
        os.path.join(ROOT, 'v1.0-trainval_meta/v1.0-trainval'),
        os.path.join(ROOT, 'v1.0-trainval'),
        METAROOT,
    ]
    
    print(f"\nüîç Testing potential metadata paths:")
    for path in test_paths:
        exists = os.path.exists(path)
        print(f"   {'‚úÖ' if exists else '‚ùå'} {path}")
        if exists and os.path.isdir(path):
            try:
                files = os.listdir(path)
                json_count = sum(1 for f in files if f.endswith('.json'))
                print(f"      ({len(files)} items, {json_count} .json files)")
                if json_count > 0:
                    print(f"      Sample files: {[f for f in files if f.endswith('.json')][:3]}")
            except:
                pass
    
    print("="*70)
else:
    print("Skipping diagnostic (running locally)")

In [None]:
# ========== LOAD JSON UTILITY FUNCTION ==========
def load_json(name):
    """Load JSON metadata file from METAROOT"""
    path = os.path.join(METAROOT, name)
    if not os.path.exists(path):
        print(f"‚ùå File not found: {path}")
        print(f"Available files in {METAROOT}:")
        if os.path.exists(METAROOT):
            for f in os.listdir(METAROOT)[:10]:
                print(f"  - {f}")
        raise FileNotFoundError(f"Cannot find {name} in {METAROOT}")
    print(f'Loading: {path}')
    with open(path, 'r') as f:
        return json.load(f)

# ========== DATASET LOADING ==========
# Load metadata - works for both local and Kaggle datasets
# Uses ROOT and METAROOT paths defined in cell 3 via USE_LOCAL_DATASET flag

if USE_LOCAL_DATASET:
    print("\n" + "="*70)
    print("üîÑ LOADING LOCAL DATASET")
    print("="*70)
    
    print(f"\nüìÇ Using dataset paths (defined in cell 3):")
    print(f"   ROOT: {ROOT}")
    print(f"   METAROOT: {METAROOT}")
    
    # Check if paths exist
    if not os.path.exists(ROOT):
        print(f"\n‚ùå ERROR: ROOT does not exist!")
        print(f"   Expected: {ROOT}")
        raise FileNotFoundError(f"Local dataset not found at {ROOT}")
    
    if not os.path.exists(METAROOT):
        print(f"\n‚ùå ERROR: METAROOT does not exist!")
        print(f"   Expected: {METAROOT}")
        raise FileNotFoundError(f"Local metadata not found at {METAROOT}")
    
    print("‚úÖ Paths verified")
    
    print("\nüì• Loading local metadata...")
    samples = load_json('sample.json')
    sample_data = load_json('sample_data.json')
    calibrated_sensor = load_json('calibrated_sensor.json')
    ego_pose = load_json('ego_pose.json')
    sensor = load_json('sensor.json')
    sample_annotation = load_json('sample_annotation.json')
    scene = load_json('scene.json')
    
    print('‚úÖ Loaded tables:', len(samples), 'samples,', len(sample_data), 'sample_data')
else:
    # Kaggle dataset - metadata will be loaded in cell 5
    print("üìç Skipping local loading (will load Kaggle dataset in next cell)")

In [None]:
# ========== KAGGLE DATASET LOADING ==========
# Load metadata from Kaggle dataset (only runs when USE_LOCAL_DATASET = False)

if not USE_LOCAL_DATASET:
    print("\n" + "="*70)
    print("‚òÅÔ∏è  LOADING KAGGLE DATASET")
    print("="*70)
    
    print(f"\nüìÇ Using dataset paths (defined in cell 3):")
    print(f"   ROOT: {ROOT}")
    print(f"   METAROOT: {METAROOT}")
    
    # Check if paths exist
    if not os.path.exists(ROOT):
        print(f"\n‚ùå ERROR: ROOT does not exist!")
        print(f"   Expected: {ROOT}")
        raise FileNotFoundError(f"Kaggle dataset not found at {ROOT}")
    
    if not os.path.exists(METAROOT):
        print(f"\n‚ùå ERROR: METAROOT does not exist!")
        print(f"   Expected: {METAROOT}")
        raise FileNotFoundError(f"Kaggle metadata not found at {METAROOT}")
    
    print("‚úÖ Paths verified")
    
    print("\nüì• Loading Kaggle metadata...")
    samples = load_json('sample.json')
    sample_data = load_json('sample_data.json')
    calibrated_sensor = load_json('calibrated_sensor.json')
    ego_pose = load_json('ego_pose.json')
    sensor = load_json('sensor.json')
    sample_annotation = load_json('sample_annotation.json')
    scene = load_json('scene.json')
    
    print('‚úÖ Loaded tables:', len(samples), 'samples,', len(sample_data), 'sample_data')
    print("="*70)
else:
    print("\nüìç Skipping Kaggle loading (using local dataset)")

In [None]:
# Token -> record lookup dicts
sd_by_token = {rec['token']: rec for rec in sample_data}
cs_by_token = {rec['token']: rec for rec in calibrated_sensor}
ep_by_token = {rec['token']: rec for rec in ego_pose}
sample_by_token = {rec['token']: rec for rec in samples}

def sensor_channel_from_filename(fn):
    parts = fn.split('/')
    if len(parts) >= 3:
        return parts[1]  # folder name
    return parts[1]

# sample_token -> {channel: sample_data_token}
sample_to_sensor = {}
bad_sd = 0
for sd in sample_data:
    samp_tok = sd.get('sample_token')
    fn = sd.get('filename', '')
    chan = sensor_channel_from_filename(fn)
    if samp_tok is None or chan is None:
        bad_sd += 1
        continue
    if samp_tok not in sample_to_sensor:
        sample_to_sensor[samp_tok] = {}
    sample_to_sensor[samp_tok][chan] = sd['token']
    if not sd.get('is_keyframe', False):
        continue  # Only keep keyframes for now

print('Built sample_to_sensor for', len(sample_to_sensor), 'samples (skipped', bad_sd, 'sample_data records)')
first_sample_token = samples[0]['token']
print('Channels for first sample:', list(sample_to_sensor.get(first_sample_token, {}).keys()))


In [None]:
# ========== HELPER: SENSOR PATH RESOLUTION ==========
def abs_sensor_path(sd_rec):
    """Convert sample_data record to absolute file path."""
    filename = sd_rec.get('filename', '')
    return os.path.join(ROOT, filename)


# ========== LOADERS: LIDAR + RADAR ==========
def load_lidar_points(sd_rec):
    """Load LIDAR points and return as (N, 4) array [x, y, z, intensity]"""
    path = abs_sensor_path(sd_rec)
    if not os.path.exists(path):
        return np.zeros((0, 4), dtype=np.float32)

    pts = np.fromfile(path, dtype=np.float32)
    if pts.size == 0:
        return np.zeros((0, 4), dtype=np.float32)

    ncols = 5  # Standard nuScenes LIDAR: x,y,z,intensity,ring
    if pts.size % ncols != 0:
        return np.zeros((0, 4), dtype=np.float32)

    pts = pts.reshape(-1, ncols)[:, :4]  # Keep only x,y,z,i
    pts = np.clip(pts, -1000, 1000)  # Safe bounds
    pts = np.nan_to_num(pts, nan=0.0, posinf=0.0, neginf=0.0)  # keep finite
    return pts.astype(np.float32)


def load_radar_points(sd_rec):
    """Load RADAR points and return as (N, 5) array [x, y, z, doppler, rcs]. Robust to NaNs."""
    path = abs_sensor_path(sd_rec)
    if not os.path.exists(path):
        return np.zeros((0, 5), dtype=np.float32)

    try:
        pts = np.fromfile(path, dtype=np.float32)
        if pts.size == 0:
            return np.zeros((0, 5), dtype=np.float32)

        # Try common RADAR formats
        possible_cols = [18, 20, 24]
        ncols = next((c for c in possible_cols if pts.size % c == 0), None)
        if ncols is None:
            return np.zeros((0, 5), dtype=np.float32)

        pts = pts.reshape(-1, ncols)

        # Extract fields safely with defaults
        x = pts[:, 0] if pts.shape[1] > 0 else np.zeros(pts.shape[0])
        y = pts[:, 1] if pts.shape[1] > 1 else np.zeros(pts.shape[0])
        z = pts[:, 2] if pts.shape[1] > 2 else np.zeros(pts.shape[0])
        vx = pts[:, 6] if pts.shape[1] > 6 else np.zeros(pts.shape[0])
        vy = pts[:, 7] if pts.shape[1] > 7 else np.zeros(pts.shape[0])
        rcs = pts[:, 8] if pts.shape[1] > 8 else np.zeros(pts.shape[0])

        # Clip to valid ranges
        x = np.clip(x, -100, 100)
        y = np.clip(y, -100, 100)
        z = np.clip(z, -10, 10)
        vx = np.clip(vx, -50, 50)
        vy = np.clip(vy, -50, 50)
        rcs = np.clip(rcs, -50, 50)

        # Compute doppler magnitude safely
        vx = np.nan_to_num(vx, nan=0.0, posinf=0.0, neginf=0.0)
        vy = np.nan_to_num(vy, nan=0.0, posinf=0.0, neginf=0.0)
        vel_sq = np.clip(vx**2 + vy**2, 0, 1e3)
        doppler = np.sqrt(vel_sq)

        # Ensure finite
        x = np.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        y = np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0)
        z = np.nan_to_num(z, nan=0.0, posinf=0.0, neginf=0.0)
        doppler = np.nan_to_num(doppler, nan=0.0, posinf=0.0, neginf=0.0)
        rcs = np.nan_to_num(rcs, nan=0.0, posinf=0.0, neginf=0.0)

        out = np.stack([x, y, z, doppler, rcs], axis=1)
        return out.astype(np.float32)
    except Exception:
        return np.zeros((0, 5), dtype=np.float32)


In [None]:
# ========== BEV BUILDERS ==========

def compute_distance_to_lead_vehicle(pts_lidar):
    """
    Compute distance to lead vehicle from LIDAR points.
    Returns the minimum distance of points in front (x > 0.5m).
    """
    if pts_lidar.shape[0] == 0:
        return 50.0  # Default distance if no points

    # Filter points in front (x > 0.5m ego margin)
    front_mask = pts_lidar[:, 0] > 0.5
    if not front_mask.any():
        return 50.0

    front_pts = pts_lidar[front_mask]
    # Distance along x-axis (forward direction)
    distances = front_pts[:, 0]
    dist = np.min(distances)
    
    # Clamp to valid range [1, 50]
    return np.clip(dist, 1.0, 50.0)


def build_lidar_bev_from_points(pts_lidar, nx=NX, ny=NY, res=RES, xrange=XRANGE, yrange=YRANGE):
    """
    Build LIDAR BEV occupancy grid (4 channels: count, avg_z, max_z, avg_intensity).
    
    Args:
        pts_lidar: (N, 4) array [x, y, z, intensity]
        nx, ny: BEV grid size
        res: resolution (m/cell)
        xrange, yrange: (min, max) ranges
    
    Returns:
        (4, NX, NY) numpy array
    """
    bev = np.zeros((4, nx, ny), dtype=np.float32)
    
    if pts_lidar.shape[0] == 0:
        return bev
    
    # Quantize points to grid
    x, y, z, intensity = pts_lidar[:, 0], pts_lidar[:, 1], pts_lidar[:, 2], pts_lidar[:, 3]
    
    # Map coordinates to grid indices
    ix = ((x - xrange[0]) / res).astype(np.int32)
    iy = ((y - yrange[0]) / res).astype(np.int32)
    
    # Filter in-bounds points
    valid = (ix >= 0) & (ix < nx) & (iy >= 0) & (iy < ny)
    ix, iy = ix[valid], iy[valid]
    z, intensity = z[valid], intensity[valid]
    
    # Build occupancy grid
    # Channel 0: point count per cell
    np.add.at(bev[0], (ix, iy), 1)
    
    # Channel 1: average z per cell
    np.add.at(bev[1], (ix, iy), z)
    
    # Channel 2: max z per cell
    np.maximum.at(bev[2], (ix, iy), z)
    
    # Channel 3: average intensity per cell
    np.add.at(bev[3], (ix, iy), intensity)
    
    # Normalize channels
    count_nonzero = (bev[0] > 0).astype(np.float32)
    bev[1] = np.divide(bev[1], bev[0], where=count_nonzero > 0, out=np.zeros_like(bev[1]))
    bev[3] = np.divide(bev[3], bev[0], where=count_nonzero > 0, out=np.zeros_like(bev[3]))
    
    # Ensure finite
    bev = np.nan_to_num(bev, nan=0.0, posinf=0.0, neginf=0.0)
    return bev


def build_radar_bev_from_points(pts_radar, nx=NX, ny=NY, res=RES, xrange=XRANGE, yrange=YRANGE):
    """
    Build RADAR BEV occupancy grid (4 channels: count, avg_z, avg_doppler, avg_rcs).
    
    Args:
        pts_radar: (N, 5) array [x, y, z, doppler, rcs]
        nx, ny: BEV grid size
        res: resolution (m/cell)
        xrange, yrange: (min, max) ranges
    
    Returns:
        (4, NX, NY) numpy array
    """
    bev = np.zeros((4, nx, ny), dtype=np.float32)
    
    if pts_radar.shape[0] == 0:
        return bev
    
    # Quantize points to grid
    x, y, z, doppler, rcs = (pts_radar[:, i] for i in range(5))
    
    # Map coordinates to grid indices
    ix = ((x - xrange[0]) / res).astype(np.int32)
    iy = ((y - yrange[0]) / res).astype(np.int32)
    
    # Filter in-bounds points
    valid = (ix >= 0) & (ix < nx) & (iy >= 0) & (iy < ny)
    ix, iy = ix[valid], iy[valid]
    z, doppler, rcs = z[valid], doppler[valid], rcs[valid]
    
    # Build occupancy grid
    # Channel 0: point count per cell
    np.add.at(bev[0], (ix, iy), 1)
    
    # Channel 1: average z per cell
    np.add.at(bev[1], (ix, iy), z)
    
    # Channel 2: average doppler per cell
    np.add.at(bev[2], (ix, iy), doppler)
    
    # Channel 3: average rcs per cell
    np.add.at(bev[3], (ix, iy), rcs)
    
    # Normalize channels
    count_nonzero = (bev[0] > 0).astype(np.float32)
    bev[1] = np.divide(bev[1], bev[0], where=count_nonzero > 0, out=np.zeros_like(bev[1]))
    bev[2] = np.divide(bev[2], bev[0], where=count_nonzero > 0, out=np.zeros_like(bev[2]))
    bev[3] = np.divide(bev[3], bev[0], where=count_nonzero > 0, out=np.zeros_like(bev[3]))
    
    # Ensure finite
    bev = np.nan_to_num(bev, nan=0.0, posinf=0.0, neginf=0.0)
    return bev


def build_camera_bev_placeholder(nx=NX, ny=NY):
    """
    Build placeholder CAMERA BEV (8 channels, all zeros for now).
    To be replaced with actual camera BEV construction from images.
    
    Args:
        nx, ny: BEV grid size
    
    Returns:
        (8, NX, NY) numpy array of zeros
    """
    return np.zeros((BEV_CHANNELS_CAM, nx, ny), dtype=np.float32)


print("‚úÖ BEV BUILDERS READY")


In [None]:
# ========== DATASET CLASS ==========

class NuScenesBEVDataset(Dataset):
    """
    Multi-sensor BEV dataset for lead vehicle distance prediction.
    Returns: (lidar_bev, camera_bev, radar_bev, distance_label)
    
    Distance labels are COMPUTED from LIDAR point clouds (closest front point).
    """

    def __init__(self, sample_tokens):
        """Filter to only samples with existing LIDAR files"""
        valid_tokens = []
        for tok in sample_tokens:
            sensors = sample_to_sensor.get(tok, {})
            lid_tok = sensors.get("LIDAR_TOP", None)
            if lid_tok is None:
                continue

            sd_lidar = sd_by_token[lid_tok]
            lid_path = abs_sensor_path(sd_lidar)
            if not os.path.exists(lid_path):
                continue

            valid_tokens.append(tok)

        self.sample_tokens = valid_tokens
        print(f"‚úÖ Dataset: {len(self.sample_tokens)} samples with LIDAR_TOP")

    def __len__(self):
        return len(self.sample_tokens)

    def __getitem__(self, idx):
        token = self.sample_tokens[idx]
        sensors = sample_to_sensor[token]

        # ----- LIDAR_TOP -----
        sd_lidar_token = sensors["LIDAR_TOP"]
        sd_lidar = sd_by_token[sd_lidar_token]
        pts_lidar = load_lidar_points(sd_lidar)
        lidar_bev = build_lidar_bev_from_points(pts_lidar)
        
        # ----- COMPUTE DISTANCE LABEL from LIDAR -----
        distance = compute_distance_to_lead_vehicle(pts_lidar)

        # ----- RADAR (fuse FRONT radars) -----
        radar_points = []
        for key in ["RADAR_FRONT", "RADAR_FRONT_LEFT", "RADAR_FRONT_RIGHT"]:
            if key in sensors:
                sd_tok = sensors[key]
                sd_r = sd_by_token[sd_tok]
                pts_r = load_radar_points(sd_r)
                if pts_r.shape[0] > 0:
                    radar_points.append(pts_r)

        if radar_points:
            radar_points = np.concatenate(radar_points, axis=0)
        else:
            radar_points = np.zeros((0, 5), dtype=np.float32)
        radar_bev = build_radar_bev_from_points(radar_points)

        # ----- CAMERA (placeholder) -----
        cam_bev = build_camera_bev_placeholder()

        return (
            torch.from_numpy(lidar_bev),
            torch.from_numpy(cam_bev),
            torch.from_numpy(radar_bev),
            torch.tensor(distance, dtype=torch.float32)
        )

print("‚úÖ DATASET CLASS READY (using computed distance labels)")


In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
%matplotlib inline

def visualize_bev_channels(bev_array, sensor_name="", title_prefix=""):
    """
    Visualize all channels of a BEV representation.
    
    Args:
        bev_array: (C, H, W) numpy array
        sensor_name: name of sensor (LIDAR, RADAR, etc)
        title_prefix: prefix for subplot titles
    
    Returns:
        fig, axes for further customization
    """
    c, h, w = bev_array.shape
    fig, axes = plt.subplots(1, c, figsize=(15, 3))
    if c == 1:
        axes = [axes]
    
    channel_labels = {
        'LIDAR': ['Count', 'Avg Height', 'Max Height', 'Avg Intensity'],
        'RADAR': ['Count', 'Avg Z', 'Avg Doppler', 'Avg RCS'],
        'CAMERA': ['R', 'G', 'B', 'Mean', 'Max', 'Std', 'Occupancy', 'Valid']
    }
    
    labels = channel_labels.get(sensor_name, [f'Ch {i}' for i in range(c)])
    cmaps = ['hot', 'cool', 'viridis', 'plasma', 'RdYlGn', 'Blues', 'Purples', 'Greys']
    
    for i in range(c):
        ch_data = bev_array[i]
        nonzero_count = (ch_data > 0).sum()
        vmax = ch_data.max()
        
        im = axes[i].imshow(ch_data, cmap=cmaps[i % len(cmaps)], aspect='auto')
        axes[i].set_title(f'{title_prefix}{labels[i] if i < len(labels) else f"Ch{i}"}\n' + 
                         f'Nonzero: {nonzero_count} | Max: {vmax:.2f}')
        axes[i].set_xlabel('Y (m)')
        axes[i].set_ylabel('X (m)')
        plt.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    return fig, axes

def visualize_bev_fusion(lidar_bev, camera_bev, radar_bev, distance_label=None):
    """
    Comprehensive BEV fusion visualization showing all three modalities.
    
    Args:
        lidar_bev: (4, NX, NY) LIDAR BEV
        camera_bev: (8, NX, NY) CAMERA BEV  
        radar_bev: (4, NX, NY) RADAR BEV
        distance_label: scalar distance to lead vehicle
    """
    fig = plt.figure(figsize=(20, 12))
    gs = fig.add_gridspec(3, 5, hspace=0.4, wspace=0.3)
    
    # Row 1: LIDAR channels
    lidar_labels = ['Count', 'Avg Height', 'Max Height', 'Avg Intensity']
    for i in range(4):
        ax = fig.add_subplot(gs[0, i])
        ch = lidar_bev[i]
        nonzero = (ch > 0).sum()
        im = ax.imshow(ch, cmap='hot', aspect='auto')
        ax.set_title(f'LIDAR {lidar_labels[i]}\n({nonzero} voxels, max: {ch.max():.2f})')
        plt.colorbar(im, ax=ax, fraction=0.046)
    
    # Row 1, Col 5: LIDAR summary
    ax_summary_l = fig.add_subplot(gs[0, 4])
    ax_summary_l.axis('off')
    lidar_stats = f"""LIDAR Statistics:
    ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    Count voxels: {(lidar_bev[0] > 0).sum()}
    Avg height: {lidar_bev[1][lidar_bev[0] > 0].mean():.3f}m
    Max height: {lidar_bev[2].max():.3f}m
    Intensity: {lidar_bev[3][lidar_bev[0] > 0].mean():.3f}
    Total energy: {lidar_bev.sum():.1f}"""
    ax_summary_l.text(0.05, 0.95, lidar_stats, transform=ax_summary_l.transAxes,
                     fontfamily='monospace', fontsize=10, verticalalignment='top',
                     bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # Row 2: RADAR channels
    radar_labels = ['Count', 'Avg Z', 'Avg Doppler', 'Avg RCS']
    for i in range(4):
        ax = fig.add_subplot(gs[1, i])
        ch = radar_bev[i]
        nonzero = (ch > 0).sum()
        im = ax.imshow(ch, cmap='plasma', aspect='auto')
        ax.set_title(f'RADAR {radar_labels[i]}\n({nonzero} voxels, max: {ch.max():.2f})')
        plt.colorbar(im, ax=ax, fraction=0.046)
    
    # Row 2, Col 5: RADAR summary
    ax_summary_r = fig.add_subplot(gs[1, 4])
    ax_summary_r.axis('off')
    radar_stats = f"""RADAR Statistics:
    ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    Count voxels: {(radar_bev[0] > 0).sum()}
    Avg Z: {radar_bev[1][radar_bev[0] > 0].mean():.3f}m
    Max Doppler: {radar_bev[2].max():.3f}m/s
    RCS: {radar_bev[3][radar_bev[0] > 0].mean():.3f}dB
    Total energy: {radar_bev.sum():.1f}"""
    ax_summary_r.text(0.05, 0.95, radar_stats, transform=ax_summary_r.transAxes,
                     fontfamily='monospace', fontsize=10, verticalalignment='top',
                     bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5))
    
    # Row 3: CAMERA (first 4 channels) + Distance + Fusion
    camera_labels = ['R', 'G', 'B', 'Mean']
    for i in range(4):
        ax = fig.add_subplot(gs[2, i])
        ch = camera_bev[i]
        nonzero = (ch > 0).sum()
        im = ax.imshow(ch, cmap='viridis', aspect='auto')
        ax.set_title(f'CAMERA {camera_labels[i]}\n({nonzero} voxels, max: {ch.max():.2f})')
        plt.colorbar(im, ax=ax, fraction=0.046)
    
    # Row 3, Col 5: Distance label info
    ax_dist = fig.add_subplot(gs[2, 4])
    ax_dist.axis('off')
    if distance_label is not None:
        dist_info = f"""Lead Vehicle:
        ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        Distance: {distance_label:.2f}m
        
        Range: [1m, 50m]
        Status: {'‚úì Valid' if 1.0 <= distance_label <= 50.0 else '‚úó Invalid'}
        """
    else:
        dist_info = """Lead Vehicle:
        ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        Distance: N/A
        """
    ax_dist.text(0.05, 0.95, dist_info, transform=ax_dist.transAxes,
                fontfamily='monospace', fontsize=10, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))
    
    plt.suptitle('Multi-Sensor BEV Fusion Visualization', fontsize=16, fontweight='bold', y=0.995)
    return fig

def verify_and_visualize_bev(dataset, name="dataset", max_samples=3):
    """Verify BEV data integrity and visualize samples with comprehensive analysis"""
    print(f"\nüé® BEV VERIFICATION: {name} ({len(dataset)} samples)")
    
    if len(dataset) == 0:
        print("‚ùå Dataset is empty!")
        return False
    
    lidar_ok, radar_ok, fusion_ok = 0, 0, 0
    distances = []
    all_lidar_voxels = []
    all_radar_voxels = []
    
    for i in range(min(max_samples, len(dataset))):
        try:
            lidar_bev, cam_bev, radar_bev, distance = dataset[i]
            
            # Convert to numpy if needed
            if hasattr(lidar_bev, 'numpy'):
                lidar_bev = lidar_bev.numpy()
            if hasattr(cam_bev, 'numpy'):
                cam_bev = cam_bev.numpy()
            if hasattr(radar_bev, 'numpy'):
                radar_bev = radar_bev.numpy()
            if hasattr(distance, 'item'):
                distance = distance.item()
            
            # Count non-zero voxels
            lidar_vox = (lidar_bev > 0).sum()
            radar_vox = (radar_bev > 0).sum()
            distances.append(distance)
            all_lidar_voxels.append(lidar_vox)
            all_radar_voxels.append(radar_vox)
            
            if lidar_vox > 0:
                lidar_ok += 1
            if radar_vox > 0:
                radar_ok += 1
            if (lidar_vox > 0) and (radar_vox > 0):
                fusion_ok += 1
            
            # Create detailed fusion visualization
            print(f"\n  Sample {i+1}/{min(max_samples, len(dataset))}:")
            print(f"    LIDAR:    {lidar_vox} voxels | Height: [{lidar_bev[1].min():.2f}, {lidar_bev[1].max():.2f}]m")
            print(f"    RADAR:    {radar_vox} voxels | Doppler: [{radar_bev[2].min():.2f}, {radar_bev[2].max():.2f}]m/s")
            print(f"    Distance: {distance:.2f}m ‚úì")
            
            # Show fusion visualization
            fig = visualize_bev_fusion(lidar_bev, cam_bev, radar_bev, distance)
            plt.show()
            
        except Exception as e:
            print(f"  ‚ùå Error loading sample {i}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # Summary statistics
    print(f"\nüìä {name.upper()} VERIFICATION SUMMARY:")
    print(f"  ‚úÖ Samples processed: {min(max_samples, len(dataset))}")
    print(f"  ‚úÖ LIDAR valid:      {lidar_ok}/{min(max_samples, len(dataset))}")
    print(f"  ‚úÖ RADAR valid:      {radar_ok}/{min(max_samples, len(dataset))}")
    print(f"  ‚úÖ Fusion ready:     {fusion_ok}/{min(max_samples, len(dataset))}")
    
    if len(distances) > 0:
        print(f"\n  üìè Distance Statistics:")
        print(f"    Range:    {min(distances):.2f}m - {max(distances):.2f}m")
        print(f"    Mean:     {np.mean(distances):.2f}m")
        print(f"    Median:   {np.median(distances):.2f}m")
        print(f"    Std:      {np.std(distances):.2f}m")
    
    if len(all_lidar_voxels) > 0:
        print(f"\n  üü° LIDAR Point Density:")
        print(f"    Mean voxels: {np.mean(all_lidar_voxels):.0f}")
        print(f"    Min voxels:  {np.min(all_lidar_voxels):.0f}")
        print(f"    Max voxels:  {np.max(all_lidar_voxels):.0f}")
    
    if len(all_radar_voxels) > 0:
        print(f"\n  üî¥ RADAR Detection Density:")
        print(f"    Mean voxels: {np.mean(all_radar_voxels):.0f}")
        print(f"    Min voxels:  {np.min(all_radar_voxels):.0f}")
        print(f"    Max voxels:  {np.max(all_radar_voxels):.0f}")
    
    print(f"\n{'='*60}")
    
    return lidar_ok > 0

print("‚úÖ ENHANCED BEV VISUALIZATION FUNCTIONS READY")

In [None]:
# ========== HYPERPARAMETERS ==========
LR = 1e-4
EPOCHS = 10
BATCH_SIZE = 4
DBEV = 64  # Feature dimension in model

print(f"Hyperparameters: LR={LR}, EPOCHS={EPOCHS}, BATCH_SIZE={BATCH_SIZE}, DBEV={DBEV}")

# ========== CREATE DATASETS + DATALOADERS ==========
# Build 70/15/15 train/val/test split
import random
all_tokens = [s['token'] for s in samples]

# IMPORTANT: Shuffle to distribute LIDAR samples across splits
# (local dataset may only have partial data)
random.seed(42)  # For reproducibility
random.shuffle(all_tokens)

ntotal = len(all_tokens)
ntrain = int(0.7 * ntotal)
nval = int(0.15 * ntotal)

train_tokens = all_tokens[:ntrain]
val_tokens = all_tokens[ntrain:ntrain+nval]
test_tokens = all_tokens[ntrain+nval:]

print(f"\nüìä Split: {len(train_tokens)} train / {len(val_tokens)} val / {len(test_tokens)} test")

# Create datasets
train_ds = NuScenesBEVDataset(train_tokens)
val_ds = NuScenesBEVDataset(val_tokens)
test_ds = NuScenesBEVDataset(test_tokens)

# Create dataloaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Test one batch
print("\nüîç Testing batch loading...")
batch = next(iter(train_loader))
print(f"  Batch size: {len(batch)} tensors")
print(f"  LIDAR BEV shape: {batch[0].shape}")
print(f"  CAMERA BEV shape: {batch[1].shape}")
print(f"  RADAR BEV shape: {batch[2].shape}")
print(f"  Distance shape: {batch[3].shape}")
print("‚úÖ Dataloaders ready!")

In [None]:
# ========== LOADING VALIDATION ==========
# Comprehensive error checking for all loading functions

print("üîç VALIDATING LOADING FUNCTIONS...")
print("="*70)

def validate_loading_pipeline():
    """Test all loading functions on first few samples"""
    
    errors = []
    warnings = []
    
    # Test 1: abs_sensor_path for different sensor types
    print("\n[1/4] Testing sensor path resolution...")
    test_sample_idx = 0
    test_token = train_tokens[test_sample_idx] if len(train_tokens) > 0 else None
    
    if test_token and test_token in sample_to_sensor:
        sensors = sample_to_sensor[test_token]
        for sensor_key in ['LIDAR_TOP', 'RADAR_FRONT']:
            if sensor_key in sensors:
                try:
                    sd_rec = sd_by_token[sensors[sensor_key]]
                    path = abs_sensor_path(sd_rec)
                    if os.path.exists(path):
                        print(f"  ‚úÖ {sensor_key}: {path[:50]}...")
                    else:
                        warnings.append(f"{sensor_key} path doesn't exist: {path}")
                        print(f"  ‚ö†Ô∏è  {sensor_key}: File not found")
                except Exception as e:
                    errors.append(f"Error resolving {sensor_key}: {str(e)}")
                    print(f"  ‚ùå {sensor_key}: {e}")
    
    # Test 2: LIDAR loading
    print("\n[2/4] Testing LIDAR loading...")
    if test_token and 'LIDAR_TOP' in sample_to_sensor.get(test_token, {}):
        try:
            sd_rec = sd_by_token[sample_to_sensor[test_token]['LIDAR_TOP']]
            pts = load_lidar_points(sd_rec)
            print(f"  ‚úÖ LIDAR: {pts.shape[0]} points, shape: {pts.shape}")
            if pts.shape[0] == 0:
                warnings.append("LIDAR returned empty point cloud")
        except Exception as e:
            errors.append(f"LIDAR loading failed: {str(e)}")
            print(f"  ‚ùå LIDAR loading: {e}")
    
    # Test 3: RADAR loading (multiple sensors)
    print("\n[3/4] Testing RADAR loading...")
    if test_token:
        for radar_key in ['RADAR_FRONT', 'RADAR_FRONT_LEFT', 'RADAR_FRONT_RIGHT']:
            if radar_key in sample_to_sensor.get(test_token, {}):
                try:
                    sd_rec = sd_by_token[sample_to_sensor[test_token][radar_key]]
                    pts = load_radar_points(sd_rec)
                    print(f"  ‚úÖ {radar_key}: {pts.shape[0]} points, shape: {pts.shape}")
                    if pts.shape[0] == 0:
                        warnings.append(f"{radar_key} returned empty point cloud")
                except Exception as e:
                    errors.append(f"{radar_key} loading failed: {str(e)}")
                    print(f"  ‚ùå {radar_key}: {e}")
    
    # Test 4: BEV building and fusion
    print("\n[4/4] Testing BEV building and fusion...")
    if test_token:
        try:
            # Load all sensors
            lidar_bev = np.zeros((BEV_CHANNELS_LIDAR, NX, NY), dtype=np.float32)
            radar_bev = np.zeros((BEV_CHANNELS_RADAR, NX, NY), dtype=np.float32)
            camera_bev = np.zeros((BEV_CHANNELS_CAM, NX, NY), dtype=np.float32)
            
            sensors = sample_to_sensor.get(test_token, {})
            
            if 'LIDAR_TOP' in sensors:
                pts_l = load_lidar_points(sd_by_token[sensors['LIDAR_TOP']])
                lidar_bev = build_lidar_bev_from_points(pts_l)
            
            radar_pts_list = []
            for radar_key in ['RADAR_FRONT', 'RADAR_FRONT_LEFT', 'RADAR_FRONT_RIGHT']:
                if radar_key in sensors:
                    pts_r = load_radar_points(sd_by_token[sensors[radar_key]])
                    if pts_r.shape[0] > 0:
                        radar_pts_list.append(pts_r)
            
            if radar_pts_list:
                radar_pts = np.concatenate(radar_pts_list, axis=0)
                radar_bev = build_radar_bev_from_points(radar_pts)
            
            # Test fusion (concatenation)
            fused = np.concatenate([lidar_bev, camera_bev, radar_bev], axis=0)
            print(f"  ‚úÖ LIDAR BEV:  {lidar_bev.shape}")
            print(f"  ‚úÖ CAMERA BEV: {camera_bev.shape}")
            print(f"  ‚úÖ RADAR BEV:  {radar_bev.shape}")
            print(f"  ‚úÖ Fused:      {fused.shape} ({fused.shape[0]} channels)")
            
            # Test with torch tensors
            t_l = torch.from_numpy(lidar_bev).unsqueeze(0)
            t_c = torch.from_numpy(camera_bev).unsqueeze(0)
            t_r = torch.from_numpy(radar_bev).unsqueeze(0)
            t_fused = torch.cat([t_l, t_c, t_r], dim=1)
            print(f"  ‚úÖ PyTorch fused: {t_fused.shape}")
            
        except Exception as e:
            errors.append(f"BEV building/fusion failed: {str(e)}")
            print(f"  ‚ùå BEV fusion: {e}")
            import traceback
            traceback.print_exc()
    
    # Summary
    print("\n" + "="*70)
    if errors:
        print(f"‚ùå ERRORS ({len(errors)}):")
        for err in errors:
            print(f"   ‚Ä¢ {err}")
    else:
        print("‚úÖ NO CRITICAL ERRORS DETECTED")
    
    if warnings:
        print(f"\n‚ö†Ô∏è  WARNINGS ({len(warnings)}):")
        for warn in warnings:
            print(f"   ‚Ä¢ {warn}")
    
    print("="*70)
    return len(errors) == 0

# Run validation
pipeline_ok = validate_loading_pipeline()

if not pipeline_ok:
    print("\n‚ùå Pipeline validation failed. Check errors above before proceeding.")
else:
    print("\n‚úÖ Pipeline validation PASSED. Safe to proceed with training.")


In [None]:
# ========== OPTIONAL: VERIFY PIPELINE BEFORE TRAINING ==========
# Run this cell to verify everything works correctly

print("üîç VERIFYING PIPELINE...")
print("=" * 60)

# Verify train dataset
print("\nüìã TRAIN Dataset:")
verify_and_visualize_bev(train_ds, "TRAIN", max_samples=3)

# Verify val dataset
if len(val_ds) > 0:
    print("\nüìã VAL Dataset:")
    verify_and_visualize_bev(val_ds, "VAL", max_samples=2)

# Verify test dataset
if len(test_ds) > 0:
    print("\nüìã TEST Dataset:")
    verify_and_visualize_bev(test_ds, "TEST", max_samples=2)

print("\n" + "=" * 60)
print("‚úÖ PIPELINE VERIFICATION COMPLETE - Ready to train!")
print("=" * 60)


In [None]:
# ========== BEV TRANSFORMATION VISUALIZATION ==========
# Visualize how raw sensor data transforms into Bird's Eye View

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def visualize_bev_transformation(sample_idx=0):
    """
    Comprehensive visualization showing:
    1. Raw sensor data (camera image, LIDAR point cloud)
    2. Intermediate transformation steps
    3. Final BEV representation
    """
    
    # Get a sample
    sample_token = train_tokens[sample_idx]
    sample_rec = sample_by_token[sample_token]
    sensors = sample_to_sensor[sample_token]
    
    # Check what sensors are available
    has_lidar = 'LIDAR_TOP' in sensors
    has_cam_front = 'CAM_FRONT' in sensors
    has_radar = 'RADAR_FRONT' in sensors
    
    fig = plt.figure(figsize=(20, 12))
    gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)
    
    # ========== ROW 1: RAW SENSOR DATA ==========
    
    # 1.1 Camera Image (if available)
    ax_cam = fig.add_subplot(gs[0, :2])
    if has_cam_front:
        try:
            cam_token = sensors['CAM_FRONT']
            cam_rec = sd_by_token[cam_token]
            cam_path = os.path.join(ROOT, cam_rec['filename'])
            
            if os.path.exists(cam_path):
                img = Image.open(cam_path)
                ax_cam.imshow(img)
                ax_cam.set_title(f'CAM_FRONT Raw Image\n{img.size[0]}√ó{img.size[1]} pixels', 
                               fontsize=12, fontweight='bold')
                ax_cam.axis('off')
            else:
                ax_cam.text(0.5, 0.5, f'Image not found:\n{cam_path}', 
                          ha='center', va='center', fontsize=10)
                ax_cam.set_title('CAM_FRONT (not found)')
                ax_cam.axis('off')
        except Exception as e:
            ax_cam.text(0.5, 0.5, f'Error loading camera:\n{str(e)}', 
                      ha='center', va='center', fontsize=10)
            ax_cam.set_title('CAM_FRONT (error)')
            ax_cam.axis('off')
    else:
        ax_cam.text(0.5, 0.5, 'CAM_FRONT not available', ha='center', va='center')
        ax_cam.set_title('CAM_FRONT')
        ax_cam.axis('off')
    
    # 1.2 LIDAR Point Cloud (top-down view)
    ax_lidar_raw = fig.add_subplot(gs[0, 2:])
    if has_lidar:
        lidar_token = sensors['LIDAR_TOP']
        lidar_rec = sd_by_token[lidar_token]
        pts = load_lidar_points(lidar_rec)
        
        if pts.shape[0] > 0:
            # Plot points in ego frame (x forward, y left)
            x, y, z, intensity = pts[:, 0], pts[:, 1], pts[:, 2], pts[:, 3]
            
            # Filter to BEV range
            mask = (x >= XRANGE[0]) & (x < XRANGE[1]) & (y >= YRANGE[0]) & (y < YRANGE[1])
            x_filt, y_filt, z_filt, i_filt = x[mask], y[mask], z[mask], intensity[mask]
            
            scatter = ax_lidar_raw.scatter(y_filt, x_filt, c=z_filt, s=0.5, 
                                          cmap='jet', vmin=-3, vmax=3)
            ax_lidar_raw.set_xlim(YRANGE)
            ax_lidar_raw.set_ylim(XRANGE)
            ax_lidar_raw.set_xlabel('Y (m) - Left/Right')
            ax_lidar_raw.set_ylabel('X (m) - Forward')
            ax_lidar_raw.set_title(f'LIDAR Point Cloud (Top View)\n{pts.shape[0]} points, colored by height', 
                                  fontsize=12, fontweight='bold')
            ax_lidar_raw.grid(True, alpha=0.3)
            plt.colorbar(scatter, ax=ax_lidar_raw, label='Height (m)')
            
            # Add ego vehicle marker
            ego = patches.Rectangle((-1, -1.5), 2, 3, linewidth=2, 
                                   edgecolor='red', facecolor='none', label='Ego Vehicle')
            ax_lidar_raw.add_patch(ego)
            ax_lidar_raw.legend(loc='upper right')
        else:
            ax_lidar_raw.text(0.5, 0.5, 'No LIDAR points', ha='center', va='center')
    else:
        ax_lidar_raw.text(0.5, 0.5, 'LIDAR_TOP not available', ha='center', va='center')
    ax_lidar_raw.set_aspect('equal')
    
    # ========== ROW 2: TRANSFORMATION STEPS ==========
    
    # 2.1 Camera Calibration Info
    ax_calib = fig.add_subplot(gs[1, 0])
    ax_calib.axis('off')
    if has_cam_front:
        try:
            cam_token = sensors['CAM_FRONT']
            cam_rec = sd_by_token[cam_token]
            cs_rec = cs_by_token[cam_rec['calibrated_sensor_token']]
            
            # Extract calibration data
            translation = cs_rec.get('translation', [0, 0, 0])
            rotation = cs_rec.get('rotation', [1, 0, 0, 0])
            camera_intrinsic = cs_rec.get('camera_intrinsic', [[0, 0, 0], [0, 0, 0], [0, 0, 0]])
            
            calib_text = f"""Camera Calibration
‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ
Translation (m):
  X: {translation[0]:.2f}
  Y: {translation[1]:.2f}
  Z: {translation[2]:.2f}

Rotation (quaternion):
  [{rotation[0]:.3f}, {rotation[1]:.3f},
   {rotation[2]:.3f}, {rotation[3]:.3f}]

Intrinsics:
  fx: {camera_intrinsic[0][0]:.1f}
  fy: {camera_intrinsic[1][1]:.1f}
  cx: {camera_intrinsic[0][2]:.1f}
  cy: {camera_intrinsic[1][2]:.1f}
"""
            ax_calib.text(0.05, 0.95, calib_text, transform=ax_calib.transAxes,
                        fontfamily='monospace', fontsize=9, verticalalignment='top',
                        bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
        except Exception as e:
            ax_calib.text(0.5, 0.5, f'Calibration error:\n{str(e)[:50]}', 
                        ha='center', va='center', fontsize=9)
    else:
        ax_calib.text(0.5, 0.5, 'No camera calibration', ha='center', va='center')
    
    # 2.2 Grid Discretization Info
    ax_grid = fig.add_subplot(gs[1, 1])
    ax_grid.axis('off')
    grid_text = f"""BEV Grid Configuration
‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ
Resolution: {RES}m per pixel

X Range: {XRANGE[0]:.0f}m to {XRANGE[1]:.0f}m
Y Range: {YRANGE[0]:.0f}m to {YRANGE[1]:.0f}m

Grid Size: {NX} √ó {NY}
Total Voxels: {NX * NY:,}

Channels:
  LIDAR: {BEV_CHANNELS_LIDAR} (count, avg_h, max_h, intensity)
  RADAR: {BEV_CHANNELS_RADAR} (count, z, doppler, rcs)
  CAMERA: {BEV_CHANNELS_CAM} (R,G,B,mean,max,std,occ,valid)

Total BEV Channels: {BEV_CHANNELS_LIDAR + BEV_CHANNELS_RADAR + BEV_CHANNELS_CAM}
"""
    ax_grid.text(0.05, 0.95, grid_text, transform=ax_grid.transAxes,
                fontfamily='monospace', fontsize=9, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.7))
    
    # 2.3 Transformation Diagram
    ax_transform = fig.add_subplot(gs[1, 2:])
    ax_transform.axis('off')
    transform_text = """
    3D SENSOR DATA ‚Üí BIRD'S EYE VIEW TRANSFORMATION
    ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
    
    STEP 1: Load Sensor Data
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê   ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê   ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  LIDAR Points   ‚îÇ   ‚îÇ  Camera Image   ‚îÇ   ‚îÇ  RADAR Returns  ‚îÇ
    ‚îÇ  (x, y, z, i)   ‚îÇ   ‚îÇ  (u, v, RGB)    ‚îÇ   ‚îÇ  (x, y, v, rcs) ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò   ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò   ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
             ‚îÇ                     ‚îÇ                     ‚îÇ
    
    STEP 2: Transform to Ego Frame
             ‚îÇ                     ‚îÇ                     ‚îÇ
             ‚îÇ        Apply Camera ‚îÇ        Already in   ‚îÇ
             ‚îÇ        Calibration  ‚îÇ        Ego Frame    ‚îÇ
             ‚ñº                     ‚ñº                     ‚ñº
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ           All Sensors in Ego Vehicle Frame                  ‚îÇ
    ‚îÇ         (Origin at ego center, X forward, Y left)           ‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
                                ‚îÇ
    
    STEP 3: Project to BEV Grid
                                ‚îÇ
                    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
                    ‚îÇ  Discretize into Grid ‚îÇ
                    ‚îÇ  (NX √ó NY voxels)     ‚îÇ
                    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¨‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
                                ‚îÇ
    
    STEP 4: Aggregate Features per Voxel
                                ‚îÇ
            ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îº‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
            ‚ñº                   ‚ñº                   ‚ñº
    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê   ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê   ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
    ‚îÇ  LIDAR BEV   ‚îÇ   ‚îÇ  CAMERA BEV  ‚îÇ   ‚îÇ  RADAR BEV   ‚îÇ
    ‚îÇ  (4, 400, 200)‚îÇ   ‚îÇ (8, 400, 200)‚îÇ   ‚îÇ (4, 400, 200)‚îÇ
    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò   ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò   ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
            ‚îÇ                   ‚îÇ                   ‚îÇ
            ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î¥‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
                                ‚îÇ
                                ‚ñº
                    ‚îå‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îê
                    ‚îÇ   Fused Multi-Modal   ‚îÇ
                    ‚îÇ    BEV Representation ‚îÇ
                    ‚îÇ    (16, 400, 200)     ‚îÇ
                    ‚îî‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îò
    """
    ax_transform.text(0.05, 0.95, transform_text, transform=ax_transform.transAxes,
                     fontfamily='monospace', fontsize=8, verticalalignment='top',
                     bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # ========== ROW 3: FINAL BEV REPRESENTATIONS ==========
    
    # 3.1 LIDAR BEV
    if has_lidar:
        pts = load_lidar_points(lidar_rec)
        lidar_bev = build_lidar_bev_from_points(pts)
        
        ax_lidar_bev = fig.add_subplot(gs[2, 0])
        im = ax_lidar_bev.imshow(lidar_bev[0], cmap='hot', aspect='auto')
        ax_lidar_bev.set_title(f'LIDAR BEV (Count)\n{(lidar_bev[0] > 0).sum()} voxels', 
                              fontsize=10, fontweight='bold')
        ax_lidar_bev.set_xlabel('Y (grid)')
        ax_lidar_bev.set_ylabel('X (grid)')
        plt.colorbar(im, ax=ax_lidar_bev, fraction=0.046)
    
    # 3.2 RADAR BEV
    if has_radar:
        radar_pts = []
        for radar_key in ['RADAR_FRONT', 'RADAR_FRONT_LEFT', 'RADAR_FRONT_RIGHT']:
            if radar_key in sensors:
                pts_r = load_radar_points(sd_by_token[sensors[radar_key]])
                if pts_r.shape[0] > 0:
                    radar_pts.append(pts_r)
        
        if radar_pts:
            radar_pts = np.concatenate(radar_pts, axis=0)
            radar_bev = build_radar_bev_from_points(radar_pts)
            
            ax_radar_bev = fig.add_subplot(gs[2, 1])
            im = ax_radar_bev.imshow(radar_bev[0], cmap='plasma', aspect='auto')
            ax_radar_bev.set_title(f'RADAR BEV (Count)\n{(radar_bev[0] > 0).sum()} voxels', 
                                  fontsize=10, fontweight='bold')
            ax_radar_bev.set_xlabel('Y (grid)')
            ax_radar_bev.set_ylabel('X (grid)')
            plt.colorbar(im, ax=ax_radar_bev, fraction=0.046)
    
    # 3.3 Camera BEV (placeholder - shows it's empty)
    cam_bev = build_camera_bev_placeholder()
    ax_cam_bev = fig.add_subplot(gs[2, 2])
    im = ax_cam_bev.imshow(cam_bev[0], cmap='viridis', aspect='auto')
    ax_cam_bev.set_title('CAMERA BEV (Placeholder)\nNot implemented', 
                        fontsize=10, fontweight='bold', color='orange')
    ax_cam_bev.set_xlabel('Y (grid)')
    ax_cam_bev.set_ylabel('X (grid)')
    plt.colorbar(im, ax=ax_cam_bev, fraction=0.046)
    
    # 3.4 Implementation Status
    ax_status = fig.add_subplot(gs[2, 3])
    ax_status.axis('off')
    status_text = f"""BEV Status
‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ
‚úÖ LIDAR BEV
   Fully implemented
   
‚ö†Ô∏è  RADAR BEV
   Working (sparse)
   
‚ùå CAMERA BEV
   Placeholder only
   
To Implement:
1. Load camera image
2. Get depth estimation
3. Project pixels to 3D
4. Map to BEV grid
5. Aggregate RGB values
"""
    ax_status.text(0.05, 0.95, status_text, transform=ax_status.transAxes,
                  fontfamily='monospace', fontsize=9, verticalalignment='top',
                  bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))
    
    plt.suptitle(f'BEV Transformation Pipeline Visualization - Sample {sample_idx}', 
                fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    return fig

print("‚úÖ BEV transformation visualization function ready!")
print("Run: visualize_bev_transformation(sample_idx=0) to see the full pipeline")

fig = visualize_bev_transformation(sample_idx=0)
plt.show()

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class UTENet4BranchBEV(nn.Module):
    """
    Uncertainty-aware 4-Branch BEV Network with Ensemble Prediction:
    - LIDAR branch
    - CAMERA branch  
    - RADAR branch
    - Fused (multi-modal) branch
    
    ENSEMBLE prediction: weighted average of all 4 branches.
    Weights learned via uncertainty (log-variance) and can be modulated by weather conditions.
    
    Design enables: weather-adaptive weighting at inference time.
    """
    
    def __init__(self, c_lidar=BEV_CHANNELS_LIDAR, c_cam=BEV_CHANNELS_CAM, 
                 c_radar=BEV_CHANNELS_RADAR, d_bev=64, dropout=0.2):
        super().__init__()
        
        def make_encoder(cin):
            """Conv encoder: input channels -> d_bev features"""
            return nn.Sequential(
                nn.Conv2d(cin, d_bev, 3, padding=1),
                nn.BatchNorm2d(d_bev),
                nn.ReLU(inplace=True),
                nn.Dropout2d(dropout),
                nn.Conv2d(d_bev, d_bev, 3, padding=1),
                nn.BatchNorm2d(d_bev),
                nn.ReLU(inplace=True),
                nn.Dropout2d(dropout)
            )
        
        # Encoders for each modality
        self.enc_lidar = make_encoder(c_lidar)
        self.enc_cam = make_encoder(c_cam)
        self.enc_radar = make_encoder(c_radar)
        
        # Fusion encoder (concatenates all 3 branches)
        self.fuse_bev = nn.Sequential(
            nn.Conv2d(d_bev*3, d_bev*2, kernel_size=3, padding=1),
            nn.BatchNorm2d(d_bev*2),
            nn.ReLU(inplace=True),
            nn.Conv2d(d_bev*2, d_bev, kernel_size=3, padding=1),
            nn.BatchNorm2d(d_bev),
            nn.ReLU(inplace=True)
        )
        
        # Global average pooling
        self.pool = nn.AdaptiveAvgPool2d(1)
        
        # Prediction heads (one per branch)
        self.head_lidar = nn.Linear(d_bev, 1)
        self.head_cam = nn.Linear(d_bev, 1)
        self.head_radar = nn.Linear(d_bev, 1)
        self.head_fused = nn.Linear(d_bev, 1)
        
        # Learned uncertainty (log-variance) for each branch
        # These determine the inverse-variance weights during training
        # Can be modulated at inference time based on weather conditions
        self.logvar_lidar = nn.Parameter(torch.zeros(()))
        self.logvar_cam = nn.Parameter(torch.zeros(()))
        self.logvar_radar = nn.Parameter(torch.zeros(()))
        self.logvar_fused = nn.Parameter(torch.zeros(()))
    
    def forward(self, bev_lidar, bev_cam, bev_radar):
        """
        Args:
            bev_lidar: (B, 4, NX, NY)
            bev_cam: (B, 8, NX, NY)
            bev_radar: (B, 4, NX, NY)
        
        Returns:
            pred_ensemble: (B,) - weighted ensemble of all 4 branches
            (pred_lidar, pred_cam, pred_radar, pred_fused): individual branch predictions
            (w_lidar, w_cam, w_radar, w_fused): normalized weights (detached for inference)
        """
        # Encode each modality
        l = self.enc_lidar(bev_lidar)    # (B, d_bev, NX, NY)
        c = self.enc_cam(bev_cam)        # (B, d_bev, NX, NY)
        r = self.enc_radar(bev_radar)    # (B, d_bev, NX, NY)
        
        # Fuse all modalities
        fused = self.fuse_bev(torch.cat([l, c, r], dim=1))  # (B, d_bev, NX, NY)
        
        # Pool to vectors
        l_vec = self.pool(l).flatten(1)      # (B, d_bev)
        c_vec = self.pool(c).flatten(1)
        r_vec = self.pool(r).flatten(1)
        f_vec = self.pool(fused).flatten(1)
        
        # Predict distance from each branch
        dl = self.head_lidar(l_vec).squeeze(-1)   # (B,)
        dc = self.head_cam(c_vec).squeeze(-1)
        dr = self.head_radar(r_vec).squeeze(-1)
        df = self.head_fused(f_vec).squeeze(-1)
        
        # Compute uncertainty-weighted ensemble
        sig_l = torch.exp(self.logvar_lidar)
        sig_c = torch.exp(self.logvar_cam)
        sig_r = torch.exp(self.logvar_radar)
        sig_f = torch.exp(self.logvar_fused)
        
        # Inverse variance weighting (learned via uncertainty)
        wl = 1.0 / (sig_l + 1e-8)
        wc = 1.0 / (sig_c + 1e-8)
        wr = 1.0 / (sig_r + 1e-8)
        wf = 1.0 / (sig_f + 1e-8)
        
        # Normalize weights to sum to 1
        w_sum = wl + wc + wr + wf
        wl_norm = wl / w_sum
        wc_norm = wc / w_sum
        wr_norm = wr / w_sum
        wf_norm = wf / w_sum
        
        # Ensemble prediction: weighted average of all 4 branches
        pred_ensemble = wl_norm * dl + wc_norm * dc + wr_norm * dr + wf_norm * df
        
        # Return ensemble + individual predictions + normalized weights (for inference/analysis)
        return (pred_ensemble, dl, dc, dr, df,
                wl_norm.detach(), wc_norm.detach(), wr_norm.detach(), wf_norm.detach())

# Create model
model = UTENet4BranchBEV(
    c_lidar=BEV_CHANNELS_LIDAR, 
    c_cam=BEV_CHANNELS_CAM, 
    c_radar=BEV_CHANNELS_RADAR, 
    d_bev=DBEV
).to(DEVICE)

param_count = sum(p.numel() for p in model.parameters()) / 1e6
print(f'‚úÖ Model created: {param_count:.2f}M parameters')
print(f'   Device: {DEVICE}')
print(f'   Primary prediction: Ensemble (weighted average of 4 branches)')
print(f'   Weights learned via uncertainty, modifiable for weather adaptation')


In [None]:
# ========== HYPERPARAMETERS ==========
LR = 1e-4
WEIGHT_DECAY = 1e-4
EPOCHS = 10
BATCH_SIZE = 4
DBEV = 64

print(f"Hyperparameters: LR={LR}, EPOCHS={EPOCHS}, BATCH_SIZE={BATCH_SIZE}, DBEV={DBEV}, WEIGHT_DECAY={WEIGHT_DECAY}")


In [None]:
# ========== WEATHER-AWARE WEIGHT MODULATION ==========
"""
Extract weather conditions from scene descriptions and compute per-sample weather modulation factors.
This allows the model to adaptively weight branches (LIDAR, CAMERA, RADAR, FUSED) based on 
inferred driving conditions during training and inference.

Example: In rainy scenes, LIDAR and CAMERA are less reliable ‚Üí reduce their weights, boost RADAR.
"""

import re

def extract_weather_from_description(description):
    """
    Parse scene description for weather keywords.
    
    Returns:
        dict with keys: rain, night, fog, snow
        values: binary (0 or 1) indicating presence
    """
    desc_lower = description.lower()
    
    weather = {
        'rain': int(bool(re.search(r'\brain|raining|wet|puddle\b', desc_lower))),
        'night': int(bool(re.search(r'\bnight|dark|evening\b', desc_lower))),
        'fog': int(bool(re.search(r'\bfog|foggy|mist\b', desc_lower))),
        'snow': int(bool(re.search(r'\bsnow|snowing\b', desc_lower))),
    }
    
    return weather


def compute_branch_weights_for_weather(weather_dict):
    """
    Compute sensor reliability modulation based on weather.
    
    Args:
        weather_dict: dict from extract_weather_from_description
    
    Returns:
        dict with keys 'lidar', 'camera', 'radar', 'fused'
        values: weight multipliers (0.0 to 2.0)
    
    Logic:
        - Rain: LIDAR/CAMERA unreliable ‚Üí 0.7x; RADAR reliable ‚Üí 1.3x
        - Night: CAMERA unreliable ‚Üí 0.7x; LIDAR/RADAR ‚Üí 1.15x
        - Fog: LIDAR slightly unreliable ‚Üí 0.8x; RADAR ‚Üí 1.2x
        - Snow: All sensors affected ‚Üí 0.9x baseline
        - Fused branch: always modestly boosted (1.1x) as an ensemble safeguard
    """
    
    w_lidar = 1.0
    w_camera = 1.0
    w_radar = 1.0
    w_fused = 1.1  # Ensemble safeguard
    
    if weather_dict['rain']:
        w_lidar *= 0.75
        w_camera *= 0.75
        w_radar *= 1.25
    
    if weather_dict['night']:
        w_camera *= 0.70
        w_lidar *= 1.10
        w_radar *= 1.10
    
    if weather_dict['fog']:
        w_lidar *= 0.80
        w_radar *= 1.20
    
    if weather_dict['snow']:
        w_lidar *= 0.90
        w_camera *= 0.90
        w_radar *= 0.95
    
    return {
        'lidar': w_lidar,
        'camera': w_camera,
        'radar': w_radar,
        'fused': w_fused
    }


# Build weather modulation map: sample_token -> weather_dict
sample_weather_map = {}
for scene_rec in scene:
    scene_token = scene_rec['token']
    desc = scene_rec.get('description', '')
    weather = extract_weather_from_description(desc)
    
    # Find all samples in this scene
    for sample_rec in samples:
        if sample_rec.get('scene_token') == scene_token:
            sample_token = sample_rec['token']
            sample_weather_map[sample_token] = weather

print(f"‚úÖ Weather map built for {len(sample_weather_map)} samples")

# Show weather distribution
weather_counts = {
    'rain': sum(1 for w in sample_weather_map.values() if w['rain']),
    'night': sum(1 for w in sample_weather_map.values() if w['night']),
    'fog': sum(1 for w in sample_weather_map.values() if w['fog']),
    'snow': sum(1 for w in sample_weather_map.values() if w['snow']),
}

print("\nüìä Weather Distribution in Dataset:")
for cond, count in sorted(weather_counts.items(), key=lambda x: x[1], reverse=True):
    pct = 100 * count / len(sample_weather_map) if sample_weather_map else 0
    print(f"  {cond:15s}: {count:5d} samples ({pct:5.1f}%)")

print("\nüí° Usage: weather_dict = sample_weather_map.get(sample_token, {})")
print("         weights = compute_branch_weights_for_weather(weather_dict)")


In [None]:
# Optimizer and loss (weight decay for stability)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
mse_loss = nn.MSELoss()


def epoch_pass(loader, train=False, epoch_idx=0, phase_name='train'):
    """Single epoch pass (train or eval)"""
    model.train() if train else model.eval()
    total_loss = 0.0
    y_true_all, y_pred_all = [], []

    pbar = tqdm(loader, desc=f'{phase_name} epoch {epoch_idx}', leave=False)
    for step, batch in enumerate(pbar):
        bev_l, bev_c, bev_r, dist = batch
        bev_l = torch.nan_to_num(bev_l.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
        bev_c = torch.nan_to_num(bev_c.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
        bev_r = torch.nan_to_num(bev_r.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
        dist = torch.nan_to_num(dist.to(DEVICE), nan=50.0, posinf=50.0, neginf=0.0)

        if train:
            optimizer.zero_grad()

        # Forward pass
        out = model(bev_l, bev_c, bev_r)
        pred_ensemble = torch.nan_to_num(out[0], nan=50.0, posinf=50.0, neginf=0.0)  # primary
        pred_lidar    = torch.nan_to_num(out[1], nan=50.0, posinf=50.0, neginf=0.0)
        pred_cam      = torch.nan_to_num(out[2], nan=50.0, posinf=50.0, neginf=0.0)
        pred_radar    = torch.nan_to_num(out[3], nan=50.0, posinf=50.0, neginf=0.0)
        pred_fused    = torch.nan_to_num(out[4], nan=50.0, posinf=50.0, neginf=0.0)

        # Loss: main ensemble + auxiliary individual branch losses
        loss_ensemble = mse_loss(pred_ensemble, dist)
        loss_l = mse_loss(pred_lidar, dist)
        loss_c = mse_loss(pred_cam, dist)
        loss_r = mse_loss(pred_radar, dist)
        loss_f = mse_loss(pred_fused, dist)

        # Combined loss: emphasize ensemble, regularize with branch losses
        loss = loss_ensemble + 0.15 * (loss_l + loss_c + loss_r + loss_f)

        if train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()

        total_loss += loss.item() * dist.size(0)
        y_true_all.extend(dist.detach().cpu().numpy())
        y_pred_all.extend(pred_ensemble.detach().cpu().numpy())

        pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    # Compute metrics
    avg_loss = total_loss / max(len(loader.dataset), 1)
    y_true_all = np.nan_to_num(np.array(y_true_all), nan=50.0, posinf=50.0, neginf=0.0)
    y_pred_all = np.nan_to_num(np.array(y_pred_all), nan=50.0, posinf=50.0, neginf=0.0)
    mse = mean_squared_error(y_true_all, y_pred_all)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(y_true_all, y_pred_all)
    r2 = r2_score(y_true_all, y_pred_all)

    return {'loss': avg_loss, 'mse': mse, 'rmse': rmse, 'mae': mae, 'r2': r2}


# Training loop
best_val_mse = float('inf')
STATE_PATH = '/kaggle/working/utenet4bev_state.pth'
FULL_PATH = '/kaggle/working/utenet4bev_full.pth'

print('\n' + '='*60)
print('üöÄ STARTING TRAINING (Ensemble BEV - Weather Adaptive)')
print('='*60 + '\n')

for epoch in range(1, EPOCHS + 1):
    # Train pass
    train_metrics = epoch_pass(train_loader, train=True, epoch_idx=epoch, phase_name='TRAIN')

    # Val pass
    val_metrics = epoch_pass(val_loader, train=False, epoch_idx=epoch, phase_name='VAL')

    # Log summary to console
    print(
        f"Epoch {epoch:2d}/{EPOCHS} | Train RMSE: {train_metrics['rmse']:.4f} | "
        f"Val RMSE: {val_metrics['rmse']:.4f} | R¬≤: {val_metrics['r2']:.4f} | "
        f"Train MAE: {train_metrics['mae']:.4f}"
    )

    # Track learned uncertainties / weights for quick inspection
    sig_l = torch.exp(model.logvar_lidar).item()
    sig_c = torch.exp(model.logvar_cam).item()
    sig_r = torch.exp(model.logvar_radar).item()
    sig_f = torch.exp(model.logvar_fused).item()

    wl = 1.0 / (sig_l + 1e-8)
    wc = 1.0 / (sig_c + 1e-8)
    wr = 1.0 / (sig_r + 1e-8)
    wf = 1.0 / (sig_f + 1e-8)
    w_sum = wl + wc + wr + wf
    wl_norm = wl / w_sum
    wc_norm = wc / w_sum
    wr_norm = wr / w_sum
    wf_norm = wf / w_sum

    print(
        f"    Weights -> lidar: {wl_norm:.2f}, cam: {wc_norm:.2f}, radar: {wr_norm:.2f}, fused: {wf_norm:.2f}"
    )

    # Save best model
    if val_metrics['mse'] < best_val_mse:
        best_val_mse = val_metrics['mse']
        torch.save(model.state_dict(), STATE_PATH)
        torch.save(model, FULL_PATH)
        print(f'  ‚úÖ Saved best model (Val MSE: {best_val_mse:.4f})')

print('\n' + '='*60)
print('‚úÖ TRAINING COMPLETE')
print('='*60)

# Test pass
print('\nüß™ Testing on held-out set...')
test_metrics = epoch_pass(test_loader, train=False, epoch_idx=0, phase_name='TEST')
print(f"Test RMSE: {test_metrics['rmse']:.4f} | Test MAE: {test_metrics['mae']:.4f} | Test R¬≤: {test_metrics['r2']:.4f}")

print('\n' + '='*60)
print('üìä FINAL RESULTS')
print('='*60)
print(f'Best Val MSE: {best_val_mse:.4f}')
print(f"Test RMSE:    {test_metrics['rmse']:.4f}m")
print(f"Test MAE:     {test_metrics['mae']:.4f}m")
print(f"Test R¬≤:      {test_metrics['r2']:.4f}")
print('\nüí° Branch weights can be modulated at inference for weather-adaptive predictions')
print('   Example: In rain, increase RADAR weight, decrease CAMERA weight')


In [None]:
# ========== WEATHER-AWARE TRAINING (OPTIONAL) ==========
"""
Example: Use weather modulation to adaptively weight branch losses during training.

To enable this, modify epoch_pass() to:
1. Look up weather condition for each sample
2. Compute branch weight multipliers
3. Scale individual branch losses by these weights before summing

This encourages the model to learn to trust LIDAR less in rain and CAMERA less at night, 
and to rely more on robust modalities per condition.
"""

def epoch_pass_weather_aware(loader, train=False, epoch_idx=0, phase_name='train', 
                              sample_weather_map=None, use_weather=False):
    """
    Single epoch pass with optional weather-aware loss weighting.
    
    Args:
        loader: DataLoader
        train: bool
        epoch_idx: epoch number
        phase_name: name for progress bar
        sample_weather_map: dict mapping sample_token -> weather_dict
        use_weather: if True, modulate loss weights by weather
    """
    model.train() if train else model.eval()
    total_loss = 0.0
    y_true_all, y_pred_all = [], []
    
    pbar = tqdm(loader, desc=f'{phase_name} epoch {epoch_idx}', leave=False)
    for step, batch in enumerate(pbar):
        bev_l, bev_c, bev_r, dist = batch
        bev_l = torch.nan_to_num(bev_l.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
        bev_c = torch.nan_to_num(bev_c.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
        bev_r = torch.nan_to_num(bev_r.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
        dist = torch.nan_to_num(dist.to(DEVICE), nan=50.0, posinf=50.0, neginf=0.0)
        
        if train:
            optimizer.zero_grad()
        
        # Forward pass
        out = model(bev_l, bev_c, bev_r)
        pred_ensemble = torch.nan_to_num(out[0], nan=50.0, posinf=50.0, neginf=0.0)
        pred_lidar    = torch.nan_to_num(out[1], nan=50.0, posinf=50.0, neginf=0.0)
        pred_cam      = torch.nan_to_num(out[2], nan=50.0, posinf=50.0, neginf=0.0)
        pred_radar    = torch.nan_to_num(out[3], nan=50.0, posinf=50.0, neginf=0.0)
        pred_fused    = torch.nan_to_num(out[4], nan=50.0, posinf=50.0, neginf=0.0)
        
        # Compute branch losses
        loss_ensemble = mse_loss(pred_ensemble, dist)
        loss_l = mse_loss(pred_lidar, dist)
        loss_c = mse_loss(pred_cam, dist)
        loss_r = mse_loss(pred_radar, dist)
        loss_f = mse_loss(pred_fused, dist)
        
        # (Optional) Apply weather modulation
        if use_weather and sample_weather_map:
            # Example: sample tokens would come from the dataset
            # For now, use average weather weights across batch
            weather_list = [sample_weather_map.get(tok, {}) for tok in loader.dataset.sample_tokens[step*BATCH_SIZE:(step+1)*BATCH_SIZE]]
            
            if weather_list:
                avg_weights = {
                    'lidar': np.mean([compute_branch_weights_for_weather(w)['lidar'] for w in weather_list]),
                    'camera': np.mean([compute_branch_weights_for_weather(w)['camera'] for w in weather_list]),
                    'radar': np.mean([compute_branch_weights_for_weather(w)['radar'] for w in weather_list]),
                    'fused': np.mean([compute_branch_weights_for_weather(w)['fused'] for w in weather_list]),
                }
                
                loss_l = loss_l * avg_weights['lidar']
                loss_c = loss_c * avg_weights['camera']
                loss_r = loss_r * avg_weights['radar']
                loss_f = loss_f * avg_weights['fused']
        
        # Combined loss
        loss = loss_ensemble + 0.15 * (loss_l + loss_c + loss_r + loss_f)
        
        if train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()
        
        total_loss += loss.item() * dist.size(0)
        y_true_all.extend(dist.detach().cpu().numpy())
        y_pred_all.extend(pred_ensemble.detach().cpu().numpy())
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Metrics
    avg_loss = total_loss / max(len(loader.dataset), 1)
    y_true_all = np.nan_to_num(np.array(y_true_all), nan=50.0, posinf=50.0, neginf=0.0)
    y_pred_all = np.nan_to_num(np.array(y_pred_all), nan=50.0, posinf=50.0, neginf=0.0)
    mse = mean_squared_error(y_true_all, y_pred_all)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(y_true_all, y_pred_all)
    r2 = r2_score(y_true_all, y_pred_all)
    
    return {'loss': avg_loss, 'mse': mse, 'rmse': rmse, 'mae': mae, 'r2': r2}


print("‚úÖ WEATHER-AWARE TRAINING FUNCTION READY")
print("\nüí° To use: Call epoch_pass_weather_aware(..., use_weather=True, sample_weather_map=sample_weather_map)")


In [None]:
# # ========== OPTIONAL: WEATHER-ADAPTIVE INFERENCE ==========
# # Demonstrates how to use weather conditions to modulate branch weights

# def predict_with_weather_adaptation(model, bev_lidar, bev_cam, bev_radar, weather_condition='clear'):
#     """
#     Predict distance with weather-adaptive ensemble weights.
    
#     Args:
#         model: trained UTENet4BranchBEV
#         bev_lidar, bev_cam, bev_radar: BEV tensors (B, C, H, W)
#         weather_condition: 'clear', 'rain', 'fog', 'night', 'snow'
    
#     Returns:
#         pred: (B,) ensemble predictions
#         weights_adapted: (B, 4) adapted weights for each branch
#     """
#     model.eval()
#     with torch.no_grad():
#         # Forward pass to get raw predictions and learned weights
#         out = model(bev_lidar.to(DEVICE), bev_cam.to(DEVICE), bev_radar.to(DEVICE))
#         pred_ensemble, pred_l, pred_c, pred_r, pred_f = out[:5]
#         w_l, w_c, w_r, w_f = out[5:]  # learned weights from uncertainty
        
#         # Weather-adaptive weight modulation
#         if weather_condition == 'clear':
#             # Clear weather: all sensors reliable
#             weather_weights = torch.tensor([0.25, 0.25, 0.25, 0.25], device=DEVICE)
            
#         elif weather_condition == 'rain':
#             # Rain: LIDAR attenuated, RADAR peaks through
#             weather_weights = torch.tensor([0.15, 0.10, 0.45, 0.30], device=DEVICE)
            
#         elif weather_condition == 'fog':
#             # Fog: LIDAR poor, RADAR moderate, Fused best
#             weather_weights = torch.tensor([0.10, 0.20, 0.30, 0.40], device=DEVICE)
            
#         elif weather_condition == 'night':
#             # Night: CAMERA poor, LIDAR + RADAR good
#             weather_weights = torch.tensor([0.35, 0.10, 0.35, 0.20], device=DEVICE)
            
#         elif weather_condition == 'snow':
#             # Snow: similar to rain but RADAR slightly better
#             weather_weights = torch.tensor([0.10, 0.05, 0.50, 0.35], device=DEVICE)
        
#         else:
#             weather_weights = torch.tensor([0.25, 0.25, 0.25, 0.25], device=DEVICE)
        
#         # Blend learned weights with weather-adaptive weights
#         # (can adjust blend ratio based on confidence)
#         alpha = 0.7  # Weight for learned uncertainty vs weather
#         combined_weights = alpha * torch.stack([w_l, w_c, w_r, w_f], dim=1) + \
#                           (1 - alpha) * weather_weights.unsqueeze(0)
#         combined_weights = combined_weights / combined_weights.sum(dim=1, keepdim=True)
        
#         # Compute ensemble with weather-adaptive weights
#         pred_adapted = (combined_weights[:, 0] * pred_l +
#                        combined_weights[:, 1] * pred_c +
#                        combined_weights[:, 2] * pred_r +
#                        combined_weights[:, 3] * pred_f)
        
#         return pred_adapted.cpu(), combined_weights.detach().cpu()

# # Example usage (uncomment to test)
# # print("\nüåßÔ∏è  WEATHER-ADAPTIVE INFERENCE EXAMPLE")
# # print("="*60)
# # 
# # # Get a batch from test set
# # test_batch = next(iter(test_loader))
# # bev_l_test, bev_c_test, bev_r_test, dist_test = test_batch
# # 
# # for weather in ['clear', 'rain', 'fog', 'night', 'snow']:
# #     pred, weights = predict_with_weather_adaptation(
# #         model, bev_l_test, bev_c_test, bev_r_test, weather
# #     )
# #     avg_weights = weights.mean(dim=0)
# #     print(f"\n{weather.upper():8} | L:{avg_weights[0]:.2f} C:{avg_weights[1]:.2f} " +
# #           f"R:{avg_weights[2]:.2f} F:{avg_weights[3]:.2f} | Pred: {pred[0]:.2f}m")


In [None]:
# ========== WEATHER-ADAPTIVE INFERENCE ==========
"""
At inference time, you can dynamically adjust branch weights based on detected weather,
overriding the learned logvar weights to favor more robust sensors in bad conditions.

Example: In rain, manually boost RADAR weight and reduce LIDAR/CAMERA to improve robustness.
"""

def predict_with_weather_adaptation(model, bev_l, bev_c, bev_r, weather_dict=None, device=DEVICE):
    """
    Run inference and optionally apply weather-adaptive branch weighting.
    
    Args:
        model: UTENet4BranchBEV
        bev_l, bev_c, bev_r: BEV tensors
        weather_dict: dict from extract_weather_from_description (or None for learned weights only)
        device: torch device
    
    Returns:
        pred_ensemble: final prediction
        individual_preds: (pred_lidar, pred_cam, pred_radar, pred_fused)
        active_weights: (w_lidar, w_cam, w_radar, w_fused) used in prediction
    """
    model.eval()
    with torch.no_grad():
        # Standard forward pass (get learned logvar weights)
        out = model(bev_l.unsqueeze(0).to(device), 
                   bev_c.unsqueeze(0).to(device), 
                   bev_r.unsqueeze(0).to(device))
        
        pred_ensemble_learned = out[0].squeeze().item()
        pred_lidar = out[1].squeeze().item()
        pred_cam = out[2].squeeze().item()
        pred_radar = out[3].squeeze().item()
        pred_fused = out[4].squeeze().item()
        w_learned = out[5:]  # Learned weights
        
        # If weather provided, compute adaptive weights
        if weather_dict:
            weather_mults = compute_branch_weights_for_weather(weather_dict)
            
            # Recompute ensemble with weather-modulated weights
            sig_l = torch.exp(model.logvar_lidar).item()
            sig_c = torch.exp(model.logvar_cam).item()
            sig_r = torch.exp(model.logvar_radar).item()
            sig_f = torch.exp(model.logvar_fused).item()
            
            wl = (1.0 / (sig_l + 1e-8)) * weather_mults['lidar']
            wc = (1.0 / (sig_c + 1e-8)) * weather_mults['camera']
            wr = (1.0 / (sig_r + 1e-8)) * weather_mults['radar']
            wf = (1.0 / (sig_f + 1e-8)) * weather_mults['fused']
            
            w_sum = wl + wc + wr + wf
            wl_norm = wl / w_sum
            wc_norm = wc / w_sum
            wr_norm = wr / w_sum
            wf_norm = wf / w_sum
            
            pred_ensemble_adaptive = wl_norm * pred_lidar + wc_norm * pred_cam + wr_norm * pred_radar + wf_norm * pred_fused
            
            return {
                'prediction_learned': pred_ensemble_learned,
                'prediction_adaptive': pred_ensemble_adaptive,
                'individual': {
                    'lidar': pred_lidar,
                    'camera': pred_cam,
                    'radar': pred_radar,
                    'fused': pred_fused
                },
                'weights_learned': {
                    'lidar': 1.0 / (torch.exp(model.logvar_lidar).item() + 1e-8),
                    'camera': 1.0 / (torch.exp(model.logvar_cam).item() + 1e-8),
                    'radar': 1.0 / (torch.exp(model.logvar_radar).item() + 1e-8),
                    'fused': 1.0 / (torch.exp(model.logvar_fused).item() + 1e-8),
                },
                'weights_adaptive': {
                    'lidar': wl_norm,
                    'camera': wc_norm,
                    'radar': wr_norm,
                    'fused': wf_norm,
                },
                'weather': weather_dict
            }
        else:
            # No weather: use learned weights only
            sig_l = torch.exp(model.logvar_lidar).item()
            sig_c = torch.exp(model.logvar_cam).item()
            sig_r = torch.exp(model.logvar_radar).item()
            sig_f = torch.exp(model.logvar_fused).item()
            
            wl = 1.0 / (sig_l + 1e-8)
            wc = 1.0 / (sig_c + 1e-8)
            wr = 1.0 / (sig_r + 1e-8)
            wf = 1.0 / (sig_f + 1e-8)
            w_sum = wl + wc + wr + wf
            
            return {
                'prediction': pred_ensemble_learned,
                'individual': {
                    'lidar': pred_lidar,
                    'camera': pred_cam,
                    'radar': pred_radar,
                    'fused': pred_fused
                },
                'weights': {
                    'lidar': wl / w_sum,
                    'camera': wc / w_sum,
                    'radar': wr / w_sum,
                    'fused': wf / w_sum,
                },
                'weather': None
            }


print("‚úÖ WEATHER-ADAPTIVE INFERENCE READY")
print("\nüí° Usage:")
print("   result = predict_with_weather_adaptation(model, bev_l, bev_c, bev_r, weather_dict)")
print("   print(f'Learned prediction: {result[\"prediction_learned\"]:.2f}m')")
print("   print(f'Adaptive prediction: {result[\"prediction_adaptive\"]:.2f}m')")
print("   print(f'Weights (adaptive): {result[\"weights_adaptive\"]}')")


In [None]:
# ========== WEATHER-AWARE vs BASELINE COMPARISON ==========
"""
Compare model performance with and without weather-adaptive weighting.
This evaluates whether dynamically adjusting branch weights based on weather improves predictions.
"""

print("\n" + "="*70)
print("üå¶Ô∏è  WEATHER-AWARE EVALUATION COMPARISON")
print("="*70)

# Load best model
model.load_state_dict(torch.load(STATE_PATH, map_location=DEVICE))
model.eval()

# Evaluation function for both modes
def evaluate_with_weather_modes(loader, dataset_name="test"):
    """
    Evaluate model in two modes:
    1. Baseline: Use learned weights only (original ensemble)
    2. Weather-Adaptive: Modulate weights by detected weather conditions
    
    Returns:
        dict with results for both modes
    """
    
    # Storage for predictions
    baseline_preds = []
    adaptive_preds = []
    true_labels = []
    weather_conditions_list = []
    
    model.eval()
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(loader, desc=f"Evaluating {dataset_name}")):
            bev_l, bev_c, bev_r, dist = batch
            bev_l = bev_l.to(DEVICE)
            bev_c = bev_c.to(DEVICE)
            bev_r = bev_r.to(DEVICE)
            
            # Get sample tokens for this batch
            start_idx = batch_idx * loader.batch_size
            end_idx = min(start_idx + loader.batch_size, len(loader.dataset))
            batch_tokens = loader.dataset.sample_tokens[start_idx:end_idx]
            
            # Forward pass (baseline predictions)
            out = model(bev_l, bev_c, bev_r)
            pred_baseline = out[0].cpu().numpy()  # ensemble with learned weights
            
            pred_lidar = out[1].cpu().numpy()
            pred_cam = out[2].cpu().numpy()
            pred_radar = out[3].cpu().numpy()
            pred_fused = out[4].cpu().numpy()
            
            # Get learned weights
            sig_l = torch.exp(model.logvar_lidar).item()
            sig_c = torch.exp(model.logvar_cam).item()
            sig_r = torch.exp(model.logvar_radar).item()
            sig_f = torch.exp(model.logvar_fused).item()
            
            # Compute weather-adaptive predictions for each sample
            pred_adaptive_batch = []
            for i, token in enumerate(batch_tokens):
                weather = sample_weather_map.get(token, {})
                weather_conditions_list.append(weather)
                
                if weather:
                    # Get weather modulation factors
                    weather_mults = compute_branch_weights_for_weather(weather)
                    
                    # Modulate inverse-variance weights by weather
                    wl = (1.0 / (sig_l + 1e-8)) * weather_mults['lidar']
                    wc = (1.0 / (sig_c + 1e-8)) * weather_mults['camera']
                    wr = (1.0 / (sig_r + 1e-8)) * weather_mults['radar']
                    wf = (1.0 / (sig_f + 1e-8)) * weather_mults['fused']
                    
                    # Normalize
                    w_sum = wl + wc + wr + wf
                    wl_norm = wl / w_sum
                    wc_norm = wc / w_sum
                    wr_norm = wr / w_sum
                    wf_norm = wf / w_sum
                    
                    # Recompute ensemble with weather-modulated weights
                    pred_adaptive = (wl_norm * pred_lidar[i] + 
                                   wc_norm * pred_cam[i] + 
                                   wr_norm * pred_radar[i] + 
                                   wf_norm * pred_fused[i])
                    pred_adaptive_batch.append(pred_adaptive)
                else:
                    # No weather: use baseline
                    pred_adaptive_batch.append(pred_baseline[i])
            
            baseline_preds.extend(pred_baseline)
            adaptive_preds.extend(pred_adaptive_batch)
            true_labels.extend(dist.cpu().numpy())
    
    # Convert to arrays
    baseline_preds = np.array(baseline_preds)
    adaptive_preds = np.array(adaptive_preds)
    true_labels = np.array(true_labels)
    
    # Compute metrics for both modes
    baseline_rmse = np.sqrt(mean_squared_error(true_labels, baseline_preds))
    baseline_mae = mean_absolute_error(true_labels, baseline_preds)
    baseline_r2 = r2_score(true_labels, baseline_preds)
    
    adaptive_rmse = np.sqrt(mean_squared_error(true_labels, adaptive_preds))
    adaptive_mae = mean_absolute_error(true_labels, adaptive_preds)
    adaptive_r2 = r2_score(true_labels, adaptive_preds)
    
    # Breakdown by weather condition
    weather_breakdown = {}
    for cond in ['rain', 'night', 'fog', 'snow']:
        # Find samples with this condition
        indices = [i for i, w in enumerate(weather_conditions_list) if w.get(cond, 0) == 1]
        
        if len(indices) > 0:
            cond_true = true_labels[indices]
            cond_baseline = baseline_preds[indices]
            cond_adaptive = adaptive_preds[indices]
            
            weather_breakdown[cond] = {
                'count': len(indices),
                'baseline_rmse': np.sqrt(mean_squared_error(cond_true, cond_baseline)),
                'adaptive_rmse': np.sqrt(mean_squared_error(cond_true, cond_adaptive)),
                'baseline_mae': mean_absolute_error(cond_true, cond_baseline),
                'adaptive_mae': mean_absolute_error(cond_true, cond_adaptive),
            }
    
    return {
        'baseline': {'rmse': baseline_rmse, 'mae': baseline_mae, 'r2': baseline_r2},
        'adaptive': {'rmse': adaptive_rmse, 'mae': adaptive_mae, 'r2': adaptive_r2},
        'weather_breakdown': weather_breakdown,
        'predictions': {
            'baseline': baseline_preds,
            'adaptive': adaptive_preds,
            'true': true_labels
        }
    }


# Run comparison on test set
results = evaluate_with_weather_modes(test_loader, "test")

# Display results
print("\n" + "="*70)
print("üìä OVERALL RESULTS (Test Set)")
print("="*70)
print(f"\n{'Mode':<20} {'RMSE (m)':<15} {'MAE (m)':<15} {'R¬≤':<10}")
print("-"*70)
print(f"{'Baseline':<20} {results['baseline']['rmse']:<15.4f} {results['baseline']['mae']:<15.4f} {results['baseline']['r2']:<10.4f}")
print(f"{'Weather-Adaptive':<20} {results['adaptive']['rmse']:<15.4f} {results['adaptive']['mae']:<15.4f} {results['adaptive']['r2']:<10.4f}")

# Compute improvement
rmse_improvement = ((results['baseline']['rmse'] - results['adaptive']['rmse']) / results['baseline']['rmse']) * 100
mae_improvement = ((results['baseline']['mae'] - results['adaptive']['mae']) / results['baseline']['mae']) * 100

print(f"\n{'Improvement':<20} {rmse_improvement:>14.2f}% {mae_improvement:>14.2f}%")

# Weather-specific breakdown
if results['weather_breakdown']:
    print("\n" + "="*70)
    print("üå¶Ô∏è  WEATHER-SPECIFIC PERFORMANCE")
    print("="*70)
    
    for cond, metrics in sorted(results['weather_breakdown'].items(), key=lambda x: x[1]['count'], reverse=True):
        if metrics['count'] > 0:
            rmse_delta = metrics['baseline_rmse'] - metrics['adaptive_rmse']
            mae_delta = metrics['baseline_mae'] - metrics['adaptive_mae']
            
            print(f"\n{cond.upper()} ({metrics['count']} samples):")
            print(f"  Baseline:  RMSE={metrics['baseline_rmse']:.4f}m  MAE={metrics['baseline_mae']:.4f}m")
            print(f"  Adaptive:  RMSE={metrics['adaptive_rmse']:.4f}m  MAE={metrics['adaptive_mae']:.4f}m")
            print(f"  Delta:     RMSE={rmse_delta:+.4f}m  MAE={mae_delta:+.4f}m  {'‚úÖ Better' if rmse_delta > 0 else '‚ùå Worse'}")

print("\n" + "="*70)
print("‚úÖ COMPARISON COMPLETE")
print("="*70)


In [None]:
# ========== BRANCH PREDICTION VISUALIZATION ==========
"""
Visualize predictions from all 4 branches alongside ground truth.
Shows how each branch (LIDAR, CAMERA, RADAR, FUSED) individually predicts distance.
"""

def visualize_branch_predictions(model, loader, num_samples=6, figsize=(16, 12)):
    """
    Create overlaid visualization of predictions from all 4 branches.
    
    Args:
        model: trained UTENet4BranchBEV
        loader: DataLoader (train, val, or test)
        num_samples: number of samples to visualize
        figsize: matplotlib figure size
    """
    model.eval()
    
    predictions_data = []
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if len(predictions_data) >= num_samples:
                break
            
            bev_l, bev_c, bev_r, dist = batch
            bev_l = torch.nan_to_num(bev_l.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            bev_c = torch.nan_to_num(bev_c.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            bev_r = torch.nan_to_num(bev_r.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            dist = torch.nan_to_num(dist.to(DEVICE), nan=50.0, posinf=50.0, neginf=0.0)
            
            # Forward pass
            out = model(bev_l, bev_c, bev_r)
            pred_ensemble = np.atleast_1d(out[0].cpu().numpy()).flatten()
            pred_l = np.atleast_1d(out[1].cpu().numpy()).flatten()
            pred_c = np.atleast_1d(out[2].cpu().numpy()).flatten()
            pred_r = np.atleast_1d(out[3].cpu().numpy()).flatten()
            pred_f = np.atleast_1d(out[4].cpu().numpy()).flatten()
            dist_gt = np.atleast_1d(dist.cpu().numpy()).flatten()
            
            # Get weights - flatten them (handles both scalar and batch cases)
            w_l = np.atleast_1d(out[5].cpu().numpy()).flatten()
            w_c = np.atleast_1d(out[6].cpu().numpy()).flatten()
            w_r = np.atleast_1d(out[7].cpu().numpy()).flatten()
            w_f = np.atleast_1d(out[8].cpu().numpy()).flatten()
            
            # Store for each sample in batch
            for j in range(len(dist_gt)):
                predictions_data.append({
                    'gt': dist_gt[j],
                    'ensemble': pred_ensemble[j],
                    'lidar': pred_l[j],
                    'camera': pred_c[j],
                    'radar': pred_r[j],
                    'fused': pred_f[j],
                    'weight_l': w_l[j] if j < len(w_l) else w_l[0],
                    'weight_c': w_c[j] if j < len(w_c) else w_c[0],
                    'weight_r': w_r[j] if j < len(w_r) else w_r[0],
                    'weight_f': w_f[j] if j < len(w_f) else w_f[0],
                })
    
    # Create figure with subplots (one per sample)
    nrows = (num_samples + 1) // 2
    fig, axes = plt.subplots(nrows, 2, figsize=figsize)
    if nrows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    
    # Plot each sample
    for idx, data in enumerate(predictions_data[:num_samples]):
        ax = axes[idx]
        
        # Predictions
        branches = ['LIDAR', 'CAMERA', 'RADAR', 'FUSED', 'ENSEMBLE']
        preds = [data['lidar'], data['camera'], data['radar'], data['fused'], data['ensemble']]
        weights = [data['weight_l'], data['weight_c'], data['weight_r'], data['weight_f'], 1.0]
        colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#2ECC71']
        
        # Create bar chart with error indicators
        x_pos = np.arange(len(branches))
        bars = ax.bar(x_pos, preds, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
        
        # Add ground truth horizontal line
        ax.axhline(y=data['gt'], color='red', linestyle='--', linewidth=3, label=f"Ground Truth: {data['gt']:.2f}m")
        
        # Add value labels on bars
        for i, (bar, pred, weight) in enumerate(zip(bars, preds, weights)):
            error = abs(pred - data['gt'])
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
                   f"{pred:.2f}m\n(w:{weight:.2f})", 
                   ha='center', va='bottom', fontsize=9, fontweight='bold')
        
        # Styling
        ax.set_ylabel('Distance (m)', fontsize=11, fontweight='bold')
        ax.set_title(f'Sample {idx+1} - Branch Predictions vs Ground Truth', 
                    fontsize=12, fontweight='bold', pad=10)
        ax.set_xticks(x_pos)
        ax.set_xticklabels(branches, fontsize=10)
        ax.legend(fontsize=10, loc='upper right')
        ax.grid(axis='y', alpha=0.3, linestyle=':')
        ax.set_ylim(0, max(max(preds), data['gt']) * 1.15)
    
    # Hide unused subplots
    for idx in range(len(predictions_data), len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    return fig, predictions_data

# Run visualization on test set
print("üé® GENERATING BRANCH PREDICTION VISUALIZATIONS")
print("="*70)
fig, pred_data = visualize_branch_predictions(model, test_loader, num_samples=6)
plt.show()

print("\nüìä Sample Statistics from Visualization:")
for i, data in enumerate(pred_data[:3]):
    print(f"\nSample {i+1}:")
    print(f"  Ground Truth:  {data['gt']:.2f}m")
    print(f"  LIDAR pred:    {data['lidar']:.2f}m (error: {abs(data['lidar']-data['gt']):.2f}m, weight: {data['weight_l']:.3f})")
    print(f"  CAMERA pred:   {data['camera']:.2f}m (error: {abs(data['camera']-data['gt']):.2f}m, weight: {data['weight_c']:.3f})")
    print(f"  RADAR pred:    {data['radar']:.2f}m (error: {abs(data['radar']-data['gt']):.2f}m, weight: {data['weight_r']:.3f})")
    print(f"  FUSED pred:    {data['fused']:.2f}m (error: {abs(data['fused']-data['gt']):.2f}m, weight: {data['weight_f']:.3f})")
    print(f"  ENSEMBLE pred: {data['ensemble']:.2f}m (error: {abs(data['ensemble']-data['gt']):.2f}m)")


In [None]:
# ========== BRANCH PERFORMANCE COMPARISON ==========
"""
Comprehensive metrics comparing all 4 branches across entire test set.
Shows which branches are most reliable and their error distributions.
"""

def evaluate_all_branches(model, loader, dataset_name='test'):
    """
    Compute metrics for each branch individually across entire dataset.
    
    Args:
        model: trained UTENet4BranchBEV
        loader: DataLoader
        dataset_name: name for output
    
    Returns:
        dict with metrics for each branch
    """
    model.eval()
    
    branches_preds = {
        'lidar': [],
        'camera': [],
        'radar': [],
        'fused': [],
        'ensemble': []
    }
    ground_truth = []
    all_weights = {
        'lidar': [],
        'camera': [],
        'radar': [],
        'fused': []
    }
    
    with torch.no_grad():
        pbar = tqdm(loader, desc=f'Evaluating {dataset_name}', leave=True)
        for batch in pbar:
            bev_l, bev_c, bev_r, dist = batch
            bev_l = torch.nan_to_num(bev_l.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            bev_c = torch.nan_to_num(bev_c.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            bev_r = torch.nan_to_num(bev_r.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            dist = torch.nan_to_num(dist.to(DEVICE), nan=50.0, posinf=50.0, neginf=0.0)
            
            # Forward pass
            out = model(bev_l, bev_c, bev_r)
            pred_ensemble = np.atleast_1d(out[0].cpu().numpy()).flatten()
            pred_l = np.atleast_1d(out[1].cpu().numpy()).flatten()
            pred_c = np.atleast_1d(out[2].cpu().numpy()).flatten()
            pred_r = np.atleast_1d(out[3].cpu().numpy()).flatten()
            pred_f = np.atleast_1d(out[4].cpu().numpy()).flatten()
            w_l = np.atleast_1d(out[5].cpu().numpy()).flatten()
            w_c = np.atleast_1d(out[6].cpu().numpy()).flatten()
            w_r = np.atleast_1d(out[7].cpu().numpy()).flatten()
            w_f = np.atleast_1d(out[8].cpu().numpy()).flatten()
            dist_np = np.atleast_1d(dist.cpu().numpy()).flatten()
            
            branches_preds['lidar'].extend(pred_l)
            branches_preds['camera'].extend(pred_c)
            branches_preds['radar'].extend(pred_r)
            branches_preds['fused'].extend(pred_f)
            branches_preds['ensemble'].extend(pred_ensemble)
            ground_truth.extend(dist_np)
            
            all_weights['lidar'].extend(w_l)
            all_weights['camera'].extend(w_c)
            all_weights['radar'].extend(w_r)
            all_weights['fused'].extend(w_f)
    
    # Compute metrics for each branch
    results = {}
    for branch_name, preds in branches_preds.items():
        preds = np.nan_to_num(np.array(preds), nan=50.0, posinf=50.0, neginf=0.0)
        ground_truth_arr = np.nan_to_num(np.array(ground_truth), nan=50.0, posinf=50.0, neginf=0.0)
        
        rmse = np.sqrt(mean_squared_error(ground_truth_arr, preds))
        mae = mean_absolute_error(ground_truth_arr, preds)
        r2 = r2_score(ground_truth_arr, preds)
        
        results[branch_name] = {
            'rmse': rmse,
            'mae': mae,
            'r2': r2,
            'predictions': preds,
        }
    
    # Add average weight info for non-ensemble branches
    for branch in ['lidar', 'camera', 'radar', 'fused']:
        results[branch]['avg_weight'] = np.mean(all_weights[branch])
    
    return results, ground_truth_arr

# Run comprehensive evaluation
print("\n" + "="*70)
print("üìä BRANCH PERFORMANCE COMPARISON (Full Test Set)")
print("="*70)

branch_results, gt_array = evaluate_all_branches(model, test_loader, "test")

# Display comparison table
print(f"\n{'Branch':<15} {'RMSE (m)':<12} {'MAE (m)':<12} {'R¬≤':<10} {'Avg Weight':<12}")
print("-"*70)
for branch in ['lidar', 'camera', 'radar', 'fused', 'ensemble']:
    metrics = branch_results[branch]
    rmse = metrics['rmse']
    mae = metrics['mae']
    r2 = metrics['r2']
    weight = metrics.get('avg_weight', 1.0)
    
    print(f"{branch.upper():<15} {rmse:<12.4f} {mae:<12.4f} {r2:<10.4f} {weight:<12.3f}")

# Create comparison visualizations
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# 1. RMSE comparison
ax = axes[0, 0]
branches = ['LIDAR', 'CAMERA', 'RADAR', 'FUSED', 'ENSEMBLE']
rmses = [branch_results[b.lower()]['rmse'] for b in branches]
colors_bars = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#2ECC71']
bars = ax.bar(branches, rmses, color=colors_bars, edgecolor='black', linewidth=2, alpha=0.8)
ax.set_ylabel('RMSE (m)', fontsize=12, fontweight='bold')
ax.set_title('RMSE by Branch', fontsize=13, fontweight='bold')
ax.grid(axis='y', alpha=0.3)
for bar, rmse in zip(bars, rmses):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
           f'{rmse:.3f}', ha='center', va='bottom', fontweight='bold')

# 2. MAE comparison
ax = axes[0, 1]
maes = [branch_results[b.lower()]['mae'] for b in branches]
bars = ax.bar(branches, maes, color=colors_bars, edgecolor='black', linewidth=2, alpha=0.8)
ax.set_ylabel('MAE (m)', fontsize=12, fontweight='bold')
ax.set_title('MAE by Branch', fontsize=13, fontweight='bold')
ax.grid(axis='y', alpha=0.3)
for bar, mae in zip(bars, maes):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
           f'{mae:.3f}', ha='center', va='bottom', fontweight='bold')

# 3. Prediction scatter (Ensemble vs GT)
ax = axes[1, 0]
ensemble_preds = branch_results['ensemble']['predictions']
ax.scatter(gt_array, ensemble_preds, alpha=0.5, s=30, color='#2ECC71', edgecolor='black', linewidth=0.5)
ax.plot([gt_array.min(), gt_array.max()], [gt_array.min(), gt_array.max()], 
        'r--', linewidth=2, label='Perfect Prediction')
ax.set_xlabel('Ground Truth (m)', fontsize=12, fontweight='bold')
ax.set_ylabel('Ensemble Prediction (m)', fontsize=12, fontweight='bold')
ax.set_title('Ensemble: Predictions vs Ground Truth', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(alpha=0.3)

# 4. Error distribution
ax = axes[1, 1]
for branch, color in zip(['lidar', 'camera', 'radar', 'fused', 'ensemble'], colors_bars):
    preds = branch_results[branch]['predictions']
    errors = np.abs(preds - gt_array)
    ax.hist(errors, bins=30, alpha=0.5, label=branch.upper(), color=color, edgecolor='black')
ax.set_xlabel('Absolute Error (m)', fontsize=12, fontweight='bold')
ax.set_ylabel('Frequency', fontsize=12, fontweight='bold')
ax.set_title('Error Distribution by Branch', fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print("\n" + "="*70)
print("‚úÖ BRANCH COMPARISON COMPLETE")
print("="*70)


In [None]:
# ========== DETAILED BRANCH OVERLAY VISUALIZATION ==========
"""
Side-by-side overlay of all 4 branch predictions with learned weights.
Shows exactly how each branch is contributing to the ensemble decision.
"""

def visualize_branch_overlays(model, loader, num_samples=8, figsize=(18, 10)):
    """
    Create detailed overlaid bar charts showing all 4 branch predictions.
    
    Args:
        model: trained UTENet4BranchBEV
        loader: DataLoader
        num_samples: number of samples to visualize
        figsize: figure size
    """
    model.eval()
    
    sample_data = []
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if len(sample_data) >= num_samples:
                break
            
            bev_l, bev_c, bev_r, dist = batch
            bev_l = torch.nan_to_num(bev_l.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            bev_c = torch.nan_to_num(bev_c.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            bev_r = torch.nan_to_num(bev_r.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            dist = torch.nan_to_num(dist.to(DEVICE), nan=50.0, posinf=50.0, neginf=0.0)
            
            # Forward pass
            out = model(bev_l, bev_c, bev_r)
            pred_ensemble = np.atleast_1d(out[0].cpu().numpy()).flatten()
            pred_l = np.atleast_1d(out[1].cpu().numpy()).flatten()
            pred_c = np.atleast_1d(out[2].cpu().numpy()).flatten()
            pred_r = np.atleast_1d(out[3].cpu().numpy()).flatten()
            pred_f = np.atleast_1d(out[4].cpu().numpy()).flatten()
            w_l = np.atleast_1d(out[5].cpu().numpy()).flatten()
            w_c = np.atleast_1d(out[6].cpu().numpy()).flatten()
            w_r = np.atleast_1d(out[7].cpu().numpy()).flatten()
            w_f = np.atleast_1d(out[8].cpu().numpy()).flatten()
            
            dist_gt = np.atleast_1d(dist.cpu().numpy()).flatten()
            
            # Collect for each sample
            for j in range(len(dist_gt)):
                sample_data.append({
                    'gt': dist_gt[j],
                    'ensemble': pred_ensemble[j],
                    'branches': {
                        'LIDAR': pred_l[j],
                        'CAMERA': pred_c[j],
                        'RADAR': pred_r[j],
                        'FUSED': pred_f[j]
                    },
                    'weights': {
                        'LIDAR': w_l[j] if j < len(w_l) else w_l[0],
                        'CAMERA': w_c[j] if j < len(w_c) else w_c[0],
                        'RADAR': w_r[j] if j < len(w_r) else w_r[0],
                        'FUSED': w_f[j] if j < len(w_f) else w_f[0]
                    }
                })
    
    # Create subplots
    nrows = (num_samples + 3) // 4
    fig, axes = plt.subplots(nrows, 4, figsize=figsize)
    if nrows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    
    colors_branch = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A']
    
    # Plot each sample
    for idx, data in enumerate(sample_data[:num_samples]):
        ax = axes[idx]
        
        branch_names = list(data['branches'].keys())
        branch_preds = list(data['branches'].values())
        branch_weights = list(data['weights'].values())
        
        # Create stacked visualization: bars show predictions, width shows contribution
        x = np.arange(len(branch_names))
        bars = ax.bar(x, branch_preds, color=colors_branch, alpha=0.7, 
                     edgecolor='black', linewidth=2.5)
        
        # Add ground truth line
        ax.axhline(y=data['gt'], color='red', linestyle='--', linewidth=3, 
                  label=f"GT: {data['gt']:.2f}m", zorder=10)
        
        # Add ensemble prediction marker
        ax.axhline(y=data['ensemble'], color='green', linestyle=':', linewidth=2.5, 
                  label=f"Ensemble: {data['ensemble']:.2f}m", zorder=9)
        
        # Annotate each bar with prediction and weight
        for i, (bar, pred, weight) in enumerate(zip(bars, branch_preds, branch_weights)):
            height = bar.get_height()
            error = abs(pred - data['gt'])
            
            # Main label: prediction value
            ax.text(bar.get_x() + bar.get_width()/2, height/2,
                   f"{pred:.2f}m",
                   ha='center', va='center', fontsize=11, fontweight='bold', 
                   color='white', bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
            
            # Top label: weight and error
            ax.text(bar.get_x() + bar.get_width()/2, height + 1.5,
                   f"w:{weight:.2f}\nerr:{error:.2f}m",
                   ha='center', va='bottom', fontsize=9, fontweight='bold')
        
        # Styling
        max_y = max(max(branch_preds), data['gt']) * 1.25
        ax.set_ylim(0, max_y)
        ax.set_ylabel('Distance (m)', fontsize=11, fontweight='bold')
        ax.set_title(f'Sample {idx+1}', fontsize=12, fontweight='bold', pad=8)
        ax.set_xticks(x)
        ax.set_xticklabels(branch_names, fontsize=9, rotation=0)
        ax.legend(fontsize=9, loc='upper right')
        ax.grid(axis='y', alpha=0.3, linestyle=':')
    
    # Hide unused subplots
    for idx in range(len(sample_data), len(axes)):
        axes[idx].axis('off')
    
    plt.suptitle('Branch Predictions Overlay: All 4 Sensors vs Ensemble vs Ground Truth',
                fontsize=15, fontweight='bold', y=0.995)
    plt.tight_layout()
    return fig, sample_data

# Generate the overlays
print("\nüé® DETAILED BRANCH OVERLAY VISUALIZATION")
print("="*70)
fig, overlay_data = visualize_branch_overlays(model, test_loader, num_samples=8)
plt.show()

# Print detailed analysis for first few samples
print("\nüìã DETAILED SAMPLE ANALYSIS:")
print("="*70)
for idx, data in enumerate(overlay_data[:5]):
    print(f"\nüîç SAMPLE {idx+1}:")
    print(f"   Ground Truth: {data['gt']:.2f}m")
    print(f"   Ensemble:    {data['ensemble']:.2f}m (Error: {abs(data['ensemble']-data['gt']):.3f}m)")
    print(f"\n   Individual Branch Predictions:")
    
    for branch in ['LIDAR', 'CAMERA', 'RADAR', 'FUSED']:
        pred = data['branches'][branch]
        weight = data['weights'][branch]
        error = abs(pred - data['gt'])
        error_pct = (error / data['gt'] * 100) if data['gt'] > 0 else 0
        
        # Determine if branch is better or worse than ensemble
        ensemble_error = abs(data['ensemble'] - data['gt'])
        if error < ensemble_error:
            indicator = "‚úÖ BETTER than ensemble"
        elif error > ensemble_error:
            indicator = "‚ùå WORSE than ensemble"
        else:
            indicator = "‚ûñ SAME as ensemble"
        
        print(f"      {branch:8} ‚Üí {pred:.2f}m  (weight: {weight:.3f}, error: {error:.3f}m ¬±{error_pct:.1f}%) {indicator}")

print("\n" + "="*70)


In [None]:
# ========== BRANCH CONTRIBUTION HEATMAP ==========
"""
Show how much each branch is being used (by weight) and how accurate each is.
Reveals which branches are trusted and which are down-weighted.
"""

def create_branch_analysis_heatmap(model, loader, figsize=(14, 8)):
    """
    Create comprehensive heatmap showing branch metrics.
    
    Args:
        model: trained UTENet4BranchBEV
        loader: DataLoader
        figsize: figure size
    """
    model.eval()
    
    branch_stats = {
        'lidar': {'weights': [], 'errors': []},
        'camera': {'weights': [], 'errors': []},
        'radar': {'weights': [], 'errors': []},
        'fused': {'weights': [], 'errors': []}
    }
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Computing branch stats', leave=False):
            bev_l, bev_c, bev_r, dist = batch
            bev_l = torch.nan_to_num(bev_l.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            bev_c = torch.nan_to_num(bev_c.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            bev_r = torch.nan_to_num(bev_r.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            dist = torch.nan_to_num(dist.to(DEVICE), nan=50.0, posinf=50.0, neginf=0.0)
            
            out = model(bev_l, bev_c, bev_r)
            pred_l = np.atleast_1d(out[1].cpu().numpy()).flatten()
            pred_c = np.atleast_1d(out[2].cpu().numpy()).flatten()
            pred_r = np.atleast_1d(out[3].cpu().numpy()).flatten()
            pred_f = np.atleast_1d(out[4].cpu().numpy()).flatten()
            w_l = np.atleast_1d(out[5].cpu().numpy()).flatten()
            w_c = np.atleast_1d(out[6].cpu().numpy()).flatten()
            w_r = np.atleast_1d(out[7].cpu().numpy()).flatten()
            w_f = np.atleast_1d(out[8].cpu().numpy()).flatten()
            
            dist_np = np.atleast_1d(dist.cpu().numpy()).flatten()
            
            for i in range(len(dist_np)):
                branch_stats['lidar']['weights'].append(w_l[i] if i < len(w_l) else w_l[0])
                branch_stats['lidar']['errors'].append(np.abs(pred_l[i] - dist_np[i]))
                
                branch_stats['camera']['weights'].append(w_c[i] if i < len(w_c) else w_c[0])
                branch_stats['camera']['errors'].append(np.abs(pred_c[i] - dist_np[i]))
                
                branch_stats['radar']['weights'].append(w_r[i] if i < len(w_r) else w_r[0])
                branch_stats['radar']['errors'].append(np.abs(pred_r[i] - dist_np[i]))
                
                branch_stats['fused']['weights'].append(w_f[i] if i < len(w_f) else w_f[0])
                branch_stats['fused']['errors'].append(np.abs(pred_f[i] - dist_np[i]))
    
    # Create figure
    fig, axes = plt.subplots(2, 2, figsize=figsize)
    
    # 1. Average Weight per Branch
    ax = axes[0, 0]
    branches = ['LIDAR', 'CAMERA', 'RADAR', 'FUSED']
    avg_weights = [np.mean(branch_stats[b.lower()]['weights']) for b in branches]
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A']
    bars = ax.bar(branches, avg_weights, color=colors, alpha=0.8, edgecolor='black', linewidth=2)
    ax.set_ylabel('Average Weight', fontsize=11, fontweight='bold')
    ax.set_title('Learned Branch Weights (How Much Each Branch is Used)', fontsize=12, fontweight='bold')
    ax.set_ylim(0, max(avg_weights) * 1.2)
    ax.grid(axis='y', alpha=0.3)
    for bar, weight in zip(bars, avg_weights):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
               f'{weight:.3f}', ha='center', va='bottom', fontweight='bold')
    
    # 2. Average Error per Branch
    ax = axes[0, 1]
    avg_errors = [np.mean(branch_stats[b.lower()]['errors']) for b in branches]
    bars = ax.bar(branches, avg_errors, color=colors, alpha=0.8, edgecolor='black', linewidth=2)
    ax.set_ylabel('Mean Absolute Error (m)', fontsize=11, fontweight='bold')
    ax.set_title('Average Error per Branch (Accuracy)', fontsize=12, fontweight='bold')
    ax.set_ylim(0, max(avg_errors) * 1.2)
    ax.grid(axis='y', alpha=0.3)
    for bar, error in zip(bars, avg_errors):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
               f'{error:.3f}m', ha='center', va='bottom', fontweight='bold')
    
    # 3. Weight vs Error Scatter
    ax = axes[1, 0]
    for branch, color in zip(['lidar', 'camera', 'radar', 'fused'], colors):
        weights = np.array(branch_stats[branch]['weights'])
        errors = np.array(branch_stats[branch]['errors'])
        ax.scatter(weights, errors, alpha=0.5, s=20, label=branch.upper(), color=color)
    ax.set_xlabel('Learned Weight', fontsize=11, fontweight='bold')
    ax.set_ylabel('Prediction Error (m)', fontsize=11, fontweight='bold')
    ax.set_title('Branch Weight vs Prediction Error\n(Ideal: high weight + low error)', 
                fontsize=12, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(alpha=0.3)
    
    # 4. Weight Distribution Violin Plot
    ax = axes[1, 1]
    weight_data = [branch_stats[b.lower()]['weights'] for b in branches]
    parts = ax.violinplot(weight_data, positions=np.arange(len(branches)), 
                          showmeans=True, showmedians=True)
    ax.set_ylabel('Weight Value', fontsize=11, fontweight='bold')
    ax.set_title('Weight Distribution per Branch (Violin Plot)', fontsize=12, fontweight='bold')
    ax.set_xticks(np.arange(len(branches)))
    ax.set_xticklabels(branches)
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    return fig, branch_stats

# Generate heatmap
print("\nüìä BRANCH CONTRIBUTION & ACCURACY ANALYSIS")
print("="*70)
fig, branch_stats = create_branch_analysis_heatmap(model, test_loader)
plt.show()

# Print summary statistics
print("\nüìà BRANCH STATISTICS SUMMARY:")
print("="*70)
print(f"\n{'Branch':<12} {'Avg Weight':<15} {'Avg Error':<15} {'Min Error':<15} {'Max Error':<15}")
print("-"*70)

for branch in ['lidar', 'camera', 'radar', 'fused']:
    weights = np.array(branch_stats[branch]['weights'])
    errors = np.array(branch_stats[branch]['errors'])
    
    print(f"{branch.upper():<12} {np.mean(weights):<15.4f} {np.mean(errors):<15.4f} "
          f"{np.min(errors):<15.4f} {np.max(errors):<15.4f}")

print("\nüí° INTERPRETATION:")
print("-"*70)
print("‚Ä¢ High Weight + Low Error = Branch is trusted and accurate")
print("‚Ä¢ High Weight + High Error = Branch is trusted but unreliable (needs retraining)")
print("‚Ä¢ Low Weight + Low Error = Branch is accurate but not used (good ensemble decision)")
print("‚Ä¢ Low Weight + High Error = Branch is unreliable and correctly down-weighted")
print("\n" + "="*70)


In [None]:
# ========== SIMPLE CAMERA + BEV VISUALIZATION ==========
"""
Show a single sample with:
- CAM_FRONT image
- LIDAR BEV count channel
- RADAR BEV count channel
"""

import matplotlib.pyplot as plt
from PIL import Image
import random


def load_cam_image(sample_token, fallback_size=(640, 360)):
    """Load CAM_FRONT image for a sample_token if available; else return a blank image."""
    if sample_token not in sample_to_sensor:
        return Image.new('RGB', fallback_size, color=(30, 30, 30))
    sensors = sample_to_sensor[sample_token]
    if 'CAM_FRONT' not in sensors:
        return Image.new('RGB', fallback_size, color=(30, 30, 30))
    sd_tok = sensors['CAM_FRONT']
    sd_rec = sd_by_token.get(sd_tok)
    if not sd_rec:
        return Image.new('RGB', fallback_size, color=(30, 30, 30))
    img_path = abs_sensor_path(sd_rec)
    if not os.path.exists(img_path):
        return Image.new('RGB', fallback_size, color=(30, 30, 30))
    try:
        return Image.open(img_path).convert('RGB')
    except Exception:
        return Image.new('RGB', fallback_size, color=(30, 30, 30))


def visualize_camera_and_bev(dataset, sample_idx=None, figsize=(16, 5)):
    """
    Show CAM image + LIDAR count BEV + RADAR count BEV for one sample.
    If sample_idx is None, pick a random sample.
    """
    if sample_idx is None:
        sample_idx = random.randint(0, len(dataset) - 1)
    
    # Get sample
    bev_l, bev_c, bev_r, dist = dataset[sample_idx]
    sample_token = dataset.sample_tokens[sample_idx]
    img = load_cam_image(sample_token)
    
    # Extract count channels
    lidar_count = bev_l[0].numpy() if hasattr(bev_l, 'numpy') else bev_l[0]
    radar_count = bev_r[0].numpy() if hasattr(bev_r, 'numpy') else bev_r[0]
    
    # Plot
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    axes[0].imshow(img)
    axes[0].axis('off')
    axes[0].set_title(f'CAM_FRONT (idx {sample_idx})')
    
    im1 = axes[1].imshow(lidar_count, cmap='viridis', aspect='auto')
    axes[1].set_title('LIDAR BEV (Count)')
    fig.colorbar(im1, ax=axes[1], fraction=0.046)
    
    im2 = axes[2].imshow(radar_count, cmap='plasma', aspect='auto')
    axes[2].set_title('RADAR BEV (Count)')
    fig.colorbar(im2, ax=axes[2], fraction=0.046)
    
    plt.tight_layout()
    plt.show()


# Run: simple camera + BEV visualization on one random sample
print("\nüé® SIMPLE CAMERA + BEV VISUALIZATION")
print("="*70)
visualize_camera_and_bev(test_ds)


In [None]:
# ========== CAMERA + BEV PERSPECTIVE ANALYSIS ==========
"""
Understand why objects in camera images appear farther than BEV suggests.
This compares camera perspective projection with BEV bird's-eye-view.
"""

print("\nüìä CAMERA vs BEV PERSPECTIVE ANALYSIS")
print("="*70)

# Camera parameters (typical nuScenes CAM_FRONT)
CAM_INTRINSICS = {
    'fx': 1266.4,      # focal length x
    'fy': 1266.4,      # focal length y
    'cx': 816.0,       # principal point x
    'cy': 491.0,       # principal point y
    'width': 1600,
    'height': 900,
}

CAM_HEIGHT = 1.7  # meters above ground (approximate ego mounting height)
IMG_WIDTH = CAM_INTRINSICS['width']
IMG_HEIGHT = CAM_INTRINSICS['height']

print(f"\nüé• CAMERA PROPERTIES:")
print(f"  Image size: {IMG_WIDTH} x {IMG_HEIGHT}")
print(f"  Focal length: {CAM_INTRINSICS['fx']:.1f}px")
print(f"  Principal point: ({CAM_INTRINSICS['cx']:.1f}, {CAM_INTRINSICS['cy']:.1f})")
print(f"  Camera height above ground: {CAM_HEIGHT}m")

# Field of view calculation
fov_x = 2 * np.arctan(IMG_WIDTH / (2 * CAM_INTRINSICS['fx'])) * 180 / np.pi
fov_y = 2 * np.arctan(IMG_HEIGHT / (2 * CAM_INTRINSICS['fy'])) * 180 / np.pi

print(f"\nüìê FIELD OF VIEW:")
print(f"  Horizontal FOV: {fov_x:.1f}¬∞")
print(f"  Vertical FOV: {fov_y:.1f}¬∞")

print(f"\nüõ£Ô∏è  BEV COVERAGE:")
print(f"  Forward range: {XRANGE[0]:.0f}m to {XRANGE[1]:.0f}m (200m total)")
print(f"  Lateral range: {YRANGE[0]:.0f}m to {YRANGE[1]:.0f}m (100m total)")
print(f"  Grid resolution: {RES}m/cell")
print(f"  Grid size: {NX} x {NY} cells")

print(f"\nüí° WHY OBJECTS APPEAR FARTHER IN CAMERA:")
print(f"  1. Camera perspective: Objects at 10m look ~5% into the 200m BEV range")
print(f"  2. Pinhole projection: Smaller image = appears farther")
print(f"  3. Camera height: Ground at 10m is below horizon, looks very distant")
print(f"  4. No depth cues: Without stereo/depth, perception is ambiguous")

print(f"\nüìè EXAMPLE DISTANCE INTERPRETATIONS:")
test_distances = [5, 10, 20, 30, 50]
for dist in test_distances:
    percent_of_bev = (dist - XRANGE[0]) / (XRANGE[1] - XRANGE[0]) * 100
    print(f"  {dist:2d}m lead vehicle = {percent_of_bev:5.1f}% into forward BEV range")

print("\n‚úÖ CONCLUSION: Use BEV predictions, not camera visual perception for distance.")
print("="*70)

In [None]:
# ========== CAMERA + FUSION + BRANCH WEIGHTS (SINGLE SAMPLE) ==========
"""
Show for one sample:
- CAM_FRONT image
- BEV fusion visualization (LIDAR + RADAR + CAMERA)
- Branch predictions (no GT line)
- Normalized branch weights for that sample
"""

import matplotlib.pyplot as plt
import random


def visualize_camera_fusion_and_weights(model, dataset, sample_idx=None):
    model.eval()
    if sample_idx is None:
        sample_idx = random.randint(0, len(dataset) - 1)

    # Grab sample
    bev_l, bev_c, bev_r, dist = dataset[sample_idx]
    sample_token = dataset.sample_tokens[sample_idx]

    with torch.no_grad():
        # Batch dim
        bev_l_b = bev_l.unsqueeze(0).to(DEVICE)
        bev_c_b = bev_c.unsqueeze(0).to(DEVICE)
        bev_r_b = bev_r.unsqueeze(0).to(DEVICE)

        # Forward pass
        out = model(bev_l_b, bev_c_b, bev_r_b)
        pred_ensemble, pred_l, pred_c, pred_r, pred_f = [t.cpu().item() for t in out[:5]]
        w_l, w_c, w_r, w_f = [t.cpu().item() for t in out[5:9]]

    # Camera image
    img = load_cam_image(sample_token)

    # BEVs to numpy
    lidar_bev = bev_l.numpy() if hasattr(bev_l, 'numpy') else bev_l
    cam_bev = bev_c.numpy() if hasattr(bev_c, 'numpy') else bev_c
    radar_bev = bev_r.numpy() if hasattr(bev_r, 'numpy') else bev_r

    # 1) Camera image
    plt.figure(figsize=(6, 4))
    plt.imshow(img)
    plt.axis('off')
    plt.title(f'CAM_FRONT (sample {sample_idx})')
    plt.show()

    # 2) Fusion visualization (uses existing helper)
    fig_fusion = visualize_bev_fusion(lidar_bev, cam_bev, radar_bev, distance_label=None)
    plt.show()

    # 3) Predictions + Weights
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))

    branches = ['LIDAR', 'CAMERA', 'RADAR', 'FUSED', 'ENSEMBLE']
    preds = [pred_l, pred_c, pred_r, pred_f, pred_ensemble]
    colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#2ECC71']
    x = np.arange(len(branches))
    bars = axes[0].bar(x, preds, color=colors, edgecolor='black', linewidth=1.5, alpha=0.85)
    axes[0].set_xticks(x)
    axes[0].set_xticklabels(branches, rotation=15)
    axes[0].set_ylabel('Distance (m)')
    axes[0].set_title('Branch Predictions (no GT)')
    axes[0].grid(axis='y', alpha=0.3, linestyle=':')
    axes[0].set_ylim(0, max(preds) * 1.25)
    for bar, val in zip(bars, preds):
        axes[0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.2,
                     f"{val:.2f}m", ha='center', va='bottom', fontsize=9, fontweight='bold')

    weights = [w_l, w_c, w_r, w_f]
    branches_w = ['LIDAR', 'CAM', 'RADAR', 'FUSED']
    xw = np.arange(len(weights))
    bars_w = axes[1].bar(xw, weights, color=colors[:4], edgecolor='black', linewidth=1.5, alpha=0.85)
    axes[1].set_xticks(xw)
    axes[1].set_xticklabels(branches_w, rotation=15)
    axes[1].set_ylabel('Weight (normalized)')
    axes[1].set_title('Branch Weights (this sample)')
    axes[1].grid(axis='y', alpha=0.3, linestyle=':')
    axes[1].set_ylim(0, max(weights) * 1.25)
    for bar, val in zip(bars_w, weights):
        axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                     f"{val:.3f}", ha='center', va='bottom', fontsize=9, fontweight='bold')

    plt.tight_layout()
    plt.show()

    print(f"‚úÖ Sample {sample_idx}: Token={sample_token[:8]}... | Ensemble={pred_ensemble:.2f}m | Weights L/C/R/F = {w_l:.3f}/{w_c:.3f}/{w_r:.3f}/{w_f:.3f}")


# Run on a single random sample
print("\nüé® CAMERA + FUSION + BRANCH WEIGHTS (SINGLE SAMPLE)")
print("="*70)
visualize_camera_fusion_and_weights(model, test_ds)


In [None]:
# ========== SIMPLE SENSOR STREAM + FUSION VIEW ==========
"""
Minimal visualization for one sample:
- CAM_FRONT image
- LIDAR BEV count channel (with nonzero count + max)
- RADAR BEV count channel (with nonzero count + max)
- Fusion (LIDAR + RADAR count) with stats
"""

import matplotlib.pyplot as plt
import random
import numpy as np


def visualize_sensor_stream_simple(dataset, sample_idx=None, figsize=(18, 4)):
    if sample_idx is None:
        sample_idx = random.randint(0, len(dataset) - 1)
    
    bev_l, bev_c, bev_r, dist = dataset[sample_idx]
    sample_token = dataset.sample_tokens[sample_idx]
    img = load_cam_image(sample_token)
    
    # Count channels as numpy
    lidar_count = bev_l[0].cpu().numpy() if hasattr(bev_l, 'cpu') else bev_l[0]
    radar_count = bev_r[0].cpu().numpy() if hasattr(bev_r, 'cpu') else bev_r[0]
    fusion_count = lidar_count + radar_count
    
    def stats(arr):
        nz = int((arr > 0).sum())
        mx = float(arr.max()) if arr.size > 0 else 0.0
        return nz, mx
    lid_nz, lid_max = stats(lidar_count)
    rad_nz, rad_max = stats(radar_count)
    fus_nz, fus_max = stats(fusion_count)
    
    fig, axes = plt.subplots(1, 4, figsize=figsize)
    axes[0].imshow(img)
    axes[0].axis('off')
    axes[0].set_title(f'CAM_FRONT (idx {sample_idx})')
    
    im1 = axes[1].imshow(lidar_count, cmap='viridis', aspect='auto')
    axes[1].set_title(f'LIDAR Count\nnonzero={lid_nz}, max={lid_max:.2f}')
    fig.colorbar(im1, ax=axes[1], fraction=0.046)
    
    im2 = axes[2].imshow(radar_count, cmap='plasma', aspect='auto')
    axes[2].set_title(f'RADAR Count\nnonzero={rad_nz}, max={rad_max:.2f}')
    fig.colorbar(im2, ax=axes[2], fraction=0.046)
    
    im3 = axes[3].imshow(fusion_count, cmap='magma', aspect='auto')
    axes[3].set_title(f'Fusion (L+R)\nnonzero={fus_nz}, max={fus_max:.2f}')
    fig.colorbar(im3, ax=axes[3], fraction=0.046)
    
    plt.tight_layout()
    plt.show()


print("\nüé® SIMPLE SENSOR STREAM + FUSION VIEW")
print("="*70)
visualize_sensor_stream_simple(test_ds)


In [None]:
# ========== WEATHER-AWARE HELPER FUNCTIONS ==========

# Weather modulation factors (applied multiplicatively)
WEATHER_MODULATION = {
    'rain':  {'lidar': 0.7, 'camera': 0.6, 'radar': 1.2, 'fused': 1.0},
    'night': {'lidar': 1.0, 'camera': 0.5, 'radar': 1.1, 'fused': 1.0},
    'fog':   {'lidar': 0.6, 'camera': 0.5, 'radar': 1.3, 'fused': 1.0},
    'snow':  {'lidar': 0.8, 'camera': 0.7, 'radar': 0.9, 'fused': 1.0}
}


def extract_weather_from_description(description):
    """
    Extract weather conditions from scene description string.
    
    Args:
        description: str, scene description (e.g., 'rainy', 'night', 'foggy')
    
    Returns:
        dict: weather conditions with boolean flags for rain, night, fog, snow
    """
    if not description or not isinstance(description, str):
        return {'rain': 0, 'night': 0, 'fog': 0, 'snow': 0}
    
    desc_lower = description.lower()
    
    weather = {
        'rain': 1 if 'rain' in desc_lower or 'wet' in desc_lower else 0,
        'night': 1 if 'night' in desc_lower else 0,
        'fog': 1 if 'fog' in desc_lower or 'overcast' in desc_lower else 0,
        'snow': 1 if 'snow' in desc_lower else 0
    }
    
    return weather


def get_weather_adjusted_weights(base_weights, weather_dict):
    """
    Adjust base branch weights based on weather conditions.
    
    Args:
        base_weights: tuple of (w_lidar, w_camera, w_radar, w_fused) - can be floats or tensors
        weather_dict: dict with keys 'rain', 'night', 'fog', 'snow' (binary 0/1)
    
    Returns:
        tuple: (w_lidar_adj, w_camera_adj, w_radar_adj, w_fused_adj) - adjusted weights
    """
    # Convert tensors to scalars if needed
    w_l, w_c, w_r, w_f = base_weights
    if hasattr(w_l, 'item'):
        w_l = w_l.item()
    if hasattr(w_c, 'item'):
        w_c = w_c.item()
    if hasattr(w_r, 'item'):
        w_r = w_r.item()
    if hasattr(w_f, 'item'):
        w_f = w_f.item()
    
    # Start with base weights
    w_l_adj = w_l
    w_c_adj = w_c
    w_r_adj = w_r
    w_f_adj = w_f
    
    # Apply modulation factors based on active weather conditions
    for condition, is_active in weather_dict.items():
        if is_active and condition in WEATHER_MODULATION:
            factors = WEATHER_MODULATION[condition]
            w_l_adj *= factors['lidar']
            w_c_adj *= factors['camera']
            w_r_adj *= factors['radar']
            w_f_adj *= factors['fused']
    
    # Normalize
    w_sum = w_l_adj + w_c_adj + w_r_adj + w_f_adj + 1e-8
    w_l_adj = w_l_adj / w_sum
    w_c_adj = w_c_adj / w_sum
    w_r_adj = w_r_adj / w_sum
    w_f_adj = w_f_adj / w_sum
    
    return (w_l_adj, w_c_adj, w_r_adj, w_f_adj)


print("‚úÖ WEATHER-AWARE HELPER FUNCTIONS LOADED")
print(f"   - Weather conditions: {list(WEATHER_MODULATION.keys())}")
print(f"   - Extraction & adjustment functions ready")

In [None]:
# ========== WEATHER-AWARE TRAINING WITH WEIGHT TRACKING ==========

# Initialize tracking dictionary
weight_history = {'train': {}, 'val': {}}
loss_history = {'train': [], 'val': []}


def epoch_pass_with_weather(loader, train=False, epoch_idx=0, phase_name='train', use_weather_modulation=True):
    """
    Single epoch pass with weather-aware weight modulation during training.
    Tracks weight evolution for visualization.
    
    Args:
        loader: DataLoader with samples and tokens
        train: bool
        epoch_idx: epoch number
        phase_name: name for progress bar
        use_weather_modulation: if True, apply weather adjustments to weights
    """
    model.train() if train else model.eval()
    total_loss = 0.0
    y_true_all, y_pred_all = [], []
    w_l_list, w_c_list, w_r_list, w_f_list = [], [], [], []
    
    pbar = tqdm(loader, desc=f'{phase_name} epoch {epoch_idx}', leave=False)
    for step, batch in enumerate(pbar):
        bev_l, bev_c, bev_r, dist = batch
        batch_size = bev_l.shape[0]
        
        # Get sample tokens for this batch
        start_idx = step * loader.batch_size
        end_idx = min(start_idx + loader.batch_size, len(loader.dataset))
        batch_tokens = loader.dataset.sample_tokens[start_idx:end_idx]
        
        bev_l = torch.nan_to_num(bev_l.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
        bev_c = torch.nan_to_num(bev_c.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
        bev_r = torch.nan_to_num(bev_r.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
        dist = torch.nan_to_num(dist.to(DEVICE), nan=50.0, posinf=50.0, neginf=0.0)
        
        if train:
            optimizer.zero_grad()
        
        # Forward pass
        out = model(bev_l, bev_c, bev_r)
        pred_ensemble, pred_lidar, pred_cam, pred_radar, pred_fused = out[:5]
        w_l_base, w_c_base, w_r_base, w_f_base = [t.detach() for t in out[5:9]]
        
        # Apply weather modulation if enabled (during training)
        if train and use_weather_modulation:
            # Get weather for each sample in batch
            adjusted_weights_list = []
            for token in batch_tokens:
                if token in sample_to_sensor:
                    weather = extract_weather_from_description(
                        sample_by_token.get(token, {}).get('scene_description', '')
                    )
                else:
                    weather = {'rain': 0, 'night': 0, 'fog': 0, 'snow': 0}
                
                # Adjust weights based on weather
                adj_w = get_weather_adjusted_weights(
                    (w_l_base.mean(), w_c_base.mean(), w_r_base.mean(), w_f_base.mean()),
                    weather
                )
                adjusted_weights_list.append(adj_w)
            
            # Use weather-adjusted weights for ensemble
            adjusted_weights = torch.tensor(adjusted_weights_list, device=DEVICE)
            w_l = torch.tensor([w[0] for w in adjusted_weights_list], device=DEVICE).mean()
            w_c = torch.tensor([w[1] for w in adjusted_weights_list], device=DEVICE).mean()
            w_r = torch.tensor([w[2] for w in adjusted_weights_list], device=DEVICE).mean()
            w_f = torch.tensor([w[3] for w in adjusted_weights_list], device=DEVICE).mean()
        else:
            # Use learned weights (no weather modulation)
            w_l, w_c, w_r, w_f = w_l_base, w_c_base, w_r_base, w_f_base
        
        # Normalize weights
        w_sum = w_l + w_c + w_r + w_f
        w_l_norm = w_l / w_sum
        w_c_norm = w_c / w_sum
        w_r_norm = w_r / w_sum
        w_f_norm = w_f / w_sum
        
        # Track weights (convert to scalar if tensor)
        w_l_val = w_l_norm.item() if hasattr(w_l_norm, 'item') else float(w_l_norm)
        w_c_val = w_c_norm.item() if hasattr(w_c_norm, 'item') else float(w_c_norm)
        w_r_val = w_r_norm.item() if hasattr(w_r_norm, 'item') else float(w_r_norm)
        w_f_val = w_f_norm.item() if hasattr(w_f_norm, 'item') else float(w_f_norm)
        w_l_list.append(w_l_val)
        w_c_list.append(w_c_val)
        w_r_list.append(w_r_val)
        w_f_list.append(w_f_val)
        
        # Compute losses
        loss_ensemble = mse_loss(pred_ensemble, dist)
        loss_l = mse_loss(pred_lidar, dist)
        loss_c = mse_loss(pred_cam, dist)
        loss_r = mse_loss(pred_radar, dist)
        loss_f = mse_loss(pred_fused, dist)
        
        # Apply weather-adjusted weights to branch losses (emphasize reliable branches)
        loss_l = loss_l * w_l_norm.detach()
        loss_c = loss_c * w_c_norm.detach()
        loss_r = loss_r * w_r_norm.detach()
        loss_f = loss_f * w_f_norm.detach()
        
        # Combined loss: ensemble + weather-weighted auxiliary branches
        loss = loss_ensemble + 0.15 * (loss_l + loss_c + loss_r + loss_f)
        
        if train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()
        
        total_loss += loss.item() * dist.size(0)
        y_true_all.extend(dist.detach().cpu().numpy())
        y_pred_all.extend(pred_ensemble.detach().cpu().numpy())
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Metrics
    avg_loss = total_loss / max(len(loader.dataset), 1)
    y_true_all = np.nan_to_num(np.array(y_true_all), nan=50.0, posinf=50.0, neginf=0.0)
    y_pred_all = np.nan_to_num(np.array(y_pred_all), nan=50.0, posinf=50.0, neginf=0.0)
    mse = mean_squared_error(y_true_all, y_pred_all)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(y_true_all, y_pred_all)
    r2 = r2_score(y_true_all, y_pred_all)
    
    # Store weight history
    phase = 'train' if train else 'val'
    weight_history[phase][epoch_idx] = {
        'lidar': np.mean(w_l_list),
        'camera': np.mean(w_c_list),
        'radar': np.mean(w_r_list),
        'fused': np.mean(w_f_list),
        'lidar_std': np.std(w_l_list),
        'camera_std': np.std(w_c_list),
        'radar_std': np.std(w_r_list),
        'fused_std': np.std(w_f_list),
    }
    
    if train:
        loss_history['train'].append(avg_loss)
    else:
        loss_history['val'].append(avg_loss)
    
    return {'loss': avg_loss, 'mse': mse, 'rmse': rmse, 'mae': mae, 'r2': r2}


# Training with weather-aware modulation
best_val_mse = float('inf')
STATE_PATH = '/kaggle/working/utenet4bev_state.pth'
FULL_PATH = '/kaggle/working/utenet4bev_full.pth'

print('\n' + '='*60)
print('üöÄ STARTING TRAINING (Weather-Aware Weight Modulation)')
print('='*60 + '\n')

for epoch in range(1, EPOCHS + 1):
    # Train with weather modulation
    train_metrics = epoch_pass_with_weather(train_loader, train=True, epoch_idx=epoch, 
                                            phase_name='TRAIN', use_weather_modulation=True)
    
    # Val without weather modulation (use learned weights only)
    val_metrics = epoch_pass_with_weather(val_loader, train=False, epoch_idx=epoch, 
                                          phase_name='VAL', use_weather_modulation=False)
    
    print(
        f"Epoch {epoch:2d}/{EPOCHS} | Train RMSE: {train_metrics['rmse']:.4f} | "
        f"Val RMSE: {val_metrics['rmse']:.4f} | R¬≤: {val_metrics['r2']:.4f} | "
        f"Train MAE: {train_metrics['mae']:.4f}"
    )
    
    # Checkpoint best model
    if val_metrics['mse'] < best_val_mse:
        best_val_mse = val_metrics['mse']
        torch.save(model.state_dict(), STATE_PATH)
        torch.save(model, FULL_PATH)
        print(f"   ‚úÖ New best validation MSE: {best_val_mse:.6f}")

print("\n" + "="*60)
print("‚úÖ TRAINING COMPLETE WITH WEATHER-AWARE MODULATION")
print("="*60)

In [None]:
# ========== VISUALIZING BRANCH WEIGHT EVOLUTION DURING TRAINING ==========

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Branch Weight Evolution During Training (Weather-Aware)', fontsize=14, fontweight='bold')

# Extract epoch numbers and weight values
epochs_train = sorted(weight_history['train'].keys())
epochs_val = sorted(weight_history['val'].keys())

# 1. Training Weights with Error Bands
ax = axes[0, 0]
branches = ['lidar', 'camera', 'radar', 'fused']
colors = {'lidar': '#FF6B6B', 'camera': '#4ECDC4', 'radar': '#45B7D1', 'fused': '#FFA07A'}

for branch in branches:
    means = [weight_history['train'][e][branch] for e in epochs_train]
    stds = [weight_history['train'][e][f'{branch}_std'] for e in epochs_train]
    ax.plot(epochs_train, means, marker='o', label=branch.upper(), color=colors[branch], linewidth=2)
    ax.fill_between(epochs_train, 
                     np.array(means) - np.array(stds),
                     np.array(means) + np.array(stds),
                     alpha=0.2, color=colors[branch])

ax.set_xlabel('Epoch', fontsize=11, fontweight='bold')
ax.set_ylabel('Weight', fontsize=11, fontweight='bold')
ax.set_title('Training Weights (Mean ¬± Std)', fontsize=12)
ax.legend(loc='best', fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xticks(epochs_train)

# 2. Validation Weights (no std since validation uses learned weights only)
ax = axes[0, 1]
for branch in branches:
    means = [weight_history['val'][e][branch] for e in epochs_val]
    ax.plot(epochs_val, means, marker='s', label=branch.upper(), color=colors[branch], linewidth=2)

ax.set_xlabel('Epoch', fontsize=11, fontweight='bold')
ax.set_ylabel('Weight', fontsize=11, fontweight='bold')
ax.set_title('Validation Weights (Learned Only)', fontsize=12)
ax.legend(loc='best', fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xticks(epochs_val)

# 3. Weight Distribution at Each Epoch (Train) - Stacked Area
ax = axes[1, 0]
all_means_train = {branch: [weight_history['train'][e][branch] for e in epochs_train] for branch in branches}
ax.stackplot(epochs_train,
             all_means_train['lidar'],
             all_means_train['camera'],
             all_means_train['radar'],
             all_means_train['fused'],
             labels=[b.upper() for b in branches],
             colors=[colors[b] for b in branches],
             alpha=0.7)
ax.set_xlabel('Epoch', fontsize=11, fontweight='bold')
ax.set_ylabel('Normalized Weight', fontsize=11, fontweight='bold')
ax.set_title('Stacked Weight Distribution (Training)', fontsize=12)
ax.legend(loc='upper left', fontsize=10)
ax.grid(True, alpha=0.3, axis='y')
ax.set_xticks(epochs_train)
ax.set_ylim([0, 1])

# 4. Training vs Validation Loss with Weight Divergence
ax = axes[1, 1]
ax2 = ax.twinx()

# Plot losses on primary axis
ax.plot(range(1, len(loss_history['train'])+1), loss_history['train'], 
        marker='o', label='Train Loss', color='#2E86C1', linewidth=2)
ax.plot(range(1, len(loss_history['val'])+1), loss_history['val'], 
        marker='s', label='Val Loss', color='#E74C3C', linewidth=2)
ax.set_xlabel('Epoch', fontsize=11, fontweight='bold')
ax.set_ylabel('Loss', fontsize=11, fontweight='bold', color='#2E86C1')
ax.tick_params(axis='y', labelcolor='#2E86C1')

# Plot weight variance on secondary axis
train_var = [np.var([weight_history['train'][e][b] for b in branches]) for e in epochs_train]
ax2.plot(epochs_train, train_var, marker='^', label='Weight Variance', 
         color='#27AE60', linewidth=2, linestyle='--')
ax2.set_ylabel('Weight Variance', fontsize=11, fontweight='bold', color='#27AE60')
ax2.tick_params(axis='y', labelcolor='#27AE60')

ax.set_title('Loss & Weight Stability Over Time', fontsize=12)
ax.grid(True, alpha=0.3)
ax.set_xticks(epochs_train)

# Combine legends
lines1, labels1 = ax.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right', fontsize=10)

plt.tight_layout()
plt.savefig('/kaggle/working/weight_evolution_training.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n‚úÖ WEIGHT EVOLUTION VISUALIZATION COMPLETE")
print(f"   - Tracked {len(epochs_train)} training epochs")
print(f"   - Tracked {len(epochs_val)} validation epochs")
print(f"   - Final Training Weights: L={weight_history['train'][epochs_train[-1]]['lidar']:.3f}, "
      f"C={weight_history['train'][epochs_train[-1]]['camera']:.3f}, "
      f"R={weight_history['train'][epochs_train[-1]]['radar']:.3f}, "
      f"F={weight_history['train'][epochs_train[-1]]['fused']:.3f}")
print(f"   - Final Validation Weights: L={weight_history['val'][epochs_val[-1]]['lidar']:.3f}, "
      f"C={weight_history['val'][epochs_val[-1]]['camera']:.3f}, "
      f"R={weight_history['val'][epochs_val[-1]]['radar']:.3f}, "
      f"F={weight_history['val'][epochs_val[-1]]['fused']:.3f}")

In [None]:
# ========== WEATHER-AWARE MODEL: BRANCH PREDICTIONS VS GROUND TRUTH ==========
"""
Visualize how the weather-aware trained model performs across all branches.
Compare individual branch predictions against ground truth and ensemble output.
"""

def visualize_weather_trained_branch_performance(model, loader, num_samples=8, figsize=(18, 10)):
    """
    Create detailed overlaid bar charts showing all 4 branch predictions from weather-aware model.
    
    Args:
        model: trained UTENet4BranchBEV (weather-aware)
        loader: DataLoader
        num_samples: number of samples to visualize
        figsize: figure size
    """
    model.eval()
    
    sample_data = []
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if len(sample_data) >= num_samples:
                break
            
            bev_l, bev_c, bev_r, dist = batch
            bev_l = torch.nan_to_num(bev_l.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            bev_c = torch.nan_to_num(bev_c.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            bev_r = torch.nan_to_num(bev_r.to(DEVICE), nan=0.0, posinf=0.0, neginf=0.0)
            dist = torch.nan_to_num(dist.to(DEVICE), nan=50.0, posinf=50.0, neginf=0.0)
            
            # Forward pass
            out = model(bev_l, bev_c, bev_r)
            pred_ensemble = np.atleast_1d(out[0].cpu().numpy()).flatten()
            pred_l = np.atleast_1d(out[1].cpu().numpy()).flatten()
            pred_c = np.atleast_1d(out[2].cpu().numpy()).flatten()
            pred_r = np.atleast_1d(out[3].cpu().numpy()).flatten()
            pred_f = np.atleast_1d(out[4].cpu().numpy()).flatten()
            w_l = np.atleast_1d(out[5].cpu().numpy()).flatten()
            w_c = np.atleast_1d(out[6].cpu().numpy()).flatten()
            w_r = np.atleast_1d(out[7].cpu().numpy()).flatten()
            w_f = np.atleast_1d(out[8].cpu().numpy()).flatten()
            
            dist_gt = np.atleast_1d(dist.cpu().numpy()).flatten()
            
            # Collect for each sample
            for j in range(len(dist_gt)):
                sample_data.append({
                    'gt': dist_gt[j],
                    'ensemble': pred_ensemble[j],
                    'branches': {
                        'LIDAR': pred_l[j],
                        'CAMERA': pred_c[j],
                        'RADAR': pred_r[j],
                        'FUSED': pred_f[j]
                    },
                    'weights': {
                        'LIDAR': w_l[j] if j < len(w_l) else w_l[0],
                        'CAMERA': w_c[j] if j < len(w_c) else w_c[0],
                        'RADAR': w_r[j] if j < len(w_r) else w_r[0],
                        'FUSED': w_f[j] if j < len(w_f) else w_f[0]
                    }
                })
    
    # Create subplots
    nrows = (num_samples + 3) // 4
    fig, axes = plt.subplots(nrows, 4, figsize=figsize)
    if nrows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    
    colors_branch = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A']
    
    # Plot each sample
    for idx, data in enumerate(sample_data[:num_samples]):
        ax = axes[idx]
        
        branch_names = list(data['branches'].keys())
        branch_preds = list(data['branches'].values())
        branch_weights = list(data['weights'].values())
        
        # Create stacked visualization: bars show predictions
        x = np.arange(len(branch_names))
        bars = ax.bar(x, branch_preds, color=colors_branch, alpha=0.75, 
                     edgecolor='black', linewidth=2.5)
        
        # Add ground truth line
        ax.axhline(y=data['gt'], color='red', linestyle='--', linewidth=3, 
                  label=f"GT: {data['gt']:.2f}m", zorder=10)
        
        # Add ensemble prediction marker
        ax.axhline(y=data['ensemble'], color='green', linestyle=':', linewidth=2.5, 
                  label=f"Ensemble: {data['ensemble']:.2f}m", zorder=9)
        
        # Annotate each bar with prediction and weight
        for i, (bar, pred, weight) in enumerate(zip(bars, branch_preds, branch_weights)):
            height = bar.get_height()
            error = abs(pred - data['gt'])
            
            # Main label: prediction value
            ax.text(bar.get_x() + bar.get_width()/2, height/2,
                   f"{pred:.2f}m",
                   ha='center', va='center', fontsize=11, fontweight='bold', 
                   color='white', bbox=dict(boxstyle='round,pad=0.3', facecolor='black', alpha=0.7))
            
            # Top label: weight and error
            ax.text(bar.get_x() + bar.get_width()/2, height + 1.5,
                   f"w:{weight:.2f}\nerr:{error:.2f}m",
                   ha='center', va='bottom', fontsize=9, fontweight='bold')
        
        # Styling
        max_y = max(max(branch_preds), data['gt']) * 1.25
        ax.set_ylim(0, max_y)
        ax.set_ylabel('Distance (m)', fontsize=11, fontweight='bold')
        ax.set_title(f'Sample {idx+1}', fontsize=12, fontweight='bold', pad=8)
        ax.set_xticks(x)
        ax.set_xticklabels(branch_names, fontsize=9, rotation=0)
        ax.legend(fontsize=9, loc='upper right')
        ax.grid(axis='y', alpha=0.3, linestyle=':')
    
    # Hide unused subplots
    for idx in range(len(sample_data), len(axes)):
        axes[idx].axis('off')
    
    plt.suptitle('Weather-Aware Model: Branch Predictions vs Ground Truth',
                fontsize=15, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.savefig('/kaggle/working/weather_aware_branch_predictions.png', dpi=150, bbox_inches='tight')
    return fig, sample_data


# Generate the weather-aware branch visualization
print("\n" + "="*80)
print("üå¶Ô∏è  WEATHER-AWARE MODEL: BRANCH PREDICTION ANALYSIS")
print("="*80)
fig_weather, weather_data = visualize_weather_trained_branch_performance(model, test_loader, num_samples=8)
plt.show()

# Print detailed analysis
print("\nüìä DETAILED PERFORMANCE BREAKDOWN:")
print("="*80)

branch_errors = {'LIDAR': [], 'CAMERA': [], 'RADAR': [], 'FUSED': [], 'ENSEMBLE': []}

for idx, data in enumerate(weather_data[:8]):
    print(f"\nüîç SAMPLE {idx+1}:")
    print(f"   Ground Truth: {data['gt']:.2f}m")
    
    ensemble_error = abs(data['ensemble'] - data['gt'])
    branch_errors['ENSEMBLE'].append(ensemble_error)
    print(f"   Ensemble:    {data['ensemble']:.2f}m (Error: {ensemble_error:.3f}m)")
    
    print(f"\n   Branch Performance:")
    
    for branch in ['LIDAR', 'CAMERA', 'RADAR', 'FUSED']:
        pred = data['branches'][branch]
        weight = data['weights'][branch]
        error = abs(pred - data['gt'])
        error_pct = (error / data['gt'] * 100) if data['gt'] > 0 else 0
        
        branch_errors[branch].append(error)
        
        # Compare to ensemble
        if error < ensemble_error:
            indicator = "‚úÖ BETTER than ensemble"
        elif error > ensemble_error * 1.1:  # Allow 10% tolerance
            indicator = "‚ùå WORSE than ensemble"
        else:
            indicator = "‚ûñ SIMILAR to ensemble"
        
        print(f"      {branch:8} ‚Üí {pred:.2f}m  (w:{weight:.3f}, err:{error:.3f}m ¬±{error_pct:.1f}%) {indicator}")

print("\n" + "="*80)
print("üìà AGGREGATE STATISTICS (Across All Samples):")
print("="*80)

for branch, errors in branch_errors.items():
    mean_err = np.mean(errors)
    std_err = np.std(errors)
    rmse = np.sqrt(np.mean(np.array(errors)**2))
    
    if branch == 'ENSEMBLE':
        print(f"\nüéØ {branch:10} ‚Üí Mean Error: {mean_err:.3f}m ¬± {std_err:.3f}m | RMSE: {rmse:.3f}m")
    else:
        print(f"   {branch:10} ‚Üí Mean Error: {mean_err:.3f}m ¬± {std_err:.3f}m | RMSE: {rmse:.3f}m")

print("\n" + "="*80)
print("‚úÖ WEATHER-AWARE MODEL EVALUATION COMPLETE")
print("="*80)