In [None]:
!pip install --upgrade --quiet pip
!pip install --upgrade --quiet torch transformers accelerate evaluate jiwer
!pip install datasets[audio]

##  Define SpeculativeWhisperV3 Class (Cross-Version Support)
Extended class for cross-version speculative decoding enabling Whisper Large-V3 to work with incompatible draft models: 
- **Dual Feature Extraction**: Processes audio separately for V3 (128 Mel bins) and draft models (80 Mel bins)
- **Token Remapping**: Maps draft model tokens to V3 vocabulary via text-based decoding/encoding
- **Custom Speculative Generation**: Implements verification logic with remapped tokens
- Achieves 100% token mapping coverage (51,865 tokens) between V3 and Tiny vocabularies

In [None]:
import torch
import time
import librosa
import numpy as np
from typing import List, Union, Optional
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

class SpeculativeWhisperV3:
    def __init__(self,
                 model_id: str = "openai/whisper-large-v3",
                 draft_model_id: Optional[str] = None,
                 device: Optional[str] = None,
                 torch_dtype: Optional[torch.dtype] = None):

        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.dtype = torch_dtype if torch_dtype else (torch.float16 if self.device == "cuda" else torch.float32)

        print(f"   [Init] Loading Main:  {model_id}...")
        self.main_model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_id,
            torch_dtype=self.dtype,
            low_cpu_mem_usage=True,
            use_safetensors=True,
            attn_implementation="sdpa",
        ).to(self.device)

        # Main model processor
        self.main_processor = AutoProcessor.from_pretrained(model_id)

        self.draft_model = None
        self.draft_processor = None
        self. use_token_remapping = False

        if draft_model_id:
            print(f"   [Init] Loading Draft: {draft_model_id}...")
            self.draft_model = AutoModelForSpeechSeq2Seq.from_pretrained(
                draft_model_id,
                torch_dtype=self.dtype,
                low_cpu_mem_usage=True,
                use_safetensors=True,
                attn_implementation="sdpa",
            ).to(self.device)

            # Load separate processor for draft model
            self.draft_processor = AutoProcessor.from_pretrained(draft_model_id)

            # Check if tokenizers are different
            if self._tokenizers_differ():
                print(f"   [Warning] Tokenizers differ between models.  Enabling token remapping.")
                self. use_token_remapping = True
                self._build_token_mapping()

    def _tokenizers_differ(self) -> bool:
        """Check if main and draft models have different tokenizers."""
        main_vocab_size = self.main_model.config.vocab_size
        draft_vocab_size = self. draft_model.config.vocab_size
        return main_vocab_size != draft_vocab_size

    def _build_token_mapping(self):
        """Build a mapping from draft tokens to main tokens based on text overlap."""
        print(f"   [Init] Building token mapping...")

        # Get vocabularies
        main_tokenizer = self.main_processor. tokenizer
        draft_tokenizer = self.draft_processor.tokenizer

        # Create mapping dictionary
        self.token_map = {}

        # Map based on decoded strings
        for draft_id in range(self.draft_model.config.vocab_size):
            try:
                # Decode draft token
                draft_text = draft_tokenizer.decode([draft_id], skip_special_tokens=False)

                # Encode with main tokenizer
                main_ids = main_tokenizer.encode(draft_text, add_special_tokens=False)

                # Use first token as mapping (simple heuristic)
                if main_ids:
                    self.token_map[draft_id] = main_ids[0]
                else:
                    # Fallback to unknown token
                    self.token_map[draft_id] = main_tokenizer.unk_token_id
            except:
                self.token_map[draft_id] = main_tokenizer. unk_token_id

        print(f"   [Init] Mapped {len(self.token_map)} tokens")

    def _remap_draft_tokens(self, draft_ids: torch.Tensor) -> torch.Tensor:
        """Remap draft model token IDs to main model token IDs."""
        if not self.use_token_remapping:
            return draft_ids

        # Convert to list, remap, convert back
        remapped = []
        for token_id in draft_ids.cpu().numpy().flatten():
            remapped.append(self.token_map.get(int(token_id), self.main_processor.tokenizer.unk_token_id))

        return torch. tensor(remapped, device=draft_ids.device, dtype=draft_ids.dtype).reshape(draft_ids.shape)

    def _load_audio(self, audio_paths: Union[List[str], List[np.ndarray]]) -> List[np.ndarray]:
        audio_arrays = []
        for audio in audio_paths:
            if isinstance(audio, str):
                array, _ = librosa.load(audio, sr=16000)
                audio_arrays.append(array)
            elif isinstance(audio, np.ndarray):
                audio_arrays.append(audio)
            else:
                raise ValueError("Audio must be a file path or numpy array.")
        return audio_arrays

    def transcribe(self,
                   audio_inputs: Union[List[str], List[np.ndarray]],
                   batch_size: int = 1,
                   use_speculative: bool = True,
                   language: str = "en",
                   task: str = "transcribe",
                   **generate_kwargs):

        # Safety Check:  Beam Search prevents Speculative Decoding
        if generate_kwargs.get("num_beams", 1) > 1 and use_speculative:
            use_speculative = False

        raw_audio = self._load_audio(audio_inputs)
        transcriptions = []
        total_inference_time = 0

        for i in range(0, len(raw_audio), batch_size):
            batch_audio = raw_audio[i :  i + batch_size]

            # Process audio with MAIN processor
            main_inputs = self.main_processor(
                batch_audio,
                sampling_rate=16000,
                return_tensors="pt",
                padding="max_length",
                truncation=True
            )
            main_input_features = main_inputs.input_features. to(self.device, dtype=self.dtype)

            gen_args = {
                "max_new_tokens": 400,
                "language": language,
                "task":  task,
                **generate_kwargs
            }

            # For cross-version compatibility, we need custom generation
            if use_speculative and self.draft_model and self.use_token_remapping:
                # Process audio SEPARATELY with draft processor
                draft_inputs = self. draft_processor(
                    batch_audio,
                    sampling_rate=16000,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True
                )
                draft_input_features = draft_inputs.input_features.to(self.device, dtype=self.dtype)

                # Custom speculative decoding with token remapping
                if self. device == "cuda":  torch.cuda.synchronize()
                start_time = time.time()

                generated_ids = self._speculative_generate_with_remapping(
                    main_input_features=main_input_features,
                    draft_input_features=draft_input_features,
                    max_new_tokens=gen_args["max_new_tokens"],
                    language=gen_args["language"],
                    task=gen_args["task"]
                )

                if self.device == "cuda": torch.cuda.synchronize()
                total_inference_time += (time.time() - start_time)

            elif use_speculative and self.draft_model:
                # Standard speculative decoding (same tokenizer)
                gen_args["assistant_model"] = self.draft_model
                gen_args["input_features"] = main_input_features

                if self.device == "cuda": torch.cuda.synchronize()
                start_time = time. time()

                with torch.no_grad():
                    generated_ids = self. main_model.generate(**gen_args)

                if self.device == "cuda": torch.cuda. synchronize()
                total_inference_time += (time.time() - start_time)

            else:
                # Standard generation
                gen_args["input_features"] = main_input_features

                if self.device == "cuda": torch. cuda.synchronize()
                start_time = time.time()

                with torch.no_grad():
                    generated_ids = self.main_model.generate(**gen_args)

                if self. device == "cuda": torch.cuda.synchronize()
                total_inference_time += (time.time() - start_time)

            batch_transcripts = self. main_processor.batch_decode(generated_ids, skip_special_tokens=True)
            transcriptions.extend(batch_transcripts)

        return transcriptions, total_inference_time

    def _speculative_generate_with_remapping(
    self,
    main_input_features:  torch.Tensor,
    draft_input_features: torch.Tensor,
    max_new_tokens: int = 400,
    language:  str = "en",
    task: str = "transcribe",
    num_draft_tokens: int = 3,  # Reduced from 5 for better accuracy
    **kwargs
  ):
      """
      Fixed speculative generation with proper token comparison.
      """
      # Get forced decoder IDs for language/task
      forced_decoder_ids = self. main_processor.get_decoder_prompt_ids(
          language=language,
          task=task
      )

      # Build initial decoder input with language/task tokens
      initial_tokens = [self.main_model.config.decoder_start_token_id]
      if forced_decoder_ids:
          initial_tokens.extend([token_id for _, token_id in forced_decoder_ids])

      decoder_input_ids = torch.tensor([initial_tokens], device=self.device)

      with torch.no_grad():
          # Pre-compute encoder outputs
          main_encoder_outputs = self.main_model.get_encoder()(main_input_features)
          draft_encoder_outputs = self.draft_model.get_encoder()(draft_input_features)

          for step in range(max_new_tokens):
              # === DRAFT PHASE ===
              # Generate one token at a time from draft model
              draft_next_tokens = []
              current_draft_ids = decoder_input_ids. clone()

              for k in range(num_draft_tokens):
                  draft_outputs = self.draft_model(
                      encoder_outputs=draft_encoder_outputs,
                      decoder_input_ids=current_draft_ids,
                  )
                  draft_next_token = draft_outputs.logits[: , -1, : ].argmax(dim=-1, keepdim=True)
                  draft_next_tokens.append(draft_next_token)
                  current_draft_ids = torch. cat([current_draft_ids, draft_next_token], dim=1)

              # Remap all draft tokens
              draft_sequence = torch.cat(draft_next_tokens, dim=1)
              remapped_draft = self._remap_draft_tokens(draft_sequence)

              # === VERIFICATION PHASE ===
              # Build candidate sequence
              candidate = torch.cat([decoder_input_ids, remapped_draft], dim=1)

              # Get main model's predictions
              main_outputs = self.main_model(
                  encoder_outputs=main_encoder_outputs,
                  decoder_input_ids=candidate,
              )

              # Compare main model's prediction at each position
              main_next_tokens = main_outputs.logits. argmax(dim=-1)

              # Find how many tokens to accept
              num_accepted = 0
              start_pos = decoder_input_ids.shape[1] - 1  # Position to start comparing

              for k in range(num_draft_tokens):
                  # Main model's prediction for position start_pos + k
                  main_pred = main_next_tokens[0, start_pos + k]
                  draft_token = remapped_draft[0, k]

                  if main_pred == draft_token:
                      num_accepted += 1
                  else:
                      # Mismatch - take main model's prediction and stop
                      decoder_input_ids = torch.cat([
                          decoder_input_ids,
                          main_pred.unsqueeze(0).unsqueeze(0)
                      ], dim=1)
                      break
              else:
                  # All tokens accepted - add them
                  if num_accepted > 0:
                      decoder_input_ids = torch.cat([
                          decoder_input_ids,
                          remapped_draft[: , : num_accepted]
                      ], dim=1)
                  else:
                      # Nothing accepted - use main model's next token
                      next_token = main_next_tokens[: , start_pos]. unsqueeze(1)
                      decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)

              # Check for EOS or max length
              if decoder_input_ids[0, -1]. item() == self.main_processor.tokenizer.eos_token_id:
                  break

              if decoder_input_ids.shape[1] >= max_new_tokens + len(initial_tokens):
                  break

      return decoder_input_ids

