In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Whisper Analysis: WER vs Training Hours
=======================================

Generates plots of word error rate (WER) against Whisper training hours
for various languages.
"""

from __future__ import annotations

import argparse
import logging
import re
import sys
from pathlib import Path
from typing import Dict, List, Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from jiwer import wer
from scipy.stats import pearsonr

# ────────────────────────────────────────────────────────────
# Configuration
# ────────────────────────────────────────────────────────────
DATA_DIR = Path("results_beam_600s")
TRAINING_HOURS_CSV = Path("whisper_training_hours.csv")
OUTPUT_DIR = Path("analysis_results_beam")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

HIGH_RESOURCE = {"de", "es", "fr", "pt", "tr"}
MEDIUM_RESOURCE = {"it", "nl", "sv", "ca", "fi", "id", "vi", "ro",
                   "no", "cs", "hu", "yo"}
LOW_RESOURCE = {"cy", "lt", "lv", "az", "et", "eu"}

EXCLUDED_LANGUAGES = {"uz", "mt", "sw", "sq", "yo", "da", "vi", "cs"}

COLOUR = {"High": "steelblue", "Medium": "seagreen",
          "Low": "crimson", "Other": "grey"}
MARKER = {"High": "o", "Medium": "^", "Low": "s", "Other": "x"}

plt.rcParams.update({
    "figure.dpi": 100,
    "savefig.dpi": 300,
    "font.size": 8,
    "axes.titlesize": 10,
    "axes.labelsize": 8,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,
    "legend.fontsize": 7,
    "figure.titlesize": 12
})

# ────────────────────────────────────────────────────────────
# Utility helpers
# ────────────────────────────────────────────────────────────
_token_re = re.compile(r"[^\w\s']|_")
_bracket_re = re.compile(r"\[.*?\]")


def normalise(text: str) -> str:
    """Lower-case, strip punctuation & bracketed comments, normalise spaces."""
    if not isinstance(text, str):
        return ""
    text = _bracket_re.sub("", text.lower())
    text = _token_re.sub(" ", text)
    return " ".join(text.split())


def get_resource_group(code: str | None) -> str:
    if code in HIGH_RESOURCE:
        return "High"
    if code in MEDIUM_RESOURCE:
        return "Medium"
    if code in LOW_RESOURCE:
        return "Low"
    return "Other"


def load_hours(csv: Path) -> Dict[str, float]:
    """Return mapping Whisper_Code → training_hours (float)."""
    try:
        df = pd.read_csv(csv)
    except FileNotFoundError:
        logging.error(f"Training hours CSV file not found: {csv}")
        return {}
    except Exception as e:
        logging.error(f"Failed to read training hours CSV {csv}: {e}")
        return {}

    df = df[df["Whisper_Training_Hours"].notna()]
    df = df[df["Whisper_Training_Hours"] != "Unknown"]
    try:
        return (df.set_index("Whisper_Code")["Whisper_Training_Hours"]
                  .astype(float)
                  .to_dict())
    except KeyError:
        logging.error(f"'Whisper_Code' or 'Whisper_Training_Hours' column not found in {csv}")
        return {}
    except ValueError:
        logging.error(f"Could not convert 'Whisper_Training_Hours' to float in {csv}")
        return {}


def safe_read_csv(path: Path, **kw) -> pd.DataFrame:
    try:
        return pd.read_csv(path, **kw)
    except FileNotFoundError:
        logging.warning("File not found: %s", path.name)
        return pd.DataFrame()
    except Exception as e:
        logging.warning("Failed to read %s (%s)", path.name, e)
        return pd.DataFrame()


# ────────────────────────────────────────────────────────────
# Core analysis
# ────────────────────────────────────────────────────────────
def calc_wer(df: pd.DataFrame) -> float:
    if "step" in df.columns:
        df_filtered = df[df["step"] == -1]
    else:
        logging.warning("'step' column not found in DataFrame. Attempting to calculate WER on all rows.")
        df_filtered = df

    if df_filtered.empty:
        logging.debug("DataFrame is empty after filtering for step == -1 or 'step' column missing.")
        return np.nan

    if {"ground_truth", "full_transcription"}.difference(df_filtered.columns):
        logging.warning("Missing 'ground_truth' or 'full_transcription' columns.")
        return np.nan

    refs = df_filtered["ground_truth"].map(normalise).tolist()
    hyps = df_filtered["full_transcription"].map(normalise).tolist()

    pairs = [(r, h) for r, h in zip(refs, hyps) if isinstance(r, str) and r.strip() and isinstance(h, str) and h.strip()]

    if not pairs:
        logging.debug("No valid (reference, hypothesis) pairs found after normalization and filtering.")
        return np.nan

    r_clean, h_clean = zip(*pairs)
    return wer(list(r_clean), list(h_clean))


def process_language(csv_path: Path, hours: Dict[str, float]) -> Dict:
    parts = csv_path.stem.split("_")
    lang_name = (parts[1] if len(parts) >= 2 else csv_path.stem).capitalize()

    df = safe_read_csv(csv_path)

    code = None
    if not df.empty and "whisper_lang" in df.columns:
        unique_codes = df["whisper_lang"].dropna().unique()
        if len(unique_codes) == 1:
            code = unique_codes[0]
        elif len(unique_codes) > 1:
            logging.warning(f"Multiple whisper_lang codes found in {csv_path.name}: {unique_codes}. Using the first one: {unique_codes[0]}")
            code = unique_codes[0]
        else:
            logging.warning(f"No valid whisper_lang codes found in {csv_path.name} after dropping NaNs.")

    num_utterances = 0
    if not df.empty and "step" in df.columns:
        num_utterances = int((df["step"] == -1).sum())
    elif not df.empty:
        num_utterances = len(df)

    out = {
        "language": lang_name,
        "code": code,
        "wer": calc_wer(df.copy()),
        "training_hours": hours.get(code, np.nan) if code else np.nan,
        "num_utterances": num_utterances,
    }
    out["resource_group"] = get_resource_group(out["code"])
    return out


def collect_results(data_dir: Path, hours_csv: Path) -> pd.DataFrame:
    hours = load_hours(hours_csv)
    if not hours:
        logging.warning("Training hours data could not be loaded. 'training_hours' will be NaN.")

    csvs = sorted(data_dir.glob("*_subtoken_beam.csv"))
    if not csvs:
        logging.error("No CSV files matching '*_subtoken_beam.csv' found in %s", data_dir)
        return pd.DataFrame()

    rows = [process_language(p, hours) for p in csvs]
    if not rows:
        logging.error("No data processed from CSV files.")
        return pd.DataFrame()

    df = pd.DataFrame(rows)
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    try:
        df.to_csv(OUTPUT_DIR / "language_wer_metrics.csv", index=False)
        logging.info(f"Language WER metrics saved to {OUTPUT_DIR / 'language_wer_metrics.csv'}")
    except Exception as e:
        logging.error(f"Failed to save language_wer_metrics.csv: {e}")
    return df


# ────────────────────────────────────────────────────────────
# Plotting helpers
# ────────────────────────────────────────────────────────────
def plot_scatter(df: pd.DataFrame):
    df_filtered = df[~df["code"].isin(EXCLUDED_LANGUAGES)]
    if len(df) != len(df_filtered):
        logging.info(f"Excluded {len(df) - len(df_filtered)} languages from plotting: {', '.join(EXCLUDED_LANGUAGES)}")

    df_plot = df_filtered.dropna(subset=["wer", "training_hours"])
    if df_plot.empty:
        logging.warning("No data to plot (scatter) after dropping NaNs in 'wer' or 'training_hours'.")
        return

    fig, ax = plt.subplots(figsize=(4.5, 3.5))

    for grp_name, gdf in df_plot.groupby("resource_group"):
        marker = MARKER.get(grp_name, "x")
        color = COLOUR.get(grp_name, "grey")
        ax.scatter(gdf["training_hours"], gdf["wer"],
                   s=40, alpha=.7, marker=marker,
                   label=f"{grp_name}-resource", color=color)
        for _, r in gdf.iterrows():
            ax.annotate(r.language, (r.training_hours, r.wer),
                        xytext=(3, 1), textcoords="offset points", fontsize=6)

    if len(df_plot) >= 2:
        valid_hours_idx = np.isfinite(np.log10(df_plot["training_hours"]))
        valid_hours = df_plot["training_hours"][valid_hours_idx]
        valid_wer = df_plot["wer"][valid_hours_idx]

        finite_wer_idx = np.isfinite(valid_wer)
        valid_hours = valid_hours[finite_wer_idx]
        valid_wer = valid_wer[finite_wer_idx]

        if len(valid_hours) >= 2:
            xs = np.log10(valid_hours)
            z = np.polyfit(xs, valid_wer, 1)
            p_poly = np.poly1d(z)

            correlation_coefficient, p_value_calc = pearsonr(xs, valid_wer)

            print(f"--- Scatter Plot Correlation ---")
            print(f"Pearson correlation coefficient (r): {correlation_coefficient:.3f}")
            print(f"P-value: {p_value_calc:.3g}")
            logging.info(f"Scatter plot correlation: r={correlation_coefficient:.3f}, p-value={p_value_calc:.3g}")

            trend_label = f"Trend (r={correlation_coefficient:.2f})"

            min_log_x = xs.min()
            max_log_x = xs.max()
            if min_log_x < max_log_x:
                x_line = np.logspace(min_log_x, max_log_x, 100)
                ax.plot(x_line, p_poly(np.log10(x_line)), "k--", linewidth=1,
                        label=trend_label)
            else:
                logging.warning("Could not plot trend line due to insufficient range in training hours after log.")
        else:
            logging.warning("Not enough valid data points to calculate trend line, r, and p-value.")

    ax.set_xscale("log")
    ax.set_xlabel("Whisper training hours (log scale)")
    ax.set_ylabel("WER")
    ax.set_ylim(0, min(2, df_plot.wer.max()*1.1) if not df_plot.empty and pd.notna(df_plot.wer.max()) else 2)
    ax.legend()
    ax.grid(True, linestyle='--', alpha=0.5)
    fig.tight_layout(pad=0.5)

    fig.savefig(OUTPUT_DIR / "wer_vs_hours.svg", format="svg")
    fig.savefig(OUTPUT_DIR / "wer_vs_hours.png")
    fig.savefig(OUTPUT_DIR / "wer_vs_hours.pdf")

    logging.info(f"Scatter plot 'wer_vs_hours' saved to {OUTPUT_DIR} (including SVG)")
    plt.close(fig)


def plot_ranked_bar(df: pd.DataFrame):
    df_filtered = df[~df["code"].isin(EXCLUDED_LANGUAGES)]
    df_plot = df_filtered.dropna(subset=["wer", "language"]).sort_values("wer")
    if df_plot.empty:
        logging.warning("No data to plot (ranked bar) after dropping NaNs in 'wer' or 'language'.")
        return

    fig_height = max(2.5, 0.18 * len(df_plot))
    fig_width = 4.0
    fig, ax = plt.subplots(figsize=(fig_width, fig_height))

    colors = [COLOUR.get(get_resource_group(c), "grey") for c in df_plot.code]
    bars = ax.barh(df_plot.language, df_plot.wer, color=colors, alpha=.8)

    for bar, val in zip(bars, df_plot.wer):
        if pd.notna(val):
            ax.text(val + .005, bar.get_y() + bar.get_height()/2,
                    f"{val:.3f}", va="center", fontsize=6)

    ax.set_xlabel("Word error rate")
    ax.set_xlim(0, min(2, df_plot.wer.max()*1.1) if not df_plot.empty and pd.notna(df_plot.wer.max()) else 2)
    ax.grid(axis="x", ls="--", alpha=.3)
    ax.tick_params(axis='y', labelsize=7)
    fig.tight_layout(pad=0.5)
    try:
        fig.savefig(OUTPUT_DIR / "languages_ranked_by_wer.png")
        fig.savefig(OUTPUT_DIR / "languages_ranked_by_wer.pdf")
        logging.info(f"Ranked bar plot 'languages_ranked_by_wer' saved to {OUTPUT_DIR}")
    except Exception as e:
        logging.error(f"Failed to save ranked bar plot: {e}")
    plt.close(fig)


def plot_box(df: pd.DataFrame):
    df_filtered = df[~df["code"].isin(EXCLUDED_LANGUAGES)]
    df_plot = df_filtered.dropna(subset=["wer", "resource_group"])
    if df_plot.empty:
        logging.warning("No data to plot (box plot) after dropping NaNs in 'wer' or 'resource_group'.")
        return

    groups = ["High", "Medium", "Low", "Other"]
    data, labels, colours_for_plot = [], [], []

    for g in groups:
        vals = df_plot[df_plot.resource_group == g].wer.dropna()
        if not vals.empty:
            data.append(vals)
            labels.append(f"{g} (n={len(vals)})")
            colours_for_plot.append(COLOUR.get(g, "grey"))

    if not data:
        logging.warning("No data available for any resource group for box plot.")
        return

    fig, ax = plt.subplots(figsize=(4, 3))
    bp = ax.boxplot(data, patch_artist=True, labels=labels)

    for patch, c in zip(bp["boxes"], colours_for_plot):
        patch.set_facecolor(c)
        patch.set_alpha(.6)

    for i, vals in enumerate(data):
        ax.scatter(np.random.normal(i + 1, .03, len(vals)), vals,
                   s=15, alpha=.7, color="black")

    ax.set_ylabel("Word error rate")
    ax.set_ylim(0, min(2, df_plot.wer.max()*1.1) if not df_plot.empty and pd.notna(df_plot.wer.max()) else 2)
    fig.tight_layout(pad=0.5)
    try:
        fig.savefig(OUTPUT_DIR / "wer_by_resource_group.png")
        fig.savefig(OUTPUT_DIR / "wer_by_resource_group.pdf")
        logging.info(f"Box plot 'wer_by_resource_group' saved to {OUTPUT_DIR}")
    except Exception as e:
        logging.error(f"Failed to save box plot: {e}")
    plt.close(fig)


# ────────────────────────────────────────────────────────────
# CLI entry
# ────────────────────────────────────────────────────────────
def parse_args(custom_argv: Optional[List[str]]) -> argparse.Namespace:
    p = argparse.ArgumentParser(description="Plot Whisper WER vs training hours")
    p.add_argument("--data_dir", type=Path, default=DATA_DIR,
                   help="Directory with *_subtoken_beam.csv files")
    p.add_argument("--hours_csv", type=Path, default=TRAINING_HOURS_CSV,
                   help="CSV mapping Whisper_Code → training hours")
    p.add_argument("-q", "--quiet", action="store_true", help="Suppress info logs")
    return p.parse_args(custom_argv if custom_argv is not None else sys.argv[1:])


def main(custom_argv: Optional[List[str]] = None):
    args_to_parse = custom_argv
    if custom_argv is None and any('ipykernel_launcher.py' in arg for arg in sys.argv):
        args_to_parse = []
    elif custom_argv is None:
        args_to_parse = sys.argv[1:]

    args = parse_args(args_to_parse)

    logging.basicConfig(level=logging.WARNING if args.quiet else logging.INFO,
                        format="%(levelname)s: %(message)s")

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    logging.info(f"Using data directory: {args.data_dir}")
    logging.info(f"Using training hours CSV: {args.hours_csv}")
    logging.info(f"Output directory: {OUTPUT_DIR}")

    df = collect_results(args.data_dir, args.hours_csv)

    if df.empty:
        logging.error("No usable results – exiting.")
        return

    logging.info("Collected results for %d languages.", len(df[df['language'].notna()]))

    plot_scatter(df)

    logging.info(f"WER vs Hours plot written to {OUTPUT_DIR.resolve()}")


if __name__ == "__main__":
    main()


INFO: Using data directory: results_beam_600s
INFO: Using training hours CSV: whisper_training_hours.csv
INFO: Output directory: analysis_results_beam
INFO: Language WER metrics saved to analysis_results_beam/language_wer_metrics.csv
INFO: Collected results for 28 languages.
INFO: Excluded 7 languages from plotting: mt, uz, sw, cs, sq, yo, da
INFO: Scatter plot correlation: r=-0.598, p-value=0.00421


--- Scatter Plot Correlation ---
Pearson correlation coefficient (r): -0.598
P-value: 0.00421


INFO: Scatter plot 'wer_vs_hours' saved to analysis_results_beam (including SVG for text editing)
INFO: WER vs Hours plot written to /home/siyu/analysis_results_beam
