In [10]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Whisper Analysis: Average Rank of GOLD Token in Model Candidates vs. Training Hours
==================================================================================

This script calculates the rank of ground truth (gold) subtokens within the
model's prediction candidates. It requires ground truth transcriptions corresponding
to the processed _subtoken_beam.csv files.
"""

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import glob
import logging
import Levenshtein # For alignment
from transformers import WhisperTokenizer # For Whisper's tokenizer
from scipy.stats import pearsonr # <<<< NEW IMPORT FOR P-VALUE CALCULATION

# ────────────────────────────────────────────────────────────
# Hardcoded configuration
# ────────────────────────────────────────────────────────────
DATA_DIR = Path("results_beam_600s")  # Directory with CSVs
TRAINING_HOURS_CSV = Path("whisper_training_hours.csv")  # CSV with training hours
OUTPUT_DIR = Path("analysis_results_beam")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

# Define language resource groups
HIGH_RESOURCE = {"de", "es", "fr", "pt", "tr"}
MEDIUM_RESOURCE = {"it", "nl", "sv", "ca", "fi", "id", "vi", "ro", "da",
                   "no", "cs", "hu"}
LOW_RESOURCE = {"cy", "lt", "lv", "az", "et", "eu", "sq", "sw", "mt", "uz"}

# Language code to full name mapping
LANGUAGE_CODE_TO_NAME_MAP = {
    "de": "German", "es": "Spanish", "fr": "French", "pt": "Portuguese", "tr": "Turkish",
    "it": "Italian", "nl": "Dutch", "sv": "Swedish", "ca": "Catalan", "fi": "Finnish",
    "id": "Indonesian", "vi": "Vietnamese", "ro": "Romanian", "da": "Danish",
    "no": "Norwegian", "cs": "Czech", "hu": "Hungarian",
    "cy": "Welsh", "lt": "Lithuanian", "lv": "Latvian", "az": "Azerbaijani",
    "et": "Estonian", "eu": "Basque", "sq": "Albanian", "sw": "Swahili",
    "mt": "Maltese", "uz": "Uzbek",
    "yo": "Yoruba", # From EXCLUDED_LANGUAGES
    # Add any other codes that might appear if necessary, ensuring they are lowercase
}

# Languages to exclude from plotting
EXCLUDED_LANGUAGES = {"uz", "mt", "sw", "sq", "yo", "da", "vi", "cs"}

# Maximum rank to check in candidates
MAX_RANK = 50 # This is K in top-K

# Whisper model name for tokenizer - ENSURE THIS MATCHES YOUR ASR MODEL
MODEL_NAME = "openai/whisper-base" # E.g., "openai/whisper-small", "openai/whisper-medium"
GROUND_TRUTH_COLUMN_NAME = "ground_truth" # <--- ASSUMED COLUMN NAME FOR GROUND TRUTH IN CSV

# Visual styling for plots
COLOURS = {"High": "steelblue", "Medium": "seagreen",
           "Low": "crimson", "Other": "grey"}
MARKERS = {"High": "o", "Medium": "^", "Low": "s", "Other": "x"}

# --- MODIFICATION: Matplotlib rcParams for general appearance (font family removed) ---
plt.rcParams.update({
    "figure.dpi": 100,
    "savefig.dpi": 300,
    # "font.family": "Times New Roman", # Removed font family enforcement
    "font.size": 8,
    "axes.titlesize": 10,
    "axes.labelsize": 8,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "legend.fontsize": 7,
    "figure.titlesize": 12
})
# --- END MODIFICATION ---

# ────────────────────────────────────────────────────────────
# Setup logging
# ────────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

# ────────────────────────────────────────────────────────────
# Global Tokenizer Initialization
# ────────────────────────────────────────────────────────────
try:
    GLOBAL_TOKENIZER = WhisperTokenizer.from_pretrained(MODEL_NAME)
    logging.info(f"Successfully loaded WhisperTokenizer for model: {MODEL_NAME}")
except Exception as e:
    logging.error(f"Failed to load Whisper tokenizer ({MODEL_NAME}): {e}. "
                  "Please ensure 'transformers' and 'sentencepiece' are installed, "
                  "and the model name is correct.")
    exit()

# ────────────────────────────────────────────────────────────
# Utility functions
# ────────────────────────────────────────────────────────────
def get_resource_group(code):
    """Determine resource group based on language code."""
    if code in HIGH_RESOURCE: return "High"
    if code in MEDIUM_RESOURCE: return "Medium"
    if code in LOW_RESOURCE: return "Low"
    return "Other"

def load_training_hours():
    """Load training hours from CSV file."""
    try:
        df = pd.read_csv(TRAINING_HOURS_CSV)
        df = df[df["Whisper_Training_Hours"].notna()]
        df = df[df["Whisper_Training_Hours"] != "Unknown"]
        hours_dict = {str(row["Whisper_Code"]).lower(): float(row["Whisper_Training_Hours"]) # Ensure code is lowercase
                      for _, row in df.iterrows()
                      if pd.notna(row["Whisper_Code"]) and pd.notna(row["Whisper_Training_Hours"])}
        logging.info(f"Loaded training hours for {len(hours_dict)} languages")
        return hours_dict
    except Exception as e:
        logging.error(f"Failed to load training hours: {e}")
        return {}

def safe_read_csv(path):
    """Safely read a CSV file, handling potential errors."""
    try:
        return pd.read_csv(path)
    except FileNotFoundError:
        logging.warning(f"File not found: {path}")
        return pd.DataFrame()
    except Exception as e:
        logging.warning(f"Failed to read {path}: {e}")
        return pd.DataFrame()

# ────────────────────────────────────────────────────────────
# Core analysis functions
# ────────────────────────────────────────────────────────────
def calculate_rank_of_gold_subtokens(model_outputs_df, ground_truth_text, tokenizer, current_max_rank):
    """
    Calculates the rank of gold standard subtokens in the model's predictions
    for a single utterance.
    """
    if model_outputs_df is None or model_outputs_df.empty:
        logging.debug("Model output DataFrame is empty for rank calculation.")
        return np.nan, np.nan, np.nan, []
    if not ground_truth_text:
        logging.debug("Ground truth text is empty for rank calculation.")
        return np.nan, np.nan, np.nan, []
    try:
        gold_subtoken_ids = tokenizer(ground_truth_text, add_special_tokens=False).input_ids
    except Exception as e:
        logging.error(f"Tokenizer error on ground truth: '{str(ground_truth_text)[:50]}...': {e}")
        return np.nan, np.nan, np.nan, []
    if not gold_subtoken_ids:
        logging.debug("Ground truth tokenized to empty sequence.")
        return np.nan, np.nan, 0.0, []

    model_chosen_ids = model_outputs_df['chosen_id'].dropna().astype(int).tolist()
    top_k_cols = sorted(
        [col for col in model_outputs_df.columns if col.startswith('top') and col.endswith('_id')],
        key=lambda x: int(x.replace('top', '').replace('_id', ''))
    )[:current_max_rank]

    model_candidates_at_step = [[int(row[col]) for col in top_k_cols if pd.notna(row[col])]
                                     for _, row in model_outputs_df.iterrows()]

    if not model_chosen_ids:
        ranks_for_gold_tokens = [current_max_rank + 1] * len(gold_subtoken_ids)
        if not ranks_for_gold_tokens: return np.nan, np.nan, np.nan, []
        avg_rank = np.mean(ranks_for_gold_tokens) if ranks_for_gold_tokens else np.nan
        median_rank = np.median(ranks_for_gold_tokens) if ranks_for_gold_tokens else np.nan
        percent_found = 0.0
        return avg_rank, median_rank, percent_found, ranks_for_gold_tokens

    opcodes = Levenshtein.opcodes(gold_subtoken_ids, model_chosen_ids)
    ranks_for_gold_tokens = []
    found_in_candidates_count = 0
    total_gold_tokens_processed = 0

    for tag, i1, i2, j1, j2 in opcodes:
        if tag == 'equal' or tag == 'replace':
            for k_op in range(i2 - i1):
                gold_token_id = gold_subtoken_ids[i1 + k_op]
                model_output_step_index = j1 + k_op
                total_gold_tokens_processed += 1
                current_rank_for_gold_token = current_max_rank + 1
                if model_output_step_index < len(model_candidates_at_step):
                    candidates_at_this_model_step = model_candidates_at_step[model_output_step_index]
                    try:
                        rank_0_indexed = candidates_at_this_model_step.index(gold_token_id)
                        current_rank_for_gold_token = rank_0_indexed + 1
                        if current_rank_for_gold_token <= current_max_rank:
                            found_in_candidates_count += 1
                    except ValueError: pass
                ranks_for_gold_tokens.append(current_rank_for_gold_token)
        elif tag == 'delete':
            for _ in range(i2 - i1):
                total_gold_tokens_processed += 1
                ranks_for_gold_tokens.append(current_max_rank + 1)
        elif tag == 'insert': pass

    if not ranks_for_gold_tokens or total_gold_tokens_processed == 0:
        return np.nan, np.nan, 0.0, []
    avg_rank = np.mean(ranks_for_gold_tokens)
    median_rank = np.median(ranks_for_gold_tokens)
    percent_found = found_in_candidates_count / total_gold_tokens_processed
    return avg_rank, median_rank, percent_found, ranks_for_gold_tokens

def process_language_file(file_path, training_hours, tokenizer_instance, current_max_rank):
    """Process a single language file and extract metrics using gold token ranks."""
    try:
        filename = os.path.basename(file_path)
        stem_parts = Path(filename).stem.split('_')

        lang_code_from_fn = None
        if len(stem_parts) > 0 and len(stem_parts[0]) == 2:
            lang_code_from_fn = stem_parts[0].lower()

        df = safe_read_csv(file_path)
        if df.empty:
            logging.warning(f"Empty or invalid data for file: {filename}")
            return None

        csv_lang_code = None
        if 'whisper_lang' in df.columns:
            unique_codes = df['whisper_lang'].dropna().unique()
            if len(unique_codes) == 1:
                csv_lang_code = str(unique_codes[0]).lower()
            elif len(unique_codes) > 1:
                logging.warning(f"Multiple whisper_lang codes in {filename}: {unique_codes}. Using first: {unique_codes[0]}")
                csv_lang_code = str(unique_codes[0]).lower()
        
        final_lang_code = csv_lang_code if csv_lang_code else lang_code_from_fn

        display_language_name = "Unknown"
        if final_lang_code:
            display_language_name = LANGUAGE_CODE_TO_NAME_MAP.get(final_lang_code, final_lang_code.capitalize())
        else:
            if len(stem_parts) > 0:
                name_candidate = stem_parts[0]
                if name_candidate.lower() not in LANGUAGE_CODE_TO_NAME_MAP.keys() and \
                   name_candidate.capitalize() in LANGUAGE_CODE_TO_NAME_MAP.values():
                    display_language_name = name_candidate.capitalize()
                elif len(name_candidate) > 2 :
                    display_language_name = name_candidate.capitalize()
                elif name_candidate:
                    display_language_name = LANGUAGE_CODE_TO_NAME_MAP.get(name_candidate.lower(), name_candidate.upper())
            logging.warning(f"Could not determine a definitive language code for {filename}. Attempting to use '{display_language_name}' based on filename parts.")

        logging.info(f"Processing {display_language_name} (Code: {final_lang_code if final_lang_code else 'N/A'}, File: {filename})...")

        ground_truth_text = None
        if GROUND_TRUTH_COLUMN_NAME in df.columns:
            gts = df[GROUND_TRUTH_COLUMN_NAME].dropna()
            if not gts.empty:
                ground_truth_text = str(gts.iloc[0]).strip()
                if gts.astype(str).str.strip().nunique() > 1:
                    logging.warning(
                        f"Multiple distinct ground truth strings found in column '{GROUND_TRUTH_COLUMN_NAME}' "
                        f"for {filename} even after stripping whitespace. Using the first one: '{ground_truth_text[:100]}...'"
                    )
            else:
                logging.warning(f"Column '{GROUND_TRUTH_COLUMN_NAME}' found but contains no data for {filename}.")
        else:
            logging.warning(f"Ground truth column '{GROUND_TRUTH_COLUMN_NAME}' not found in {filename}. ")

        if 'step' in df.columns: model_outputs_df = df[df['step'] >= 0].copy()
        else:
            logging.warning(f"'step' column not found in {filename}. Assuming all rows are model outputs.")
            model_outputs_df = df.copy()

        avg_rank, median_rank, percent_found = np.nan, np.nan, np.nan
        if model_outputs_df.empty and 'step' in df.columns:
            logging.warning(f"No model output (prediction) rows found in {filename}.")

        if ground_truth_text is None:
            logging.warning(f"No ground truth text available for {filename}. Cannot calculate gold token ranks.")
        elif model_outputs_df.empty and not ground_truth_text:
            avg_rank, median_rank, percent_found = 0.0, 0.0, 1.0
            logging.info(f"Both model output and ground truth are empty for {filename}.")
        elif model_outputs_df.empty and ground_truth_text:
            logging.info(f"Model output is empty but ground truth exists for {filename}.")
            try:
                gold_ids = tokenizer_instance(ground_truth_text, add_special_tokens=False).input_ids
                if gold_ids: avg_rank = median_rank = current_max_rank + 1.0; percent_found = 0.0
                else: avg_rank, median_rank, percent_found = 0.0, 0.0, 1.0
            except Exception as e_tok: logging.error(f"Tokenizer error on GT for empty model output {filename}: {e_tok}")
        else:
            avg_rank, median_rank, percent_found, _ = calculate_rank_of_gold_subtokens(
                model_outputs_df, ground_truth_text, tokenizer_instance, current_max_rank
            )

        num_utterances = 0
        if not df.empty:
            if ground_truth_text is not None or not model_outputs_df.empty:
                num_utterances = 1
            if "step" in df.columns and (df["step"] == -1).any():
                num_utterances_step = int((df["step"] == -1).sum())
                if num_utterances_step > 0 : num_utterances = num_utterances_step

        return {"language": display_language_name, "code": final_lang_code,
                "training_hours": training_hours.get(final_lang_code, np.nan) if final_lang_code else np.nan,
                "num_utterances": num_utterances, "avg_rank": avg_rank, "median_rank": median_rank,
                "percent_found": percent_found, 
                "resource_group": get_resource_group(final_lang_code) if final_lang_code else "Other"}
    except Exception as e:
        logging.error(f"Error processing file {file_path}: {e}", exc_info=True)
        return None

def collect_all_results(tokenizer_instance, current_max_rank):
    """Collect results from all language files."""
    training_hours = load_training_hours()
    csv_files = sorted(DATA_DIR.glob("*_subtoken_beam.csv"))
    if not csv_files:
        logging.error(f"No *_subtoken_beam.csv files found in {DATA_DIR}")
        return pd.DataFrame()

    logging.info(f"Found {len(csv_files)} files to process.")
    results = [res for file_path in csv_files if (res := process_language_file(file_path, training_hours, tokenizer_instance, current_max_rank)) is not None]
    if not results:
        logging.error("No valid results collected after processing all files.")
        return pd.DataFrame()

    df_results = pd.DataFrame(results)
    try:
        output_file = OUTPUT_DIR / "language_gold_token_rank_metrics_FULL.csv" # MODIFIED: Changed name to avoid conflict
        df_results.to_csv(output_file, index=False, float_format='%.4f')
        logging.info(f"Saved full metrics to {output_file}")
    except Exception as e: logging.error(f"Failed to save full metrics: {e}")
    return df_results

# ────────────────────────────────────────────────────────────
# NEW function to save average rank CSV
# ────────────────────────────────────────────────────────────
def save_average_rank_csv(results_df):
    """Saves a CSV file with language, code, average rank, and resource group, sorted by average rank."""
    if results_df.empty or 'avg_rank' not in results_df.columns:
        logging.warning("Results DataFrame is empty or 'avg_rank' column is missing. Cannot save average rank CSV.")
        return

    df_avg_rank = results_df[['language', 'code', 'avg_rank', 'resource_group']].copy()
    df_avg_rank = df_avg_rank.dropna(subset=['avg_rank'])
    df_avg_rank = df_avg_rank.sort_values(by="avg_rank", ascending=True)

    if df_avg_rank.empty:
        logging.warning("No data to save for average rank CSV after filtering NaNs.")
        return

    try:
        output_csv_path = OUTPUT_DIR / "average_gold_token_rank_by_language.csv"
        df_avg_rank.to_csv(output_csv_path, index=False, float_format='%.4f')
        logging.info(f"Successfully saved average rank data to {output_csv_path}")
    except Exception as e:
        logging.error(f"Failed to save average rank CSV: {e}")

# ────────────────────────────────────────────────────────────
# Plotting functions
# ────────────────────────────────────────────────────────────
def plot_avg_rank_vs_hours(results_df):
    """Plot average rank of GOLD token vs training hours."""
    df_copy = results_df.copy()
    if 'code' not in df_copy.columns:
        df_copy['code'] = None
        
    df = df_copy[~df_copy["code"].isin(EXCLUDED_LANGUAGES)].copy()
    df = df.dropna(subset=["avg_rank", "training_hours", "language"])
    
    if df.empty:
        logging.warning("No valid data for average gold rank plot after filtering.")
        return

    fig, ax = plt.subplots(figsize=(5.5, 4.0)) 

    for group_name, group_df in df.groupby("resource_group"):
        color = COLOURS.get(group_name, COLOURS["Other"])
        marker = MARKERS.get(group_name, MARKERS["Other"])
        # --- MODIFICATION: Language count removed from label ---
        ax.scatter(group_df["training_hours"], group_df["avg_rank"], s=40, alpha=0.7,
                   marker=marker, color=color, label=f"{group_name}-resource")
        # --- END MODIFICATION ---
        for _, row in group_df.iterrows():
            ax.annotate(row["language"], (row["training_hours"], row["avg_rank"]),
                        xytext=(3, 1), textcoords="offset points", fontsize=6, alpha=0.9)

    # --- MODIFICATION FOR P-VALUE ---
    correlation_text = "N/A"
    if len(df) >= 2:
        valid_mask = np.isfinite(df["training_hours"]) & (df["training_hours"] > 0) & np.isfinite(df["avg_rank"])
        if sum(valid_mask) >= 2:
            x_log = np.log10(df.loc[valid_mask, "training_hours"])
            y_vals = df.loc[valid_mask, "avg_rank"]
            try:
                # Calculate Pearson correlation and p-value
                r, p_value = pearsonr(x_log, y_vals)
                logging.info(f"Correlation (log(training_hours) vs avg_rank): r={r:.3f}, p-value={p_value:.3e}")
                print(f"INFO: Correlation (log(training_hours) vs avg_rank): r={r:.3f}, p-value={p_value:.3e}") # <<< P-VALUE OUTPUT TO TERMINAL

                z = np.polyfit(x_log, y_vals, 1); p_poly = np.poly1d(z) # Renamed 'p' to 'p_poly' to avoid conflict
                
                x_min, x_max = df.loc[valid_mask, "training_hours"].min(), df.loc[valid_mask, "training_hours"].max()
                
                trend_label = f"Trend (r={r:.2f})" # <<< UPDATED LABEL
                
                if x_min > 0 and x_max > 0 and x_min < x_max: 
                    x_line = np.logspace(np.log10(max(x_min * 0.9, 0.1)), np.log10(x_max * 1.1), 100)
                    y_line = p_poly(np.log10(x_line))
                    ax.plot(x_line, y_line, 'k--', alpha=0.7, linewidth=1, label=trend_label)
                elif x_min == x_max and x_min > 0: # Handle case with only one distinct x value after log
                    # Plotting a point or a short horizontal line might be tricky for trend
                    # For now, we just show the correlation info. A trend line isn't meaningful here.
                    logging.warning(f"Only one distinct training hour value after log transformation for trend line: {x_min}. Trend line may not be representative.")
                    # Optionally, still add the label if you want to show r and p-value
                    # ax.plot([], [], 'k--', alpha=0.7, linewidth=1, label=trend_label) # Invisible plot for label
                else:
                    logging.warning("Could not plot trend line due to insufficient range or invalid values in training hours.")
            except Exception as e_trend: 
                logging.warning(f"Could not compute trend line or correlation: {e_trend}")
                print(f"WARNING: Could not compute trend line or correlation: {e_trend}") # <<< TERMINAL OUTPUT FOR ERROR
    # --- END MODIFICATION FOR P-VALUE ---

    ax.set_xscale("log")
    ax.set_xlabel("Whisper training hours (log scale)")
    ax.set_ylabel(f"Average Rank of Correct Token")
    # ax.set_title(f"Gold Token Rank in Candidates vs. Training Data Size (Top-{MAX_RANK})")
    y_max_val = df["avg_rank"].max()
    ax.set_ylim(0.9, min(MAX_RANK + 2, y_max_val * 1.1 if pd.notna(y_max_val) and y_max_val > 0 else MAX_RANK + 2))
    ax.grid(True, linestyle='--', alpha=0.5)
    ax.legend()
    fig.tight_layout(pad=0.5)
    try:
        plt.savefig(OUTPUT_DIR / f"avg_gold_rank_vs_hours_top{MAX_RANK}.png")
        plt.savefig(OUTPUT_DIR / f"avg_gold_rank_vs_hours_top{MAX_RANK}.pdf")
        plt.savefig(OUTPUT_DIR / f"avg_gold_rank_vs_hours_top{MAX_RANK}.svg", format="svg")
        logging.info(f"Saved gold rank vs hours plot to {OUTPUT_DIR}")
    except Exception as e: logging.error(f"Failed to save gold rank vs hours plot: {e}")
    plt.close(fig)

def plot_ranks_by_language(results_df):
    """Plot ranked languages by average gold token rank."""
    df_copy = results_df.copy()
    if 'code' not in df_copy.columns:
        df_copy['code'] = None

    df = df_copy[~df_copy["code"].isin(EXCLUDED_LANGUAGES)].copy()
    df = df.dropna(subset=["avg_rank", "median_rank", "language"])
    if df.empty:
        logging.warning("No valid data for ranks by language plot after filtering.")
        return
    df = df.sort_values(by="avg_rank")
    bar_colors = [COLOURS.get(grp, COLOURS["Other"]) for grp in df["resource_group"]]

    fig_height = max(2.5, 0.18 * len(df)) 
    fig_width = 4.5 
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    y_pos = np.arange(len(df))
    bars = ax.barh(y_pos, df["avg_rank"], height=0.4, color=bar_colors, alpha=0.8, label="Avg Gold Rank")
    
    for i, row_tuple in enumerate(df.itertuples()): 
        ax.text(row_tuple.avg_rank + 0.05 * df["avg_rank"].max(), 
                i, f"{row_tuple.avg_rank:.2f}", va="center", fontsize=6)
        
    y_tick_labels = []
    for lang, code in zip(df["language"], df["code"]):
        if pd.notna(code) and code != "":
            y_tick_labels.append(f"{lang} ({code})")
        else:
            y_tick_labels.append(f"{lang}")
            
    ax.set_yticks(y_pos)
    ax.set_yticklabels(y_tick_labels) 
    ax.set_xlabel(f"Rank of Gold Token in Top-{MAX_RANK} Candidates (lower is better)")
    ax.set_title(f"Average Rank of Gold Token in Candidates by Language (Top-{MAX_RANK})")
    max_val = df["avg_rank"].max()
    ax.set_xlim(0.8, (max_val * 1.2) if pd.notna(max_val) and max_val > 0 else (MAX_RANK + 2)) 
    ax.grid(axis="x", linestyle="--", alpha=0.5)

    from matplotlib.lines import Line2D
    unique_groups_in_plot = df["resource_group"].unique()
    legend_elements = [Line2D([0], [0], marker='s', color='w', label=f'{grp}-resource',
                              markerfacecolor=COLOURS.get(grp, COLOURS["Other"]), markersize=8) 
                       for grp in COLOURS if grp in unique_groups_in_plot and grp != "Other"]
    if "Other" in unique_groups_in_plot and "Other" in COLOURS :
         legend_elements.append(Line2D([0], [0], marker='s', color='w', label='Other-resource',
                                       markerfacecolor=COLOURS["Other"], markersize=8))
        
    if legend_elements:
        ax.legend(handles=legend_elements, loc="lower right", title="Resource Groups")
    fig.tight_layout(pad=0.5)
    try:
        plt.savefig(OUTPUT_DIR / f"gold_ranks_by_language_top{MAX_RANK}.png") 
        plt.savefig(OUTPUT_DIR / f"gold_ranks_by_language_top{MAX_RANK}.pdf")
        logging.info(f"Saved gold ranks by language plot to {OUTPUT_DIR}")
    except Exception as e: logging.error(f"Failed to save gold ranks by language plot: {e}")
    plt.close(fig)

# ────────────────────────────────────────────────────────────
# Main execution
# ────────────────────────────────────────────────────────────
def main():
    """Main execution function."""
    logging.info("Starting GOLD TOKEN rank analysis")
    logging.info(f"Data directory: {DATA_DIR}")
    logging.info(f"Training hours file: {TRAINING_HOURS_CSV}")
    logging.info(f"Output directory: {OUTPUT_DIR}")
    logging.info(f"Using Whisper Tokenizer for model: {MODEL_NAME}")
    logging.info(f"Max rank (K in Top-K) to check: {MAX_RANK}")
    logging.info(f"Attempting to read ground truth from CSV column: '{GROUND_TRUTH_COLUMN_NAME}'")

    if not hasattr(GLOBAL_TOKENIZER, 'encode'): 
        logging.error("GLOBAL_TOKENIZER is not properly initialized. Exiting.")
        return

    results_df = collect_all_results(GLOBAL_TOKENIZER, MAX_RANK)
    if results_df.empty:
        logging.error("No usable results found after processing all files. Exiting.")
        return

    # Save the average rank CSV
    save_average_rank_csv(results_df) # <-- ADDED CALL HERE

    if not results_df["avg_rank"].notna().any():
        logging.warning("No valid rank data was calculated. Plots will not be generated. "
                        "Check logs for issues with ground truth or data processing.")
        return

    logging.info("Creating plots based on GOLD token ranks...")
    plot_avg_rank_vs_hours(results_df)
    plot_ranks_by_language(results_df)
    logging.info(f"Analysis complete! Results are in {OUTPUT_DIR}")

if __name__ == "__main__":
    main()

INFO: Successfully loaded WhisperTokenizer for model: openai/whisper-base
INFO: Starting GOLD TOKEN rank analysis
INFO: Data directory: results_beam_600s
INFO: Training hours file: whisper_training_hours.csv
INFO: Output directory: analysis_results_beam
INFO: Using Whisper Tokenizer for model: openai/whisper-base
INFO: Max rank (K in Top-K) to check: 50
INFO: Attempting to read ground truth from CSV column: 'ground_truth'
INFO: Loaded training hours for 90 languages
INFO: Found 28 files to process.
INFO: Processing German (Code: de, File: 003_german_subtoken_beam.csv)...
INFO: Processing Spanish (Code: es, File: 004_spanish_subtoken_beam.csv)...
INFO: Processing French (Code: fr, File: 007_french_subtoken_beam.csv)...
INFO: Processing Portuguese (Code: pt, File: 009_portuguese_subtoken_beam.csv)...
INFO: Processing Turkish (Code: tr, File: 010_turkish_subtoken_beam.csv)...
INFO: Processing Italian (Code: it, File: 011_italian_subtoken_beam.csv)...
INFO: Processing Dutch (Code: nl, File

INFO: Correlation (log(training_hours) vs avg_rank): r=-0.355, p-value=1.246e-01


INFO: Saved gold rank vs hours plot to analysis_results_beam
INFO: Saved gold ranks by language plot to analysis_results_beam
INFO: Analysis complete! Results are in analysis_results_beam
