
# 🎤🔁 StyleTTS2 + Whisper — End‑to‑End Notebook

This notebook provides **everything in one place**:
- **TTS**: StyleTTS2 synthesis with optional style/voice cloning
- **STT**: Whisper transcription (from WAV or microphone)
- **Evaluate**: WER/CER with `jiwer`
- **Demo** and **Batch** runs

> Run the setup cell first. The first run will download models.


In [1]:

# === 1) Setup & Installs ===
# Remove the leading '#' on the next line to install dependencies in this environment.
# %pip install styletts2 whisper jiwer sounddevice soundfile numpy
print("If you haven't installed deps yet, un-comment the %pip line above and run this cell again.")


If you haven't installed deps yet, un-comment the %pip line above and run this cell again.



## 2) TTS (StyleTTS2)
Helper to synthesize text to WAV, optionally cloning a short reference voice sample.


In [7]:

# === Patched TTS helper (handles PyTorch 2.6+ unpickling) ===
from typing import Optional
from pathlib import Path

# --- Patch torch.load BEFORE styletts2 is imported ---
import torch
from torch import serialization as ts

# 1) Allowlist globals older checkpoints sometimes use
try:
    ts.add_safe_globals([getattr])
except Exception:
    pass  # fine if not supported

# 2) Force torch.load default back to weights_only=False
_real_load = torch.load
def _patched_load(*args, **kwargs):
    kwargs.setdefault("weights_only", False)
    return _real_load(*args, **kwargs)
torch.load = _patched_load

def _lazy_import_styletts2():
    # styletts2 will internally call torch.load; our patch above will handle it
    from styletts2 import tts as _tts  # type: ignore
    return _tts

def text_to_speech(
    text: str,
    output_wav: str = "tts_output.wav",
    target_voice_wav: Optional[str] = None,
    sample_rate: int = 24000,
    alpha: float = 0.3,
    beta: float = 0.7,
    diffusion_steps: int = 5,
    embedding_scale: int = 1,
) -> str:
    """
    Synthesize speech from text using StyleTTS2.

    Args:
        text: Text to synthesize.
        output_wav: Output WAV path.
        target_voice_wav: Optional short reference WAV to clone style/voice.
        sample_rate: Output sample rate (Hz).
        alpha, beta, diffusion_steps, embedding_scale: StyleTTS2 inference knobs.

    Returns:
        Absolute path to the output wav.
    """
    _tts = _lazy_import_styletts2()
    engine = _tts.StyleTTS2()
    engine.inference(
        text=text,
        target_voice_path=target_voice_wav,
        output_wav_file=output_wav,
        output_sample_rate=sample_rate,
        alpha=alpha,
        beta=beta,
        diffusion_steps=diffusion_steps,
        embedding_scale=embedding_scale,
    )
    return str(Path(output_wav).resolve())
print("Patched TTS helper ready (torch.load -> weights_only=False).")



Patched TTS helper ready (torch.load -> weights_only=False).



## 3) STT (Whisper)
Transcribe WAV files or capture from microphone and then transcribe.


In [11]:

# === STT (Whisper) — validated import + same API ===
from typing import Optional
from pathlib import Path
import tempfile
import importlib
import types

def _import_openai_whisper() -> types.ModuleType:
    """
    Import OpenAI's whisper (openai-whisper). Raise a helpful error if a wrong 'whisper' package is installed.
    """
    try:
        whisper = importlib.import_module("whisper")
    except Exception as e:
        raise RuntimeError(
            "Failed to import OpenAI Whisper. Install it with:\n"
            "  %pip install -U openai-whisper"
        ) from e

    # Must have load_model attribute
    if not hasattr(whisper, "load_model"):
        raise RuntimeError(
            "A different 'whisper' package is installed. Fix with:\n"
            "  %pip uninstall -y whisper\n"
            "  %pip install -U openai-whisper\n"
            "Then restart the kernel."
        )
    return whisper

