# IsoNet Training

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchaudio
from torchvision import models
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import cv2
import os 

# --- CONFIGURATION FROM DIAGRAM ---
VISUAL_DIM = 256       # Output of Visual Stream (V)
SPATIAL_DIM = 128      # Output of Spatial Stream (S)
AUDIO_ENC_DIM = 512    # Internal Audio Feature Dimension
AUDIO_CHANNELS = 4     # Number of Mics

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Memory-Optimized Training Configuration for 8GB GPU
BATCH_SIZE = 16         # Batch size for training
GRADIENT_ACCUMULATION_STEPS = 2  # Effective batch size = 16 * 2 = 32
EPOCHS = 100
LR = 1e-4               # TCNs prefer lower learning rates
CHECKPOINT_DIR = "checkpoints"

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Mixed Precision Training
USE_AMP = True  # Automatic Mixed Precision (FP16) saves ~40% memory

# ============================================================
# MODEL FEATURE FLAGS (For Ablation Studies)
# ============================================================
# Turn these off to do audio-only training without multimodal features
USE_VISUAL = True           # Use video/face features (visual stream)
USE_BEAMFORMER = True       # Use neural beamformer (spatial filtering)
USE_SPATIAL_STREAM = True   # Use GCC-PHAT spatial features

# For pure audio-only baseline, set all to False:
# USE_VISUAL = False
# USE_BEAMFORMER = False
# USE_SPATIAL_STREAM = False

print(f"Device: {DEVICE}")
print(f"Mixed Precision: {USE_AMP}")
print(f"Batch Size: {BATCH_SIZE} | Gradient Accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"Effective Batch Size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"\nModel Configuration:")
print(f"  Visual Stream: {USE_VISUAL}")
print(f"  Neural Beamformer: {USE_BEAMFORMER}")
print(f"  Spatial Stream: {USE_SPATIAL_STREAM}")
if not USE_VISUAL and not USE_BEAMFORMER and not USE_SPATIAL_STREAM:
    print("  -> AUDIO-ONLY MODE")

Device: cuda
Mixed Precision: True
Batch Size: 4 | Gradient Accumulation: 8
Effective Batch Size: 32

Model Configuration:
  Visual Stream: True
  Neural Beamformer: True
  Spatial Stream: True


In [12]:
torch.cuda.empty_cache()        # free cached memory
torch.cuda.synchronize()        # wait for all kernels to finish

In [13]:
# ============================================================
# MEMORY OPTIMIZATION FOR 8GB GPU
# ============================================================

# Clear cache and optimize PyTorch settings
torch.cuda.empty_cache()
torch.cuda.synchronize()

# Enable cuDNN benchmarking for faster training (if input sizes are fixed)
torch.backends.cudnn.benchmark = True

# Enable TF32 on Ampere GPUs for faster training (A100, A6000, RTX 30/40 series)
if hasattr(torch.backends.cuda, 'matmul'):
    torch.backends.cuda.matmul.allow_tf32 = True
if hasattr(torch.backends.cudnn, 'allow_tf32'):
    torch.backends.cudnn.allow_tf32 = True

# Memory optimization
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

# Print GPU info
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"Available Memory: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1e9:.2f} GB")
    print(f"Compute Capability: {torch.cuda.get_device_properties(0).major}.{torch.cuda.get_device_properties(0).minor}")

GPU: NVIDIA GeForce RTX 3050 6GB Laptop GPU
Total Memory: 6.09 GB
Available Memory: 6.09 GB
Compute Capability: 8.6


In [17]:
# ============================================================
# DATASET PATHS CONFIGURATION (Centralized)
# ============================================================
# Update ONLY this cell when changing data locations

import platform
from pathlib import Path
import pandas as pd

# ------------------------------------------------------------
# OS DETECTION
# ------------------------------------------------------------
IS_WINDOWS = platform.system() == "Windows"

# ------------------------------------------------------------
# PATH CONFIGURATION
# ------------------------------------------------------------

# Windows paths
WINDOWS_PATHS = {
    "root_dir": r"C:\Users\bibek\isolate-speech",
    "multich_dir": r"C:\Users\bibek\isolate-speech\data\multich",
    "mp4_dir": r"C:\Users\bibek\isolate-speech\data\mp4",
    "clean_4ch_dir": r"C:\Users\bibek\isolate-speech\data\clean_4ch_duped",  # Local writeable directory for Windows
}

# Linux / Kaggle paths
LINUX_PATHS = {
    "root_dir": "/run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich",
    "multich_dir": "/run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich",
    "mp4_dir": "/run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich/mp4",
    "clean_4ch_dir": "/run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/clean_4ch_duped",  # Local writeable directory for Linux
}

PATHS = WINDOWS_PATHS if IS_WINDOWS else LINUX_PATHS

# ------------------------------------------------------------
# GLOBAL PATH OBJECTS
# ------------------------------------------------------------
ROOT_DIR = Path(PATHS["root_dir"])
MULTICH_DIR = Path(PATHS["multich_dir"])
MP4_DIR = Path(PATHS["mp4_dir"])

# Subdirectories
MIXED_DIR = MULTICH_DIR / "mixed"
CLEAN_DIR = MULTICH_DIR / "clean"
VIDEO_DIR = MULTICH_DIR / "video"

# Metadata
METADATA_CSV = MULTICH_DIR / "metadata.csv"
TRAIN_CSV = MULTICH_DIR / "train.csv" if (MULTICH_DIR / "train.csv").exists() else METADATA_CSV
VAL_CSV   = MULTICH_DIR / "val.csv"   if (MULTICH_DIR / "val.csv").exists()   else METADATA_CSV
TEST_CSV  = MULTICH_DIR / "test.csv"  if (MULTICH_DIR / "test.csv").exists()  else METADATA_CSV

# ------------------------------------------------------------
# LOCAL WRITEABLE DIRECTORY (FIXED: Platform-specific)
# ------------------------------------------------------------
# MUST be Path (not string) and OS-aware
LOCAL_CLEAN_4CH_DIR = Path(PATHS["clean_4ch_dir"])

# Create directory if missing
LOCAL_CLEAN_4CH_DIR.mkdir(parents=True, exist_ok=True)

# ------------------------------------------------------------
# PRINT CONFIGURATION
# ------------------------------------------------------------
print("=" * 60)
print("DATA PATHS CONFIGURATION")
print("=" * 60)
print(f"OS:                   {platform.system()}")
print(f"ROOT_DIR:             {ROOT_DIR}")
print(f"MULTICH_DIR:          {MULTICH_DIR}")
print(f"MP4_DIR:              {MP4_DIR}")
print(f"MIXED_DIR:            {MIXED_DIR}")
print(f"CLEAN_DIR:            {CLEAN_DIR}")
print(f"VIDEO_DIR:            {VIDEO_DIR}")
print(f"LOCAL_CLEAN_4CH_DIR:  {LOCAL_CLEAN_4CH_DIR}")
print(f"METADATA_CSV:         {METADATA_CSV}")
print(f"TRAIN_CSV:            {TRAIN_CSV}")
print(f"VAL_CSV:              {VAL_CSV}")
print(f"TEST_CSV:             {TEST_CSV}")
print("=" * 60)

# ------------------------------------------------------------
# PATH VERIFICATION (SAFE)
# ------------------------------------------------------------
print("\n--- PATH VERIFICATION ---")
for name, path in [
    ("MULTICH_DIR", MULTICH_DIR),
    ("MP4_DIR", MP4_DIR),
    ("MIXED_DIR", MIXED_DIR),
    ("CLEAN_DIR", CLEAN_DIR),
    ("VIDEO_DIR", VIDEO_DIR),
    ("LOCAL_CLEAN_4CH_DIR", LOCAL_CLEAN_4CH_DIR),
]:
    status = "OK" if path.exists() else "MISSING"
    print(f"  {name:<22}: {status}")

DATA PATHS CONFIGURATION
OS:                   Linux
ROOT_DIR:             /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich
MULTICH_DIR:          /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich
MP4_DIR:              /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich/mp4
MIXED_DIR:            /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich/mixed
CLEAN_DIR:            /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich/clean
VIDEO_DIR:            /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich/video
LOCAL_CLEAN_4CH_DIR:  /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/clean_4ch_duped
METADATA_CSV:         /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich/metadata.csv
TRAIN_CSV:            /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich/train.csv
VAL_CSV:              /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich/val.csv
TEST_CSV:             /run/media/neuronetix/BACKUP/Dataset/VOX/man

## Spatial-Visual Augmentation Demo

This demonstrates how we map face positions to match the spatial angles from the metadata.
The core insight: VoxCeleb videos have **centered faces** (frontal interviews), but our audio simulation places speakers at **varying angles**. 

**Solution**: Shift the face position in the frame to match the spatial angle!
- Azimuth (horizontal angle) â†’ Shift face left/right
- Elevation (vertical angle) â†’ Shift face up/down
- Distance â†’ Scale face size

In [18]:
# ============================================================
# VIDEO SIZE CONFIGURATION
# ============================================================
# Using larger input size (336x336) instead of 224x224
# This gives more room for face shifting in jittering augmentation

import cv2
import numpy as np
import torch

# ===== VIDEO SIZE CONFIGURATION =====
VIDEO_SIZE_SMALL = (224, 224)   # Original ResNet input
VIDEO_SIZE_MEDIUM = (336, 336)  # 1.5x larger - better for face jittering
VIDEO_SIZE_LARGE = (448, 448)   # 2x larger

# Choose active size - using medium for better face coverage with jittering
VIDEO_SIZE = VIDEO_SIZE_MEDIUM  # 336x336 for better face coverage

print(f"Video Size: {VIDEO_SIZE}")
print(f"Augmentation: Jittering (crop box shift, no padding)")

Video Size: (224, 224)
Augmentation: Jittering (crop box shift, no padding)


In [19]:
# ============================================================
# METADATA OVERVIEW (Optional - for exploration)
# ============================================================

print(f"Loading metadata from: {METADATA_CSV}")
spatial_meta = pd.read_csv(METADATA_CSV)
print(f"Total samples: {len(spatial_meta)}")
print(f"Columns: {list(spatial_meta.columns)}")

Loading metadata from: /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich/metadata.csv
Total samples: 15000
Columns: ['filename', 'source_wav', 'source_video', 'start_time', 'mixed_audio', 'clean_audio', 'video_file', 'target_azimuth', 'target_elevation', 'target_distance', 'room_x', 'room_y', 'room_z', 'rt60', 'snr_db']


In [20]:
# ============================================================
# SKIP: Old augmentation visualization removed
# ============================================================
# The old SpatialVisualAugmenter (shift & pad) has been replaced
# with proper jittering in the Dataset class.
# No visualization needed - jittering just shifts the crop box.

print("Augmentation: Jittering (shift crop box before extract)")

Augmentation: Jittering (shift crop box before extract)


In [21]:
# ============================================================
# CREATE 4-CHANNEL CLEAN AUDIO IN LOCAL WORKING DIRECTORY
# ============================================================
import torchaudio
from tqdm import tqdm

