In [5]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Whisper Analysis: Alternate-Candidate Diversity (TTR) vs. Training Hours
=======================================================================
Generates a plot of alternate-candidate diversity (Type–Token Ratio of non-top-1
tokens within the top K_D predictions) against Whisper training hours.
"""

import os
import logging
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ────────────────────────────────────────────────────────────
# Hard-coded configuration
# ────────────────────────────────────────────────────────────
DATA_DIR           = Path("results_beam_600s")          # CSVs with beam output
TRAINING_HOURS_CSV = Path("whisper_training_hours.csv") # Training-hours sheet
OUTPUT_DIR         = Path("analysis_results_beam")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

K_D = 50  # Number of top candidates to consider for diversity

# Resource-group lists
HIGH_RESOURCE   = {"de", "es", "fr", "pt", "tr"}
MEDIUM_RESOURCE = {"it", "nl", "sv", "ca", "fi", "id", "vi", "ro",
                   "no", "cs", "hu"}
LOW_RESOURCE    = {"cy", "lt", "lv", "az", "et", "eu", "sq", "sw", "mt", "uz"}

EXCLUDED_LANGUAGES = {"uz", "mt", "sw", "sq", "yo", "da", "cs", "vi"}

COLOURS = {"High": "steelblue", "Medium": "seagreen",
           "Low": "crimson",  "Other": "grey"}
MARKERS = {"High": "o", "Medium": "^", "Low": "s", "Other": "x"}

# ────────────────────────────────────────────────────────────
# Matplotlib style block – identical to entropy script
# ────────────────────────────────────────────────────────────
plt.rcParams.update({
    "figure.dpi":      100,
    "savefig.dpi":     300,
    "font.size":       8,
    "axes.titlesize": 10,
    "axes.labelsize":  8,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "legend.fontsize": 7,
    "figure.titlesize": 12
})

# ────────────────────────────────────────────────────────────
# Logging
# ────────────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")

# ────────────────────────────────────────────────────────────
# Utility helpers
# ────────────────────────────────────────────────────────────
def get_resource_group(code: str | None) -> str:
    if code in HIGH_RESOURCE:
        return "High"
    if code in MEDIUM_RESOURCE:
        return "Medium"
    if code in LOW_RESOURCE:
        return "Low"
    return "Other"

def load_training_hours() -> dict[str, float]:
    try:
        df = (pd.read_csv(TRAINING_HOURS_CSV)
                .query("Whisper_Training_Hours.notna()")
                .query("Whisper_Training_Hours != 'Unknown'"))
        hours = {r["Whisper_Code"]: float(r["Whisper_Training_Hours"])
                 for _, r in df.iterrows()}
        logging.info("Loaded training hours for %d languages", len(hours))
        return hours
    except FileNotFoundError:
        logging.error("Training-hours file not found: %s", TRAINING_HOURS_CSV)
        return {}
    except Exception as e:
        logging.error("Failed to load training hours: %s", e)
        return {}

def safe_read_csv(p: Path) -> pd.DataFrame:
    try:
        return pd.read_csv(p)
    except Exception as e:
        logging.warning("Could not read %s: %s", p, e)
        return pd.DataFrame()

# ────────────────────────────────────────────────────────────
# Core analysis
# ────────────────────────────────────────────────────────────
def calculate_diversity(df: pd.DataFrame, k_d_val: int) -> float:
    """Type–Token Ratio over alternate tokens top2..topK."""
    if k_d_val <= 1:
        return np.nan

    pred_rows = df[df.get("step", 0) >= 0].copy() if "step" in df else df.copy()
    if pred_rows.empty:
        return np.nan

    alt_cols = [f"top{k}_id" for k in range(2, k_d_val + 1)
                if f"top{k}_id" in pred_rows.columns]
    if not alt_cols:
        return np.nan

    tokens = pred_rows[alt_cols].to_numpy().ravel()
    tokens = tokens[~pd.isna(tokens)]
    if len(tokens) == 0:
        return 0.0
    return len(set(tokens)) / len(tokens)

def process_language_file(file_path: Path,
                          training_hours: dict[str, float],
                          k_d_val: int) -> dict | None:
    try:
        fname = file_path.name
        lang_name = fname.split("_")[1].capitalize() if fname.split("_")[0].isdigit() \
                    else fname.split(".")[0].capitalize()

        df = safe_read_csv(file_path)
        if df.empty:
            logging.warning("Empty data for %s", lang_name)
            return None

        code = df["whisper_lang"].dropna().unique()[0] if "whisper_lang" in df.columns else None
        diversity = calculate_diversity(df, k_d_val)
        num_utts  = int((df["step"] == -1).sum()) if "step" in df.columns else len(df)

        return {
            "language":       lang_name,
            "code":           code,
            "training_hours": training_hours.get(code, np.nan) if code else np.nan,
            "num_utterances": num_utts,
            "diversity":      diversity,
            "resource_group": get_resource_group(code)
        }
    except Exception as e:
        logging.error("Failed processing %s: %s", file_path, e)
        return None

def collect_all_results(k_d_val: int) -> pd.DataFrame:
    training_hours = load_training_hours()
    files = sorted(DATA_DIR.glob("*_subtoken_beam.csv"))
    if not files:
        logging.error("No *_subtoken_beam.csv files in %s", DATA_DIR)
        return pd.DataFrame()

    rows = []
    for f in files:
        res = process_language_file(f, training_hours, k_d_val)
        if res:
            rows.append(res)

    df = pd.DataFrame(rows)
    if not df.empty:
        out_csv = OUTPUT_DIR / f"language_diversity_kD{k_d_val}_metrics.csv"
        df.to_csv(out_csv, index=False)
        logging.info("Saved metrics to %s", out_csv)
    return df

# ────────────────────────────────────────────────────────────
# Plotting
# ────────────────────────────────────────────────────────────
def plot_diversity_vs_hours(results: pd.DataFrame, k_d_val: int) -> None:
    df = (results
          .query("code not in @EXCLUDED_LANGUAGES")
          .dropna(subset=["diversity", "training_hours"])
          .copy())
    if df.empty:
        logging.warning("No valid data for diversity plot.")
        return

    fig, ax = plt.subplots(figsize=(4.5, 3.5))

    for grp in ["High", "Medium", "Low", "Other"]:
        g = df[df["resource_group"] == grp]
        if g.empty:
            continue
        ax.scatter(g["training_hours"], g["diversity"],
                   s=40, alpha=0.7,
                   marker=MARKERS[grp], color=COLOURS[grp],
                   label=f"{grp}-resource")
        for _, r in g.iterrows():
            ax.annotate(r["language"],
                        (r["training_hours"], r["diversity"]),
                        xytext=(3, 1), textcoords="offset points",
                        fontsize=6)

    # Optional log-trend line
    mask = (df["training_hours"] > 0) & np.isfinite(df["diversity"])
    if mask.sum() >= 2:
        x_log = np.log10(df.loc[mask, "training_hours"])
        y     = df.loc[mask, "diversity"]
        try:
            z = np.polyfit(x_log, y, 1)
            p = np.poly1d(z)
            r = np.corrcoef(x_log, y)[0, 1]
            x_line = np.logspace(np.log10(df["training_hours"].min()),
                                 np.log10(df["training_hours"].max()), 100)
            ax.plot(x_line, p(np.log10(x_line)), "k--",
                    linewidth=1, alpha=0.7,
                    label=f"Trend (r={r:.2f})")
        except Exception as e:
            logging.warning("Trend line failed: %s", e)

    ax.set_xscale("log")
    ax.set_xlabel("Whisper Training Hours (log scale)")
    ax.set_ylabel(f"Alternate Candidate Diversity")

    # dynamic y-range: always show min/max ±5 %
    y_min = max(0.0, df["diversity"].min() * 0.95)
    y_max = df["diversity"].max() * 1.05
    if y_max <= y_min:  # flat data fallback
        y_max = y_min + 1.0
    ax.set_ylim(y_min, y_max)

    ax.grid(True, linestyle="--", alpha=0.5)
    ax.legend()
    fig.tight_layout(pad=0.5)

    try:
        out_png = OUTPUT_DIR / f"avg_diversity_kD{k_d_val}_vs_hours.png"
        out_pdf = OUTPUT_DIR / f"avg_diversity_kD{k_d_val}_vs_hours.pdf"
        out_svg = OUTPUT_DIR / f"avg_diversity_kD{k_d_val}_vs_hours.svg"
        fig.savefig(out_png)
        fig.savefig(out_pdf)
        fig.savefig(out_svg, format="svg")
        logging.info("Saved diversity plot (PNG/PDF/SVG) to %s", OUTPUT_DIR)
    except Exception as e:
        logging.error("Plot save failed: %s", e)

    plt.close(fig)

# ────────────────────────────────────────────────────────────
# Main
# ────────────────────────────────────────────────────────────
def main() -> None:
    if K_D <= 1:
        logging.error("K_D must be > 1 (got %d)", K_D)
        return

    logging.info("Running diversity analysis (K_D=%d)", K_D)
    logging.info("Data dir: %s", DATA_DIR)
    logging.info("Training-hours file: %s", TRAINING_HOURS_CSV)
    logging.info("Output dir: %s", OUTPUT_DIR)

    res = collect_all_results(K_D)
    if res.empty:
        logging.error("No usable results; exiting.")
        return

    plot_diversity_vs_hours(res, K_D)
    logging.info("Done!")

if __name__ == "__main__":
    main()


INFO: Running diversity analysis (K_D=50)
INFO: Data dir: results_beam_600s
INFO: Training-hours file: whisper_training_hours.csv
INFO: Output dir: analysis_results_beam
INFO: Loaded training hours for 90 languages
INFO: Saved metrics to analysis_results_beam/language_diversity_kD50_metrics.csv
INFO: Saved diversity plot (PNG/PDF/SVG) to analysis_results_beam
INFO: Done!