## Multi Sample Test - Large-V3 + Tiny (Cross-Version)
Quick test of cross-version speculative decoding on a single LibriSpeech sample:
- Compares baseline Large-V3 vs.  speculative (V3 + Tiny)
- Demonstrates token remapping and dual feature extraction
- Shows WER calculation with case-insensitive normalization
- Displays sample token mappings for verification

In [None]:
# === Large-V3 + Tiny:  Average Performance over Multiple Runs ===

from jiwer import wer, cer
import numpy as np

print("=== Large-V3 + Tiny Cross-Version Benchmark (500 samples) ===\n")

# Load samples
NUM_SAMPLES = 500
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
samples = list(dataset. take(NUM_SAMPLES))
audio_inputs = [x["audio"]["array"] for x in samples]
references = [x["text"]. lower().strip() for x in samples]

# Initialize model
print("Initializing models...")
sw_v3 = SpeculativeWhisperV3(
    model_id="openai/whisper-large-v3",
    draft_model_id="openai/whisper-tiny"
)

# === BASELINE TEST ===
print(f"\nRunning BASELINE (Large-V3 only) on {NUM_SAMPLES} samples...")
baseline_preds, baseline_time = sw_v3.transcribe(
    audio_inputs,
    batch_size=1,
    use_speculative=False
)
baseline_preds_norm = [p.lower().strip() for p in baseline_preds]
print(f" Baseline completed in {baseline_time:.2f}s")