print(f"Creating 4-channel clean audio copies...")
print(f"Source:      {CLEAN_DIR}")
print(f"Destination: {LOCAL_CLEAN_4CH_DIR}")

# Get list of clean files from metadata
if Path(METADATA_CSV).exists():
    meta_df = pd.read_csv(METADATA_CSV)
    clean_files = meta_df['filename'].tolist()
    print(f"Found {len(clean_files)} files in metadata")
    
    # Process files
    success_count = 0
    skip_count = 0
    
    for filename in tqdm(clean_files, desc="Duplicating to 4-ch"):
        src_path = CLEAN_DIR / f"{filename}.wav"
        dst_path = LOCAL_CLEAN_4CH_DIR / f"{filename}.wav"
        
        # Skip if already exists
        if dst_path.exists():
            skip_count += 1
            continue
        
        # Load mono clean
        try:
            audio, sr = torchaudio.load(src_path)  # [1, T]
            
            # Duplicate to 4 channels
            audio_4ch = audio.repeat(4, 1)  # [4, T]
            
            # Save to local directory
            torchaudio.save(dst_path, audio_4ch, sr)
            success_count += 1
            
        except Exception as e:
            print(f"Error processing {filename}: {e}")
            continue
    
    print(f"\nâœ“ Created {success_count} new 4-ch files")
    print(f"  Skipped {skip_count} existing files")
    print(f"  Total: {success_count + skip_count}/{len(clean_files)}")
    
    # Verify a sample
    if success_count > 0 or skip_count > 0:
        sample_file = LOCAL_CLEAN_4CH_DIR / f"{clean_files[0]}.wav"
        if sample_file.exists():
            test_audio, test_sr = torchaudio.load(sample_file)
            print(f"\nâœ“ Verification - Sample shape: {test_audio.shape} (Expected: [4, 64000])")
        else:
            print(f"\nWarning: Sample file not found for verification")
else:
    print(f"Metadata CSV not found at: {METADATA_CSV}")

Creating 4-channel clean audio copies...
Source:      /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/multich/clean
Destination: /run/media/neuronetix/BACKUP/Dataset/VOX/manual/dev/clean_4ch_duped
Found 15000 files in metadata


Duplicating to 4-ch: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 15000/15000 [00:44<00:00, 334.22it/s]  


âœ“ Created 2733 new 4-ch files
  Skipped 12267 existing files
  Total: 15000/15000

âœ“ Verification - Sample shape: torch.Size([4, 64000]) (Expected: [4, 64000])





## Create 4-Channel Clean Audio (Local Copy)

Duplicate the mono clean files from the remote dataset to 4-channel in local working directory. This avoids writing to the remote dataset and keeps everything in the project folder.

In [None]:
class IsoNetDataset(Dataset):
    def __init__(self, csv_path, clip_length=4.0, fps=25, video_size=(224, 224), max_samples=None, 
                 augment=True, jitter_pct=0.1):
        """
        Args:
            csv_path (str): Path to train.csv or val.csv
            clip_length (float): Audio duration in seconds (must match simulation)
            fps (int): Target frames per second for video (VoxCeleb is 25)
            video_size (tuple): Target resize dimension (H, W) - 224x224 for ResNet-18
            max_samples (int, optional): Limit dataset to first N samples for testing
            augment (bool): Apply jittering augmentation (for training, disable for val/test)
            jitter_pct (float): Max jitter as fraction of frame size (0.1 = 10%)
        """
        self.meta = pd.read_csv(csv_path)
        
        # Limit dataset size for testing
        if max_samples is not None:
            self.meta = self.meta.head(max_samples)
            print(f"Debug Mode: Using only {len(self.meta)} samples")
        
        # Root dir = multich folder containing mixed/, clean/, video/ subdirs
        self.root_dir = Path(csv_path).parent
        
        self.clip_length = clip_length
        self.fps = fps
        self.target_frames = int(clip_length * fps)  # 4.0 * 25 = 100 frames
        self.video_size = video_size
        
        # Jittering augmentation (shift crop box, no padding!)
        self.augment = augment
        self.jitter_pct = jitter_pct
        
        # Check if spatial metadata is available
        self.has_spatial_meta = all(col in self.meta.columns for col in ['target_azimuth', 'target_elevation', 'target_distance'])
        
        print(f"Dataset: {len(self.meta)} samples | Root: {self.root_dir}")
        print(f"  Augment: {augment} (jitter={jitter_pct*100:.0f}%)")

    def load_video_frames(self, video_path, start_time):
        """
        Load video frames with optional jittering augmentation.
        
        Jittering = shift the crop box BEFORE extracting, not after.
        This gives the model slightly different views without any padding.
        """
        cap = cv2.VideoCapture(str(video_path))
        
        if not cap.isOpened():
            # Return black frames as fallback
            frames = [np.zeros((self.video_size[0], self.video_size[1], 3), dtype=np.uint8)] * self.target_frames
            buffer = np.array(frames, dtype=np.float32) / 255.0
            return torch.from_numpy(buffer).permute(3, 0, 1, 2)
        
        # Get video properties
        vid_fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        # Seek to start frame
        start_frame_idx = int(start_time * vid_fps)
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame_idx)
        
        # Calculate crop box (we crop a square from center, then resize)
        # Jittering: randomly shift the crop box slightly
        crop_size = min(frame_width, frame_height)
        
        # Default center crop
        cx, cy = frame_width // 2, frame_height // 2
        
        if self.augment:
            # Random jitter: shift center by up to jitter_pct of frame size
            max_shift_x = int(frame_width * self.jitter_pct)
            max_shift_y = int(frame_height * self.jitter_pct)
            
            jitter_x = np.random.randint(-max_shift_x, max_shift_x + 1) if max_shift_x > 0 else 0
            jitter_y = np.random.randint(-max_shift_y, max_shift_y + 1) if max_shift_y > 0 else 0
            
            cx += jitter_x
            cy += jitter_y
        
        # Calculate crop bounds (ensure we stay within frame)
        half = crop_size // 2
        x1 = max(0, min(cx - half, frame_width - crop_size))
        y1 = max(0, min(cy - half, frame_height - crop_size))
        x2 = x1 + crop_size
        y2 = y1 + crop_size
        
        frames = []
        for _ in range(self.target_frames):
            ret, frame = cap.read()
            if not ret:
                break
            
            # BGR to RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            
            # Crop (jittered or centered)
            frame = frame[y1:y2, x1:x2]
            
            # Resize to target size
            frame = cv2.resize(frame, self.video_size)
            frames.append(frame)
        
        cap.release()
        
        # Pad with last frame if video ended early
        if len(frames) < self.target_frames:
            if len(frames) == 0:
                frames = [np.zeros((self.video_size[0], self.video_size[1], 3), dtype=np.uint8)] * self.target_frames
            else:
                frames.extend([frames[-1]] * (self.target_frames - len(frames)))
        
        # Convert to Tensor: [Time, H, W, C] -> [C, Time, H, W]
        buffer = np.array(frames, dtype=np.float32) / 255.0
        return torch.from_numpy(buffer).permute(3, 0, 1, 2)

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        row = self.meta.iloc[idx]
        
        # 1. Get Paths & Info
        filename = row['filename']
        start_time = float(row['start_time'])
        
        # Build full paths
        # Mixed & Video from MULTICH_DIR (remote)
        # Clean from LOCAL_CLEAN_4CH_DIR (local working directory, 4-channel)
        mixed_path = self.root_dir / "mixed" / f"{filename}.wav"
        clean_path = LOCAL_CLEAN_4CH_DIR / f"{filename}.wav"  # Use local 4-ch clean
        video_path = self.root_dir / "video" / f"{filename}.mp4"

        # 2. Load Audio
        mixed_audio, _ = torchaudio.load(mixed_path)
        clean_audio, _ = torchaudio.load(clean_path)

        # 3. Load Video (with jittering if augment=True)
        video_tensor = self.load_video_frames(video_path, start_time)

        # 4. Collect spatial metadata (for reference, not for augmentation)
        spatial_meta = {}
        if self.has_spatial_meta:
            spatial_meta = {
                'azimuth': float(row['target_azimuth']),
                'elevation': float(row['target_elevation']),
                'distance': float(row['target_distance']),
                'snr_db': float(row.get('snr_db', 0.0))
            }

        # 5. Ensure audio length matches exactly (4.0s @ 16kHz = 64000 samples)
        target_samples = int(self.clip_length * 16000)
        
        if mixed_audio.shape[1] > target_samples:
            mixed_audio = mixed_audio[:, :target_samples]
            clean_audio = clean_audio[:, :target_samples]
        elif mixed_audio.shape[1] < target_samples:
            pad_size = target_samples - mixed_audio.shape[1]
            mixed_audio = torch.nn.functional.pad(mixed_audio, (0, pad_size))
            clean_audio = torch.nn.functional.pad(clean_audio, (0, pad_size))

        return mixed_audio, clean_audio, video_tensor, spatial_meta

In [None]:
class VisualStream(nn.Module):
    def __init__(self):
        super(VisualStream, self).__init__()
        # Load ResNet-18
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        
        # Remove classification head
        modules = list(resnet.children())[:-1] 
        self.resnet = nn.Sequential(*modules)
        
        # Project 512 -> 256 (V)
        self.projection = nn.Sequential(
            nn.Linear(512, VISUAL_DIM),
            nn.PReLU()
        )
        
        # FIXED: ImageNet normalization for pretrained ResNet
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, x):
        # x: [Batch, 3, Time, H, W] where H, W can be any size (e.g., 336x336)
        B, C, T, H, W = x.shape
        
        # Fold Time into Batch
        x = x.permute(0, 2, 1, 3, 4).contiguous().view(B * T, C, H, W)
        
        # FIXED: Apply ImageNet normalization before ResNet
        x = (x - self.mean) / self.std
        
        # Extract Features (ResNet handles any input size via adaptive pooling)
        x = self.resnet(x)       # [B*T, 512, 1, 1]
        x = x.view(B * T, -1)    # [B*T, 512]
        
        # Project to 256
        x = self.projection(x)   # [B*T, 256]
        
        # Unfold Time
        x = x.view(B, T, -1).permute(0, 2, 1) # [B, 256, Time]
        
        return x

