# üöÄ OpenWakeWord Training Platform - A100 Optimized

**Complete training pipeline for high-accuracy wake word detection models.**

## Features
- A100 GPU optimized (batch_size=256, torch.compile, BF16)
- x86_64 target (ResNet18, 80 Mel bands, 2.0s audio)
- Full augmentation pipeline (noise, RIR, SpecAugment)
- EMA, mixed precision, early stopping
- ONNX export for deployment

## Dataset
Expects dataset at: `/content/drive/My Drive/OpenWakeWord_Backups/dataset.tar.gz`

---

## 1Ô∏è‚É£ GPU Check & System Setup

In [None]:
# Check GPU
!nvidia-smi

# Install system dependencies
!apt-get update -qq
!apt-get install -y -qq libsndfile1 ffmpeg

import torch
print(f"\n‚úÖ PyTorch: {torch.__version__}")
print(f"‚úÖ CUDA: {torch.version.cuda}")
print(f"‚úÖ Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

# Detect A100 for optimizations
IS_A100 = 'A100' in torch.cuda.get_device_name(0) if torch.cuda.is_available() else False
print(f"‚úÖ A100 Detected: {IS_A100}")

## 2Ô∏è‚É£ Install Python Packages

In [None]:
%%capture
# Install required packages (suppress output)
!pip install librosa==0.10.0.post2 soundfile==0.12.1 resampy==0.4.2
!pip install scikit-learn==1.3.0 tqdm pyyaml structlog
!pip install onnx onnxruntime-gpu
!pip install matplotlib seaborn plotly

print("‚úÖ All packages installed!")

## 3Ô∏è‚É£ Mount Google Drive & Extract Dataset

In [None]:
from google.colab import drive
import os
import tarfile
from pathlib import Path

# Mount Drive
drive.mount('/content/drive')

# Dataset paths
DATASET_TAR = Path('/content/drive/My Drive/OpenWakeWord_Backups/dataset.tar.gz')
DATASET_DIR = Path('/content/dataset')
OUTPUT_DIR = Path('/content/output')

# Create directories
DATASET_DIR.mkdir(exist_ok=True)
OUTPUT_DIR.mkdir(exist_ok=True)

# Extract dataset
if DATASET_TAR.exists():
    print(f"üì¶ Extracting {DATASET_TAR.name}...")
    with tarfile.open(DATASET_TAR, 'r:gz') as tar:
        tar.extractall(DATASET_DIR)
    print(f"‚úÖ Extracted to {DATASET_DIR}")
else:
    raise FileNotFoundError(f"Dataset not found: {DATASET_TAR}")

# List contents
print("\nüìÇ Dataset structure:")
for item in sorted(DATASET_DIR.iterdir()):
    if item.is_dir():
        count = len(list(item.rglob('*.wav'))) + len(list(item.rglob('*.mp3')))
        print(f"  {item.name}/: {count} audio files")

## 4Ô∏è‚É£ Configuration (A100 Optimized, x86_64 Target)

In [None]:
from dataclasses import dataclass, field, asdict
from typing import List, Optional, Dict, Any
import json

@dataclass
class Config:
    """A100-optimized configuration for x86_64 deployment"""
    # Paths
    data_root: str = "/content/dataset"
    output_dir: str = "/content/output"
    
    # Audio parameters
    sample_rate: int = 16000
    audio_duration: float = 2.0  # 2s for x86_64 (more context)
    n_mels: int = 80  # High resolution for x86_64
    n_fft: int = 512
    hop_length: int = 160
    
    # Model
    architecture: str = "resnet18"
    num_classes: int = 2
    dropout: float = 0.4
    
    # Training (A100 optimized)
    batch_size: int = 256  # Large batch for A100
    epochs: int = 100
    learning_rate: float = 0.001
    weight_decay: float = 0.02
    num_workers: int = 4
    early_stopping_patience: int = 20
    
    # Optimizer
    optimizer: str = "adamw"
    scheduler: str = "cosine"
    warmup_epochs: int = 5
    gradient_clip: float = 1.0
    mixed_precision: bool = True
    
    # Loss
    loss_function: str = "focal_loss"
    focal_alpha: float = 0.75
    focal_gamma: float = 2.0
    label_smoothing: float = 0.1
    
    # Augmentation
    time_stretch_range: tuple = (0.85, 1.15)
    pitch_shift_range: tuple = (-3, 3)
    noise_prob: float = 0.6
    noise_snr_range: tuple = (3.0, 20.0)
    rir_prob: float = 0.5
    spec_augment: bool = True
    freq_mask_param: int = 20
    time_mask_param: int = 40
    
    # EMA
    use_ema: bool = True
    ema_decay: float = 0.999
    
    # Checkpointing
    checkpoint_dir: str = "/content/output/checkpoints"
    save_to_drive: bool = True
    drive_checkpoint_dir: str = "/content/drive/My Drive/OpenWakeWord_Backups/checkpoints"

# Create config
config = Config()

# Adjust for non-A100 GPUs
if not IS_A100:
    config.batch_size = 64
    print("‚ö†Ô∏è Non-A100 GPU detected, reduced batch_size to 64")

# Print config
print("üìã Configuration:")
for key, value in asdict(config).items():
    print(f"  {key}: {value}")

## 5Ô∏è‚É£ Core Modules (Inline)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import numpy as np
import librosa
import soundfile as sf
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
import random
import warnings
warnings.filterwarnings('ignore')

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True

set_seed(42)
print("‚úÖ Core imports ready")

## 6Ô∏è‚É£ Audio Processing & Feature Extraction

In [None]:
class AudioProcessor:
    """Audio loading and preprocessing"""
    def __init__(self, sample_rate=16000, duration=2.0):
        self.sample_rate = sample_rate
        self.duration = duration
        self.target_length = int(sample_rate * duration)
    
    def load_audio(self, path):
        """Load and preprocess audio file"""
        try:
            audio, sr = librosa.load(path, sr=self.sample_rate, mono=True)
            # Pad or trim to target length
            if len(audio) < self.target_length:
                audio = np.pad(audio, (0, self.target_length - len(audio)))
            else:
                audio = audio[:self.target_length]
            return audio.astype(np.float32)
        except Exception as e:
            print(f"Error loading {path}: {e}")
            return np.zeros(self.target_length, dtype=np.float32)

class MelSpectrogramExtractor(nn.Module):
    """GPU-accelerated mel spectrogram extraction"""
    def __init__(self, sample_rate=16000, n_mels=80, n_fft=512, hop_length=160):
        super().__init__()
        self.mel_spec = T.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=n_fft,
            hop_length=hop_length,
            n_mels=n_mels,
            power=2.0
        )
        self.amplitude_to_db = T.AmplitudeToDB(stype='power', top_db=80)
    
    def forward(self, waveform):
        # waveform: (B, 1, samples) or (B, samples)
        if waveform.dim() == 2:
            waveform = waveform.unsqueeze(1)
        mel = self.mel_spec(waveform.squeeze(1))
        mel_db = self.amplitude_to_db(mel)
        return mel_db.unsqueeze(1)  # (B, 1, n_mels, time)