# === SPECULATIVE TEST ===
print(f"\nRunning SPECULATIVE (Large-V3 + Tiny) on {NUM_SAMPLES} samples...")
spec_preds, spec_time = sw_v3.transcribe(
    audio_inputs,
    batch_size=1,
    use_speculative=True
)
spec_preds_norm = [p.lower().strip() for p in spec_preds]
print(f" Speculative completed in {spec_time:.2f}s")

# === CALCULATE METRICS ===
print("\nCalculating metrics...")

# Overall WER
baseline_wer_overall = wer(references, baseline_preds_norm)
spec_wer_overall = wer(references, spec_preds_norm)

# Per-sample WER and CER
baseline_wers = [wer([ref], [pred]) for ref, pred in zip(references, baseline_preds_norm)]
spec_wers = [wer([ref], [pred]) for ref, pred in zip(references, spec_preds_norm)]
baseline_cers = [cer([ref], [pred]) for ref, pred in zip(references, baseline_preds_norm)]
spec_cers = [cer([ref], [pred]) for ref, pred in zip(references, spec_preds_norm)]

# Statistics
speedup = baseline_time / spec_time
exact_matches = sum(1 for b, s in zip(baseline_preds_norm, spec_preds_norm) if b == s)

# === RESULTS ===
print("\n" + "="*70)
print("BENCHMARK RESULTS")
print("="*70)
print(f"\nBASELINE (Large-V3):")
print(f"  Total Time:      {baseline_time:.2f}s")
print(f"  Avg per sample: {baseline_time/NUM_SAMPLES:.3f}s")
print(f"  WER (overall):  {baseline_wer_overall:.4f}")
print(f"  WER (mean):     {np.mean(baseline_wers):.4f} ± {np.std(baseline_wers):.4f}")
print(f"  CER (mean):     {np.mean(baseline_cers):.4f} ± {np.std(baseline_cers):.4f}")