In [None]:
class SpatialStream(nn.Module):
    def __init__(self, num_mics=4):
        super(SpatialStream, self).__init__()
        
        # We compute GCC-PHAT for all pairs. 
        # For 4 mics, pairs = 4*(3)/2 = 6 pairs.
        self.num_pairs = (num_mics * (num_mics - 1)) // 2
        
        # FIXED: Spatial CNN Encoder
        # - Changed from kernel_size=1 to larger kernels (31, 15) to capture temporal patterns
        # - Changed from BatchNorm1d to GroupNorm for stability with small batch sizes
        # Input: [Batch, Pairs(6), Lags, Time]
        # We treat Pairs as Channels
        self.encoder = nn.Sequential(
            nn.Conv1d(self.num_pairs, 64, kernel_size=31, stride=1, padding=15),
            nn.GroupNorm(1, 64),  # Changed from BatchNorm1d
            nn.PReLU(),
            nn.Conv1d(64, 128, kernel_size=15, stride=1, padding=7),
            nn.GroupNorm(1, 128),  # Changed from BatchNorm1d
            nn.PReLU(),
            nn.Conv1d(128, SPATIAL_DIM, kernel_size=1, stride=1)
        )

    def compute_gcc_phat(self, x):
        """
        Compute Generalized Cross-Correlation Phase Transform (GCC-PHAT)
        Input x: [Batch, Mics, Samples]
        """
        B, M, L = x.shape
        
        # 1. FFT
        # n_fft matches window size roughly
        X = torch.fft.rfft(x, dim=-1)
        
        # 2. Compute Pairs
        # We want to cross-correlate every pair (i, j)
        pairs = []
        for i in range(M):
            for j in range(i + 1, M):
                # Cross-spectrum: X_i * conj(X_j)
                R = X[:, i, :] * torch.conj(X[:, j, :])
                # Normalization (PHAT): Divide by magnitude
                R = R / (torch.abs(R) + 1e-8)
                # IFFT to get time-domain correlation
                r = torch.fft.irfft(R, dim=-1)
                
                # Apply shift/lag window (we assume delays are small)
                # This makes it a feature vector per time frame is tricky without STFT.
                # Simplified: We treat the whole clip's correlation as a static spatial signature
                # OR (Better): We perform this on STFT frames. 
                
                # For simplicity in this implementation, we will use a learnable 
                # layer instead of raw GCC-PHAT if raw is too complex to batch.
                # BUT, let's assume the input here is actually the GCC features.
                pairs.append(r)
                
        return torch.stack(pairs, dim=1) # [B, 6, Samples]

    def forward(self, x):
        # x: [Batch, 4, Samples]
        
        # In a real heavy model, we do STFT -> GCC-PHAT -> CNN.
        # Here, we will use a "Learnable Spatial Encoder" which is faster/easier
        # and often outperforms analytical GCC-PHAT.
        
        # 1. Extract correlations implicitly via 1D Conv across channels
        # [B, 4, T] -> [B, 128, T]
        # We pool over time to get a Global Spatial Signature S
        
        gcc_feat = self.compute_gcc_phat(x) # [B, 6, Samples]
        
        # Encode features
        x = self.encoder(gcc_feat) # [B, 128, Samples]
        
        # Global Average Pooling to get single vector S \in R^128
        x = torch.mean(x, dim=-1)  # [B, 128]
        
        return x

In [None]:
# ============================================================
# NEURAL BEAMFORMER (TEMPORAL)
# ============================================================
# This applies spatial filtering BEFORE the audio encoder
# to focus on the target direction indicated by visual cues
# NOW WITH TEMPORAL TRACKING for moving speakers