print("‚úÖ Audio processors ready")

## 7Ô∏è‚É£ Data Augmentation

In [None]:
class AudioAugmentation(nn.Module):
    """Audio augmentation pipeline"""
    def __init__(self, sample_rate=16000, noise_files=None, rir_files=None,
                 noise_prob=0.5, rir_prob=0.25, snr_range=(5, 20)):
        super().__init__()
        self.sample_rate = sample_rate
        self.noise_prob = noise_prob
        self.rir_prob = rir_prob
        self.snr_range = snr_range
        
        # Load noise files
        self.noises = []
        if noise_files:
            for f in noise_files[:200]:  # Limit for memory
                try:
                    audio, sr = librosa.load(f, sr=sample_rate, mono=True)
                    if len(audio) > sample_rate:  # At least 1s
                        self.noises.append(torch.from_numpy(audio).float())
                except:
                    pass
        print(f"  Loaded {len(self.noises)} noise files")
        
        # Load RIR files
        self.rirs = []
        if rir_files:
            for f in rir_files[:100]:  # Limit for memory
                try:
                    audio, sr = librosa.load(f, sr=sample_rate, mono=True)
                    rir = torch.from_numpy(audio).float()
                    rir = rir / (rir.abs().max() + 1e-8)
                    self.rirs.append(rir)
                except:
                    pass
        print(f"  Loaded {len(self.rirs)} RIR files")
    
    def add_noise(self, waveform):
        if not self.noises or random.random() > self.noise_prob:
            return waveform
        
        noise = random.choice(self.noises)
        target_len = waveform.shape[-1]
        
        # Crop or loop noise
        if len(noise) > target_len:
            start = random.randint(0, len(noise) - target_len)
            noise = noise[start:start + target_len]
        else:
            noise = noise.repeat((target_len // len(noise)) + 1)[:target_len]
        
        noise = noise.to(waveform.device)
        
        # Calculate SNR
        snr = random.uniform(*self.snr_range)
        signal_power = waveform.pow(2).mean()
        noise_power = noise.pow(2).mean()
        scale = torch.sqrt(signal_power / (noise_power * (10 ** (snr / 10)) + 1e-8))
        
        return waveform + scale * noise
    
    def apply_rir(self, waveform):
        if not self.rirs or random.random() > self.rir_prob:
            return waveform
        
        rir = random.choice(self.rirs).to(waveform.device)
        original_len = waveform.shape[-1]
        
        # Convolve
        waveform_padded = F.pad(waveform, (0, len(rir) - 1))
        rir_flipped = rir.flip(0).unsqueeze(0).unsqueeze(0)
        waveform_2d = waveform_padded.unsqueeze(0).unsqueeze(0)
        convolved = F.conv1d(waveform_2d, rir_flipped).squeeze()
        
        # Trim and normalize
        convolved = convolved[:original_len]
        convolved = convolved / (convolved.abs().max() + 1e-8)
        
        # Dry/wet mix
        wet_ratio = random.uniform(0.3, 0.7)
        return waveform * wet_ratio + convolved * (1 - wet_ratio)
    
    def forward(self, waveform):
        waveform = self.add_noise(waveform)
        waveform = self.apply_rir(waveform)
        return waveform

class SpecAugment(nn.Module):
    """SpecAugment for spectrograms"""
    def __init__(self, freq_mask=20, time_mask=40, n_freq=2, n_time=2):
        super().__init__()
        self.freq_mask = T.FrequencyMasking(freq_mask)
        self.time_mask = T.TimeMasking(time_mask)
        self.n_freq = n_freq
        self.n_time = n_time
    
    def forward(self, spec):
        for _ in range(self.n_freq):
            spec = self.freq_mask(spec)
        for _ in range(self.n_time):
            spec = self.time_mask(spec)
        return spec

print("‚úÖ Augmentation modules ready")

## 8Ô∏è‚É£ Dataset Class

In [None]:
class WakewordDataset(Dataset):
    """Wake word detection dataset"""
    def __init__(self, files, labels, audio_processor, augmentation=None):
        self.files = files
        self.labels = labels
        self.audio_processor = audio_processor
        self.augmentation = augmentation
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        audio = self.audio_processor.load_audio(self.files[idx])
        audio_tensor = torch.from_numpy(audio).float()
        
        # Apply augmentation if in training mode
        if self.augmentation is not None:
            audio_tensor = self.augmentation(audio_tensor)
        
        return audio_tensor, self.labels[idx]

def scan_dataset(data_root):
    """Scan dataset folder and return file lists"""
    data_root = Path(data_root)
    files = []
    labels = []
    
    # Positive samples (label = 1)
    pos_dir = data_root / 'positive'
    if pos_dir.exists():
        pos_files = list(pos_dir.rglob('*.wav')) + list(pos_dir.rglob('*.mp3'))
        files.extend(pos_files)
        labels.extend([1] * len(pos_files))
        print(f"  Positive: {len(pos_files)} files")
    
    # Negative samples (label = 0)
    neg_dir = data_root / 'negative'
    if neg_dir.exists():
        neg_files = list(neg_dir.rglob('*.wav')) + list(neg_dir.rglob('*.mp3'))
        files.extend(neg_files)
        labels.extend([0] * len(neg_files))
        print(f"  Negative: {len(neg_files)} files")
    
    # Background noise (for augmentation)
    bg_dir = data_root / 'background'
    bg_files = list(bg_dir.rglob('*.wav')) if bg_dir.exists() else []
    print(f"  Background noise: {len(bg_files)} files")
    
    # RIRs (for augmentation)
    rir_dir = data_root / 'rirs'
    rir_files = list(rir_dir.rglob('*.wav')) + list(rir_dir.rglob('*.flac')) if rir_dir.exists() else []
    print(f"  RIRs: {len(rir_files)} files")
    
    return [str(f) for f in files], labels, bg_files, rir_files

# Scan dataset
print("üìÇ Scanning dataset...")
all_files, all_labels, bg_files, rir_files = scan_dataset(config.data_root)
print(f"\n‚úÖ Total: {len(all_files)} audio files")

## 9Ô∏è‚É£ Train/Val/Test Split & DataLoaders

In [None]:
# Split dataset
train_files, temp_files, train_labels, temp_labels = train_test_split(
    all_files, all_labels, test_size=0.3, random_state=42, stratify=all_labels
)
val_files, test_files, val_labels, test_labels = train_test_split(
    temp_files, temp_labels, test_size=0.5, random_state=42, stratify=temp_labels
)

print(f"üìä Dataset splits:")
print(f"  Train: {len(train_files)} (pos: {sum(train_labels)}, neg: {len(train_labels)-sum(train_labels)})")
print(f"  Val:   {len(val_files)} (pos: {sum(val_labels)}, neg: {len(val_labels)-sum(val_labels)})")
print(f"  Test:  {len(test_files)} (pos: {sum(test_labels)}, neg: {len(test_labels)-sum(test_labels)})")

# Create processors
audio_processor = AudioProcessor(config.sample_rate, config.audio_duration)

# Create augmentation (for training only)
print("\nüîä Loading augmentation files...")
train_augmentation = AudioAugmentation(
    sample_rate=config.sample_rate,
    noise_files=bg_files,
    rir_files=rir_files,
    noise_prob=config.noise_prob,
    rir_prob=config.rir_prob,
    snr_range=config.noise_snr_range
)

# Create datasets
train_dataset = WakewordDataset(train_files, train_labels, audio_processor, train_augmentation)
val_dataset = WakewordDataset(val_files, val_labels, audio_processor, None)
test_dataset = WakewordDataset(test_files, test_labels, audio_processor, None)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, 
                          num_workers=config.num_workers, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False,
                        num_workers=config.num_workers, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False,
                         num_workers=config.num_workers, pin_memory=True)

print(f"\n‚úÖ DataLoaders ready: {len(train_loader)} train batches, {len(val_loader)} val batches")

## üîü Model Architecture (ResNet18)

In [None]:
import torchvision.models as models

class ResNet18Wakeword(nn.Module):
    """ResNet18 for wake word detection"""
    def __init__(self, num_classes=2, dropout=0.4):
        super().__init__()
        self.resnet = models.resnet18(weights=None)
        # Modify first conv for single channel input
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Replace classifier
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(num_features, num_classes)
        )
    
    def forward(self, x):
        return self.resnet(x)

