# VAD Training Pipeline

This notebook provides a complete end-to-end workflow for training a tiny VAD model using knowledge distillation.

## Workflow Steps

1. **Imports & Config** - Load dependencies and configuration
2. **Dataset Discovery & Split** - Scan WAV directory and create train/val/test splits
3. **Preprocessing** - Extract features, generate teacher labels, create chunks
4. **Quick Data Inspection** - Visualize sample chunks and metadata
5. **Training** - Train the tiny VAD model on chunks
6. **Evaluate a Sample WAV** - Test the trained model on a sample audio file
7. **Export ONNX** - Export model to ONNX format for deployment
8. **Final Summary** - Display training results and file locations


## Cell 1: Imports & Config


In [1]:
# Setup Python path to include project root
import sys
from pathlib import Path

# Get project root directory (parent of notebooks directory)
notebook_dir = Path.cwd()
if notebook_dir.name == 'notebooks':
    project_root = notebook_dir.parent
else:
    # If running from project root, use current directory
    project_root = notebook_dir

# Add project root to Python path
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")
print(f"Working directory: {Path.cwd()}")

# Standard library imports
import torch
import numpy as np
import matplotlib.pyplot as plt
import json
import random
import os

# Import preprocessing pipeline (this handles sampling logic)
from vad_distill.scripts.run_preprocessing_pipeline import run_preprocessing_pipeline
from preprocessing.chunk_config import CHUNK_SIZE, N_MELS, SAMPLE_RATE

# Import dataset and model
from vad_distill.distill.tiny_vad.dataset import TinyVADChunkDataset
from vad_distill.distill.tiny_vad.train import train_tiny_vad
from vad_distill.distill.tiny_vad.model import TinyVADModel, build_tiny_vad_model
from vad_distill.distill.tiny_vad.export_onnx import export_tiny_vad_onnx

# Import test function
from vad_distill.scripts.test_single_wav import test_single_wav

# Import config loader
from vad_distill.utils.config import load_yaml

# Load configuration (use absolute path based on project_root)
config_path = project_root / "vad_distill" / "configs" / "student_tiny_vad.yaml"
config = load_yaml(str(config_path))
print("Configuration loaded:")
print(json.dumps(config, indent=2))

# Check CUDA availability
print(f"\nCUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")


Project root: h:\Personal Projects\VAD_Training
Working directory: h:\Personal Projects\VAD_Training\notebooks
Configuration loaded:
{
  "chunk_size": 100,
  "stride": 50,
  "n_mels": 80,
  "chunks_dir": "data/chunks",
  "teacher_labels_dir": "data/teacher_labels",
  "data": {
    "sample_list": "data/sample_list.txt",
    "num_samples": 2000
  },
  "batch_size": 32,
  "learning_rate": 0.001,
  "epochs": 100,
  "num_workers": 4,
  "random_seed": 2025,
  "use_tensorboard": true,
  "save_every_n_steps": 1000,
  "validate_every_n_epochs": 5,
  "resume_from": null,
  "model": {
    "n_mels": 80,
    "hidden_dims": [
      32,
      64,
      32
    ],
    "kernel_sizes": [
      3,
      3,
      3
    ],
    "dropout": 0.1
  },
  "paths": {
    "checkpoint_dir": "vad_distill/distill/tiny_vad/checkpoints",
    "onnx_dir": "models/tiny_vad/onnx",
    "logs_dir": "logs"
  }
}

CUDA available: True
CUDA device: NVIDIA GeForce RTX 3080 Ti
CUDA version: 11.8


## Cell 2: Dataset Discovery & Split


In [2]:
# User specifies audio directory path
# For WenetSpeech dataset, use: "H:/wenet_data/audio/train/podcast" or similar
audio_dir = Path("H:/wenet_data/audio/train/podcast")  # CHANGE THIS to your audio directory

# Recursively scan for audio files (WAV and OPUS)
# WenetSpeech uses OPUS format, but we support both
audio_files = []
if audio_dir.exists():
    # Recursive search for audio files
    audio_files.extend(audio_dir.rglob("*.wav"))
    audio_files.extend(audio_dir.rglob("*.WAV"))
    audio_files.extend(audio_dir.rglob("*.opus"))
    audio_files.extend(audio_dir.rglob("*.OPUS"))
    audio_files = sorted(list(set(audio_files)))  # Remove duplicates and sort