class NeuralBeamformer(nn.Module):
    """
    Learnable beamformer that combines multi-channel audio
    based on spatial conditioning from visual stream.
    
    TEMPORAL VERSION: Accepts frame-by-frame visual features
    to track speakers who move during the video segment.
    """
    
    def __init__(self, num_mics=4, n_fft=512, hop_length=128, conditioning_dim=256):
        super(NeuralBeamformer, self).__init__()
        self.num_mics = num_mics
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.num_freqs = n_fft // 2 + 1
        
        # Temporal weight network: processes time-varying visual features
        # Input: [B, conditioning_dim, T_visual]
        # Output: [B, M * F * 2, T_stft] - beamforming weights per time frame
        self.weight_net = nn.Sequential(
            nn.Conv1d(conditioning_dim, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(256, num_mics * self.num_freqs * 2, kernel_size=1)  # Real + Imag for each freq & mic
        )
        
        # Window for STFT
        self.register_buffer('window', torch.hann_window(n_fft))
        
    def forward(self, audio, visual_condition):
        """
        Args:
            audio: [B, num_mics, samples] - Multi-channel audio
            visual_condition: [B, conditioning_dim, T_visual] - Temporal visual features
        
        Returns:
            beamformed: [B, 1, samples] - Spatially filtered audio
        """
        B, M, L = audio.shape
        
        # Disable autocast for complex operations (ComplexHalf not fully supported)
        with torch.amp.autocast(device_type='cuda', enabled=False):
            # Force float32 for STFT operations
            audio = audio.float()
            visual_condition = visual_condition.float()
            
            # 1. STFT on all channels: [B, M, F, T_stft] complex
            stft_list = []
            for m in range(M):
                stft_m = torch.stft(
                    audio[:, m, :], 
                    n_fft=self.n_fft, 
                    hop_length=self.hop_length,
                    window=self.window,
                    return_complex=True
                )  # [B, F, T_stft]
                stft_list.append(stft_m)
            
            X = torch.stack(stft_list, dim=1)  # [B, M, F, T_stft]
            F_bins, T_stft = X.shape[2], X.shape[3]
            
            # 2. Upsample visual features to match STFT time resolution
            # visual_condition: [B, conditioning_dim, T_visual] -> [B, conditioning_dim, T_stft]
            visual_upsampled = F.interpolate(
                visual_condition, 
                size=T_stft, 
                mode='linear', 
                align_corners=False
            )
            
            # 3. Compute temporal beamforming weights from visual conditioning
            # [B, conditioning_dim, T_stft] -> [B, M * F * 2, T_stft]
            weights_temporal = self.weight_net(visual_upsampled)
            
            # Reshape to [B, M, F, 2, T_stft] (real, imag, time)
            weights_temporal = weights_temporal.view(B, M, F_bins, 2, T_stft)
            
            # Convert to complex: [B, M, F, T_stft] - now in float32
            W = torch.complex(weights_temporal[..., 0, :], weights_temporal[..., 1, :])
            
            # Normalize weights (sum to 1 per frequency per time frame for stable beamforming)
            W = W / (torch.abs(W).sum(dim=1, keepdim=True) + 1e-8)
            
            # 4. Apply temporal beamforming: weighted sum across microphones
            # X: [B, M, F, T_stft], W: [B, M, F, T_stft] -> [B, F, T_stft]
            Y = (X * W).sum(dim=1)  # [B, F, T_stft]
            
            # 5. iSTFT to get time-domain output
            beamformed = torch.istft(
                Y,
                n_fft=self.n_fft,
                hop_length=self.hop_length,
                window=self.window,
                length=L
            )  # [B, samples]
        
        return beamformed.unsqueeze(1)  # [B, 1, samples]


class FiLMLayer(nn.Module):
    def __init__(self, in_channels, cond_dim):
        super(FiLMLayer, self).__init__()
        # We map the Conditioning (S+V) to Gamma (Scale) and Beta (Shift)
        self.conv_gamma = nn.Conv1d(cond_dim, in_channels, 1)
        self.conv_beta = nn.Conv1d(cond_dim, in_channels, 1)

    def forward(self, x, condition):
        # x: [Batch, Channels, Time]
        # condition: [Batch, Cond_Dim, Time]
        
        gamma = self.conv_gamma(condition)  # [B, C, T]
        beta = self.conv_beta(condition)    # [B, C, T]
            
        # FiLM Formula: Gamma * x + Beta
        return (gamma * x) + beta

class ExtractionBlock(nn.Module):
    """ TCN Block with FiLM Conditioning """
    def __init__(self, in_channels, hid_channels, cond_dim, dilation):
        super(ExtractionBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, hid_channels, 1)
        self.norm1 = nn.GroupNorm(1, hid_channels)
        self.prelu1 = nn.PReLU()
        
        # FiLM comes after first activation usually
        self.film = FiLMLayer(hid_channels, cond_dim)
        
        self.dconv = nn.Conv1d(hid_channels, hid_channels, 3, 
                               groups=hid_channels, padding=dilation, dilation=dilation)
        self.norm2 = nn.GroupNorm(1, hid_channels)
        self.prelu2 = nn.PReLU()
        
        self.conv2 = nn.Conv1d(hid_channels, in_channels, 1)

    def forward(self, x, condition):
        residual = x
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.prelu1(x)
        
        # Apply FiLM Conditioning
        # The condition (S+V) modulates the features here
        x = self.film(x, condition)
        
        x = self.dconv(x)
        x = self.norm2(x)
        x = self.prelu2(x)
        
        x = self.conv2(x)
        return x + residual

In [None]:
class IsoNet(nn.Module):
    def __init__(self, use_checkpointing=False, use_beamformer=True, use_spatial_stream=True, use_visual=True):
        super(IsoNet, self).__init__()
        
        # Feature flags
        self.use_beamformer = use_beamformer
        self.use_spatial_stream = use_spatial_stream
        self.use_visual = use_visual
        
        # 1. Visual Stream (optional)
        if use_visual:
            self.visual_stream = VisualStream()
        else:
            self.visual_stream = None

        # 2. Spatial Stream (optional)
        if use_spatial_stream:
            self.spatial_stream = SpatialStream()
        else:
            self.spatial_stream = None

        # 3. Neural Beamformer (optional)
        if use_beamformer:
            self.beamformer = NeuralBeamformer(conditioning_dim=VISUAL_DIM)
            audio_channels_for_encoder = 1
        else:
            self.beamformer = None
            audio_channels_for_encoder = AUDIO_CHANNELS

        # 4. Audio Encoder (takes beamformed audio or raw 4-channel)
        self.audio_enc = nn.Conv1d(
            audio_channels_for_encoder, 
            AUDIO_ENC_DIM, 
            kernel_size=16, 
            stride=8, 
            bias=False
        )

        # Calculate conditioning dimension
        self.cond_dim = 0
        if use_spatial_stream:
            self.cond_dim += SPATIAL_DIM
        if use_visual:
            self.cond_dim += VISUAL_DIM

        # 5. TCN with FiLM (or without if no conditioning)
        if self.cond_dim > 0:
            self.tcn_blocks = nn.ModuleList([
                ExtractionBlock(AUDIO_ENC_DIM, 128, self.cond_dim, dilation=2**i) 
                for i in range(8)
            ])
        else:
            # Audio-only TCN without FiLM conditioning
            self.tcn_blocks = nn.ModuleList([
                nn.Sequential(
                    nn.Conv1d(AUDIO_ENC_DIM, 128, kernel_size=3, dilation=2**i, padding=2**i),
                    nn.PReLU(),
                    nn.Conv1d(128, AUDIO_ENC_DIM, kernel_size=1),
                    nn.PReLU()
                )
                for i in range(8)
            ])

        # 6. Mask Decoder
        self.mask_conv = nn.Conv1d(AUDIO_ENC_DIM, AUDIO_ENC_DIM, 1)
        self.sigmoid = nn.Sigmoid()

        # 7. Audio Decoder (Reconstructs waveform)
        self.audio_dec = nn.ConvTranspose1d(AUDIO_ENC_DIM, 1, kernel_size=16, stride=8, bias=False)

        # Gradient Checkpointing for memory savings
        self.use_checkpointing = use_checkpointing

    def forward(self, audio_mix, video_frames=None):
        """
        Args:
            audio_mix: [B, 4, Samples] - 4-channel microphone array input
            video_frames: [B, 3, T_v, H, W] - Video frames showing target speaker (optional)

        Returns:
            clean_speech: [B, 1, Samples] - Isolated speech of target speaker
        """
        # --- A. Visual Stream (optional) ---
        if self.use_visual and self.visual_stream is not None and video_frames is not None:
            # Get temporal visual embedding: [B, 256, T_v]
            visual_feat = self.visual_stream(video_frames)
        else:
            visual_feat = None

        # --- B. Spatial Stream (optional) ---
        if self.use_spatial_stream and self.spatial_stream is not None:
            # Get global spatial embedding S: [B, 128]
            S = self.spatial_stream(audio_mix)
        else:
            S = None

        # --- C. TEMPORAL NEURAL BEAMFORMING (optional) ---
        if self.use_beamformer and self.beamformer is not None and visual_feat is not None:
            # Beamform with temporal visual features: [B, 4, samples] -> [B, 1, samples]
            # visual_feat: [B, 256, T_v] provides frame-by-frame conditioning
            audio_beamformed = self.beamformer(audio_mix, visual_feat)
        else:
            # No beamforming: need to match expected input channels for audio encoder
            if self.beamformer is not None:
                # Beamformer was initialized, so encoder expects 1 channel
                audio_beamformed = audio_mix.mean(dim=1, keepdim=True)
            else:
                # No beamformer, encoder expects 4 channels - keep as is
                audio_beamformed = audio_mix

        # --- D. Audio Encoding ---
        audio_feat = self.audio_enc(audio_beamformed)  # [B, 512, T_a]

        # --- E. Conditioning ---
        if self.cond_dim > 0:
            # Build conditioning based on available streams
            conditions = []
            
            if self.use_visual and visual_feat is not None:
                # Upsample visual features to match audio time dimension
                visual_upsampled = F.interpolate(visual_feat, size=audio_feat.shape[-1], mode='nearest')
                conditions.append(visual_upsampled)
            
            if self.use_spatial_stream and S is not None:
                # Expand S to match time dimension: [B, 128] -> [B, 128, T_a]
                S_expanded = S.unsqueeze(-1).expand(-1, -1, audio_feat.shape[-1])
                conditions.append(S_expanded)
            
            # Concatenate all conditioning signals
            condition = torch.cat(conditions, dim=1) if len(conditions) > 0 else None
        else:
            condition = None

        # --- F. TCN Processing ---
        x = audio_feat

        if self.cond_dim > 0:
            # FiLM-conditioned TCN
            if self.use_checkpointing and self.training:
                for block in self.tcn_blocks:
                    x = torch.utils.checkpoint.checkpoint(block, x, condition, use_reentrant=False)
            else:
                for block in self.tcn_blocks:
                    x = block(x, condition)
        else:
            # Audio-only TCN (no conditioning)
            for block in self.tcn_blocks:
                residual = x
                x = block(x)
                x = x + residual  # Residual connection
            
        # --- G. Masking & Decoding ---
        mask = self.sigmoid(self.mask_conv(x))
        masked_feat = audio_feat * mask
        clean_speech = self.audio_dec(masked_feat)

        return clean_speech

# Print architecture summary
print("IsoNet Architecture (WITH TEMPORAL BEAMFORMING):")
print("=" * 60)
print("1. Visual Stream: Video -> ResNet-18 -> V [B, 256, T_v] (optional)")
print("2. Spatial Stream: 4ch Audio -> GCC-PHAT -> S [B, 128] (optional)")
print("3. TEMPORAL Neural Beamformer: 4ch Audio + V [B,256,T_v]")
print("   -> Frame-by-frame weights -> Beamformed [B, 1, L] (optional)")
print("   ** NOW TRACKS MOVING SPEAKERS **")
print("4. Audio Encoder: Beamformed -> Features [B, 512, T_a]")
print("5. FiLM TCN: Features + (S,V) -> Refined [B, 512, T_a]")
print("   (Audio-only TCN if no conditioning)")
print("6. Mask & Decode: -> Clean Speech [B, 1, L]")
print("=" * 60)

In [None]:
# # Create Model
# model = IsoNet().to(DEVICE)
# print(f"IsoNet Created. Parameters: {sum(p.numel() for p in model.parameters()):,}")

# # Dummy Data
# dummy_audio = torch.randn(2, 4, 64000).to(DEVICE)     # 4 seconds audio
# dummy_video = torch.randn(2, 3, 100, 112, 112).to(DEVICE) # 100 frames

# # Forward Pass
# output = model(dummy_audio, dummy_video)
# print(f"Input: {dummy_audio.shape}")
# print(f"Output: {output.shape}")

# # Check
# if output.shape[1] == 1 and abs(output.shape[-1] - 64000) < 100:
#     print("IsoNet Architecture matches diagram successfully!")
# else:
#     print("IsoNet Architecture does not match diagram.")

In [None]:
def si_snr_loss(estimate, reference, epsilon=1e-8):
    """
    Scale-Invariant SNR Loss with numerical stability.
    Args:
        estimate: [Batch, Samples] - The predicted audio (mono)
        reference: [Batch, Channels, Samples] or [Batch, Samples] - The clean ground truth
    Returns:
        Scalar Loss (Negative SI-SNR), Actual SI-SNR value (for logging)

    Note: We return -SI-SNR as loss to minimize. Higher SI-SNR is better.
    """
    # Handle multi-channel reference by averaging to mono
    if reference.dim() == 3:  # [B, C, T]
        reference = reference.mean(dim=1)  # [B, T]

    # 1. Zero-mean the signals
    estimate = estimate - torch.mean(estimate, dim=-1, keepdim=True)
    reference = reference - torch.mean(reference, dim=-1, keepdim=True)

    # 2. Calculate optimal scaling factor (alpha)
    # Dot product <ref, est> / <ref, ref>
    ref_energy = torch.sum(reference ** 2, dim=-1, keepdim=True) + epsilon
    dot = torch.sum(reference * estimate, dim=-1, keepdim=True)
    alpha = dot / ref_energy

    # 3. Projection
    target = alpha * reference
    noise = estimate - target

    # 4. SI-SNR Calculation with numerical stability
    target_energy = torch.sum(target ** 2, dim=-1) + epsilon
    noise_energy = torch.sum(noise ** 2, dim=-1) + epsilon

    si_snr = 10 * torch.log10(target_energy / noise_energy)

    # Clamp to prevent extreme values
    si_snr = torch.clamp(si_snr, min=-30, max=30)

    si_snr_mean = torch.mean(si_snr)

    # 5. Return negative loss and actual SI-SNR for logging
    return -si_snr_mean, si_snr_mean.item()


# Multi-resolution STFT loss helps preserve spectral detail
STFT_CONFIGS = [
    (512, 128, 512),
    (1024, 256, 1024),
    (2048, 512, 2048),
]

SI_SNR_WEIGHT = 1.0
STFT_WEIGHT = 0.1  # Reduced to prevent STFT from dominating when SI-SNR is good

def multi_resolution_stft_loss(estimate, reference, configs=STFT_CONFIGS, eps=1e-7):
    """Magnitude + log-magnitude multi-resolution STFT loss.
    
    Note: STFT is computed in float32 to avoid ComplexHalf warnings with AMP.
    """
    if reference.dim() == 3:
        reference = reference.mean(dim=1)
    if estimate.dim() == 3:
        estimate = estimate.mean(dim=1)

    # Force float32 for STFT to avoid ComplexHalf experimental warning
    estimate = estimate.float()
    reference = reference.float()

    losses = []
    for n_fft, hop, win in configs:
        window = torch.hann_window(win, device=estimate.device, dtype=torch.float32)
        est_spec = torch.stft(
            estimate,
            n_fft=n_fft,
            hop_length=hop,
            win_length=win,
            window=window,
            return_complex=True,
            normalized=False,
        )
        ref_spec = torch.stft(
            reference,
            n_fft=n_fft,
            hop_length=hop,
            win_length=win,
            window=window,
            return_complex=True,
            normalized=False,
        )

        mag_est = est_spec.abs()
        mag_ref = ref_spec.abs()

        sc_loss = torch.norm(mag_ref - mag_est, p="fro") / (torch.norm(mag_ref, p="fro") + eps)
        mag_loss = torch.mean(torch.abs(torch.log(mag_ref + eps) - torch.log(mag_est + eps)))
        losses.append(sc_loss + mag_loss)

    return torch.mean(torch.stack(losses))


def separation_loss(estimate, reference):
    """
    Composite loss: SI-SNR for temporal fidelity + STFT for spectral detail.
    
    Returns:
        total_loss: Combined loss to minimize (always positive, lower is better)
        si_snr_metric: Actual SI-SNR in dB (higher is better)
        stft_val: STFT loss component (lower is better)
    
    Note: total_loss = -SI_SNR + STFT, so lower total is better
    """
    si_snr_loss_val, si_snr_metric = si_snr_loss(estimate, reference)
    stft_val = multi_resolution_stft_loss(estimate, reference)
    
    # si_snr_loss_val is already negative SI-SNR, so adding them makes sense
    # When SI-SNR is high (good), si_snr_loss_val is very negative
    # When STFT is low (good), stft_val is low positive
    # Total should trend negative for good performance, but we want positive loss
    # Solution: Return absolute value or add offset
    total = SI_SNR_WEIGHT * si_snr_loss_val + STFT_WEIGHT * stft_val
    
    return total, si_snr_metric, stft_val.detach()

In [None]:
print(f"Using VIDEO_SIZE: {VIDEO_SIZE}")

# Set max_samples=None for full training, or a number for quick debugging
MAX_TRAIN_SAMPLES = None  # Use full dataset (set to e.g. 100 for quick debug)
MAX_VAL_SAMPLES = None    # Use full dataset (set to e.g. 10 for quick debug)

train_ds = IsoNetDataset(
    TRAIN_CSV, 
    max_samples=MAX_TRAIN_SAMPLES,
    video_size=VIDEO_SIZE,
    augment=True,      # Enable jittering for training
    jitter_pct=0.1     # 10% max shift
)

val_ds = IsoNetDataset(
    VAL_CSV, 
    max_samples=MAX_VAL_SAMPLES,
    video_size=VIDEO_SIZE,
    augment=False,     # No augmentation for validation (deterministic)
    jitter_pct=0.0
)

print(f"\nTrain: {len(train_ds)} | Val: {len(val_ds)} | Size: {VIDEO_SIZE}")

In [None]:
# Memory-Optimized DataLoaders
# On Windows, multiprocessing can cause issues in notebooks - use num_workers=0

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,  # Set to 0 for Windows notebook compatibility
    pin_memory=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,  # Set to 0 for Windows notebook compatibility
    pin_memory=True,
)

print(f"DataLoaders created (num_workers=0 for notebook compatibility)")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

In [None]:
model = IsoNet(
    use_checkpointing=True,
    use_beamformer=USE_BEAMFORMER,
    use_spatial_stream=USE_SPATIAL_STREAM,
    use_visual=USE_VISUAL
).to(DEVICE)

