In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Whisper Analysis: Average Confidence of Chosen Token vs. Training Hours
=======================================================================

Generates a plot of average model confidence (probability of the chosen token)
against Whisper training hours for various languages.
"""

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import logging
from scipy.stats import pearsonr

# ────────────────────────────────────────────────────────────
# Configuration
# ────────────────────────────────────────────────────────────
DATA_DIR = Path("results_beam_600s")
TRAINING_HOURS_CSV = Path("whisper_training_hours.csv")
OUTPUT_DIR = Path("analysis_results_beam")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

# Language resource groups
HIGH_RESOURCE = {"de", "es", "fr", "pt", "tr"}
MEDIUM_RESOURCE = {"it", "nl", "sv", "ca", "fi", "id", "vi", "ro",
                   "no", "cs", "hu", "yo"}
LOW_RESOURCE = {"cy", "lt", "lv", "az", "et", "eu", "sq", "sw", "mt", "uz"}

EXCLUDED_LANGUAGES = {"uz", "mt", "sw", "sq", "yo", "da", "vi", "cs"}

LANGUAGE_CODE_TO_NAME_MAP = {
    "de": "German", "es": "Spanish", "fr": "French", "pt": "Portuguese", "tr": "Turkish",
    "it": "Italian", "nl": "Dutch", "sv": "Swedish", "ca": "Catalan", "fi": "Finnish",
    "id": "Indonesian", "vi": "Vietnamese", "ro": "Romanian", "da": "Danish",
    "no": "Norwegian", "cs": "Czech", "hu": "Hungarian", "yo": "Yoruba",
    "cy": "Welsh", "lt": "Lithuanian", "lv": "Latvian", "az": "Azerbaijani",
    "et": "Estonian", "eu": "Basque", "sq": "Albanian", "sw": "Swahili",
    "mt": "Maltese", "uz": "Uzbek",
}

# Visual styling
COLOURS = {"High": "steelblue", "Medium": "seagreen",
           "Low": "crimson", "Other": "grey"}
MARKERS = {"High": "o", "Medium": "^", "Low": "s", "Other": "x"}

plt.rcParams.update({
    "figure.dpi": 100,
    "savefig.dpi": 300,
    "font.size": 8,
    "axes.titlesize": 10,
    "axes.labelsize": 8,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "legend.fontsize": 7,
    "figure.titlesize": 12
})

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")


# ────────────────────────────────────────────────────────────
# Utility functions
# ────────────────────────────────────────────────────────────
def get_resource_group(code):
    """Determine resource group based on language code."""
    if code in HIGH_RESOURCE:
        return "High"
    if code in MEDIUM_RESOURCE:
        return "Medium"
    if code in LOW_RESOURCE:
        return "Low"
    return "Other"


def load_training_hours():
    """Load training hours from CSV file."""
    try:
        df = pd.read_csv(TRAINING_HOURS_CSV)
        df = df[df["Whisper_Training_Hours"].notna()]
        df = df[df["Whisper_Training_Hours"] != "Unknown"]
        
        hours_dict = {}
        for _, row in df.iterrows():
            try:
                code = str(row["Whisper_Code"]).lower()
                hours = float(row["Whisper_Training_Hours"])
                hours_dict[code] = hours
            except (ValueError, KeyError):
                logging.warning(f"Skipping row due to missing code or invalid hours: {row}")
                continue
            
        logging.info(f"Loaded training hours for {len(hours_dict)} languages")
        return hours_dict
    except FileNotFoundError:
        logging.error(f"Training hours file not found: {TRAINING_HOURS_CSV}")
        return {}
    except Exception as e:
        logging.error(f"Failed to load training hours: {e}")
        return {}


def safe_read_csv(path):
    """Safely read a CSV file, handling potential errors."""
    try:
        return pd.read_csv(path)
    except FileNotFoundError:
        logging.warning(f"File not found: {path}")
        return pd.DataFrame()
    except Exception as e:
        logging.warning(f"Failed to read {path}: {e}")
        return pd.DataFrame()


# ────────────────────────────────────────────────────────────
# Core analysis functions
# ────────────────────────────────────────────────────────────
def calculate_avg_confidence(df):
    """Calculate the average confidence (chosen_prob) of chosen tokens."""
    if 'step' in df.columns:
        pred_rows = df[df['step'] >= 0].copy()
    else:
        logging.warning("No 'step' column found. Using all rows for confidence calculation.")
        pred_rows = df.copy()
    
    if pred_rows.empty:
        logging.debug("No prediction rows found for confidence calculation")
        return np.nan
    
    if 'chosen_prob' not in pred_rows.columns:
        logging.warning("Missing 'chosen_prob' column. Cannot calculate confidence.")
        return np.nan
        
    valid_probs = pred_rows['chosen_prob'].dropna()
    if valid_probs.empty:
        logging.debug("No valid 'chosen_prob' values found.")
        return np.nan
        
    return valid_probs.mean()


def process_language_file(file_path, training_hours):
    """Process a single language file and extract metrics including confidence."""
    try:
        filename = os.path.basename(file_path)
        stem_parts = Path(filename).stem.split('_')
        
        lang_code_from_fn = None
        if len(stem_parts) > 0 and len(stem_parts[0]) == 2 and not stem_parts[0].isdigit():
            lang_code_from_fn = stem_parts[0].lower()
            if len(stem_parts) > 1 and stem_parts[1] not in ["subtoken", "beam", "wer", "results"]:
                if len(stem_parts) > 1:
                    lang_code_from_fn = None

        df = safe_read_csv(file_path)
        
        if df.empty:
            logging.warning(f"Empty or invalid data from {filename}")
            return None
            
        csv_lang_code = None
        if 'whisper_lang' in df.columns:
            unique_codes = df['whisper_lang'].dropna().astype(str).str.lower().unique()
            if len(unique_codes) == 1:
                csv_lang_code = unique_codes[0]
            elif len(unique_codes) > 1:
                logging.warning(f"Multiple whisper_lang codes in {filename}: {unique_codes}. Using first: {unique_codes[0]}")
                csv_lang_code = unique_codes[0]
        
        final_lang_code = csv_lang_code if csv_lang_code else lang_code_from_fn
        
        display_language_name = "Unknown"
        if final_lang_code and final_lang_code in LANGUAGE_CODE_TO_NAME_MAP:
            display_language_name = LANGUAGE_CODE_TO_NAME_MAP[final_lang_code]
        elif final_lang_code:
            display_language_name = final_lang_code.capitalize()
        else:
            if len(stem_parts) > 0 and not stem_parts[0].isdigit() and len(stem_parts[0]) > 2:
                display_language_name = stem_parts[0].capitalize()
            elif len(stem_parts) > 1 and not stem_parts[1].isdigit():
                display_language_name = stem_parts[1].capitalize()

        logging.debug(f"Processing {display_language_name} (Code: {final_lang_code if final_lang_code else 'N/A'}) from {filename}...")

        avg_confidence = calculate_avg_confidence(df)
        
        num_utterances = 0
        if not df.empty:
            if "step" in df.columns and (df["step"] == -1).any():
                num_utterances = int((df["step"] == -1).sum())
            elif not df.empty:
                num_utterances = 1
            else:
                num_utterances = 0

        return {
            "language": display_language_name,
            "code": final_lang_code,
            "training_hours": training_hours.get(final_lang_code, np.nan) if final_lang_code else np.nan,
            "num_utterances": num_utterances,
            "avg_confidence": avg_confidence,
            "resource_group": get_resource_group(final_lang_code) if final_lang_code else "Other"
        }
    except Exception as e:
        logging.error(f"Error processing {file_path}: {e}", exc_info=True)
        return None


def collect_all_results():
    """Collect results (including confidence) from all language files."""
    training_hours = load_training_hours()
    
    csv_files = sorted(DATA_DIR.glob("*_subtoken_beam.csv"))

    if not csv_files:
        logging.error(f"No relevant CSV files found in {DATA_DIR} (e.g., *_subtoken_beam.csv).")
        try:
            all_files = list(DATA_DIR.glob("*.csv"))
            if all_files:
                logging.info(f"Found these CSV files: {', '.join(f.name for f in all_files[:10])}...")
            else:
                logging.info(f"Directory {DATA_DIR} contains no CSV files.")
        except Exception as e:
            logging.error(f"Error listing directory contents: {e}")
        return pd.DataFrame()
        
    logging.info(f"Found {len(csv_files)} potential language CSV files.")
    
    results = []
    for file_path in csv_files:
        result = process_language_file(file_path, training_hours)
        if result:
            results.append(result)
    
    if not results:
        logging.error("No valid results collected from CSV files.")
        return pd.DataFrame()
    
    df = pd.DataFrame(results)
    
    try:
        output_file = OUTPUT_DIR / "language_confidence_metrics.csv"
        df.to_csv(output_file, index=False, float_format='%.4f')
        logging.info(f"Saved metrics to {output_file}")
    except Exception as e:
        logging.error(f"Failed to save metrics CSV: {e}")
    
    return df


# ────────────────────────────────────────────────────────────
# Plotting functions
# ────────────────────────────────────────────────────────────
def plot_confidence_vs_hours(results_df):
    """Plot average confidence vs training hours."""
    if 'code' not in results_df.columns:
        logging.warning("'code' column not found in results_df. Skipping language exclusion for plotting.")
        df = results_df.copy()
    else:
        df = results_df[~results_df["code"].isin(EXCLUDED_LANGUAGES)].copy()
        
    df = df.dropna(subset=["avg_confidence", "training_hours"])
    
    if df.empty:
        logging.warning("No valid data for average confidence plot after filtering and dropping NaNs.")
        return
        
    fig, ax = plt.subplots(figsize=(4.5, 3.5))
    
    if 'resource_group' not in df.columns:
        logging.warning("'resource_group' column not found. Plotting all points as 'Other'.")
        df['resource_group'] = 'Other'

    for group_name in ["High", "Medium", "Low", "Other"]:
        group_df = df[df["resource_group"] == group_name]
        if group_df.empty:
            continue
        
        marker = MARKERS.get(group_name, "x")
        color = COLOURS.get(group_name, "grey")

        ax.scatter(
            group_df["training_hours"],
            group_df["avg_confidence"],
            s=40,
            alpha=0.7,
            marker=marker,
            color=color,
            label=f"{group_name}-resource"
        )
        
        if 'language' in group_df.columns:
            for _, row in group_df.iterrows():
                ax.annotate(
                    row["language"],
                    (row["training_hours"], row["avg_confidence"]),
                    xytext=(3, 1),
                    textcoords="offset points",
                    fontsize=6,
                )
        else:
            logging.warning("'language' column missing, skipping annotations.")

    # Calculate trend line and correlation
    if len(df) >= 2:
        valid_mask = (df["training_hours"] > 0) & \
                     np.isfinite(df["training_hours"].astype(float)) & \
                     np.isfinite(df["avg_confidence"])
        
        if sum(valid_mask) >= 2:
            x_log_values = np.log10(df.loc[valid_mask, "training_hours"].astype(float))
            y_values = df.loc[valid_mask, "avg_confidence"]
            
            try:
                r_val, p_value = pearsonr(x_log_values, y_values)
                logging.info(f"Correlation (log(training_hours) vs avg_confidence): r={r_val:.3f}, p-value={p_value:.3e}")
                print(f"INFO: Correlation (log(training_hours) vs avg_confidence): r={r_val:.3f}, p-value={p_value:.3e}")

                z = np.polyfit(x_log_values, y_values, 1)
                p_polyfit = np.poly1d(z)
                
                trend_label = f"Trend (r={r_val:.2f}, p={p_value:.2e})"

                x_min = df.loc[valid_mask, "training_hours"].min()
                x_max = df.loc[valid_mask, "training_hours"].max()

                if x_min > 0 and x_max > 0 and x_min < x_max:
                    x_line = np.logspace(np.log10(x_min), np.log10(x_max), 100)
                    y_line = p_polyfit(np.log10(x_line))
                    ax.plot(x_line, y_line, 'k--', linewidth=1, alpha=0.7, label=trend_label)
                else:
                    ax.plot([], [], 'k--', linewidth=1, alpha=0.7, label=trend_label)

            except Exception as e:
                logging.warning(f"Error fitting trend line or calculating correlation: {e}")
                ax.plot([], [], 'k--', linewidth=1, alpha=0.7, label="Trend (error)")
        else:
            logging.warning("Not enough valid data points to calculate trend line for confidence plot.")
            ax.plot([], [], 'k--', linewidth=1, alpha=0.7, label="Trend (insufficient data)")
    else:
        ax.plot([], [], 'k--', linewidth=1, alpha=0.7, label="Trend (insufficient data)")

    ax.set_xscale("log")
    ax.set_xlabel("Whisper Training Hours (log scale)")
    ax.set_ylabel("Average Confidence of Chosen Token")
    
    y_min_data = df["avg_confidence"].min() if not df["avg_confidence"].empty else np.nan
    y_max_data = df["avg_confidence"].max() if not df["avg_confidence"].empty else np.nan

    y_min_plot = max(0, y_min_data - 0.05) if pd.notna(y_min_data) else 0
    y_max_plot = min(1, y_max_data + 0.05) if pd.notna(y_max_data) else 1
    
    if pd.notna(y_min_plot) and pd.notna(y_max_plot) and y_min_plot < y_max_plot:
        ax.set_ylim(y_min_plot, y_max_plot)
    else:
        ax.set_ylim(0, 1)
    
    ax.grid(True, linestyle='--', alpha=0.5)
    ax.legend()
    
    fig.tight_layout(pad=0.5)
    
    try:
        fig.savefig(OUTPUT_DIR / "avg_confidence_vs_hours.png")
        fig.savefig(OUTPUT_DIR / "avg_confidence_vs_hours.pdf")
        fig.savefig(OUTPUT_DIR / "avg_confidence_vs_hours.svg", format="svg")
        logging.info(f"Saved confidence plot to {OUTPUT_DIR} (PNG, PDF, SVG)")
    except Exception as e:
        logging.error(f"Failed to save confidence plot: {e}")
    
    plt.close(fig)


# ────────────────────────────────────────────────────────────
# Main execution
# ────────────────────────────────────────────────────────────
def main():
    """Main execution function."""
    logging.info("Starting average confidence analysis")
    logging.info(f"Data directory: {DATA_DIR}")
    logging.info(f"Training hours file: {TRAINING_HOURS_CSV}")
    logging.info(f"Output directory: {OUTPUT_DIR}")
    
    OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

    results_df = collect_all_results()
    
    if results_df.empty:
        logging.error("No usable results found from data collection. Exiting.")
        return
    
    confidence_data_present = results_df["avg_confidence"].notna().any()
    if not confidence_data_present:
        logging.warning("No valid average confidence data was found.")
        return
        
    logging.info("Creating confidence plot...")
    plot_confidence_vs_hours(results_df)
    
    logging.info("Confidence analysis complete!")


if __name__ == "__main__":
    main()


2025-05-20 00:52:14,380 [INFO] Starting average confidence analysis
2025-05-20 00:52:14,381 [INFO] Data directory: results_beam_600s
2025-05-20 00:52:14,382 [INFO] Training hours file: whisper_training_hours.csv
2025-05-20 00:52:14,382 [INFO] Output directory: analysis_results_beam
2025-05-20 00:52:14,390 [INFO] Loaded training hours for 90 languages
2025-05-20 00:52:14,391 [INFO] Found 28 potential language CSV files.
2025-05-20 00:52:16,278 [INFO] Saved metrics to analysis_results_beam/language_confidence_metrics.csv
2025-05-20 00:52:16,279 [INFO] Creating confidence plot...
2025-05-20 00:52:16,299 [INFO] Correlation (log(training_hours) vs avg_confidence): r=0.314, p-value=1.778e-01


INFO: Correlation (log(training_hours) vs avg_confidence): r=0.314, p-value=1.778e-01


2025-05-20 00:52:17,114 [INFO] Saved confidence plot to analysis_results_beam (PNG, PDF, SVG)
2025-05-20 00:52:17,115 [INFO] Confidence analysis complete!