def _lazy_import_audio():
    import sounddevice as sd  # type: ignore
    import soundfile as sf    # type: ignore
    return sd, sf

def speech_to_text_from_file(audio_path: str, model_size: str = "base", language: Optional[str] = None) -> str:
    whisper = _import_openai_whisper()
    model = whisper.load_model(model_size)
    result = model.transcribe(audio_path, language=language)
    return result.get("text", "").strip()

def speech_to_text_from_mic(seconds: float = 5.0, samplerate: int = 16000, channels: int = 1,
                            model_size: str = "base", language: Optional[str] = None) -> str:
    sd, sf = _lazy_import_audio()
    print(f"Recording from microphone for {seconds} seconds...")
    audio = sd.rec(int(seconds * samplerate), samplerate=samplerate, channels=channels, dtype="float32")
    sd.wait()
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
        tmp_path = tmp.name
    sf.write(tmp_path, audio, samplerate)
    text = speech_to_text_from_file(tmp_path, model_size=model_size, language=language)
    try:
        Path(tmp_path).unlink(missing_ok=True)
    except Exception:
        pass
    return text

print("STT helpers ready (validated OpenAI Whisper import).")



STT helpers ready (validated OpenAI Whisper import).



## 4) Evaluation (WER/CER)
Tools to score transcription vs. the input text.


In [9]:

from dataclasses import dataclass

def _lazy_import_metrics():
    from jiwer import wer, cer  # type: ignore
    return wer, cer

@dataclass
class EvalResult:
    reference: str
    hypothesis: str
    wer: float
    cer: float
    wav_path: str

def evaluate_tts_stt(reference_text: str, tts_out_path: str, whisper_model: str = "base",
                     language: Optional[str] = None) -> EvalResult:
    hyp = speech_to_text_from_file(tts_out_path, model_size=whisper_model, language=language)
    wer, cer = _lazy_import_metrics()
    return EvalResult(
        reference=reference_text.strip(),
        hypothesis=hyp.strip(),
        wer=wer(reference_text, hyp),
        cer=cer(reference_text, hyp),
        wav_path=str(Path(tts_out_path).resolve())
    )
print("Evaluation helpers ready.")


Evaluation helpers ready.



## 5) Single Demo Run (Robust)
Synthesize → Transcribe → Score. Adds audio playback (if supported) and saves artifacts.


In [15]:

# --- Parameters ---
text = "Hello, this is a TTS and STT demo using StyleTTS2 and Whisper."
out_wav = "tts_eval.wav"

# TTS params
voice_ref = None  # e.g., "path/to/voice_ref.wav" for cloning
sample_rate = 24000
alpha = 0.3
beta = 0.7
diffusion_steps = 5
embedding_scale = 1

# Whisper params
whisper_model = "base"   # tiny, base, small, medium, large
language = None          # e.g., "en" or "de" for better accuracy

# Optional: audio playback in notebook
try:
    from IPython.display import Audio, display
    _can_play = True
except Exception:
    _can_play = False

# Run TTS -> STT -> Metrics
print("Synthesizing...")
try:
    wav_path = text_to_speech(
        text,
        output_wav=out_wav,
        target_voice_wav=voice_ref,
        sample_rate=sample_rate,
        alpha=alpha,
        beta=beta,
        diffusion_steps=diffusion_steps,
        embedding_scale=embedding_scale,
    )
except Exception as e:
    raise RuntimeError(f"TTS failed: {e}")

if _can_play:
    try:
        display(Audio(filename=wav_path, autoplay=False))
    except Exception:
        pass

print("Transcribing & scoring...")
try:
    result = evaluate_tts_stt(text, wav_path, whisper_model=whisper_model, language=language)
except FileNotFoundError:
    raise
except Exception as e:
    raise RuntimeError(f"STT/metrics failed: {e}")