optimizer = Adam(model.parameters(), lr=LR)

# Print model info
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Parameters: {total_params:,}")
print(f"Trainable Parameters: {trainable_params:,}")

# Print active features
active_features = []
if USE_VISUAL:
    active_features.append("Visual")
if USE_BEAMFORMER:
    active_features.append("Beamformer")
if USE_SPATIAL_STREAM:
    active_features.append("Spatial")
if not active_features:
    active_features.append("Audio-Only")
print(f"Active Features: {', '.join(active_features)}")
# print(f"Model Size: ~{total_params * 4 / 1e6:.1f} MB (FP32)")

In [None]:
# ============================================================
# TRAINING LOOP (Clean Logging with Early Stopping)
# ============================================================

import time
from datetime import datetime

best_val_loss = float('inf')
best_val_si_snr = -float('inf')
scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)

# Early Stopping Configuration
EARLY_STOPPING_PATIENCE = 5  # Stop if no improvement for N epochs
early_stopping_counter = 0

# Print config once at start
print(f"Training | {EPOCHS} epochs | BS={BATCH_SIZE}x{GRADIENT_ACCUMULATION_STEPS} | LR={LR} | AMP={USE_AMP}")
print(f"Data     | Train: {len(train_ds)} | Val: {len(val_ds)} | Jitter: {train_ds.augment}")
print(f"Loss Weights | SI-SNR: {SI_SNR_WEIGHT} | STFT: {STFT_WEIGHT}")
print(f"Early Stopping | Patience: {EARLY_STOPPING_PATIENCE} epochs")
print("-" * 80)
print(f"{'Epoch':<8} {'Train Loss':<12} {'Train STFT':<12} {'Train SI-SNR':<14} {'Val Loss':<12} {'Val STFT':<10} {'Val SI-SNR':<12} {'Time':<8} {'Status'}")
print("-" * 80)

for epoch in range(EPOCHS):
    epoch_start = time.time()
    model.train()
    train_loss = 0
    train_si_snr_sum = 0
    train_stft_sum = 0
    batch_count = 0

    # Mask statistics
    mask_means = []
    mask_stds = []

    optimizer.zero_grad()

    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=False)

    for i, batch_data in enumerate(pbar):
        try:
            if len(batch_data) == 4:
                mixed, clean, video, spatial_meta = batch_data
            else:
                mixed, clean, video = batch_data

            mixed = mixed.to(DEVICE, non_blocking=True)
            clean = clean.to(DEVICE, non_blocking=True)
            video = video.to(DEVICE, non_blocking=True)

            with torch.amp.autocast('cuda', enabled=USE_AMP):
                estimate = model(mixed, video)

                # Get mask for logging (before final decoding)
                with torch.no_grad():
                    # Access the mask from the model's last forward pass
                    # We'll compute it again just for stats (lightweight)
                    if model.visual_stream is not None:
                        V = model.visual_stream(video)
                        V_pooled = V.mean(dim=-1)
                    else:
                        V = None
                        V_pooled = None

                    if model.use_spatial_stream and model.spatial_stream is not None:
                        S = model.spatial_stream(mixed)
                    else:
                        S = None

                    if model.use_beamformer and model.beamformer is not None and V_pooled is not None:
                        audio_beamformed = model.beamformer(mixed, V_pooled)
                    else:
                        # Handle case where beamformer expects 1 channel but we need to average
                        if model.beamformer is not None:
                            audio_beamformed = mixed.mean(dim=1, keepdim=True)
                        else:
                            audio_beamformed = mixed

                    audio_feat = model.audio_enc(audio_beamformed)
                    
                    # Build condition based on available streams
                    conditions = []
                    if V is not None:
                        V_upsampled = F.interpolate(V, size=audio_feat.shape[-1], mode='nearest')
                        conditions.append(V_upsampled)
                    
                    if model.use_spatial_stream and S is not None:
                        S_expanded = S.unsqueeze(-1).expand(-1, -1, audio_feat.shape[-1])
                        conditions.append(S_expanded)
                    
                    condition = torch.cat(conditions, dim=1) if len(conditions) > 0 else None

                    x = audio_feat
                    if condition is not None and model.cond_dim > 0:
                        for block in model.tcn_blocks:
                            x = block(x, condition)
                    else:
                        # Audio-only TCN without conditioning
                        for block in model.tcn_blocks:
                            residual = x
                            x = block(x)
                            x = x + residual

                    mask = torch.sigmoid(model.mask_conv(x))

                    # Log mask statistics (every 10 batches to reduce overhead)
                    if i % 10 == 0:
                        mask_means.append(mask.mean().item())
                        mask_stds.append(mask.std().item())

                if estimate.shape[-1] != clean.shape[-1]:
                    min_len = min(estimate.shape[-1], clean.shape[-1])
                    estimate = estimate[..., :min_len]
                    clean = clean[..., :min_len]

                total_loss, si_snr_val, stft_val = separation_loss(estimate.squeeze(1), clean)
                loss = total_loss / GRADIENT_ACCUMULATION_STEPS

            scaler.scale(loss).backward()

            if (i + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            train_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS
            train_si_snr_sum += si_snr_val
            train_stft_sum += stft_val.item()
            batch_count += 1

            # Progress bar shows current SI-SNR
            current_si_snr = train_si_snr_sum / batch_count
            current_stft = train_stft_sum / batch_count
            pbar.set_postfix_str(f"SI-SNR={current_si_snr:.2f}dB | STFT={current_stft:.3f} | mask={mask_means[-1] if mask_means else 0:.3f}")

            if i % 100 == 0:
                torch.cuda.empty_cache()

        except RuntimeError as e:
            print(f"[!] Batch {i} error: {str(e)[:50]}")
            torch.cuda.empty_cache()
            optimizer.zero_grad()
            continue

    avg_train_loss = train_loss / max(batch_count, 1)
    avg_train_si_snr = train_si_snr_sum / max(batch_count, 1)
    avg_train_stft = train_stft_sum / max(batch_count, 1)

    # Compute mask statistics
    avg_mask_mean = np.mean(mask_means) if mask_means else 0.0
    avg_mask_std = np.mean(mask_stds) if mask_stds else 0.0

    # Validation
    model.eval()
    val_loss_sum = 0
    val_si_snr_sum = 0
    val_stft_sum = 0
    val_count = 0

    with torch.no_grad():
        for batch_data in tqdm.tqdm(val_loader, desc="Val", leave=False):
            try:
                if len(batch_data) == 4:
                    mixed, clean, video, _ = batch_data
                else:
                    mixed, clean, video = batch_data

                mixed = mixed.to(DEVICE, non_blocking=True)
                clean = clean.to(DEVICE, non_blocking=True)
                video = video.to(DEVICE, non_blocking=True)

                with torch.amp.autocast('cuda', enabled=USE_AMP):
                    estimate = model(mixed, video)
                    if estimate.shape[-1] != clean.shape[-1]:
                        min_len = min(estimate.shape[-1], clean.shape[-1])
                        estimate = estimate[..., :min_len]
                        clean = clean[..., :min_len]
                    val_total, si_snr_val, stft_val = separation_loss(estimate.squeeze(1), clean)
                    val_loss_sum += val_total.item()
                    val_si_snr_sum += si_snr_val
                    val_stft_sum += stft_val.item()
                    val_count += 1
            except RuntimeError:
                continue

    avg_val_loss = val_loss_sum / max(val_count, 1)
    avg_val_si_snr = val_si_snr_sum / max(val_count, 1)
    avg_val_stft = val_stft_sum / max(val_count, 1)
    epoch_time = time.time() - epoch_start

    # Check for improvement (use SI-SNR as primary metric since loss can be misleading)
    improved = avg_val_si_snr > best_val_si_snr

    if improved:
        best_val_loss = avg_val_loss
        best_val_si_snr = avg_val_si_snr
        early_stopping_counter = 0
        status = "* BEST"
    else:
        early_stopping_counter += 1
        status = f"  ({early_stopping_counter}/{EARLY_STOPPING_PATIENCE})"

    # Print epoch summary with SI-SNR and mask statistics
    print(f"[{epoch+1:02d}/{EPOCHS}]  {avg_train_loss:>10.4f}  {avg_train_stft:>10.3f}  {avg_train_si_snr:>12.2f}dB  {avg_val_loss:>10.4f}  {avg_val_stft:>10.3f}  {avg_val_si_snr:>12.2f}dB  {epoch_time/60:>6.1f}m  mask:{avg_mask_mean:.3f}Â±{avg_mask_std:.3f}  {status}")

    # Save checkpoint
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'train_si_snr': avg_train_si_snr,
        'val_si_snr': avg_val_si_snr,
        'train_stft': avg_train_stft,
        'val_stft': avg_val_stft,
        'best_val_loss': best_val_loss,
        'best_val_si_snr': best_val_si_snr,
        'scaler_state_dict': scaler.state_dict(),
    }, f"{CHECKPOINT_DIR}/checkpoint_epoch_{epoch+1}.pth")

    # Save best model
    if improved:
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'train_loss': avg_train_loss,
            'val_loss': best_val_loss,
            'train_si_snr': avg_train_si_snr,
            'val_si_snr': avg_val_si_snr,
            'train_stft': avg_train_stft,
            'val_stft': avg_val_stft,
        }, f"{CHECKPOINT_DIR}/best_model.pth")
        print(f"         -> Saved best model to {CHECKPOINT_DIR}/best_model.pth")

    # Early Stopping Check
    if early_stopping_counter >= EARLY_STOPPING_PATIENCE:
        print("-" * 80)
        print(f"Early stopping triggered! No improvement for {EARLY_STOPPING_PATIENCE} epochs.")
        break

    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()

print("-" * 80)
print(f"Training Complete!")
print(f"  Best Val Loss:   {best_val_loss:.4f}")
print(f"  Best Val SI-SNR: {best_val_si_snr:.2f} dB")
print(f"  Saved: {CHECKPOINT_DIR}/best_model.pth")
print("-" * 80)

## Visual Encoding Test

Visualize what the visual encoding model sees. Using larger input size (336x336) instead of 224x224 to accommodate spatial augmentation - this gives more room for face shifting without cutting off content.

ResNet-18 handles any input size via adaptive pooling, so the output dimension stays the same.

In [None]:
import matplotlib.pyplot as plt

# Test dataset with jittering
print(f"Loading dataset sample to visualize visual encoding input...")
print(f"Using VIDEO_SIZE: {VIDEO_SIZE}")
test_ds = IsoNetDataset(TRAIN_CSV, max_samples=5, video_size=VIDEO_SIZE, augment=True)

# Load a sample
mixed, clean, video, spatial_meta = test_ds[2]

print("\n--- Tensor Shapes ---")
print(f"Mixed Audio: {mixed.shape}  (Expected: [4, 64000])")
print(f"Clean Audio: {clean.shape}  (Expected: [1, 64000])")
print(f"Video:       {video.shape}  (Expected: [3, 100, {VIDEO_SIZE[0]}, {VIDEO_SIZE[1]}])")