# Create model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ResNet18Wakeword(num_classes=config.num_classes, dropout=config.dropout)
model = model.to(device)

# Use channels_last for better performance
model = model.to(memory_format=torch.channels_last)

# Compile for A100 (PyTorch 2.0+)
if IS_A100 and hasattr(torch, 'compile'):
    model = torch.compile(model, mode='max-autotune')
    print("‚úÖ torch.compile enabled (max-autotune)")

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nüß† Model: ResNet18")
print(f"  Total params: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")

## 1Ô∏è‚É£1Ô∏è‚É£ Loss Function & Optimizer

In [None]:
class FocalLoss(nn.Module):
    """Focal Loss for imbalanced classification"""
    def __init__(self, alpha=0.75, gamma=2.0, label_smoothing=0.1):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.label_smoothing = label_smoothing
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', 
                                  label_smoothing=self.label_smoothing)
        pt = torch.exp(-ce_loss)
        
        # Apply alpha weighting
        alpha_t = torch.where(targets == 1, self.alpha, 1 - self.alpha)
        focal_loss = alpha_t * (1 - pt) ** self.gamma * ce_loss
        
        return focal_loss.mean()

# Create loss, optimizer, scheduler
criterion = FocalLoss(alpha=config.focal_alpha, gamma=config.focal_gamma, 
                      label_smoothing=config.label_smoothing)

optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, 
                              weight_decay=config.weight_decay)

# Cosine scheduler with warmup
total_steps = len(train_loader) * config.epochs
warmup_steps = len(train_loader) * config.warmup_epochs

def lr_lambda(step):
    if step < warmup_steps:
        return step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return 0.5 * (1 + np.cos(np.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Mixed precision scaler
scaler = torch.cuda.amp.GradScaler(enabled=config.mixed_precision)

# Feature extractor & SpecAugment
mel_extractor = MelSpectrogramExtractor(
    sample_rate=config.sample_rate, n_mels=config.n_mels,
    n_fft=config.n_fft, hop_length=config.hop_length
).to(device)

spec_augment = SpecAugment(
    freq_mask=config.freq_mask_param, time_mask=config.time_mask_param
).to(device) if config.spec_augment else None

print("‚úÖ Loss, optimizer, scheduler ready")

## 1Ô∏è‚É£2Ô∏è‚É£ EMA (Exponential Moving Average)

In [None]:
class EMA:
    """Exponential Moving Average for model weights"""
    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
    
    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = self.decay * self.shadow[name] + (1 - self.decay) * param.data
    
    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]
    
    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
        self.backup = {}

ema = EMA(model, decay=config.ema_decay) if config.use_ema else None
print(f"‚úÖ EMA: {'enabled' if ema else 'disabled'}")

## 1Ô∏è‚É£3Ô∏è‚É£ Training Loop

In [None]:
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
import time

def train_epoch(model, loader, criterion, optimizer, scheduler, scaler, mel_extractor, spec_augment, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training', leave=False)
    for audio, labels in pbar:
        audio = audio.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)
        
        with torch.cuda.amp.autocast(enabled=config.mixed_precision):
            # Extract features
            features = mel_extractor(audio)
            features = features.to(memory_format=torch.channels_last)
            
            # Apply SpecAugment
            if spec_augment is not None and model.training:
                features = spec_augment(features)
            
            # Forward pass
            outputs = model(features)
            loss = criterion(outputs, labels)
        
        # Backward pass
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        # Update EMA
        if ema:
            ema.update()
        
        # Metrics
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100*correct/total:.1f}%'})
    
    return total_loss / len(loader), correct / total