print("\n=== Results ===")
print(f"WAV:        {result.wav_path}")
print(f"Reference:  {result.reference}")
print(f"Hypothesis: {result.hypothesis}")
print(f"WER: {result.wer:.3f} (lower is better)")
print(f"CER: {result.cer:.3f} (lower is better)")

# Save artifacts
import json, pathlib
pathlib.Path("artifacts").mkdir(exist_ok=True)
with open("artifacts/transcription.txt", "w", encoding="utf-8") as f:
    f.write(result.hypothesis + "\n")
with open("artifacts/metrics.json", "w", encoding="utf-8") as f:
    json.dump({"wer": result.wer, "cer": result.cer}, f, indent=2)
print("\nSaved: artifacts/transcription.txt, artifacts/metrics.json")


Synthesizing...
Invalid or missing model checkpoint path. Loading default model...
Invalid or missing config path. Loading default config...
Invalid ASR config path. Loading default config...
Invalid ASR model checkpoint path. Loading default model...
Invalid F0 model path. Loading default model...
bert loaded
bert_encoder loaded
predictor loaded
decoder loaded
text_encoder loaded
predictor_encoder loaded
style_encoder loaded
diffusion loaded
text_aligner loaded
pitch_extractor loaded
mpd loaded
msd loaded
wd loaded
Cloning default target voice...
177
hɛlˈoʊ | ðˈɪs ˈɪz ə tˈi tˈi ˈɛs ˈænd ˈɛs tˈi tˈi dˈɛmoʊ jˈuzɪŋ stˈaɪləts ˈænd wˈɪspɚ ‖
hɛlˈoʊ | ðˈɪs ˈɪz ə tˈi tˈi ˈɛs ˈænd ˈɛs tˈi tˈi dˈɛmoʊ jˈuzɪŋ stˈaɪləts ˈænd wˈɪspɚ ‖


Transcribing & scoring...

=== Results ===
WAV:        C:\Users\Timothy\OneDrive\Desktop\New folder\tts_eval.wav
Reference:  Hello, this is a TTS and STT demo using StyleTTS2 and Whisper.
Hypothesis: Hello, this is a TTS and STT demo using stylets and whisper.
WER: 0.167 (lower is better)
CER: 0.097 (lower is better)

Saved: artifacts/transcription.txt, artifacts/metrics.json



## 6) Batch Evaluation (Optional)
Provide a list of sentences; we synthesize, transcribe, and return average WER/CER.


In [14]:

from typing import Sequence, Tuple
import statistics, json, os

def batch_eval(
    texts: Sequence[str],
    whisper_model: str = "base",
    language: Optional[str] = None,
    out_dir: str = "batch_outputs",
) -> Dict[str, float]:
    os.makedirs(out_dir, exist_ok=True)
    wers, cers = [], []
    for i, t in enumerate(texts, 1):
        out_wav = Path(out_dir) / f"utt_{i:03d}.wav"
        wav_path = text_to_speech(str(t), output_wav=str(out_wav))
        res = evaluate_tts_stt(str(t), str(wav_path), whisper_model=whisper_model, language=language)
        wers.append(res.wer); cers.append(res.cer)
        with open(Path(out_dir)/f"utt_{i:03d}.txt", "w", encoding="utf-8") as f:
            f.write(res.hypothesis + "\n")
    summary = {
        "utterances": len(texts),
        "wer_mean": float(statistics.mean(wers)) if wers else 0.0,
        "cer_mean": float(statistics.mean(cers)) if cers else 0.0,
    }
    with open(Path(out_dir)/"summary.json", "w", encoding="utf-8") as f:
        json.dump(summary, f, indent=2)
    return summary

# Example (uncomment to run):
# summary = batch_eval(
#     ["The quick brown fox jumps over the lazy dog.",
#      "StyleTTS2 sounds natural when configured well.",
#      "Whisper transcribes robustly in many languages."],
#     whisper_model="base",
#     language=None
# )
# summary
print("Batch evaluation helper ready.")


Batch evaluation helper ready.