# Show spatial metadata
print(f"\n--- Spatial Metadata ---")
if spatial_meta:
    for key, val in spatial_meta.items():
        if key in ['azimuth', 'elevation']:
            print(f"  {key}: {val:.4f} rad ({np.degrees(val):.1f} deg)")
        else:
            print(f"  {key}: {val:.4f}")
else:
    print("  No spatial metadata available")

# Create visualization
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
fig.suptitle(f'Visual Encoding Input (Jittered) - {VIDEO_SIZE[0]}x{VIDEO_SIZE[1]}', fontsize=14, fontweight='bold')

# Sample 10 frames evenly across the 4-second clip
sample_frames = np.linspace(0, video.shape[1]-1, 10, dtype=int)

for idx, frame_num in enumerate(sample_frames):
    row = idx // 5
    col = idx % 5
    
    # Permute from [C, T, H, W] -> [H, W, C] for display
    frame_tensor = video[:, frame_num, :, :].permute(1, 2, 0)
    axes[row, col].imshow(frame_tensor.numpy())
    axes[row, col].set_title(f'Frame {frame_num}', fontsize=10)
    axes[row, col].axis('off')

plt.tight_layout()
plt.savefig("visual_encoding_input.png", dpi=100, bbox_inches='tight')
plt.show()

print("\nVisual encoding input visualization complete!")

## Comprehensive Dataset Sample Check

Load a random sample and visualize EVERYTHING to verify data integrity: video frames, audio waveforms, spectrograms, and metadata alignment.

In [None]:
import matplotlib.pyplot as plt
import IPython.display as ipd
import librosa
import librosa.display
import random

# Load dataset with jittering
print(f"Loading dataset for comprehensive check...")
print(f"Using VIDEO_SIZE: {VIDEO_SIZE}")
check_ds = IsoNetDataset(TRAIN_CSV, max_samples=20, video_size=VIDEO_SIZE, augment=True)

# Pick a random sample
random_idx = random.randint(0, len(check_ds) - 1)
print(f"\n{'='*80}")
print(f"RANDOM SAMPLE CHECK - Index: {random_idx}/{len(check_ds)-1}")
print(f"{'='*80}")

# Load sample
mixed_audio, clean_audio, video_tensor, spatial_meta = check_ds[random_idx]

# Get metadata
row = check_ds.meta.iloc[random_idx]
print(f"\n--- METADATA ---")
print(f"Filename: {row['filename']}")
print(f"Video Path: {row.get('video_path', row.get('source_video', 'N/A'))}")
print(f"Start Time: {row['start_time']:.2f}s")

# Show spatial metadata
print(f"\n--- SPATIAL METADATA ---")
if spatial_meta:
    for key, val in spatial_meta.items():
        if key in ['azimuth', 'elevation']:
            print(f"  {key}: {val:.4f} rad ({np.degrees(val):.1f} deg)")
        else:
            print(f"  {key}: {val:.4f}")
else:
    print("  No spatial metadata available")

print(f"\n--- TENSOR SHAPES ---")
print(f"Mixed Audio: {mixed_audio.shape}  (Expected: [4, 64000])")
print(f"Clean Audio: {clean_audio.shape}  (Expected: [1, 64000])")
print(f"Video:       {video_tensor.shape}  (Expected: [3, 100, {VIDEO_SIZE[0]}, {VIDEO_SIZE[1]}])")

print(f"\n--- AUDIO STATISTICS ---")
print(f"Mixed Audio - Max: {mixed_audio.max():.4f}, Min: {mixed_audio.min():.4f}, Mean: {mixed_audio.mean():.4f}")
print(f"Clean Audio - Max: {clean_audio.max():.4f}, Min: {clean_audio.min():.4f}, Mean: {clean_audio.mean():.4f}")

print(f"\n--- VIDEO STATISTICS ---")
print(f"Video - Max: {video_tensor.max():.4f}, Min: {video_tensor.min():.4f}, Mean: {video_tensor.mean():.4f}")
print(f"Video - Frames: {video_tensor.shape[1]}, Duration: {video_tensor.shape[1]/25:.2f}s @ 25fps")

# Check for silent audio
if mixed_audio.abs().max() < 1e-6:
    print(f"\nWARNING: Mixed audio appears to be SILENT!")
if clean_audio.abs().max() < 1e-6:
    print(f"\nWARNING: Clean audio appears to be SILENT!")

# Check duration mismatch
audio_duration = mixed_audio.shape[1] / 16000
video_duration = video_tensor.shape[1] / 25
print(f"\n--- DURATION CHECK ---")
print(f"Audio Duration: {audio_duration:.3f}s")
print(f"Video Duration: {video_duration:.3f}s")

In [None]:
# Audio playback - Mixed Audio (Mic 1)
print("\nðŸ”Š Playing Mixed Audio (Microphone 1)...")
ipd.display(ipd.Audio(mixed_audio[0].numpy(), rate=16000))

In [None]:
# Audio playback - Clean Audio
print("\nðŸ”Š Playing Clean Audio (Target Speech)...")
ipd.display(ipd.Audio(clean_audio[0].numpy(), rate=16000))

In [None]:
# Compare all 4 microphone channels side-by-side
print("\n--- Multi-Channel Comparison ---")
fig, axes = plt.subplots(4, 1, figsize=(15, 8), sharex=True, sharey=True)
time_axis = np.arange(mixed_audio.shape[1]) / 16000