@torch.no_grad()
def validate(model, loader, criterion, mel_extractor, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    all_probs = []
    
    for audio, labels in tqdm(loader, desc='Validating', leave=False):
        audio = audio.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        
        with torch.cuda.amp.autocast(enabled=config.mixed_precision):
            features = mel_extractor(audio)
            features = features.to(memory_format=torch.channels_last)
            outputs = model(features)
            loss = criterion(outputs, labels)
        
        total_loss += loss.item()
        probs = F.softmax(outputs, dim=1)
        _, predicted = outputs.max(1)
        
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs[:, 1].cpu().numpy())
    
    # Calculate metrics
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    acc = (all_preds == all_labels).mean()
    f1 = f1_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    cm = confusion_matrix(all_labels, all_preds)
    
    # FPR, FNR
    tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
    fpr = fp / (fp + tn) if (fp + tn) > 0 else 0
    fnr = fn / (fn + tp) if (fn + tp) > 0 else 0
    
    return {
        'loss': total_loss / len(loader),
        'acc': acc, 'f1': f1, 'precision': precision, 'recall': recall,
        'fpr': fpr, 'fnr': fnr, 'cm': cm, 'probs': all_probs, 'labels': all_labels
    }

print("‚úÖ Training functions ready")

## 1Ô∏è‚É£4Ô∏è‚É£ Training Execution