else:
    print(f"Warning: Directory not found: {audio_dir}")
    print("Please update audio_dir to point to your audio files directory")

print(f"Found {len(audio_files)} audio files")

# Calculate total duration (approximate)
if len(audio_files) > 0:
    from vad_distill.utils.audio_io import load_wav
    total_duration = 0.0
    sample_size = min(10, len(audio_files))
    successful_samples = 0
    
    for audio_file in audio_files[:sample_size]:
        try:
            waveform = load_wav(str(audio_file), target_sr=SAMPLE_RATE)
            total_duration += len(waveform) / SAMPLE_RATE
            successful_samples += 1
        except Exception as e:
            print(f"Warning: Failed to load {audio_file.name}: {e}")
            pass
    
    if successful_samples > 0:
        avg_duration = total_duration / successful_samples
        estimated_total = avg_duration * len(audio_files)
        print(f"Estimated total duration: {estimated_total:.2f} seconds ({estimated_total/3600:.2f} hours)")
        print(f"Average file duration: {avg_duration:.2f} seconds")
    else:
        print("Warning: Could not load any sample files for duration estimation")
else:
    print("No audio files found. Please check the directory path.")

# Check if we have files before splitting
if len(audio_files) == 0:
    print("\nERROR: No audio files found. Cannot proceed with dataset split.")
    print("Please check the audio_dir path and ensure it contains audio files.")
    train_files = []
    val_files = []
    test_files = []
else:
    # Shuffle with fixed seed
    seed = config.get('random_seed', 2025)
    random.seed(seed)
    audio_files_shuffled = audio_files.copy()
    random.shuffle(audio_files_shuffled)
    
    # Split into train/val/test (70/15/15)
    num_total = len(audio_files_shuffled)
    num_train = int(0.7 * num_total)
    num_val = int(0.15 * num_total)
    num_test = num_total - num_train - num_val
    
    train_files = audio_files_shuffled[:num_train]
    val_files = audio_files_shuffled[num_train:num_train+num_val]
    test_files = audio_files_shuffled[num_train+num_val:]
    
    print(f"\nSplit: Train={len(train_files)}, Val={len(val_files)}, Test={len(test_files)}")

# Create manifest file for preprocessing pipeline (JSONL format)
# The preprocessing pipeline will apply sampling logic based on config
if len(train_files) > 0:
    splits_dir = project_root / "data" / "splits"
    splits_dir.mkdir(parents=True, exist_ok=True)
    
    # Create manifest file for training (JSONL format)
    manifest_path = splits_dir / "train_manifest.jsonl"
    with open(manifest_path, 'w', encoding='utf-8') as f:
        for audio_file in train_files:
            item = {
                "utt_id": audio_file.stem,
                "wav_path": str(audio_file.resolve())
            }
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    # Also save split lists for reference (val and test)
    with open(splits_dir / "val_list.txt", "w") as f:
        for audio_file in val_files:
            f.write(f"{audio_file}\n")
    
    with open(splits_dir / "test_list.txt", "w") as f:
        for audio_file in test_files:
            f.write(f"{audio_file}\n")
    
    print(f"\nManifest file created: {manifest_path}")
    print(f"Total files in manifest: {len(train_files)}")
    print(f"Note: Preprocessing will apply sampling logic from config")
    print(f"  - sample_list: {config.get('data', {}).get('sample_list', 'not configured')}")
    print(f"  - num_samples: {config.get('data', {}).get('num_samples', 'not configured')}")
    print("\nSample train files:")
    for audio_file in train_files[:3]:
        print(f"  {audio_file.name}")
else:
    print("\nNo manifest file created (no files found)")


Found 37325 audio files
Estimated total duration: 12694238.80 seconds (3526.18 hours)
Average file duration: 340.10 seconds

Split: Train=26127, Val=5598, Test=5600

Manifest file created: h:\Personal Projects\VAD_Training\data\splits\train_manifest.jsonl
Total files in manifest: 26127
Note: Preprocessing will apply sampling logic from config
  - sample_list: data/sample_list.txt
  - num_samples: 2000

Sample train files:
  X0000027864_330958324.opus
  X0000003112_4617142.opus
  X0000027920_331675656.opus


