In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Whisper Analysis: Average Token Entropy vs. Training Hours
===========================================================

Generates a plot of average token entropy (uncertainty)
against Whisper training hours for various languages.
Entropy is calculated over the top K_H probabilities.
"""

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import logging

# ────────────────────────────────────────────────────────────
# Configuration
# ────────────────────────────────────────────────────────────
DATA_DIR = Path("results_beam_600s")
TRAINING_HOURS_CSV = Path("whisper_training_hours.csv")
OUTPUT_DIR = Path("analysis_results_beam")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

K_H = 50  # Number of top candidates for entropy calculation

# Language resource groups
HIGH_RESOURCE = {"de", "es", "fr", "pt", "tr"}
MEDIUM_RESOURCE = {"it", "nl", "sv", "ca", "fi", "id", "vi", "ro",
                   "no", "cs", "hu"}
LOW_RESOURCE = {"cy", "lt", "lv", "az", "et", "eu", "sq", "sw", "mt", "uz"}

EXCLUDED_LANGUAGES = {"uz", "mt", "sw", "sq", "yo", "da", "vi", "cs"}

# Visual styling
COLOURS = {"High": "steelblue", "Medium": "seagreen",
           "Low": "crimson", "Other": "grey"}
MARKERS = {"High": "o", "Medium": "^", "Low": "s", "Other": "x"}

plt.rcParams.update({
    "figure.dpi": 100,
    "savefig.dpi": 300,
    "font.size": 8,
    "axes.titlesize": 10,
    "axes.labelsize": 8,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "legend.fontsize": 7,
    "figure.titlesize": 12
})

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")


# ────────────────────────────────────────────────────────────
# Utility functions
# ────────────────────────────────────────────────────────────
def get_resource_group(code):
    """Determine resource group based on language code."""
    if code in HIGH_RESOURCE:
        return "High"
    if code in MEDIUM_RESOURCE:
        return "Medium"
    if code in LOW_RESOURCE:
        return "Low"
    return "Other"


def load_training_hours():
    """Load training hours from CSV file."""
    try:
        df = pd.read_csv(TRAINING_HOURS_CSV)
        df = df[df["Whisper_Training_Hours"].notna()]
        df = df[df["Whisper_Training_Hours"] != "Unknown"]
        
        hours_dict = {}
        for _, row in df.iterrows():
            try:
                code = row["Whisper_Code"]
                hours = float(row["Whisper_Training_Hours"])
                hours_dict[code] = hours
            except (ValueError, KeyError):
                logging.warning(f"Skipping row due to missing code or invalid hours: {row}")
                continue
                
        logging.info(f"Loaded training hours for {len(hours_dict)} languages")
        return hours_dict
    except FileNotFoundError:
        logging.error(f"Training hours file not found: {TRAINING_HOURS_CSV}")
        return {}
    except Exception as e:
        logging.error(f"Failed to load training hours: {e}")
        return {}


def safe_read_csv(path):
    """Safely read a CSV file, handling potential errors."""
    try:
        return pd.read_csv(path)
    except FileNotFoundError:
        logging.warning(f"File not found: {path}")
        return pd.DataFrame()
    except Exception as e:
        logging.warning(f"Failed to read {path}: {e}")
        return pd.DataFrame()


# ────────────────────────────────────────────────────────────
# Core analysis functions
# ────────────────────────────────────────────────────────────
def calculate_avg_entropy(df, k_h_val):
    """Calculate the average token entropy H_s based on top k_h_val probabilities."""
    if 'step' in df.columns:
        pred_rows = df[df['step'] >= 0].copy()
    else:
        logging.warning("No 'step' column found. Using all rows for entropy calculation.")
        pred_rows = df.copy()
    
    if pred_rows.empty:
        logging.debug("No prediction rows found for entropy calculation")
        return np.nan
        
    prob_cols = [f'top{k}_prob' for k in range(1, k_h_val + 1)]
    
    if 'top1_prob' not in pred_rows.columns:
        logging.warning(f"Missing 'top1_prob' column. Cannot calculate entropy.")
        return np.nan

    entropies_for_file = []
    for _, row in pred_rows.iterrows():
        probs_kH = []
        for col_name in prob_cols:
            prob_val = row.get(col_name)
            if pd.notna(prob_val):
                probs_kH.append(prob_val)
            else:
                break

        if not probs_kH:
            continue

        probs_kH = [p for p in probs_kH if p > 0]
        if not probs_kH:
            entropies_for_file.append(0.0)
            continue

        sum_probs_kH = sum(probs_kH)
        
        if sum_probs_kH == 0:
            entropies_for_file.append(0.0)
            continue
            
        normalized_probs = [p / sum_probs_kH for p in probs_kH]
        
        entropy_s = 0.0
        for p_prime in normalized_probs:
            if p_prime > 0:
                entropy_s -= p_prime * np.log2(p_prime)
        
        entropies_for_file.append(entropy_s)
        
    if not entropies_for_file:
        return np.nan
        
    return np.mean(entropies_for_file)


def process_language_file(file_path, training_hours, k_h_val):
    """Process a single language file and extract metrics including entropy."""
    try:
        filename = os.path.basename(file_path)
        parts = filename.split('_')
        if len(parts) > 1 and parts[0].isdigit():
            lang_name = parts[1].capitalize()
        else:
            lang_name = parts[0].split('.')[0].capitalize()
            
        logging.debug(f"Processing {lang_name} from {filename}...")
        
        df = safe_read_csv(file_path)
        
        if df.empty:
            logging.warning(f"Empty or invalid data for {lang_name} from {filename}")
            return None
            
        code = None
        if 'whisper_lang' in df.columns:
            codes = df['whisper_lang'].dropna().unique()
            if len(codes) >= 1:
                code = codes[0]
                if len(codes) > 1:
                    logging.warning(f"Multiple whisper_lang codes in {filename}: {codes}. Using first: {code}")

        avg_entropy = calculate_avg_entropy(df, k_h_val)
        
        num_utterances = 0
        if not df.empty:
            if "step" in df.columns:
                num_utterances = int((df["step"] == -1).sum())
            else:
                num_utterances = len(df)

        return {
            "language": lang_name,
            "code": code,
            "training_hours": training_hours.get(code, np.nan) if code else np.nan,
            "num_utterances": num_utterances,
            "avg_entropy": avg_entropy,
            "resource_group": get_resource_group(code)
        }
    except Exception as e:
        logging.error(f"Error processing {file_path}: {e}")
        return None


def collect_all_results(k_h_val):
    """Collect results (including entropy) from all language files."""
    training_hours = load_training_hours()
    
    csv_files = sorted(DATA_DIR.glob("*_subtoken_beam.csv"))

    if not csv_files:
        logging.error(f"No relevant CSV files found in {DATA_DIR} (e.g., *_subtoken_beam.csv).")
        try:
            all_files = list(DATA_DIR.glob("*.csv"))
            if all_files:
                logging.info(f"Found these CSV files: {', '.join(f.name for f in all_files[:10])}...")
            else:
                logging.info(f"Directory {DATA_DIR} contains no CSV files.")
        except Exception as e:
            logging.error(f"Error listing directory contents: {e}")
        return pd.DataFrame()
        
    logging.info(f"Found {len(csv_files)} potential language CSV files.")
    
    results = []
    for file_path in csv_files:
        result = process_language_file(file_path, training_hours, k_h_val)
        if result:
            results.append(result)
    
    if not results:
        logging.error("No valid results collected from CSV files.")
        return pd.DataFrame()
    
    df = pd.DataFrame(results)
    
    try:
        output_file = OUTPUT_DIR / f"language_entropy_kH{k_h_val}_metrics.csv"
        df.to_csv(output_file, index=False)
        logging.info(f"Saved metrics to {output_file}")
    except Exception as e:
        logging.error(f"Failed to save metrics CSV: {e}")
    
    return df


# ────────────────────────────────────────────────────────────
# Plotting functions
# ────────────────────────────────────────────────────────────
def plot_entropy_vs_hours(results_df, k_h_val):
    """Plot average token entropy vs training hours."""
    df = results_df[~results_df["code"].isin(EXCLUDED_LANGUAGES)].copy()
    df = df.dropna(subset=["avg_entropy", "training_hours"])
    
    if df.empty:
        logging.warning("No valid data for average entropy plot after filtering and dropping NaNs.")
        fig, ax = plt.subplots(figsize=(4.5, 3.5))
        ax.text(0.5, 0.5, "No data to plot for entropy vs. hours.",
                horizontalalignment='center', verticalalignment='center',
                transform=ax.transAxes, fontsize=10)
        ax.set_xlabel("Whisper Training Hours (log scale)")
        ax.set_ylabel(f"Average Token Entropy (bits, $K_H={k_h_val}$)")
        try:
            output_plot_file_png = OUTPUT_DIR / f"avg_entropy_kH{k_h_val}_vs_hours_nodata.png"
            fig.savefig(output_plot_file_png)
            logging.info(f"Saved empty plot placeholder to {output_plot_file_png}")
        except Exception as e:
            logging.error(f"Failed to save empty entropy plot: {e}")
        plt.close(fig)
        return
        
    fig, ax = plt.subplots(figsize=(4.5, 3.5))
    
    for group_name in ["High", "Medium", "Low", "Other"]:
        group_df = df[df["resource_group"] == group_name]
        if group_df.empty:
            continue
        
        marker = MARKERS.get(group_name, "x")
        color = COLOURS.get(group_name, "grey")

        ax.scatter(
            group_df["training_hours"],
            group_df["avg_entropy"],
            s=40,
            alpha=0.7,
            marker=marker,
            color=color,
            label=f"{group_name}-resource"
        )
        
        for _, row in group_df.iterrows():
            ax.annotate(
                row["language"],
                (row["training_hours"], row["avg_entropy"]),
                xytext=(3, 1),
                textcoords="offset points",
                fontsize=6,
            )
    
    # Calculate trend line
    if len(df) >= 2:
        valid_mask = (df["training_hours"].astype(float) > 0) & \
                     np.isfinite(np.log10(df["training_hours"].astype(float))) & \
                     np.isfinite(df["avg_entropy"])
        
        if sum(valid_mask) >= 2:
            x_log = np.log10(df.loc[valid_mask, "training_hours"].astype(float))
            y_vals = df.loc[valid_mask, "avg_entropy"]
            
            try:
                z = np.polyfit(x_log, y_vals, 1)
                p = np.poly1d(z)
                r_val = np.corrcoef(x_log, y_vals)[0, 1]
                
                x_min_valid = df.loc[valid_mask, "training_hours"].min()
                x_max_valid = df.loc[valid_mask, "training_hours"].max()

                if x_min_valid > 0 and x_max_valid > 0 and x_min_valid < x_max_valid:
                    x_line = np.logspace(np.log10(x_min_valid), np.log10(x_max_valid), 100)
                    y_line = p(np.log10(x_line))
                    ax.plot(x_line, y_line, 'k--', linewidth=1, alpha=0.7, label=f"Trend (r={r_val:.2f})")
                else:
                    logging.warning("Cannot plot trend line due to insufficient range.")
            except Exception as e:
                logging.warning(f"Error fitting trend line: {e}")
        else:
            logging.warning("Not enough valid data points to calculate trend line for entropy plot.")

    ax.set_xscale("log")
    ax.set_xlabel("Whisper Training Hours (log scale)")
    ax.set_ylabel(f"Average Token Entropy")
    
    # Dynamic Y-axis limits based on data range
    valid_entropy_data = df["avg_entropy"].dropna()
    
    if not valid_entropy_data.empty:
        data_min_y = valid_entropy_data.min()
        data_max_y = valid_entropy_data.max()

        if data_min_y == data_max_y:
            padding = 0.1 * abs(data_min_y) if data_min_y != 0 else 0.1
            y_lower_bound = max(0, data_min_y - padding)
            y_upper_bound = data_max_y + padding
        else:
            padding = (data_max_y - data_min_y) * 0.05
            y_lower_bound = max(0, data_min_y - padding)
            y_upper_bound = data_max_y + padding
        
        if y_upper_bound <= y_lower_bound:
            y_upper_bound = y_lower_bound + 0.1

        ax.set_ylim(y_lower_bound, y_upper_bound)
    else:
        y_min_default = 0.0
        y_max_default = np.log2(k_h_val) if k_h_val > 1 else 1.0
        
        if y_max_default <= y_min_default:
            y_max_default = y_min_default + 1.0
        ax.set_ylim(y_min_default, y_max_default)
    
    ax.grid(True, linestyle='--', alpha=0.5)
    ax.legend()
    
    fig.tight_layout(pad=0.5)
    
    try:
        fig.savefig(OUTPUT_DIR / f"avg_entropy_kH{k_h_val}_vs_hours.png")
        fig.savefig(OUTPUT_DIR / f"avg_entropy_kH{k_h_val}_vs_hours.pdf")
        fig.savefig(OUTPUT_DIR / f"avg_entropy_kH{k_h_val}_vs_hours.svg", format="svg")
        logging.info(f"Saved entropy plot to {OUTPUT_DIR} (PNG, PDF, SVG)")
    except Exception as e:
        logging.error(f"Failed to save entropy plot: {e}")
    
    plt.close(fig)


# ────────────────────────────────────────────────────────────
# Main execution
# ────────────────────────────────────────────────────────────
def main():
    """Main execution function."""
    logging.info(f"Starting average token entropy analysis (K_H={K_H})")
    logging.info(f"Data directory: {DATA_DIR}")
    logging.info(f"Training hours file: {TRAINING_HOURS_CSV}")
    logging.info(f"Output directory: {OUTPUT_DIR}")
    
    OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
    
    results_df = collect_all_results(K_H)
    
    if results_df.empty:
        logging.error("No usable results found from data collection. Exiting.")
        return
    
    if results_df["avg_entropy"].notna().sum() == 0:
        logging.warning("No valid average entropy data was found.")
        plot_entropy_vs_hours(results_df, K_H)
        return
        
    logging.info("Creating entropy plot...")
    plot_entropy_vs_hours(results_df, K_H)
    
    logging.info("Entropy analysis complete!")


if __name__ == "__main__":
    main()


2025-05-20 01:05:51,812 [INFO] Starting average token entropy analysis (K_H=50)
2025-05-20 01:05:51,812 [INFO] Data directory: results_beam_600s
2025-05-20 01:05:51,813 [INFO] Training hours file: whisper_training_hours.csv
2025-05-20 01:05:51,813 [INFO] Output directory: analysis_results_beam
2025-05-20 01:05:51,821 [INFO] Loaded training hours for 90 languages
2025-05-20 01:05:51,822 [INFO] Found 28 potential language CSV files.
2025-05-20 01:06:12,461 [INFO] Saved metrics to analysis_results_beam/language_entropy_kH50_metrics.csv
2025-05-20 01:06:12,463 [INFO] Creating entropy plot...
2025-05-20 01:06:13,448 [INFO] Saved entropy plot to analysis_results_beam (PNG, PDF, SVG)
2025-05-20 01:06:13,450 [INFO] Entropy analysis complete!