print(f"\nSPECULATIVE (V3 + Tiny):")
print(f"  Total Time:      {spec_time:.2f}s")
print(f"  Avg per sample: {spec_time/NUM_SAMPLES:. 3f}s")
print(f"  WER (overall):  {spec_wer_overall:.4f}")
print(f"  WER (mean):     {np.mean(spec_wers):.4f} ± {np.std(spec_wers):.4f}")
print(f"  CER (mean):     {np.mean(spec_cers):.4f} ± {np.std(spec_cers):.4f}")

print(f"\n{'='*70}")
print(f" SPEEDUP:           {speedup:.2f}x")
print(f" TIME SAVED:       {baseline_time - spec_time:.2f}s ({100*(baseline_time-spec_time)/baseline_time:.1f}%)")
print(f" WER DIFFERENCE:   {abs(spec_wer_overall - baseline_wer_overall):.4f}")
print(f" EXACT MATCHES:    {exact_matches}/{NUM_SAMPLES} ({100*exact_matches/NUM_SAMPLES:.1f}%)")
print(f"{'='*70}")

# === DETAILED STATISTICS ===
print("\n" + "="*70)
print("DETAILED STATISTICS")
print("="*70)
print(f"\nPer-Sample WER Statistics:")
print(f"  Baseline  - Min: {min(baseline_wers):.4f}, Max: {max(baseline_wers):.4f}, Median: {np.median(baseline_wers):.4f}")
print(f"  Speculative - Min: {min(spec_wers):.4f}, Max: {max(spec_wers):.4f}, Median: {np.median(spec_wers):.4f}")