## Cell 3: Preprocessing


In [None]:
# Run preprocessing pipeline using run_preprocessing_pipeline
# This will automatically apply deterministic sampling based on config.data.sample_list
# The pipeline will only process files from the sampled list

import time

# Setup paths
output_root = project_root / "data"
teacher_model_dir = project_root / "teacher"
manifest_path = project_root / "data" / "splits" / "train_manifest.jsonl"

# Check if manifest exists
if not manifest_path.exists():
    print(f"ERROR: Manifest file not found: {manifest_path}")
    print("Please run Cell 4 first to create the manifest file.")
else:
    print(f"Using manifest file: {manifest_path}")
    print(f"Configuration:")
    print(f"  - sample_list: {config.get('data', {}).get('sample_list', 'not configured')}")
    print(f"  - num_samples: {config.get('data', {}).get('num_samples', 'not configured')}")
    print()
    
    # Run preprocessing pipeline
    # This will automatically:
    # 1. Load all WAV files from manifest
    # 2. Apply deterministic sampling (if configured)
    # 3. Process only the sampled files
    start_time = time.time()
    
    stats = run_preprocessing_pipeline(
        manifest_path=str(manifest_path),
        output_root=str(output_root),
        teacher_model_dir=str(teacher_model_dir),
        device="cpu",
        config=config,
    )
    
    # Calculate total time
    total_time = time.time() - start_time
    hours = int(total_time // 3600)
    minutes = int((total_time % 3600) // 60)
    
    # Print summary
    print(f"\nPreprocessing complete!")
    print(f"Files processed: {stats['success_count']}")
    print(f"Total chunks: {stats['total_chunks']:,}")
    if hours > 0:
        print(f"Total time: {hours}h {minutes}m")
    else:
        print(f"Total time: {minutes}m")
    print(f"Chunks directory: {stats['chunks_dir']}")


Using manifest file: h:\Personal Projects\VAD_Training\data\splits\train_manifest.jsonl
Configuration:
  - sample_list: data/sample_list.txt
  - num_samples: 2000

Loaded fixed sample list from: H:\Personal Projects\VAD_Training\data\sample_list.txt
Using fixed subset: 2000 files
Notice: ffmpeg is not installed. torchaudio is used to load audio
If you want to use ffmpeg backend to load audio, please install it by:
	sudo apt install ffmpeg # ubuntu
	# brew install ffmpeg # mac


Processing audio files:   6%|â–Œ         | 111/2000 [05:05<1:36:26,  3.06s/it]

In [None]:
# Load a sample chunk (use project_root for absolute paths)
chunks_dir = project_root / "data" / "chunks"
chunk_files = sorted(chunks_dir.glob("chunk_*.npy"))
if len(chunk_files) > 0:
    sample_chunk_path = chunk_files[0]
    chunk_data = np.load(sample_chunk_path, allow_pickle=True).item()
    
    features = chunk_data['features']  # (100, 80)
    labels = chunk_data['labels']      # (100,)
    
    print(f"Chunk shape: features={features.shape}, labels={labels.shape}")
    print(f"Chunk UID: {chunk_data.get('uid', 'unknown')}")
    
    # Visualize spectrogram
    plt.figure(figsize=(12, 6))
    
    plt.subplot(2, 1, 1)
    plt.imshow(features.T, aspect='auto', origin='lower', cmap='viridis')
    plt.colorbar(label='Log-mel energy')
    plt.title('Log-mel Spectrogram')
    plt.xlabel('Frame index')
    plt.ylabel('Mel bin')
    
    # Overlay labels
    plt.subplot(2, 1, 2)
    plt.plot(labels, 'r-', linewidth=2, label='Teacher VAD probability')
    plt.axhline(y=0.5, color='k', linestyle='--', alpha=0.5, label='Threshold')
    plt.xlabel('Frame index')
    plt.ylabel('VAD probability')
    plt.title('Teacher VAD Labels')
    plt.legend()
    plt.ylim([0, 1])
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print metadata
    metadata_path = chunks_dir / "metadata.json"
    if metadata_path.exists():
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        print("\nMetadata.json content:")
        print(json.dumps(metadata, indent=2))
else:
    print("No chunks found. Run Cell 3 first.")


## Cell 5: Training


In [None]:
# Train the model using existing train_tiny_vad function
# This function handles dataset creation, splitting, and training loop internally
train_tiny_vad(config)

# Load training history for plotting (use project_root for absolute paths)
checkpoint_dir_rel = config.get('paths', {}).get('checkpoint_dir', 'vad_distill/distill/tiny_vad/checkpoints')
checkpoint_dir = project_root / checkpoint_dir_rel if not Path(checkpoint_dir_rel).is_absolute() else Path(checkpoint_dir_rel)
history_path = checkpoint_dir / "train_history.json"

if history_path.exists():
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    # Extract training curves
    epochs = [h['epoch'] for h in history]
    train_losses = [h['train_loss'] for h in history]
    val_losses = [h['val_loss'] for h in history if h['val_loss'] is not None]
    val_epochs = [h['epoch'] for h in history if h['val_loss'] is not None]
    
    # Plot training curves
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2)
    if len(val_losses) > 0:
        plt.plot(val_epochs, val_losses, 'r-', label='Val Loss', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Curves')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print(f"\nTraining complete!")
    print(f"Best checkpoint: {checkpoint_dir / 'best.pt'}")
    print(f"Final train loss: {train_losses[-1]:.6f}")
    if len(val_losses) > 0:
        print(f"Best val loss: {min(val_losses):.6f}")
else:
    print("Training history not found. Check logs for details.")


## Cell 6: Evaluate a Sample WAV


In [None]:
# Select a sample WAV file for evaluation
sample_wav = test_files[0] if len(test_files) > 0 else train_files[0]
print(f"Evaluating: {sample_wav}")

# Load model checkpoint (use project_root for absolute paths)
# Ensure checkpoint_dir is defined (from Cell 5 or define here)
if 'checkpoint_dir' not in globals():
    checkpoint_dir_rel = config.get('paths', {}).get('checkpoint_dir', 'vad_distill/distill/tiny_vad/checkpoints')
    checkpoint_dir = project_root / checkpoint_dir_rel if not Path(checkpoint_dir_rel).is_absolute() else Path(checkpoint_dir_rel)
checkpoint_path = checkpoint_dir / "best.pt"
config_path = str(project_root / "vad_distill" / "configs" / "student_tiny_vad.yaml")
output_dir = project_root / "outputs"
output_dir.mkdir(parents=True, exist_ok=True)

# Run inference using existing test function
test_single_wav(
    wav_path=sample_wav,
    model_path=checkpoint_path,
    output_dir=output_dir,
    config_path=config_path,
    threshold=0.5,
    use_onnx=False,
)

# Load and visualize results
wav_name = sample_wav.stem
scores_path = output_dir / f"{wav_name}_scores.npy"
segments_path = output_dir / f"{wav_name}_segments.json"

if scores_path.exists():
    scores = np.load(scores_path)
    
    # Load audio for waveform display
    from vad_distill.utils.audio_io import load_wav
    wav = load_wav(str(sample_wav), target_sr=SAMPLE_RATE)
    time_axis = np.arange(len(wav)) / SAMPLE_RATE
    frame_time = np.arange(len(scores)) * 0.01  # 10ms per frame
    
    # Plot results
    plt.figure(figsize=(14, 8))
    
    # Waveform
    plt.subplot(3, 1, 1)
    plt.plot(time_axis, wav, 'b-', linewidth=0.5)
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    plt.title('Waveform')
    plt.grid(True, alpha=0.3)
    
    # VAD scores
    plt.subplot(3, 1, 2)
    plt.plot(frame_time, scores, 'r-', linewidth=1.5, label='VAD probability')
    plt.axhline(y=0.5, color='k', linestyle='--', alpha=0.5, label='Threshold')
    plt.xlabel('Time (s)')
    plt.ylabel('VAD probability')
    plt.title('Model VAD Output (Postprocessed)')
    plt.legend()
    plt.ylim([0, 1])
    plt.grid(True, alpha=0.3)
    
    # Segments overlay
    if segments_path.exists():
        with open(segments_path, 'r') as f:
            segments = json.load(f)
        plt.subplot(3, 1, 3)
        for start, end in segments:
            plt.axvspan(start, end, alpha=0.3, color='green')
        plt.plot(frame_time, scores, 'r-', linewidth=1.5)
        plt.xlabel('Time (s)')
        plt.ylabel('VAD probability')
        plt.title('Speech Segments (Green)')
        plt.ylim([0, 1])
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nResults saved to {output_dir}")
else:
    print("Inference output not found.")


## Cell 7: Export ONNX


In [None]:
# Export model to ONNX format
export_tiny_vad_onnx(config, checkpoint_path=None)

# Verify ONNX file exists (use project_root for absolute paths)
onnx_dir_rel = config.get('paths', {}).get('onnx_dir', 'models/tiny_vad/onnx')
onnx_dir = project_root / onnx_dir_rel if not Path(onnx_dir_rel).is_absolute() else Path(onnx_dir_rel)
onnx_path = onnx_dir / "tiny_vad.onnx"

if onnx_path.exists():
    print(f"ONNX export OK: {onnx_path}")
    print(f"File size: {onnx_path.stat().st_size / 1024:.2f} KB")
    
    # Optional: Test ONNX inference
    try:
        import onnxruntime as ort
        import numpy as np
        
        session = ort.InferenceSession(str(onnx_path))
        dummy_input = np.random.randn(1, 100, 80).astype(np.float32)
        outputs = session.run(None, {'mel_features': dummy_input})
        print(f"ONNX inference test passed. Output shape: {outputs[0].shape}")
    except ImportError:
        print("onnxruntime not available, skipping ONNX inference test")
    except Exception as e:
        print(f"ONNX inference test failed: {e}")
else:
    print(f"ONNX export failed. Check logs for details.")


## Cell 8: Final Summary


In [None]:
# Print final summary
print("=" * 60)
print("VAD TRAINING PIPELINE SUMMARY")
print("=" * 60)

# Ensure checkpoint_dir and onnx_dir are defined
if 'checkpoint_dir' not in globals():
    checkpoint_dir_rel = config.get('paths', {}).get('checkpoint_dir', 'vad_distill/distill/tiny_vad/checkpoints')
    checkpoint_dir = project_root / checkpoint_dir_rel if not Path(checkpoint_dir_rel).is_absolute() else Path(checkpoint_dir_rel)
if 'onnx_dir' not in globals():
    onnx_dir_rel = config.get('paths', {}).get('onnx_dir', 'models/tiny_vad/onnx')
    onnx_dir = project_root / onnx_dir_rel if not Path(onnx_dir_rel).is_absolute() else Path(onnx_dir_rel)
if 'config_path' not in globals():
    config_path = str(project_root / "vad_distill" / "configs" / "student_tiny_vad.yaml")

# Checkpoint location
checkpoint_path = checkpoint_dir / "best.pt"
print(f"\nCheckpoint saved: {checkpoint_path}")
print(f"  Exists: {checkpoint_path.exists()}")

# Chunks location (use project_root for absolute paths)
chunks_dir = project_root / "data" / "chunks"
chunk_files = list(chunks_dir.glob("chunk_*.npy"))
print(f"\nChunks directory: {chunks_dir}")
print(f"  Total chunks: {len(chunk_files)}")
print(f"  Metadata: {chunks_dir / 'metadata.json'}")

# Config used
print(f"\nConfig used: {config_path}")
print(f"  Batch size: {config.get('batch_size', 'N/A')}")
print(f"  Learning rate: {config.get('learning_rate', 'N/A')}")
print(f"  Epochs: {config.get('epochs', 'N/A')}")

# ONNX file
onnx_path = onnx_dir / "tiny_vad.onnx"
print(f"\nONNX file: {onnx_path}")
print(f"  Exists: {onnx_path.exists()}")

# Training history
if history_path.exists():
    with open(history_path, 'r') as f:
        history = json.load(f)
    final_epoch = history[-1]['epoch']
    final_train_loss = history[-1]['train_loss']
    val_losses = [h['val_loss'] for h in history if h['val_loss'] is not None]
    print(f"\nTraining completed:")
    print(f"  Final epoch: {final_epoch}")
    print(f"  Final train loss: {final_train_loss:.6f}")
    if len(val_losses) > 0:
        print(f"  Best val loss: {min(val_losses):.6f}")

print("\n" + "=" * 60)
