In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Whisper sub-token dump (top-50 per decoding step) for multiple languages (Beam Search with Rank)
================================================================================================
For each specified language:
  • Streams ≈600 s of Common Voice 17.0 audio
  • Forces the <|lang|> token so Whisper skips language-ID detection
  • Transcribes using BEAM SEARCH (beam_size=5, temperature=0.2)
  • Decodes step-by-step ALONG THE BEAM SEARCH PATH and records top-K candidate IDs + strings
  • Adds a 'chosen_rank' column indicating the rank of the chosen token in that step's top-K predictions
  • EVERY row carries `audio_path`, `ground_truth`, `full_transcription`,
    and `whisper_lang`, so merging is unnecessary
  • A sentinel row with step = -1 stores utterance-level metadata
  • One CSV per language is written to results_beam_600s/
"""

# ───────────────────────── Imports ────────────────────────────
import os
import shutil
import warnings
from pathlib import Path
import traceback

import torch
import whisper
import soundfile as sf
import pandas as pd
from datasets import load_dataset, Audio
from whisper.tokenizer import get_tokenizer
from tqdm import tqdm

# Silence a harmless HF warning about trust_remote_code
warnings.filterwarnings("ignore", category=UserWarning)

# ──────────────── Language configuration ─────────────────────────
LANGUAGES = {
    # Format: "language_name": {"num": index, "code": whisper_code, "cv": cv_code}
    # High-resource languages (more than ~10k hours)
    "german": {"num": 3, "code": "de", "cv": "de"},
    "spanish": {"num": 4, "code": "es", "cv": "es"},
    "french": {"num": 7, "code": "fr", "cv": "fr"},
    "portuguese": {"num": 9, "code": "pt", "cv": "pt"},
    "turkish": {"num": 10, "code": "tr", "cv": "tr"},
    
    # Medium-resource languages (between 100 and ~10k hours)
    "italian": {"num": 11, "code": "it", "cv": "it"},
    "dutch": {"num": 12, "code": "nl", "cv": "nl"},
    "swedish": {"num": 13, "code": "sv", "cv": "sv-SE"},
    "catalan": {"num": 14, "code": "ca", "cv": "ca"},
    "finnish": {"num": 15, "code": "fi", "cv": "fi"},
    "indonesian": {"num": 16, "code": "id", "cv": "id"},
    "vietnamese": {"num": 17, "code": "vi", "cv": "vi"},
    "romanian": {"num": 18, "code": "ro", "cv": "ro"},
    "danish": {"num": 19, "code": "da", "cv": "da"},
    "norwegian": {"num": 20, "code": "no", "cv": "nn-NO"},
    "czech": {"num": 21, "code": "cs", "cv": "cs"},
    "hungarian": {"num": 22, "code": "hu", "cv": "hu"},
    "yoruba": {"num": 23, "code": "yo", "cv": "yo"},
    
    # Low-resource languages (less than 100 hours)
    "welsh": {"num": 39, "code": "cy", "cv": "cy"},
    "lithuanian": {"num": 35, "code": "lt", "cv": "lt"},
    "latvian": {"num": 43, "code": "lv", "cv": "lv"},
    "azerbaijani": {"num": 46, "code": "az", "cv": "az"},
    "estonian": {"num": 49, "code": "et", "cv": "et"},
    "basque": {"num": 52, "code": "eu", "cv": "eu"},
    "albanian": {"num": 59, "code": "sq", "cv": "sq"},
    "swahili": {"num": 60, "code": "sw", "cv": "sw"},
    "maltese": {"num": 85, "code": "mt", "cv": "mt"},
    "uzbek": {"num": 79, "code": "uz", "cv": "uz"},
}

# ───────────── Global configuration ──────────────────────────────
TOP_K = 50
MAX_DECODER_INPUT_TOKENS = 448  # Max tokens for decoder input sequence (context window)
TARGET_SEC = 600  # 10 minutes of audio per language
BEAM_SIZE = 5
TEMPERATURE = 0.2

OUT_DIR = Path(f"results_beam_{TARGET_SEC}s")
TEMP_DIR = OUT_DIR / "tmp_audio"
ERROR_LOG = OUT_DIR / "errors.log"

OUT_DIR.mkdir(exist_ok=True, parents=True)
TEMP_DIR.mkdir(exist_ok=True, parents=True)
if ERROR_LOG.exists():
    ERROR_LOG.unlink()

# ─────────────── Whisper model (shared) ──────────────────────────
print(f"Loading Whisper (large-v2) on device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model("large-v2", device=device)
print("Whisper model loaded.")


# ─────────── Helper: stream ~TARGET_SEC of audio ─────────────────
def load_some_cv_audio(cv_code: str, seconds: int):
    """
    Load approximately `seconds` worth of audio from Common Voice 17.0.
    Tries validated, train, and test splits in order.
    """
    splits = ["validated", "train", "test"]
    for split in splits:
        try:
            ds = load_dataset(
                "mozilla-foundation/common_voice_17_0",
                cv_code,
                split=split,
                streaming=True,
                trust_remote_code=True,
            ).cast_column("audio", Audio(16000))
            
            batch, dur = [], 0.0
            for ex in ds:
                raw = ex["path"]
                ex["cv_path"] = f"{cv_code}_{split}/{Path(raw).name}"
                arr = ex["audio"]["array"]
                dur += len(arr) / 16000
                batch.append(ex)
                if dur >= seconds:
                    return batch

        except Exception as e:
            print(f"  ⚠️ Error loading {cv_code} {split}: {str(e)}")
            continue
    
    print(f"No audio found for {cv_code} after trying all splits.")
    return []


def log_error_to_file(language, error_msg):
    """Log error details to the error log file."""
    with open(ERROR_LOG, 'a', encoding='utf-8') as f:
        f.write(f"Language: {language}\nError: {error_msg}\n\n")


# ───────────────────── Main loop ─────────────────────────────────
total_langs_to_process = len(LANGUAGES)
processed_langs_count = 0
successful_langs_count = 0
failed_langs_count = 0

for lang_name, info in sorted(LANGUAGES.items(), key=lambda x: x[1]["num"]):
    processed_langs_count += 1
    lang_num = info["num"]
    lang_code = info["code"]
    cv_code = info["cv"]

    print(f"\n─── [{lang_num:03d}/{total_langs_to_process}] {lang_name.upper()} ({lang_code} / {cv_code}) ───────────")

    try:
        samples = load_some_cv_audio(cv_code, TARGET_SEC)
        if not samples:
            print(f"⚠️  No audio found for {lang_name} – skipping.")
            log_error_to_file(lang_name, "No audio found after trying all splits.")
            failed_langs_count += 1
            continue

        total_dur = sum(len(s["audio"]["array"]) / 16000 for s in samples)
        print(f"✓ Processing {len(samples)} clips, {total_dur:.1f}s total for {lang_name}")

        tok = get_tokenizer(multilingual=True, language=lang_code, task="transcribe")
        lang_token_id = tok.encode(f"<|{lang_code}|>", allowed_special="all")[0]
        transcribe_token_id = tok.encode("<|transcribe|>", allowed_special="all")[0]
        sot_token_id = tok.sot

        rows = []
        for clip_id, ex in enumerate(tqdm(samples, desc=f"Decoding {lang_name}")):
            cv_path = ex["cv_path"]
            ground_truth = ex.get("sentence", "").strip()

            wav_path = TEMP_DIR / f"{lang_name}_{cv_code}_{clip_id}.wav"
            sf.write(wav_path, ex["audio"]["array"], 16000, subtype='PCM_16')

            # Perform transcription using beam search
            res = model.transcribe(
                str(wav_path),
                language=lang_code,
                task="transcribe",
                beam_size=BEAM_SIZE,
                temperature=TEMPERATURE,
                verbose=None
            )
            
            full_text = res["text"].strip()
            lang_guess = res["language"]

            # Extract beam search output tokens
            beam_search_output_tokens = []
            if "segments" in res and res["segments"]:
                for segment in res["segments"]:
                    beam_search_output_tokens.extend(segment['tokens'])
            else:
                print(f"Warning: Segments not found for clip {clip_id} ({lang_name}). Re-tokenizing full_text.")
                beam_search_output_tokens = tok.encode(full_text)
            
            # Ensure EOT is included if missing
            if not beam_search_output_tokens or beam_search_output_tokens[-1] != tok.eot:
                if len(beam_search_output_tokens) < MAX_DECODER_INPUT_TOKENS - 1:
                    beam_search_output_tokens.append(tok.eot)
                else:
                    beam_search_output_tokens[-1] = tok.eot

            # Add sentinel row with utterance-level metadata
            rows.append({
                "clip": clip_id,
                "step": -1,
                "audio_path": cv_path,
                "ground_truth": ground_truth,
                "full_transcription": full_text,
                "whisper_lang": lang_guess,
                "chosen_rank": None
            })

            # Prepare encoder output
            audio_features = whisper.load_audio(str(wav_path))
            mel_segment = whisper.log_mel_spectrogram(whisper.pad_or_trim(audio_features)).to(device)
            encoder_output = model.encoder(mel_segment.unsqueeze(0))

            current_decoder_input_token_ids = [sot_token_id, lang_token_id, transcribe_token_id]

            # Step through each token in the beam search output
            for step_idx, actual_next_token_id in enumerate(beam_search_output_tokens):
                if len(current_decoder_input_token_ids) >= MAX_DECODER_INPUT_TOKENS:
                    print(f"Warning: Decoder input sequence for clip {clip_id}, step {step_idx} "
                          f"reached {len(current_decoder_input_token_ids)} tokens. Breaking step-by-step.")
                    break
                
                decoder_input_tensor = torch.tensor([current_decoder_input_token_ids], device=device)
                
                with torch.no_grad():
                    logits_for_next_token = model.decoder(decoder_input_tensor, encoder_output)[0, -1]
                
                probabilities_for_next_token = torch.softmax(logits_for_next_token, dim=-1)
                top_k_probs, top_k_ids = torch.topk(probabilities_for_next_token, TOP_K)

                # Get probability of the chosen token
                chosen_token_prob = 0.0
                if actual_next_token_id < probabilities_for_next_token.shape[0]:
                    chosen_token_prob = probabilities_for_next_token[actual_next_token_id].item()
                
                # Calculate rank of the chosen token
                chosen_rank = TOP_K + 1  # Default if not in top_k
                try:
                    chosen_rank = top_k_ids.tolist().index(actual_next_token_id) + 1
                except ValueError:
                    pass  # Remains TOP_K + 1

                row_data = {
                    "clip": clip_id,
                    "step": step_idx,
                    "audio_path": cv_path,
                    "ground_truth": ground_truth,
                    "full_transcription": full_text,
                    "whisper_lang": lang_guess,
                    "chosen_id": actual_next_token_id,
                    "chosen_prob": chosen_token_prob,
                    "chosen_rank": chosen_rank,
                }
                
                # Add top-K predictions
                for k_val in range(TOP_K):
                    if k_val < len(top_k_ids):
                        token_id_k = top_k_ids[k_val].item()
                        row_data[f"top{k_val+1}_id"] = token_id_k
                        row_data[f"top{k_val+1}_txt"] = tok.decode([token_id_k])
                        row_data[f"top{k_val+1}_prob"] = float(top_k_probs[k_val].item())
                    else:
                        row_data[f"top{k_val+1}_id"] = None
                        row_data[f"top{k_val+1}_txt"] = ""
                        row_data[f"top{k_val+1}_prob"] = 0.0
                
                rows.append(row_data)

                current_decoder_input_token_ids.append(actual_next_token_id)
                if actual_next_token_id == tok.eot:
                    break
            
            if wav_path.exists():
                wav_path.unlink()

        # Write results to CSV
        csv_path = OUT_DIR / f"{lang_num:03d}_{lang_name}_subtoken_beam.csv"
        df = pd.DataFrame(rows)
        
        # Reorder columns for better readability
        cols = list(df.columns)
        if "chosen_rank" in cols:
            rank_idx = cols.index("chosen_rank")
            prob_idx = cols.index("chosen_prob")
            if rank_idx > prob_idx + 1:
                cols.insert(prob_idx + 1, cols.pop(rank_idx))
                df = df[cols]
        
        df.to_csv(csv_path, index=False, encoding='utf-8')
        print(f"✓ CSV written → {csv_path}")
        successful_langs_count += 1

    except Exception as e:
        detailed_error_msg = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
        print(f"❌ Error processing {lang_name}: {detailed_error_msg}")
        log_error_to_file(lang_name, detailed_error_msg)
        failed_langs_count += 1
        
        # Clean up temp file if error occurred
        if 'wav_path' in locals() and wav_path.exists():
            wav_path.unlink()
        continue

# ───────────── Cleanup ───────────────────────────────────────────
if TEMP_DIR.exists():
    try:
        shutil.rmtree(TEMP_DIR)
        print(f"✓ Temporary directory {TEMP_DIR} removed.")
    except OSError as e:
        print(f"Error removing temporary directory {TEMP_DIR}: {e.strerror}")

# ───────────── Summary ───────────────────────────────────────────
print(f"\n✓ All {processed_langs_count}/{total_langs_to_process} language tasks attempted.")
print(f"  Successfully processed: {successful_langs_count} languages.")
if failed_langs_count > 0:
    print(f"  Failed to process: {failed_langs_count} languages (see {ERROR_LOG} for details).")
print(f"  Output files are in: {OUT_DIR}")


  from .autonotebook import tqdm as notebook_tqdm


Loading Whisper (large-v2) on device: NVIDIA A100-PCIE-40GB
Whisper model loaded.

─── [003/28] GERMAN (de / de) ───────────


Reading metadata...: 917350it [00:22, 41317.27it/s]


✓ Processing 98 clips, 601.7s total for german


Decoding german: 100%|██████████| 98/98 [03:10<00:00,  1.94s/it]


✓ CSV written → results_beam_600s/003_german_subtoken_beam.csv

─── [004/28] SPANISH (es / es) ───────────


Reading metadata...: 415157it [00:09, 42742.92it/s]


✓ Processing 101 clips, 602.7s total for spanish


Decoding spanish:  40%|███▉      | 40/101 [01:14<01:51,  1.83s/it]



Decoding spanish: 100%|██████████| 101/101 [03:39<00:00,  2.17s/it]


✓ CSV written → results_beam_600s/004_spanish_subtoken_beam.csv

─── [007/28] FRENCH (fr / fr) ───────────


Reading metadata...: 726005it [00:22, 32094.48it/s]


✓ Processing 105 clips, 602.0s total for french


Decoding french:  14%|█▍        | 15/105 [00:28<02:26,  1.63s/it]



Decoding french:  73%|███████▎  | 77/105 [02:17<00:46,  1.66s/it]



Decoding french: 100%|██████████| 105/105 [03:05<00:00,  1.77s/it]


✓ CSV written → results_beam_600s/007_french_subtoken_beam.csv

─── [009/28] PORTUGUESE (pt / pt) ───────────


Reading metadata...: 152647it [00:03, 46667.14it/s]


✓ Processing 119 clips, 603.3s total for portuguese


Decoding portuguese: 100%|██████████| 119/119 [03:04<00:00,  1.55s/it]


✓ CSV written → results_beam_600s/009_portuguese_subtoken_beam.csv

─── [010/28] TURKISH (tr / tr) ───────────


Reading metadata...: 114056it [00:03, 35459.69it/s]


✓ Processing 129 clips, 604.2s total for turkish


Decoding turkish: 100%|██████████| 129/129 [03:24<00:00,  1.58s/it]


✓ CSV written → results_beam_600s/010_turkish_subtoken_beam.csv

─── [011/28] ITALIAN (it / it) ───────────


Reading metadata...: 236777it [00:09, 25735.00it/s]


✓ Processing 104 clips, 607.0s total for italian


Decoding italian: 100%|██████████| 104/104 [03:15<00:00,  1.88s/it]


✓ CSV written → results_beam_600s/011_italian_subtoken_beam.csv

─── [012/28] DUTCH (nl / nl) ───────────


Reading metadata...: 90449it [00:03, 27364.37it/s]


✓ Processing 117 clips, 601.7s total for dutch


Decoding dutch: 100%|██████████| 117/117 [03:24<00:00,  1.75s/it]


✓ CSV written → results_beam_600s/012_dutch_subtoken_beam.csv

─── [013/28] SWEDISH (sv / sv-SE) ───────────


Reading metadata...: 40770it [00:01, 27393.88it/s]


✓ Processing 112 clips, 603.5s total for swedish


Decoding swedish: 100%|██████████| 112/112 [03:23<00:00,  1.82s/it]


✓ CSV written → results_beam_600s/013_swedish_subtoken_beam.csv

─── [014/28] CATALAN (ca / ca) ───────────


Reading metadata...: 1824534it [01:02, 29174.55it/s]


✓ Processing 95 clips, 601.8s total for catalan


Decoding catalan: 100%|██████████| 95/95 [03:05<00:00,  1.95s/it]


✓ CSV written → results_beam_600s/014_catalan_subtoken_beam.csv

─── [015/28] FINNISH (fi / fi) ───────────


Reading metadata...: 10447it [00:00, 24901.36it/s]


✓ Processing 112 clips, 601.6s total for finnish


Decoding finnish: 100%|██████████| 112/112 [03:48<00:00,  2.04s/it]


✓ CSV written → results_beam_600s/015_finnish_subtoken_beam.csv

─── [016/28] INDONESIAN (id / id) ───────────


Reading metadata...: 26108it [00:01, 25112.22it/s]


✓ Processing 136 clips, 601.9s total for indonesian


Decoding indonesian:  80%|████████  | 109/136 [02:38<00:37,  1.40s/it]



Decoding indonesian: 100%|██████████| 136/136 [03:19<00:00,  1.46s/it]


✓ CSV written → results_beam_600s/016_indonesian_subtoken_beam.csv

─── [017/28] VIETNAMESE (vi / vi) ───────────


Reading metadata...: 5135it [00:00, 9787.16it/s]


✓ Processing 140 clips, 604.5s total for vietnamese


Decoding vietnamese: 100%|██████████| 140/140 [03:45<00:00,  1.61s/it]


✓ CSV written → results_beam_600s/017_vietnamese_subtoken_beam.csv

─── [018/28] ROMANIAN (ro / ro) ───────────


Reading metadata...: 17737it [00:00, 21752.07it/s]


✓ Processing 137 clips, 600.3s total for romanian


Decoding romanian: 100%|██████████| 137/137 [03:54<00:00,  1.72s/it]


✓ CSV written → results_beam_600s/018_romanian_subtoken_beam.csv

─── [019/28] DANISH (da / da) ───────────


Reading metadata...: 10225it [00:00, 30319.81it/s]


✓ Processing 211 clips, 600.5s total for danish


Decoding danish:  88%|████████▊ | 186/211 [02:31<00:19,  1.29it/s]



Decoding danish: 100%|██████████| 211/211 [02:50<00:00,  1.24it/s]


✓ CSV written → results_beam_600s/019_danish_subtoken_beam.csv

─── [020/28] NORWEGIAN (no / nn-NO) ───────────


Reading metadata...: 1173it [00:00, 3931.75it/s]


✓ Processing 124 clips, 605.7s total for norwegian


Decoding norwegian: 100%|██████████| 124/124 [03:35<00:00,  1.73s/it]


✓ CSV written → results_beam_600s/020_norwegian_subtoken_beam.csv

─── [021/28] CZECH (cs / cs) ───────────


Reading metadata...: 61391it [00:01, 47032.49it/s]


✓ Processing 171 clips, 600.2s total for czech


Decoding czech: 100%|██████████| 171/171 [04:06<00:00,  1.44s/it]


✓ CSV written → results_beam_600s/021_czech_subtoken_beam.csv

─── [022/28] HUNGARIAN (hu / hu) ───────────


Reading metadata...: 60358it [00:03, 20093.29it/s]


✓ Processing 101 clips, 607.4s total for hungarian


Decoding hungarian: 100%|██████████| 101/101 [04:09<00:00,  2.48s/it]


✓ CSV written → results_beam_600s/022_hungarian_subtoken_beam.csv

─── [023/28] YORUBA (yo / yo) ───────────


Reading metadata...: 3077it [00:00, 6571.41it/s]


✓ Processing 87 clips, 600.1s total for yoruba


Decoding yoruba: 100%|██████████| 87/87 [03:42<00:00,  2.56s/it]


✓ CSV written → results_beam_600s/023_yoruba_subtoken_beam.csv

─── [035/28] LITHUANIAN (lt / lt) ───────────


Reading metadata...: 16643it [00:00, 25221.00it/s]


✓ Processing 111 clips, 601.8s total for lithuanian


Decoding lithuanian: 100%|██████████| 111/111 [04:14<00:00,  2.29s/it]


✓ CSV written → results_beam_600s/035_lithuanian_subtoken_beam.csv

─── [039/28] WELSH (cy / cy) ───────────


Reading metadata...: 90369it [00:01, 47477.70it/s]


✓ Processing 100 clips, 600.3s total for welsh


Decoding welsh: 100%|██████████| 100/100 [03:53<00:00,  2.33s/it]


✓ CSV written → results_beam_600s/039_welsh_subtoken_beam.csv

─── [043/28] LATVIAN (lv / lv) ───────────


Reading metadata...: 171652it [00:03, 46328.42it/s]


✓ Processing 105 clips, 603.3s total for latvian


Decoding latvian: 100%|██████████| 105/105 [04:23<00:00,  2.51s/it]


✓ CSV written → results_beam_600s/043_latvian_subtoken_beam.csv

─── [046/28] AZERBAIJANI (az / az) ───────────


Reading metadata...: 130it [00:00, 1012.10it/s]


✓ Processing 115 clips, 601.0s total for azerbaijani


Decoding azerbaijani: 100%|██████████| 115/115 [04:48<00:00,  2.51s/it]


✓ CSV written → results_beam_600s/046_azerbaijani_subtoken_beam.csv

─── [049/28] ESTONIAN (et / et) ───────────


Reading metadata...: 24381it [00:01, 18162.27it/s]


✓ Processing 81 clips, 601.6s total for estonian


Decoding estonian: 100%|██████████| 81/81 [05:02<00:00,  3.73s/it]


✓ CSV written → results_beam_600s/049_estonian_subtoken_beam.csv

─── [052/28] BASQUE (eu / eu) ───────────


Reading metadata...: 179132it [00:06, 26277.73it/s]


✓ Processing 100 clips, 603.0s total for basque


Decoding basque: 100%|██████████| 100/100 [03:51<00:00,  2.32s/it]


✓ CSV written → results_beam_600s/052_basque_subtoken_beam.csv

─── [059/28] ALBANIAN (sq / sq) ───────────


Reading metadata...: 1524it [00:00, 4959.71it/s]


✓ Processing 122 clips, 605.2s total for albanian


Decoding albanian: 100%|██████████| 122/122 [04:23<00:00,  2.16s/it]


✓ CSV written → results_beam_600s/059_albanian_subtoken_beam.csv

─── [060/28] SWAHILI (sw / sw) ───────────


Reading metadata...: 267001it [00:10, 26219.02it/s]


✓ Processing 104 clips, 605.4s total for swahili


Decoding swahili:  47%|████▋     | 49/104 [02:00<01:41,  1.84s/it]



Decoding swahili:  86%|████████▌ | 89/104 [03:26<00:39,  2.64s/it]



Decoding swahili: 100%|██████████| 104/104 [04:18<00:00,  2.49s/it]


✓ CSV written → results_beam_600s/060_swahili_subtoken_beam.csv

─── [079/28] UZBEK (uz / uz) ───────────


Reading metadata...: 86430it [00:01, 49121.30it/s]


✓ Processing 109 clips, 603.7s total for uzbek


Decoding uzbek: 100%|██████████| 109/109 [04:40<00:00,  2.58s/it]


✓ CSV written → results_beam_600s/079_uzbek_subtoken_beam.csv

─── [085/28] MALTESE (mt / mt) ───────────


Reading metadata...: 6442it [00:00, 11225.47it/s]


✓ Processing 110 clips, 600.7s total for maltese


Decoding maltese: 100%|██████████| 110/110 [05:53<00:00,  3.22s/it]


✓ CSV written → results_beam_600s/085_maltese_subtoken_beam.csv
✓ Temporary directory results_beam_600s/tmp_audio removed.

✓ All 28/28 language tasks attempted.
  Successfully processed: 28 languages.
  Output files are in: results_beam_600s