In [None]:
# Training state
best_f1 = 0
best_epoch = 0
epochs_without_improvement = 0
history = {'train_loss': [], 'val_loss': [], 'val_f1': [], 'val_fpr': [], 'val_fnr': []}

# Checkpointing
checkpoint_dir = Path(config.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)

print(f"üöÄ Starting training for {config.epochs} epochs...")
print(f"   Batch size: {config.batch_size}, LR: {config.learning_rate}")
print("=" * 60)

start_time = time.time()

for epoch in range(config.epochs):
    epoch_start = time.time()
    
    # Train
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, scheduler, 
        scaler, mel_extractor, spec_augment, device
    )
    
    # Validate (use EMA weights if available)
    if ema:
        ema.apply_shadow()
    
    val_metrics = validate(model, val_loader, criterion, mel_extractor, device)
    
    if ema:
        ema.restore()
    
    # Update history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_metrics['loss'])
    history['val_f1'].append(val_metrics['f1'])
    history['val_fpr'].append(val_metrics['fpr'])
    history['val_fnr'].append(val_metrics['fnr'])
    
    # Check improvement
    if val_metrics['f1'] > best_f1:
        best_f1 = val_metrics['f1']
        best_epoch = epoch
        epochs_without_improvement = 0
        
        # Save best model
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict() if not hasattr(model, '_orig_mod') else model._orig_mod.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_metrics': val_metrics,
            'config': asdict(config)
        }
        torch.save(checkpoint, checkpoint_dir / 'best_model.pt')
        improved = "‚úÖ NEW BEST"
    else:
        epochs_without_improvement += 1
        improved = ""
    
    epoch_time = time.time() - epoch_start
    lr = scheduler.get_last_lr()[0]
    
    print(f"Epoch {epoch+1:3d}/{config.epochs} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_metrics['loss']:.4f} | "
          f"F1: {val_metrics['f1']:.4f} | "
          f"FPR: {val_metrics['fpr']:.4f} | "
          f"FNR: {val_metrics['fnr']:.4f} | "
          f"LR: {lr:.6f} | "
          f"{epoch_time:.1f}s {improved}")
    
    # Early stopping
    if epochs_without_improvement >= config.early_stopping_patience:
        print(f"\n‚èπÔ∏è Early stopping at epoch {epoch+1}")
        break

total_time = time.time() - start_time
print("=" * 60)
print(f"‚úÖ Training complete in {total_time/3600:.2f} hours")
print(f"üèÜ Best F1: {best_f1:.4f} at epoch {best_epoch+1}")

# Copy best model to Drive
if config.save_to_drive:
    drive_dir = Path(config.drive_checkpoint_dir)
    drive_dir.mkdir(parents=True, exist_ok=True)
    import shutil
    shutil.copy(checkpoint_dir / 'best_model.pt', drive_dir / 'best_model.pt')
    print(f"üíæ Model saved to Drive: {drive_dir / 'best_model.pt'}")

## 1Ô∏è‚É£5Ô∏è‚É£ Evaluation on Test Set

In [None]:
# Load best model
checkpoint = torch.load(checkpoint_dir / 'best_model.pt', map_location=device)
if hasattr(model, '_orig_mod'):
    model._orig_mod.load_state_dict(checkpoint['model_state_dict'])
else:
    model.load_state_dict(checkpoint['model_state_dict'])

# Evaluate on test set
print("üìä Evaluating on test set...")
test_metrics = validate(model, test_loader, criterion, mel_extractor, device)