for ch in range(4):
    axes[ch].plot(time_axis, mixed_audio[ch].numpy(), linewidth=0.5)
    axes[ch].set_ylabel(f'Mic {ch+1}')
    axes[ch].grid(True, alpha=0.3)
    axes[ch].set_xlim([0, time_axis[-1]])
    
    # Calculate RMS energy
    rms = np.sqrt(np.mean(mixed_audio[ch].numpy()**2))
    axes[ch].text(0.02, 0.95, f'RMS: {rms:.4f}', transform=axes[ch].transAxes, 
                  fontsize=9, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

axes[-1].set_xlabel('Time (s)')
fig.suptitle('4-Channel Microphone Array - Waveform Comparison', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(f"multichannel_comparison_{random_idx}.png", dpi=120, bbox_inches='tight')
plt.show()

print(f"\nâœ“ All visualizations complete!")
print(f"âœ“ Files saved: dataset_check_sample_{random_idx}.png, multichannel_comparison_{random_idx}.png")

## Model Inference Test

Load the best trained model and test it on a validation sample. Compare input mixed audio, model output, and ground truth clean audio.

In [None]:
import matplotlib.pyplot as plt
import IPython.display as ipd
import librosa
import librosa.display

# Load Best Model
print("Loading best model checkpoint...")
best_model_path = f"{CHECKPOINT_DIR}/best_model.pth"

if not os.path.exists(best_model_path):
    print(f"Best model not found at: {best_model_path}")
    print(f"Available checkpoints:")
    for f in os.listdir(CHECKPOINT_DIR):
        print(f"  - {f}")
else:
    # ============================================================
    # IMPORTANT: Initialize model with SAME config used during training!
    # The checkpoint was trained with:
    #   USE_VISUAL = False
    #   USE_BEAMFORMER = True  (but beamformer requires visual, so it wasn't created)
    #   USE_SPATIAL_STREAM = True
    # ============================================================
    print(f"Initializing model with training config:")
    print(f"  use_visual={USE_VISUAL}, use_beamformer={USE_BEAMFORMER}, use_spatial_stream={USE_SPATIAL_STREAM}")
    
    test_model = IsoNet(
        use_checkpointing=False,
        use_beamformer=USE_BEAMFORMER,
        use_spatial_stream=USE_SPATIAL_STREAM,
        use_visual=USE_VISUAL
    ).to(DEVICE)
    
    # Load checkpoint
    checkpoint = torch.load(best_model_path, map_location=DEVICE)
    test_model.load_state_dict(checkpoint['model_state_dict'])
    test_model.eval()
    
    print(f"\nModel loaded successfully!")
    print(f"  Epoch: {checkpoint['epoch']}")
    print(f"  Validation Loss: {checkpoint['val_loss']:.6f}")
    print(f"  Training Loss: {checkpoint['train_loss']:.6f}")
    
    # ============================================================
    # FIX: Use batch size of 4 to avoid BatchNorm1d issues
    # BatchNorm1d behaves poorly with batch_size=1 during inference
    # ============================================================
    print("\n" + "="*80)
    print("RUNNING INFERENCE ON 4 VALIDATION SAMPLES (BatchNorm Fix)")
    print("="*80)
    
    NUM_TEST_SAMPLES = 4  # Use 4 samples to match training batch size
    
    # Collect 4 real samples from validation set
    mixed_list, clean_list, video_list, meta_list = [], [], [], []
    
    for test_idx in range(NUM_TEST_SAMPLES):
        sample_data = val_ds[test_idx]
        
        if len(sample_data) == 4:
            mixed_audio, clean_audio, video_tensor, spatial_meta = sample_data
        else:
            mixed_audio, clean_audio, video_tensor = sample_data
            spatial_meta = None
        
        mixed_list.append(mixed_audio)
        clean_list.append(clean_audio)
        video_list.append(video_tensor)
        meta_list.append(spatial_meta)
        
        row = val_ds.meta.iloc[test_idx]
        print(f"  Sample {test_idx+1}: {row['filename']}")
    
    # Stack into batches [4, C, T] and [4, 3, T, H, W]
    mixed_batch = torch.stack(mixed_list, dim=0).to(DEVICE)
    clean_batch = torch.stack(clean_list, dim=0).to(DEVICE)
    video_batch = torch.stack(video_list, dim=0).to(DEVICE) if USE_VISUAL else None
    
    print(f"\nBatch shapes:")
    print(f"  Mixed: {mixed_batch.shape}")
    if video_batch is not None:
        print(f"  Video: {video_batch.shape}")
    else:
        print(f"  Video: None (USE_VISUAL=False)")
    
    # Run inference with batch size 4
    print("\nRunning model inference with batch_size=4...")
    with torch.no_grad():
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            estimated_batch = test_model(mixed_batch, video_batch)
    
    print(f"  Output shape: {estimated_batch.shape}")
    
    # Move to CPU for visualization
    estimated_batch = estimated_batch.cpu()
    mixed_batch = mixed_batch.cpu()
    clean_batch = clean_batch.cpu()
    
    # Handle shape mismatch
    min_len = min(estimated_batch.shape[-1], clean_batch.shape[-1], mixed_batch.shape[-1])
    estimated_batch = estimated_batch[..., :min_len]
    clean_batch = clean_batch[..., :min_len]
    mixed_batch = mixed_batch[..., :min_len]
    
    # ============================================================
    # Calculate metrics for each sample
    # ============================================================
    print(f"\n{'='*80}")
    print(f"{'Sample':<10} {'Input SI-SNR':<15} {'Output SI-SNR':<15} {'Improvement':<15}")
    print(f"{'='*80}")
    
    improvements = []
    for i in range(NUM_TEST_SAMPLES):
        with torch.no_grad():
            clean_mono = clean_batch[i].mean(dim=0, keepdim=True).unsqueeze(0)
            mixed_mono = mixed_batch[i, 0:1].unsqueeze(0)
            estimated_mono = estimated_batch[i].unsqueeze(0)
            
            _, input_si_snr = si_snr_loss(mixed_mono, clean_mono)
            _, output_si_snr = si_snr_loss(estimated_mono, clean_mono)
            improvement = output_si_snr - input_si_snr
            improvements.append(improvement)
        
        print(f"  {i+1:<8} {input_si_snr:>12.2f} dB   {output_si_snr:>12.2f} dB   {improvement:>+12.2f} dB")
    
    avg_improvement = np.mean(improvements)
    print(f"{'='*80}")
    print(f"  Average SI-SNR Improvement: {avg_improvement:+.2f} dB")
    print(f"{'='*80}")
    
    # ============================================================
    # VISUALIZATION: Show all 4 samples
    # ============================================================
    fig, axes = plt.subplots(NUM_TEST_SAMPLES, 3, figsize=(20, 4*NUM_TEST_SAMPLES))
    
    for i in range(NUM_TEST_SAMPLES):
        mixed_audio = mixed_batch[i]
        estimated_audio = estimated_batch[i]
        clean_audio = clean_batch[i]
        
        time_axis = np.arange(mixed_audio.shape[1]) / 16000
        row_data = val_ds.meta.iloc[i]
        
        # Calculate metrics for title
        with torch.no_grad():
            clean_mono = clean_audio.mean(dim=0, keepdim=True).unsqueeze(0)
            mixed_mono = mixed_audio[0:1].unsqueeze(0)
            estimated_mono = estimated_audio.unsqueeze(0)
            _, input_si_snr = si_snr_loss(mixed_mono, clean_mono)
            _, output_si_snr = si_snr_loss(estimated_mono, clean_mono)
        
        # Mixed
        axes[i, 0].plot(time_axis, mixed_audio[0].numpy(), linewidth=0.5, color='blue')
        axes[i, 0].set_title(f'Sample {i+1} Input (Mixed)\nSI-SNR: {input_si_snr:.2f} dB', fontsize=10)
        axes[i, 0].set_ylabel('Amplitude')
        axes[i, 0].grid(True, alpha=0.3)
        axes[i, 0].set_xlim([0, time_axis[-1]])
        
        # Estimated
        axes[i, 1].plot(time_axis, estimated_audio[0].numpy(), linewidth=0.5, color='red')
        axes[i, 1].set_title(f'Sample {i+1} Output (Separated)\nSI-SNR: {output_si_snr:.2f} dB', fontsize=10)
        axes[i, 1].set_ylabel('Amplitude')
        axes[i, 1].grid(True, alpha=0.3)
        axes[i, 1].set_xlim([0, time_axis[-1]])
        
        # Clean (Ground Truth)
        axes[i, 2].plot(time_axis, clean_audio[0].numpy(), linewidth=0.5, color='green')
        axes[i, 2].set_title(f'Sample {i+1} Ground Truth (Clean)', fontsize=10)
        axes[i, 2].set_ylabel('Amplitude')
        axes[i, 2].grid(True, alpha=0.3)
        axes[i, 2].set_xlim([0, time_axis[-1]])
        
        if i == NUM_TEST_SAMPLES - 1:
            for j in range(3):
                axes[i, j].set_xlabel('Time (s)')
    
    plt.suptitle(f'Model Inference Test (Batch Size = {NUM_TEST_SAMPLES}) | Avg SI-SNR Improvement: {avg_improvement:+.2f} dB', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig("inference_test_result.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nVisualization saved to: inference_test_result.png")
    
    # ============================================================
    # AUDIO PLAYBACK for first sample
    # ============================================================
    print(f"\n{'='*80}")
    print("AUDIO PLAYBACK (Sample 1)")
    print(f"{'='*80}")
    print("\nMixed Audio (Input):")
    ipd.display(ipd.Audio(mixed_batch[0, 0].numpy(), rate=16000))
    print("\nSeparated Audio (Output):")
    ipd.display(ipd.Audio(estimated_batch[0, 0].numpy(), rate=16000))
    print("\nClean Audio (Ground Truth):")
    ipd.display(ipd.Audio(clean_batch[0, 0].numpy(), rate=16000))

In [None]:
# Play Input Mixed Audio (Mic 1)
print("\nðŸ”Š Input: Mixed Audio (Microphone 1)")
ipd.display(ipd.Audio(mixed_audio[0].numpy(), rate=16000))

In [None]:
# Play Model Output
print("\nðŸ”Š Output: Model Separated Speech")
ipd.display(ipd.Audio(estimated_audio[0].numpy(), rate=16000))

In [None]:
# Play Ground Truth Clean Audio
print("\nðŸ”Š Ground Truth: Clean Speech")
ipd.display(ipd.Audio(clean_audio[0].numpy(), rate=16000))

## Test Set Evaluation

Evaluate the best trained model on the held-out test set. Compute SI-SNR improvement statistics across all test samples.

In [None]:
# ============================================================
# TEST SET EVALUATION
# ============================================================
# Evaluate the best model on held-out test set

import numpy as np

print("=" * 80)
print("TEST SET EVALUATION")
print("=" * 80)

# Load test dataset
print(f"\nLoading test dataset from: {TEST_CSV}")
test_ds = IsoNetDataset(
    TEST_CSV,
    max_samples=None,  # Use full test set
    video_size=VIDEO_SIZE,
    augment=False,     # No augmentation for test (deterministic)
    jitter_pct=0.0
)

# Use same batch size as training
test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)

print(f"Test samples: {len(test_ds)} | Test batches: {len(test_loader)} | Batch size: {BATCH_SIZE}")

# Load best model
checkpoint_path = os.path.join(CHECKPOINT_DIR, "best_model.pt")
print(f"\nLoading checkpoint: {checkpoint_path}")

test_model = IsoNet(
    use_checkpointing=False,
    use_beamformer=USE_BEAMFORMER,
    use_spatial_stream=USE_SPATIAL_STREAM,
    use_visual=USE_VISUAL
).to(DEVICE)

checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
test_model.load_state_dict(checkpoint['model_state_dict'])
test_model.eval()

print(f"Loaded model from epoch {checkpoint['epoch']} | Best Val Loss: {checkpoint['val_loss']:.6f}")

# Test evaluation loop
with torch.cuda.amp.autocast(enabled=USE_AMP):
    all_input_si_snr = []
    all_output_si_snr = []
    all_si_snr_improvement = []
    test_losses = []
    
    print(f"\nRunning evaluation on test set...")
    print("-" * 80)
    
    with torch.no_grad():
        for batch_idx, batch_data in enumerate(tqdm.tqdm(test_loader, desc="Testing")):
            try:
                if len(batch_data) == 4:
                    mixed, clean, video, _ = batch_data
                else:
                    mixed, clean, video = batch_data
                
                mixed = mixed.to(DEVICE, non_blocking=True)
                clean = clean.to(DEVICE, non_blocking=True)
                video = video.to(DEVICE, non_blocking=True)
                
                with torch.amp.autocast('cuda', enabled=USE_AMP):
                    estimate = test_model(mixed, video)
                    
                    # Handle shape mismatch
                    if estimate.shape[-1] != clean.shape[-1]:
                        min_len = min(estimate.shape[-1], clean.shape[-1])
                        estimate = estimate[..., :min_len]
                        clean = clean[..., :min_len]
                        mixed = mixed[..., :min_len]
                    
                    # Compute loss (unpack tuple: si_snr_loss returns (loss, si_snr))
                    loss, _ = si_snr_loss(estimate.squeeze(1), clean.squeeze(1))
                    test_losses.append(loss.item())
                    
                    # Compute SI-SNR metrics for each sample in batch
                    for i in range(mixed.shape[0]):
                        # Extract single samples
                        mix_sample = mixed[i, 0, :].cpu()
                        clean_sample = clean[i, 0, :].cpu()
                        est_sample = estimate[i, 0, :].cpu()
                        
                        # Compute input SI-SNR (mixture vs clean)
                        # si_snr_loss returns (negative_si_snr_loss, si_snr_value_in_dB)
                        _, input_si_snr = si_snr_loss(mix_sample, clean_sample)
                        
                        # Compute output SI-SNR (estimate vs clean)
                        _, output_si_snr = si_snr_loss(est_sample, clean_sample)
                        
                        # SI-SNR Improvement
                        improvement = output_si_snr - input_si_snr
                        
                        all_input_si_snr.append(input_si_snr)
                        all_output_si_snr.append(output_si_snr)
                        all_si_snr_improvement.append(improvement)
            
            except Exception as e:
                print(f"\nError in batch {batch_idx}: {e}")
                import traceback
                traceback.print_exc()
                continue

# Convert to numpy for statistics
all_input_si_snr = np.array(all_input_si_snr)
all_output_si_snr = np.array(all_output_si_snr)
all_si_snr_improvement = np.array(all_si_snr_improvement)

# Print results
print("\n" + "=" * 80)
print("TEST RESULTS")
print("=" * 80)
if len(test_losses) > 0:
    avg_test_loss = np.mean(test_losses)
    avg_input_si_snr = all_input_si_snr.mean()
    avg_output_si_snr = all_output_si_snr.mean()
    avg_improvement = all_si_snr_improvement.mean()
    median_improvement = np.median(all_si_snr_improvement)
    std_improvement = all_si_snr_improvement.std()
    
    print(f"\nTest Metrics (n={len(all_si_snr_improvement)}):")
    print(f"  Test Loss:           {avg_test_loss:>8.6f}")
    print(f"  Input SI-SNR:        {avg_input_si_snr:>8.2f} Â± {np.std(all_input_si_snr):>6.2f} dB")
    print(f"  Output SI-SNR:       {avg_output_si_snr:>8.2f} Â± {np.std(all_output_si_snr):>6.2f} dB")
    print(f"  Improvement:         {avg_improvement:>8.2f} Â± {std_improvement:>6.2f} dB")
    print(f"  Median Improvement:  {median_improvement:>8.2f} dB")
    print(f"\nPercentile Statistics:")
    print(f"  10th percentile:     {np.percentile(all_si_snr_improvement, 10):>8.2f} dB")
    print(f"  25th percentile:     {np.percentile(all_si_snr_improvement, 25):>8.2f} dB")
    print(f"  50th percentile:     {np.percentile(all_si_snr_improvement, 50):>8.2f} dB")
    print(f"  75th percentile:     {np.percentile(all_si_snr_improvement, 75):>8.2f} dB")
    print(f"  90th percentile:     {np.percentile(all_si_snr_improvement, 90):>8.2f} dB")
    print("=" * 80)

In [None]:
# ============================================================
# TEST RESULTS VISUALIZATION
# ============================================================

import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# 1. SI-SNR Improvement Distribution
axes[0].hist(all_si_snr_improvement, bins=30, edgecolor='black', alpha=0.7, color='steelblue')
axes[0].axvline(x=0, color='red', linestyle='--', linewidth=2, label='No Improvement')
axes[0].axvline(x=all_si_snr_improvement.mean(), color='green', linestyle='-', linewidth=2, label=f'Mean: {all_si_snr_improvement.mean():.2f} dB')
axes[0].set_xlabel('SI-SNR Improvement (dB)')
axes[0].set_ylabel('Count')
axes[0].set_title('SI-SNR Improvement Distribution')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# 2. Input vs Output SI-SNR Scatter
axes[1].scatter(all_input_si_snr, all_output_si_snr, alpha=0.5, s=10)
min_val = min(all_input_si_snr.min(), all_output_si_snr.min()) - 2
max_val = max(all_input_si_snr.max(), all_output_si_snr.max()) + 2
axes[1].plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='No Change Line')
axes[1].set_xlabel('Input SI-SNR (dB)')
axes[1].set_ylabel('Output SI-SNR (dB)')
axes[1].set_title('Input vs Output SI-SNR')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim([min_val, max_val])
axes[1].set_ylim([min_val, max_val])

# 3. Cumulative Distribution
sorted_improvements = np.sort(all_si_snr_improvement)
cdf = np.arange(1, len(sorted_improvements) + 1) / len(sorted_improvements)
axes[2].plot(sorted_improvements, cdf * 100, linewidth=2, color='steelblue')
axes[2].axvline(x=0, color='red', linestyle='--', linewidth=2)
axes[2].axhline(y=50, color='gray', linestyle=':', linewidth=1)
axes[2].set_xlabel('SI-SNR Improvement (dB)')
axes[2].set_ylabel('Cumulative % of Samples')
axes[2].set_title('Cumulative Distribution of Improvement')
axes[2].grid(True, alpha=0.3)

plt.suptitle(f'Test Set Results | Mean Improvement: {all_si_snr_improvement.mean():.2f} dB | {len(all_si_snr_improvement)} samples', 
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig("test_results_visualization.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"\nVisualization saved to: test_results_visualization.png")

## Cross-Video Testing: Same Audio + Different Video Selections

This test demonstrates the core multimodal speaker isolation capability:
- **Same mixed audio** (containing multiple speakers)
- **Different video inputs** (showing different target speakers)
- **Expected result**: Different isolated outputs based on who is shown in the video

This proves that the model uses visual cues to select which speaker to isolate.

In [None]:
# ============================================================
# CROSS-VIDEO TESTING: Same Audio + Different Videos
# ============================================================
# Test how model responds to different video inputs with same audio

print("=" * 80)
print("CROSS-VIDEO TESTING")
print("=" * 80)

# Load test dataset for cross-video testing
cross_test_ds = IsoNetDataset(
    TEST_CSV,
    max_samples=10,  # Load a few samples for cross-testing
    video_size=VIDEO_SIZE,
    augment=False,
    jitter_pct=0.0
)

# Load best model if not already loaded
if 'test_model' not in locals():
    print("Loading best model...")
    test_model = IsoNet(use_checkpointing=False).to(DEVICE)
    checkpoint = torch.load(f"{CHECKPOINT_DIR}/best_model.pth", map_location=DEVICE)
    test_model.load_state_dict(checkpoint['model_state_dict'])
    test_model.eval()
    print(f"Model loaded from epoch {checkpoint['epoch']}")

# Select a base sample (audio to isolate from)
base_idx = 0
base_mixed, base_clean, base_video, base_meta = cross_test_ds[base_idx]
base_row = cross_test_ds.meta.iloc[base_idx]

print(f"\nBase Mixed Audio: {base_row['filename']}")
print(f"Base Audio Shape: {base_mixed.shape}")
print(f"\nThis audio will be tested with videos from {min(5, len(cross_test_ds))} different samples")

# Test with multiple different videos
num_videos_to_test = min(5, len(cross_test_ds))
cross_results = []

for video_idx in range(num_videos_to_test):
    # Get video from a different sample
    _, _, test_video, _ = cross_test_ds[video_idx]
    test_row = cross_test_ds.meta.iloc[video_idx]
    
    print(f"\n{'-'*60}")
    print(f"Test {video_idx + 1}/{num_videos_to_test}: Using video from {test_row['filename']}")
    
    # Use base_mixed audio + test_video (potentially from different speaker)
    mixed_batch = base_mixed.unsqueeze(0).to(DEVICE)
    video_batch = test_video.unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        with torch.amp.autocast('cuda', enabled=USE_AMP):
            isolated = test_model(mixed_batch, video_batch)
    
    isolated = isolated.squeeze(0).cpu()
    
    # Handle shape mismatch
    if isolated.shape[-1] != base_clean.shape[-1]:
        min_len = min(isolated.shape[-1], base_clean.shape[-1])
        isolated = isolated[..., :min_len]
    
    # Calculate metrics (unpack the tuple properly)
    output_rms = torch.sqrt(torch.mean(isolated[0] ** 2)).item()
    loss_val, si_snr_val = si_snr_loss(isolated, base_clean[..., :isolated.shape[-1]])
    output_si_snr = si_snr_val  # Already positive SI-SNR in dB
    
    # Check if this is the matching video (same sample)
    is_matching = (video_idx == base_idx)
    match_str = " [MATCHING VIDEO]" if is_matching else ""
    
    print(f"  Output RMS: {output_rms:.6f} | SI-SNR: {output_si_snr:.2f} dB{match_str}")
    
    cross_results.append({
        'video_idx': video_idx,
        'video_filename': test_row['filename'],
        'video_frames': test_video,
        'isolated_audio': isolated,
        'output_rms': output_rms,
        'output_si_snr': output_si_snr,
        'is_matching': is_matching
    })

print(f"\n{'='*80}")
print(f"Cross-video testing complete!")
print(f"Same mixed audio produced {num_videos_to_test} different outputs")
print(f"{'='*80}")

In [None]:
# ============================================================
# CROSS-VIDEO VISUALIZATION: Waveforms & Spectrograms
# ============================================================

num_videos = len(cross_results)
fig, axes = plt.subplots(num_videos, 2, figsize=(16, 3 * num_videos))

if num_videos == 1:
    axes = axes.reshape(1, -1)

for test_num, result in enumerate(cross_results):
    isolated = result['isolated_audio'][0].numpy()
    time_axis = np.arange(len(isolated)) / 16000
    
    # Color based on matching
    color = 'green' if result['is_matching'] else f'C{test_num}'
    title_suffix = " [MATCHING]" if result['is_matching'] else ""
    
    # Waveform
    axes[test_num, 0].plot(time_axis, isolated, linewidth=0.5, color=color)
    axes[test_num, 0].set_title(
        f'Output {test_num+1}: {result["video_filename"][:20]}...{title_suffix}\n'
        f'SI-SNR: {result["output_si_snr"]:.2f} dB | RMS: {result["output_rms"]:.4f}', 
        fontsize=10, fontweight='bold'
    )
    axes[test_num, 0].set_ylabel('Amplitude')
    axes[test_num, 0].grid(True, alpha=0.3)
    axes[test_num, 0].set_xlim([0, time_axis[-1]])
    
    if test_num == num_videos - 1:
        axes[test_num, 0].set_xlabel('Time (s)')
    
    # Spectrogram
    D = librosa.amplitude_to_db(np.abs(librosa.stft(isolated)), ref=np.max)
    librosa.display.specshow(D, sr=16000, x_axis='time', y_axis='hz', 
                            ax=axes[test_num, 1], cmap='viridis')
    axes[test_num, 1].set_title(f'Output {test_num+1} Spectrogram', fontsize=10, fontweight='bold')
    axes[test_num, 1].set_ylim([0, 8000])

plt.suptitle(f'Cross-Video Test: Same Audio ({base_row["filename"]}) + Different Video Selections', 
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig("cross_video_output_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

print("Cross-video visualization saved: cross_video_output_comparison.png")

# Print summary statistics
print("\n" + "=" * 80)
print("CROSS-VIDEO TEST SUMMARY")
print("=" * 80)
print(f"\n{'Video':<35} {'SI-SNR (dB)':<15} {'RMS':<12} {'Match'}")
print("-" * 80)
for result in cross_results:
    match_str = "YES" if result['is_matching'] else "no"
    print(f"{result['video_filename'][:32]:<35} {result['output_si_snr']:>12.2f}   {result['output_rms']:>10.6f}   {match_str}")
print("-" * 80)

# Calculate differences
si_snr_values = [r['output_si_snr'] for r in cross_results]
rms_values = [r['output_rms'] for r in cross_results]
print(f"\nSI-SNR Range: {min(si_snr_values):.2f} to {max(si_snr_values):.2f} dB (diff: {max(si_snr_values)-min(si_snr_values):.2f} dB)")
print(f"RMS Range: {min(rms_values):.6f} to {max(rms_values):.6f}")
print(f"\nObservation: Different videos â†’ Different isolated outputs!")
print("=" * 80)

In [None]:
# ============================================================
# VISUALIZE VIDEO FRAMES USED FOR SELECTION
# ============================================================

num_videos = len(cross_results)
fig, axes = plt.subplots(num_videos + 1, 5, figsize=(15, 3 * (num_videos + 1)))

# Top row: Base mixed audio waveform
time_axis = np.arange(base_mixed.shape[1]) / 16000
for col in range(5):
    if col == 2:  # Center subplot for waveform
        axes[0, col].plot(time_axis, base_mixed[0].numpy(), linewidth=0.5, color='blue')
        axes[0, col].set_title(f'Base Mixed Audio\n{base_row["filename"][:25]}...', 
                               fontsize=9, fontweight='bold')
        axes[0, col].set_xlabel('Time (s)', fontsize=8)
        axes[0, col].set_ylabel('Amplitude', fontsize=8)
        axes[0, col].grid(True, alpha=0.3)
    else:
        axes[0, col].axis('off')

# Subsequent rows: Video frames for each test
for test_num, result in enumerate(cross_results):
    video = result['video_frames']
    sample_frames = np.linspace(0, video.shape[1]-1, 5, dtype=int)
    
    border_color = 'green' if result['is_matching'] else 'gray'
    
    for col, frame_num in enumerate(sample_frames):
        frame = video[:, frame_num, :, :].permute(1, 2, 0).numpy()
        axes[test_num + 1, col].imshow(frame)
        axes[test_num + 1, col].axis('off')
        
        # Add colored border for matching video
        if result['is_matching']:
            for spine in axes[test_num + 1, col].spines.values():
                spine.set_edgecolor('green')
                spine.set_linewidth(3)
                spine.set_visible(True)
        
        if col == 0:
            match_str = " [MATCH]" if result['is_matching'] else ""
            axes[test_num + 1, col].set_ylabel(
                f"Video {test_num+1}{match_str}\n{result['video_filename'][:12]}...", 
                fontsize=9, fontweight='bold'
            )

plt.suptitle(f'Cross-Video Test: Same Audio + Different Video Selections\n'
             f'Base Audio: {base_row["filename"]}', 
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig("cross_video_frames.png", dpi=120, bbox_inches='tight')
plt.show()

print("Cross-video frames visualization saved: cross_video_frames.png")