print(f"\nPer-Sample CER Statistics:")
print(f"  Baseline  - Min: {min(baseline_cers):.4f}, Max: {max(baseline_cers):.4f}, Median: {np.median(baseline_cers):.4f}")
print(f"  Speculative - Min: {min(spec_cers):.4f}, Max: {max(spec_cers):.4f}, Median: {np.median(spec_cers):.4f}")

print(f"\nSpeed Metrics:")
print(f"  Baseline avg latency:      {baseline_time/NUM_SAMPLES:. 3f}s per sample")
print(f"  Speculative avg latency:  {spec_time/NUM_SAMPLES:.3f}s per sample")
print(f"  Time saved per sample:    {(baseline_time - spec_time)/NUM_SAMPLES:.3f}s")

print(f"\nQuality Preservation:")
print(f"  Exact match rate:         {100*exact_matches/NUM_SAMPLES:.2f}%")
print(f"  WER correlation:           {np.corrcoef(baseline_wers, spec_wers)[0,1]:.4f}")

# Cleanup
del sw_v3
gc.collect()
torch.cuda.empty_cache()

print("\n Benchmark complete!")

## Beam Search Benchmark - Large-V3 
Analyzes beam search trade-offs for Whisper Large-V3 with beam sizes 1-5:
- Measures latency vs. accuracy for each beam size
- Compares greedy decoding (beam=1) vs. beam search
- Provides efficiency metrics (time overhead vs. WER improvement)
- **Key Finding**: Beam search increases latency without improving WER on clean audio

In [None]:
# === Beam Search Benchmark for Large-V3 ===

from jiwer import wer, cer
import numpy as np

print("=== Beam Search Benchmark (Large-V3) ===\n")

# Load samples
NUM_SAMPLES = 20  # Reduced for beam search (it's slower)
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
samples = list(dataset. take(NUM_SAMPLES))
audio_inputs = [x["audio"]["array"] for x in samples]
references = [x["text"]. lower().strip() for x in samples]

# Beam sizes to test
BEAM_SIZES = [1, 2, 3, 4, 5]

# Initialize model (no draft for beam search)
print("Initializing model...")
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    "openai/whisper-large-v3",
    torch_dtype=dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="sdpa",
).to(device)

processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")

# Storage for results
results = {
    "beam_size": [],
    "total_time": [],
    "avg_time": [],
    "wer_overall": [],
    "wer_mean": [],
    "wer_std": [],
    "cer_mean": [],
    "cer_std": []
}

# === RUN BENCHMARKS ===
print(f"\nRunning benchmarks on {NUM_SAMPLES} samples.. .\n")