print("\n" + "=" * 60)
print("üìà TEST SET RESULTS")
print("=" * 60)
print(f"  Accuracy:  {test_metrics['acc']*100:.2f}%")
print(f"  F1 Score:  {test_metrics['f1']:.4f}")
print(f"  Precision: {test_metrics['precision']:.4f}")
print(f"  Recall:    {test_metrics['recall']:.4f}")
print(f"  FPR:       {test_metrics['fpr']:.4f}")
print(f"  FNR:       {test_metrics['fnr']:.4f}")
print("\nConfusion Matrix:")
print(f"  TN={test_metrics['cm'][0,0]:5d}  FP={test_metrics['cm'][0,1]:5d}")
print(f"  FN={test_metrics['cm'][1,0]:5d}  TP={test_metrics['cm'][1,1]:5d}")

## 1Ô∏è‚É£6Ô∏è‚É£ Training Visualization

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss curves
axes[0, 0].plot(history['train_loss'], label='Train')
axes[0, 0].plot(history['val_loss'], label='Validation')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training & Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# F1 Score
axes[0, 1].plot(history['val_f1'], color='green')
axes[0, 1].axhline(y=best_f1, color='r', linestyle='--', label=f'Best: {best_f1:.4f}')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('F1 Score')
axes[0, 1].set_title('Validation F1 Score')
axes[0, 1].legend()
axes[0, 1].grid(True)

# ROC Curve
fpr_curve, tpr_curve, _ = roc_curve(test_metrics['labels'], test_metrics['probs'])
roc_auc = auc(fpr_curve, tpr_curve)
axes[1, 0].plot(fpr_curve, tpr_curve, label=f'ROC (AUC = {roc_auc:.4f})')
axes[1, 0].plot([0, 1], [0, 1], 'k--')
axes[1, 0].set_xlabel('False Positive Rate')
axes[1, 0].set_ylabel('True Positive Rate')
axes[1, 0].set_title('ROC Curve')
axes[1, 0].legend()
axes[1, 0].grid(True)

# FPR/FNR over epochs
axes[1, 1].plot(history['val_fpr'], label='FPR', color='red')
axes[1, 1].plot(history['val_fnr'], label='FNR', color='blue')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Rate')
axes[1, 1].set_title('FPR & FNR Over Training')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'training_results.png', dpi=150)
plt.show()

print(f"üìä Results saved to {OUTPUT_DIR / 'training_results.png'}")

## 1Ô∏è‚É£7Ô∏è‚É£ Export to ONNX

In [None]:
import onnx

# Prepare model for export
if hasattr(model, '_orig_mod'):
    export_model = model._orig_mod
else:
    export_model = model

export_model.eval()
export_model = export_model.to('cpu')

# Create dummy input
n_samples = int(config.sample_rate * config.audio_duration)
n_frames = n_samples // config.hop_length + 1
dummy_input = torch.randn(1, 1, config.n_mels, n_frames)

# Export
onnx_path = OUTPUT_DIR / 'wakeword_model.onnx'
torch.onnx.export(
    export_model,
    dummy_input,
    str(onnx_path),
    opset_version=14,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

# Validate
onnx_model = onnx.load(str(onnx_path))
onnx.checker.check_model(onnx_model)

# Size
size_mb = onnx_path.stat().st_size / (1024 * 1024)

print(f"‚úÖ ONNX model exported: {onnx_path}")
print(f"üì¶ Model size: {size_mb:.2f} MB")

# Copy to Drive
if config.save_to_drive:
    drive_dir = Path(config.drive_checkpoint_dir)
    shutil.copy(onnx_path, drive_dir / 'wakeword_model.onnx')
    print(f"üíæ ONNX saved to Drive: {drive_dir / 'wakeword_model.onnx'}")

## ‚úÖ Training Complete!

**Outputs saved to:**
- Best checkpoint: `/content/output/checkpoints/best_model.pt`
- ONNX model: `/content/output/wakeword_model.onnx`
- Training plots: `/content/output/training_results.png`

**Copied to Google Drive:**
- `/content/drive/My Drive/OpenWakeWord_Backups/checkpoints/`