From e26445dafe65d24b1dac57bde7d28ad41a161805 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Mon, 11 May 2026 22:18:34 +1200 Subject: [PATCH] feat: extend accuracy report to cover wavekat-zh backend Reference probabilities are now keyed by `(backend, file)` so multiple Smart Turn variants can coexist. `make accuracy` builds with `wavekat-smart-turn` and emits one row per pair, covering upstream Pipecat on both English and Mandarin fixtures plus the WaveKat zh fine-tune on the Mandarin fixtures. All 9 rows pass within the 0.02 tolerance; pipecat-only and no-feature builds remain green. Co-Authored-By: Claude Opus 4.7 (1M context) --- Makefile | 6 +- crates/wavekat-turn/tests/accuracy.rs | 147 +++++++++++++++++++------- scripts/README.md | 13 ++- scripts/gen_reference.py | 99 +++++++++++++---- tests/fixtures/reference.json | 33 ++++++ 5 files changed, 235 insertions(+), 63 deletions(-) diff --git a/Makefile b/Makefile index 2ae6151..c293fa8 100644 --- a/Makefile +++ b/Makefile @@ -21,9 +21,11 @@ check: test: cargo test --workspace -# Cross-validate Rust mel+ONNX pipeline against Python reference probabilities +# Cross-validate Rust mel+ONNX pipeline against Python reference probabilities. +# Builds with `wavekat-smart-turn` so the zh fine-tune rows are also emitted; +# WaveKat weights are fetched from HuggingFace on first run (cached in $HF_HOME). accuracy: - cargo test --features pipecat --test accuracy -- --ignored accuracy_report --nocapture + cargo test --features wavekat-smart-turn --test accuracy -- --ignored accuracy_report --nocapture # Compare Rust vs Python mel spectrograms element-wise (requires .npy fixtures) mel: diff --git a/crates/wavekat-turn/tests/accuracy.rs b/crates/wavekat-turn/tests/accuracy.rs index dbb5efd..b384b2d 100644 --- a/crates/wavekat-turn/tests/accuracy.rs +++ b/crates/wavekat-turn/tests/accuracy.rs @@ -1,7 +1,8 @@ //! Cross-validation accuracy test: Rust pipeline vs. Python reference. //! //! Verifies that our mel preprocessing and ONNX inference produce probabilities -//! within ±0.02 of the Python (Pipecat) reference for each fixture audio clip. +//! within ±0.02 of the Python reference for each fixture audio clip, across +//! every enabled backend. //! //! Prerequisites: //! 1. Run `python scripts/gen_reference.py` once to produce @@ -10,6 +11,10 @@ //! //! Run individual regression tests: `cargo test --features pipecat --test accuracy` //! Run the full report table: `make accuracy` +//! +//! When the `wavekat-smart-turn` feature is enabled, the report additionally +//! exercises the WaveKat zh fine-tune against the `zh_*.wav` fixtures. Weights +//! are downloaded from HuggingFace on first run (cached under `$HF_HOME/hub/`). use std::path::PathBuf; @@ -34,10 +39,19 @@ fn fixtures_dir() -> PathBuf { #[cfg(any(feature = "pipecat"))] #[derive(serde::Deserialize)] struct RefEntry { + /// Which backend produced this reference probability. + /// Defaults to "pipecat" so older `reference.json` files keep working. + #[serde(default = "default_backend")] + backend: String, file: String, probability: f32, } +#[cfg(any(feature = "pipecat"))] +fn default_backend() -> String { + "pipecat".to_string() +} + #[cfg(any(feature = "pipecat"))] fn load_reference() -> Vec { let path = fixtures_dir().join("reference.json"); @@ -50,6 +64,11 @@ fn load_reference() -> Vec { serde_json::from_str(&json).expect("invalid reference.json") } +#[cfg(any(feature = "pipecat"))] +fn entries_for<'a>(entries: &'a [RefEntry], backend: &str) -> Vec<&'a RefEntry> { + entries.iter().filter(|e| e.backend == backend).collect() +} + // --------------------------------------------------------------------------- // Report row — one entry per (backend, clip) // --------------------------------------------------------------------------- @@ -75,44 +94,57 @@ impl Row { } } +// --------------------------------------------------------------------------- +// Shared audio helpers used by backend modules +// --------------------------------------------------------------------------- + +#[cfg(feature = "pipecat")] +fn load_wav_f32(path: &std::path::Path) -> Vec { + let mut reader = hound::WavReader::open(path) + .unwrap_or_else(|e| panic!("failed to open {}: {}", path.display(), e)); + let spec = reader.spec(); + assert_eq!(spec.sample_rate, 16_000, "expected 16 kHz"); + assert_eq!(spec.channels, 1, "expected mono"); + match spec.sample_format { + hound::SampleFormat::Int => reader + .samples::() + .map(|s| s.unwrap() as f32 / 32768.0) // match soundfile's normalization + .collect(), + hound::SampleFormat::Float => reader.samples::().map(|s| s.unwrap()).collect(), + } +} + +#[cfg(feature = "pipecat")] +fn raw_prob(pred: &wavekat_turn::TurnPrediction) -> f32 { + use wavekat_turn::TurnState; + match pred.state { + TurnState::Finished => pred.confidence, + TurnState::Unfinished => 1.0 - pred.confidence, + TurnState::Wait => unreachable!(), + } +} + // --------------------------------------------------------------------------- // Pipecat backend // --------------------------------------------------------------------------- #[cfg(feature = "pipecat")] mod pipecat { - use std::path::Path; - use wavekat_turn::audio::PipecatSmartTurn; - use wavekat_turn::{AudioFrame, AudioTurnDetector, TurnPrediction, TurnState}; - - use super::{fixtures_dir, RefEntry, Row, TOLERANCE}; - - fn load_wav_f32(path: &Path) -> Vec { - let mut reader = hound::WavReader::open(path) - .unwrap_or_else(|e| panic!("failed to open {}: {}", path.display(), e)); - let spec = reader.spec(); - assert_eq!(spec.sample_rate, 16_000, "expected 16 kHz"); - assert_eq!(spec.channels, 1, "expected mono"); - match spec.sample_format { - hound::SampleFormat::Int => reader - .samples::() - .map(|s| s.unwrap() as f32 / 32768.0) // match soundfile's normalization - .collect(), - hound::SampleFormat::Float => reader.samples::().map(|s| s.unwrap()).collect(), - } - } + use wavekat_turn::{AudioFrame, AudioTurnDetector}; + + use super::{entries_for, fixtures_dir, load_wav_f32, raw_prob, RefEntry, Row, TOLERANCE}; fn reference_prob(entries: &[RefEntry], name: &str) -> f32 { entries .iter() - .find(|e| e.file == name) - .unwrap_or_else(|| panic!("no entry for '{}' in reference.json", name)) + .find(|e| e.backend == "pipecat" && e.file == name) + .unwrap_or_else(|| panic!("no pipecat entry for '{}' in reference.json", name)) .probability } pub(super) fn rows(entries: &[RefEntry]) -> Vec { - entries + entries_for(entries, "pipecat") .iter() .map(|entry| { let samples = load_wav_f32(&fixtures_dir().join(&entry.file)); @@ -132,18 +164,11 @@ mod pipecat { .collect() } - fn raw_prob(pred: &TurnPrediction) -> f32 { - match pred.state { - TurnState::Finished => pred.confidence, - TurnState::Unfinished => 1.0 - pred.confidence, - TurnState::Wait => unreachable!(), - } - } - pub(super) fn run_regression(clip: &str) { let entries = super::load_reference(); let python_prob = reference_prob(&entries, clip); let row = rows(&[RefEntry { + backend: "pipecat".to_string(), file: clip.to_string(), probability: python_prob, }]) @@ -173,12 +198,53 @@ mod pipecat { } } -// Add future audio backends here: +// --------------------------------------------------------------------------- +// WaveKat zh backend (Smart Turn fine-tune) +// --------------------------------------------------------------------------- // -// #[cfg(feature = "livekit-audio")] -// mod livekit_audio { -// pub(super) fn rows(entries: &[super::RefEntry]) -> Vec { ... } -// } +// Loads `wavekat/smart-turn-ONNX` (zh) from HuggingFace on first run. Subsequent +// runs hit the HF cache under `$HF_HOME/hub/`. The shared mel/inference pipeline +// is identical to upstream Pipecat — only the weights differ — so reusing the +// pipecat helpers is intentional. + +#[cfg(feature = "wavekat-smart-turn")] +mod wavekat { + use wavekat_turn::audio::{PipecatSmartTurn, SmartTurnLang, SmartTurnVariant}; + use wavekat_turn::{AudioFrame, AudioTurnDetector}; + + use super::{entries_for, fixtures_dir, load_wav_f32, raw_prob, RefEntry, Row}; + + pub(super) fn rows(entries: &[RefEntry]) -> Vec { + let backend_entries = entries_for(entries, "wavekat-zh"); + if backend_entries.is_empty() { + return Vec::new(); + } + + // Load once, score every clip — the HF download is the slowest step. + let mut detector = + PipecatSmartTurn::with_variant(SmartTurnVariant::Wavekat(SmartTurnLang::Zh)) + .expect("failed to load wavekat zh model from HuggingFace"); + + backend_entries + .iter() + .map(|entry| { + detector.reset(); + let samples = load_wav_f32(&fixtures_dir().join(&entry.file)); + for chunk in samples.chunks(1600) { + detector.push_audio(&AudioFrame::new(chunk, 16_000)); + } + let pred = detector.predict().expect("predict failed"); + let rust_prob = raw_prob(&pred); + Row { + backend: "wavekat-zh", + clip: entry.file.clone(), + python_prob: entry.probability, + rust_prob, + } + }) + .collect() + } +} // --------------------------------------------------------------------------- // Accuracy report — prints a markdown table covering all enabled backends @@ -194,7 +260,12 @@ fn accuracy_report() { #[allow(unused_mut)] let mut r = Vec::new(); #[cfg(feature = "pipecat")] - r.extend(pipecat::rows(&load_reference())); + { + let entries = load_reference(); + r.extend(pipecat::rows(&entries)); + #[cfg(feature = "wavekat-smart-turn")] + r.extend(wavekat::rows(&entries)); + } r }; diff --git a/scripts/README.md b/scripts/README.md index b3774e5..186cbf5 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -28,6 +28,17 @@ scripts/.venv/bin/python3 scripts/gen_reference.py | File | Description | |------|-------------| | `tests/fixtures/silence_2s.wav` | 2 s of zeros at 16 kHz (generated if missing) | -| `tests/fixtures/reference.json` | P(complete) for each fixture clip | +| `tests/fixtures/reference.json` | P(complete) for each `(backend, clip)` pair | + +Each entry in `reference.json` is keyed by `backend` so multiple Smart Turn +variants can coexist: + +- `pipecat` — upstream Pipecat Smart Turn v3 on the English fixtures, plus the + same model run on the Mandarin fixtures as a cross-lingual baseline. +- `wavekat-zh` — WaveKat zh fine-tune of Smart Turn, only run on the + Mandarin fixtures (`zh_*.wav`). + +The Rust accuracy test (`make accuracy`) filters rows by backend at compile +time — feature `wavekat-smart-turn` enables the second group. Commit both files after re-running. diff --git a/scripts/gen_reference.py b/scripts/gen_reference.py index 0a1d1e1..2a6985b 100644 --- a/scripts/gen_reference.py +++ b/scripts/gen_reference.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 -"""Generate reference probabilities from the Pipecat Python pipeline. +"""Generate reference probabilities from the Python pipelines. Outputs tests/fixtures/reference.json for use in the Rust accuracy test. +Each entry is keyed by ``(backend, file)``; the Rust test filters by enabled +backend at compile time. Usage: pip install transformers onnxruntime numpy soundfile @@ -9,9 +11,16 @@ Re-run when: - A fixture WAV changes - - The model version changes (bump MODEL_VERSION in build.rs at the same time) + - A model version changes (bump MODEL_VERSION constants below + build.rs) -Speech fixture source: +Backends covered: + - ``pipecat`` — upstream Pipecat Smart Turn v3.2-cpu, scored on + silence_2s / speech_finished / speech_mid (English). + - ``wavekat-zh`` — WaveKat zh fine-tune of Smart Turn, scored on + zh_speech_finished / zh_speech_finished_short / + zh_speech_mid (Mandarin). + +Speech fixture source (English): speech_finished.wav and speech_mid.wav are original recordings of: "Wavekat knows when you've finished speaking." recorded at 16 kHz mono 16-bit PCM. @@ -34,28 +43,57 @@ FIXTURES = REPO_ROOT / "tests" / "fixtures" SCRIPTS = REPO_ROOT / "scripts" -MODEL_URL = "https://huggingface.co/pipecat-ai/smart-turn-v3/resolve/main/smart-turn-v3.2-cpu.onnx" -MODEL_VERSION = "v3.2-cpu" -MODEL_CACHE = SCRIPTS / f"smart-turn-{MODEL_VERSION}.onnx" +PIPECAT_MODEL_URL = ( + "https://huggingface.co/pipecat-ai/smart-turn-v3/resolve/main/smart-turn-v3.2-cpu.onnx" +) +PIPECAT_MODEL_VERSION = "v3.2-cpu" +PIPECAT_MODEL_CACHE = SCRIPTS / f"smart-turn-{PIPECAT_MODEL_VERSION}.onnx" + +WAVEKAT_ZH_MODEL_URL = ( + "https://huggingface.co/wavekat/smart-turn-ONNX/resolve/main/zh/smart-turn-cpu.onnx" +) +WAVEKAT_ZH_MODEL_VERSION = "wavekat-zh-cpu" +WAVEKAT_ZH_MODEL_CACHE = SCRIPTS / f"smart-turn-{WAVEKAT_ZH_MODEL_VERSION}.onnx" SAMPLE_RATE = 16_000 BUFFER_SAMPLES = 128_000 # 8 seconds at 16 kHz (matches Rust ring buffer) -CLIPS = ["silence_2s.wav", "speech_finished.wav", "speech_mid.wav"] +# (backend, clip) — drives both the Python pipeline and the entry list written +# to reference.json. Add new rows here, then re-run the script. +TASKS: list[tuple[str, str]] = [ + # Pipecat upstream on English fixtures. + ("pipecat", "silence_2s.wav"), + ("pipecat", "speech_finished.wav"), + ("pipecat", "speech_mid.wav"), + # Pipecat upstream on Mandarin fixtures (cross-lingual baseline — useful to + # compare against the wavekat-zh fine-tune below). + ("pipecat", "zh_speech_finished.wav"), + ("pipecat", "zh_speech_finished_short.wav"), + ("pipecat", "zh_speech_mid.wav"), + # WaveKat zh fine-tune on Mandarin fixtures. + ("wavekat-zh", "zh_speech_finished.wav"), + ("wavekat-zh", "zh_speech_finished_short.wav"), + ("wavekat-zh", "zh_speech_mid.wav"), +] + +BACKEND_MODELS: dict[str, tuple[str, Path]] = { + "pipecat": (PIPECAT_MODEL_URL, PIPECAT_MODEL_CACHE), + "wavekat-zh": (WAVEKAT_ZH_MODEL_URL, WAVEKAT_ZH_MODEL_CACHE), +} # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def ensure_model() -> Path: - if MODEL_CACHE.exists(): - return MODEL_CACHE - print(f"Downloading model from {MODEL_URL} ...", flush=True) +def ensure_model(url: str, cache: Path) -> Path: + if cache.exists(): + return cache + print(f"Downloading model from {url} ...", flush=True) SCRIPTS.mkdir(parents=True, exist_ok=True) - urllib.request.urlretrieve(MODEL_URL, MODEL_CACHE) - print(f"Saved to {MODEL_CACHE}", flush=True) - return MODEL_CACHE + urllib.request.urlretrieve(url, cache) + print(f"Saved to {cache}", flush=True) + return cache def ensure_silence() -> None: @@ -80,7 +118,10 @@ def load_audio(path: Path) -> np.ndarray: def infer(audio: np.ndarray, session, extractor) -> tuple[float, np.ndarray]: - """Run the Pipecat pipeline on audio. + """Run a Smart Turn pipeline on audio. + + All Smart Turn variants share the same feature extractor and tensor I/O, + so a single helper works for both pipecat and wavekat models. Returns: (probability, mel_tensor) where mel_tensor has shape [80, 800]. @@ -106,22 +147,36 @@ def main() -> None: sys.exit(1) ensure_silence() - model_path = ensure_model() extractor = WhisperFeatureExtractor(chunk_length=8) - session = ort.InferenceSession(str(model_path)) + + # One ORT session per backend, reused across that backend's clips. + sessions: dict[str, "ort.InferenceSession"] = {} + for backend, (url, cache) in BACKEND_MODELS.items(): + if any(b == backend for b, _ in TASKS): + model_path = ensure_model(url, cache) + sessions[backend] = ort.InferenceSession(str(model_path)) results = [] - for name in CLIPS: + for backend, name in TASKS: path = FIXTURES / name if not path.exists(): print(f"ERROR: missing fixture {path}", file=sys.stderr) sys.exit(1) audio = load_audio(path) - prob, mel = infer(audio, session, extractor) - np.save(str(FIXTURES / f"{name}.mel.npy"), mel) - print(f" {name}: probability = {prob:.4f}") - results.append({"file": name, "probability": round(prob, 6)}) + prob, mel = infer(audio, sessions[backend], extractor) + # Only save mel fixtures for the pipecat backend; the mel preprocessing + # is identical across Smart Turn variants, so one set is enough. + if backend == "pipecat": + np.save(str(FIXTURES / f"{name}.mel.npy"), mel) + print(f" [{backend}] {name}: probability = {prob:.4f}") + results.append( + { + "backend": backend, + "file": name, + "probability": round(prob, 6), + } + ) out_path = FIXTURES / "reference.json" with open(out_path, "w") as f: diff --git a/tests/fixtures/reference.json b/tests/fixtures/reference.json index dfdf721..ce7965c 100644 --- a/tests/fixtures/reference.json +++ b/tests/fixtures/reference.json @@ -1,14 +1,47 @@ [ { + "backend": "pipecat", "file": "silence_2s.wav", "probability": 0.987037 }, { + "backend": "pipecat", "file": "speech_finished.wav", "probability": 0.984858 }, { + "backend": "pipecat", "file": "speech_mid.wav", "probability": 0.047724 + }, + { + "backend": "pipecat", + "file": "zh_speech_finished.wav", + "probability": 0.986523 + }, + { + "backend": "pipecat", + "file": "zh_speech_finished_short.wav", + "probability": 0.98232 + }, + { + "backend": "pipecat", + "file": "zh_speech_mid.wav", + "probability": 0.071746 + }, + { + "backend": "wavekat-zh", + "file": "zh_speech_finished.wav", + "probability": 0.766031 + }, + { + "backend": "wavekat-zh", + "file": "zh_speech_finished_short.wav", + "probability": 0.875143 + }, + { + "backend": "wavekat-zh", + "file": "zh_speech_mid.wav", + "probability": 0.164025 } ]