for beam_size in BEAM_SIZES: 
    print(f"Testing beam_size={beam_size}...")
    
    # Process all samples
    predictions = []
    
    if device == "cuda":
        torch.cuda.synchronize()
    start_time = time.time()
    
    for audio in audio_inputs:
        # Process audio
        inputs = processor(
            [audio],
            sampling_rate=16000,
            return_tensors="pt",
            padding="max_length",
            truncation=True
        )
        input_features = inputs. input_features. to(device, dtype=dtype)
        
        # Generate with beam search
        with torch.no_grad():
            generated_ids = model.generate(
                input_features=input_features,
                language="en",
                task="transcribe",
                num_beams=beam_size,
                max_new_tokens=400
            )
        
        # Decode
        transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        predictions.append(transcription. lower().strip())
    
    if device == "cuda":
        torch.cuda.synchronize()
    total_time = time.time() - start_time
    
    # Calculate metrics
    wer_overall = wer(references, predictions)
    per_sample_wer = [wer([ref], [pred]) for ref, pred in zip(references, predictions)]
    per_sample_cer = [cer([ref], [pred]) for ref, pred in zip(references, predictions)]
    
    # Store results
    results["beam_size"].append(beam_size)
    results["total_time"]. append(total_time)
    results["avg_time"].append(total_time / NUM_SAMPLES)
    results["wer_overall"].append(wer_overall)
    results["wer_mean"].append(np.mean(per_sample_wer))
    results["wer_std"].append(np.std(per_sample_wer))
    results["cer_mean"].append(np.mean(per_sample_cer))
    results["cer_std"].append(np.std(per_sample_cer))
    
    print(f"   Time: {total_time:.2f}s | WER: {wer_overall:.4f}")

# === PRINT RESULTS ===
print("\n" + "="*80)
print("BEAM SEARCH RESULTS")
print("="*80)
print(f"{'Beam Size':<12} {'Total Time':<12} {'Avg/Sample':<12} {'WER':<10} {'CER':<10}")
print("-"*80)
for i in range(len(BEAM_SIZES)):
    print(f"{results['beam_size'][i]: <12} "
          f"{results['total_time'][i]:<12.2f} "
          f"{results['avg_time'][i]: <12.3f} "
          f"{results['wer_overall'][i]:<10.4f} "
          f"{results['cer_mean'][i]:<10.4f}")
print("="*80)

# Find best configurations
best_speed_idx = np.argmin(results["total_time"])
best_wer_idx = np.argmin(results["wer_overall"])

print(f"\n FASTEST:        Beam={results['beam_size'][best_speed_idx]} ({results['total_time'][best_speed_idx]:.2f}s)")
print(f" BEST WER:      Beam={results['beam_size'][best_wer_idx]} (WER={results['wer_overall'][best_wer_idx]:.4f})")
print(f" SLOWDOWN (Beam 5 vs 1): {results['total_time'][-1]/results['total_time'][0]:.2f}x slower")

# === DETAILED ANALYSIS ===
print("\n" + "="*80)
print("DETAILED ANALYSIS")
print("="*80)

print("\nTime Scaling:")
for i in range(len(BEAM_SIZES)):
    speedup_vs_baseline = results['total_time'][i] / results['total_time'][0]
    print(f"  Beam {results['beam_size'][i]}: {speedup_vs_baseline:.2f}x vs Beam 1")

print("\nWER Statistics per Beam Size:")
for i in range(len(BEAM_SIZES)):
    print(f"  Beam {results['beam_size'][i]}:  Mean={results['wer_mean'][i]:.4f}, Std={results['wer_std'][i]:.4f}")

print("\nCER Statistics per Beam Size:")
for i in range(len(BEAM_SIZES)):
    print(f"  Beam {results['beam_size'][i]}: Mean={results['cer_mean'][i]:.4f}, Std={results['cer_std'][i]:.4f}")

print("\nEfficiency Metrics:")
best_time = min(results['total_time'])
best_wer = min(results['wer_overall'])
for i in range(len(BEAM_SIZES)):
    time_overhead = ((results['total_time'][i] - best_time) / best_time) * 100
    wer_improvement = ((best_wer - results['wer_overall'][i]) / best_wer) * 100 if best_wer > 0 else 0
    print(f"  Beam {results['beam_size'][i]}: +{time_overhead:.1f}% time, {wer_improvement: +.1f}% WER change")

# Cleanup
del model
del processor
gc.collect()
torch.cuda.empty_cache()

print("\n Beam search benchmark complete!")