In [1]:
# polyglot_accuracy_and_heatmap.py  (with --label & filename-regex support)
# -*- coding: utf-8 -*-

import os, re, json, glob, argparse
from datetime import datetime
import numpy as np
import pandas as pd
import librosa, soundfile as sf
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import StandardScaler
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
import matplotlib.pyplot as plt

DEFAULT_AUDIO_DIR = r"C:\Users\sagni\Downloads\Poly Glot AI\archive\audio"
DEFAULT_TEXT_DIR  = r"C:\Users\sagni\Downloads\Poly Glot AI\archive\text"
DEFAULT_CSV_PATH  = r"C:\Users\sagni\Downloads\Poly Glot AI\archive\British English Speech Recognition.csv"
DEFAULT_OUT_DIR   = r"C:\Users\sagni\Downloads\Poly Glot AI"

POSSIBLE_LABELS = [
    "label","Label","target","Target","accent","Accent",
    "speaker_id","Speaker","class","Class","language","Language"
]
POSSIBLE_TEXT_COLS = ["transcript","Transcript","text","Text","utterance","sentence","phrase"]

# ----------------- Helpers -----------------
def list_basenames(root, exts):
    out = {}
    for pat in exts:
        for path in glob.glob(os.path.join(root, pat)):
            out[os.path.splitext(os.path.basename(path))[0]] = path
    return out

def read_csv_any(path):
    try:
        return pd.read_csv(path)
    except Exception:
        return pd.read_csv(path, encoding="latin-1")

def robust_load_wav(path, sr=16000):
    try:
        y, r = sf.read(path, always_2d=False)
        if isinstance(y, np.ndarray):
            if y.ndim > 1:
                y = np.mean(y, axis=1)
            if r != sr:
                y = librosa.resample(y.astype(float), orig_sr=r, target_sr=sr); r = sr
        return y.astype(float), r
    except Exception:
        try:
            y, r = librosa.load(path, sr=sr, mono=True); return y, r
        except Exception:
            return None, None

# ----------------- Pickle-safe transformers -----------------
class Squeeze1D(BaseEstimator, TransformerMixin):
    def fit(self, X, y=None): return self
    def transform(self, X): return np.asarray(X, dtype=object).ravel()

class TextStatExtractor(BaseEstimator, TransformerMixin):
    def fit(self, X, y=None):
        self.feature_names_ = np.array(["char_len","word_count","avg_word_len"]); return self
    def transform(self, X):
        X = np.asarray(X, dtype=object).ravel().tolist()
        rows = []
        for t in X:
            t = t if isinstance(t, str) else ""
            words = re.findall(r"\w+", t, flags=re.UNICODE)
            rows.append([len(t), len(words), (sum(len(w) for w in words)/len(words) if words else 0.0)])
        return np.array(rows, dtype=float)

class AudioFeatureExtractor(BaseEstimator, TransformerMixin):
    def __init__(self, sr=16000, n_mfcc=13): self.sr=sr; self.n_mfcc=n_mfcc
    def fit(self, X, y=None): return self
    def transform(self, X):
        X = np.asarray(X, dtype=object).ravel().tolist(); feats=[]
        for p in X: feats.append(self._feat_one(p))
        return np.array(feats)
    def _feat_one(self, path):
        try:
            y, sr = robust_load_wav(path, sr=self.sr)
            if y is None or len(y)==0: raise RuntimeError
            dur = len(y)/float(sr)
            rms = librosa.feature.rms(y=y).flatten()
            zcr = librosa.feature.zero_crossing_rate(y).flatten()
            sc  = librosa.feature.spectral_centroid(y=y, sr=sr).flatten()
            mf  = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=self.n_mfcc)
            v = [
                dur,
                float(np.mean(rms)) if rms.size else 0.0,
                float(np.std(rms))  if rms.size else 0.0,
                float(np.mean(zcr)) if zcr.size else 0.0,
                float(np.mean(sc))  if sc.size else 0.0,
                float(np.std(sc))   if sc.size else 0.0,
            ]
            if mf.size:
                v += list(np.mean(mf, axis=1)); v += list(np.std(mf, axis=1))
            else:
                v += [0.0]*self.n_mfcc + [0.0]*self.n_mfcc
            return v
        except Exception:
            return [0.0]*(6 + self.n_mfcc*2)

# ----------------- Label discovery -----------------
def normalize_cols(df):
    return {c: re.sub(r"\s+", "", str(c)).lower() for c in df.columns}

def find_label_column(df, user_label=None):
    norm = normalize_cols(df)  # map original->normalized
    inv  = {v:k for k,v in norm.items()}
    if user_label:
        key = re.sub(r"\s+", "", user_label).lower()
        if key in inv: return inv[key]
    for name in POSSIBLE_LABELS:
        key = re.sub(r"\s+", "", name).lower()
        if key in inv: return inv[key]
    # Auto-guess: low-cardinality categorical/object columns (2..min(50, n/2) uniques)
    candidates = []
    for col in df.columns:
        if col in ("basename","audio_path","text_value"): continue
        uniques = df[col].dropna().nunique()
        if 2 <= uniques <= max(2, min(50, df.shape[0]//2)):
            candidates.append((col, uniques, str(df[col].dtype)))
    # prefer object/string columns
    candidates.sort(key=lambda x: (x[2]!="object", x[1]))  # object first, then fewer uniques
    return candidates[0][0] if candidates else None

def label_from_filename(df, regex, group=1):
    pat = re.compile(regex)
    def _ex(basename):
        m = pat.search(basename)
        return m.group(group) if m else None
    return df["basename"].astype(str).apply(_ex)

# ----------------- Main -----------------
def main():
    ap = argparse.ArgumentParser(description="PolyGlotAI accuracy + heatmap (with --label / filename regex)")
    ap.add_argument("--audio", default=DEFAULT_AUDIO_DIR)
    ap.add_argument("--text",  default=DEFAULT_TEXT_DIR)
    ap.add_argument("--csv",   default=DEFAULT_CSV_PATH)
    ap.add_argument("--out",   default=DEFAULT_OUT_DIR)
    ap.add_argument("--label", default=None, help="Exact column name to use as label (case/space-insensitive)")
    ap.add_argument("--label_from_filename", default=None,
                    help=r"Regex with one capturing group to extract label from basename, e.g. '.*_accent-([a-z]+)_.*' or '^(.*?)-\d+$'")
    ap.add_argument("--label_group", type=int, default=1, help="Capturing group index for --label_from_filename (default 1)")
    args, _ = ap.parse_known_args()
    os.makedirs(args.out, exist_ok=True)

    audio_map = list_basenames(args.audio, ("*.wav","*.WAV"))
    text_map  = list_basenames(args.text,  ("*.txt","*.TXT"))

    df_csv = read_csv_any(args.csv) if os.path.isfile(args.csv) else pd.DataFrame()
    if not df_csv.empty:
        fname_col = next((c for c in ["filename","file","path","wav","audio","fname","id","ID","Name","name"] if c in df_csv.columns), None)
        if fname_col:
            df_csv["basename"] = df_csv[fname_col].astype(str).apply(lambda p: os.path.splitext(os.path.basename(p))[0])
        else:
            df_csv["basename"] = df_csv.index.astype(str)
        text_col = next((c for c in POSSIBLE_TEXT_COLS if c in df_csv.columns), None)
    else:
        union = sorted(set(audio_map.keys()) | set(text_map.keys()))
        df_csv = pd.DataFrame({"basename": union}); text_col = None

    df_csv["audio_path"] = df_csv["basename"].map(audio_map)
    df_csv["text_value"] = df_csv[text_col].fillna("").astype(str) if text_col else ""
    # fill from text folder if empty
    def fill_txt(row):
        if isinstance(row["text_value"], str) and row["text_value"]:
            return row["text_value"]
        p = text_map.get(row["basename"])
        if p and os.path.isfile(p):
            try: return open(p, "r", encoding="utf-8", errors="ignore").read().strip()
            except Exception: return ""
        return ""
    df_csv["text_value"] = df_csv.apply(fill_txt, axis=1)

    # keep rows that have audio
    df = df_csv[df_csv["audio_path"].notna()].reset_index(drop=True)

    # ---- choose / build label ----
    label_col = None
    if args.label_from_filename:
        df["__label_from_name"] = label_from_filename(df, args.label_from_filename, args.label_group)
        if df["__label_from_name"].notna().sum() >= 2:
            label_col = "__label_from_name"

    if label_col is None:
        label_col = find_label_column(df, user_label=args.label)

    if not label_col:
        print("[WARN] No label column found or derived.")
        print("Tip 1: Pass --label \"<your_column>\" (exact name from CSV headers)")
        print(r"Tip 2: Derive from filename: --label_from_filename '.*_accent-([a-z]+)_.*'  (adjust regex)")
        print("CSV columns detected:", list(df.columns))
        return

    y_series = df[label_col]
    # coerce to integers if object dtype
    if y_series.dtype == object:
        cats = {v: i for i, v in enumerate(sorted(y_series.dropna().unique()))}
        y = y_series.map(cats)
    else:
        y = y_series.copy()

    # ---- features ----
    X_frame = pd.DataFrame({
        "audio_path": df["audio_path"].astype(str).values,
        "text_value": df["text_value"].astype(str).values
    })

    audio_pipe = Pipeline([
        ("sel", Squeeze1D()),
        ("afe", AudioFeatureExtractor(sr=16000, n_mfcc=13)),
        ("imputer", SimpleImputer(strategy="median")),
        ("scaler", StandardScaler())
    ])
    text_stats = TextStatExtractor()
    text_tfidf = TfidfVectorizer(analyzer="char", ngram_range=(2,4), max_features=800)
    text_features = ColumnTransformer(
        transformers=[
            ("tfidf", Pipeline([("sel", Squeeze1D()), ("tfidf", text_tfidf)]), "text_value"),
            ("tstats", Pipeline([("sel", Squeeze1D()), ("tstat", text_stats)]), "text_value"),
        ],
        remainder="drop"
    )
    full_feat = ColumnTransformer(
        transformers=[
            ("audio", audio_pipe, ["audio_path"]),
            ("text",  text_features, ["text_value"])
        ],
        remainder="drop"
    )

    # ---- split & train ----
    if pd.Series(y).nunique() < 2:
        print(f"[WARN] Label column '{label_col}' has fewer than 2 classes. Need at least 2 to compute accuracy/heatmap.")
        return

    X_train, X_test, y_train, y_test = train_test_split(
        X_frame, y, test_size=0.25, random_state=42, stratify=y
    )

    models = {
        "LogisticRegression": LogisticRegression(max_iter=2000),
        "RandomForest": RandomForestClassifier(n_estimators=400, random_state=42, n_jobs=-1)
    }

    metrics = {}
    best_name, best_score, best_pipe = None, -np.inf, None

    for name, est in models.items():
        pipe = Pipeline([("features", full_feat), ("model", est)])
        pipe.fit(X_train, y_train)
        y_pred = pipe.predict(X_test)
        acc = accuracy_score(y_test, y_pred)

        auc = None
        try:
            if len(np.unique(y_test)) == 2 and hasattr(pipe, "predict_proba"):
                probs = pipe.predict_proba(X_test)[:, 1]
                if not np.allclose(np.min(probs), np.max(probs)):
                    auc = roc_auc_score(y_test, probs)
        except Exception:
            auc = None

        chosen = auc if auc is not None else acc
        metrics[name] = {"accuracy": float(acc), "roc_auc": (None if auc is None else float(auc)), "chosen_score": float(chosen)}
        if chosen > best_score:
            best_score, best_name, best_pipe = chosen, name, pipe

    # ---- accuracy bar chart ----
    acc_path = os.path.join(args.out, "polyglot_accuracy.png")
    names = list(metrics.keys()); accs = [metrics[k]["accuracy"] for k in names]
    plt.figure(figsize=(7,5), dpi=140)
    bars = plt.bar(names, accs)
    plt.title("Model Accuracy (Test Set)")
    plt.ylabel("Accuracy"); plt.ylim(0, 1.0)
    for b, v in zip(bars, accs):
        plt.text(b.get_x()+b.get_width()/2, v+0.01, f"{v:.3f}", ha="center", va="bottom", fontsize=9)
    plt.tight_layout(); plt.savefig(acc_path); plt.show()

    # ---- confusion matrix heatmap ----
    y_pred_best = best_pipe.predict(X_test)
    labels_sorted = np.unique(np.concatenate([y_test, y_pred_best]))
    cm = confusion_matrix(y_test, y_pred_best, labels=labels_sorted)

    heat_path = os.path.join(args.out, "polyglot_confusion_heatmap.png")
    plt.figure(figsize=(6,5), dpi=140)
    im = plt.imshow(cm, interpolation="nearest")
    plt.title(f"Confusion Matrix (Best: {best_name})")
    plt.colorbar(im, fraction=0.046, pad=0.04)
    ticks = np.arange(len(labels_sorted))
    plt.xticks(ticks, labels_sorted); plt.yticks(ticks, labels_sorted)
    plt.xlabel("Predicted"); plt.ylabel("Actual")
    thresh = cm.max()/2.0 if cm.size else 0.5
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], "d"),
                     ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout(); plt.savefig(heat_path); plt.show()

    # ---- save metrics ----
    metrics_path = os.path.join(args.out, "polyglot_metrics.json")
    out_metrics = {
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "audio_dir": args.audio, "text_dir": args.text, "csv_path": args.csv,
        "label_column_used": label_col,
        "models": metrics, "best_model": best_name,
        "labels": [int(x) if isinstance(x, (np.integer,)) else (x.item() if isinstance(x, np.generic) else x) for x in labels_sorted]
    }
    with open(metrics_path, "w", encoding="utf-8") as f:
        json.dump(out_metrics, f, indent=2)

    print("\n=== Saved ===")
    print("Accuracy chart ->", acc_path)
    print("Heatmap        ->", heat_path)
    print("Metrics JSON   ->", metrics_path)
    print("Label column   ->", label_col)

if __name__ == "__main__":
    main()


[WARN] No label column found or derived.
Tip 1: Pass --label "<your_column>" (exact name from CSV headers)
Tip 2: Derive from filename: --label_from_filename '.*_accent-([a-z]+)_.*'  (adjust regex)
CSV columns detected: ['ID', 'Audio', 'Text', 'basename', 'audio_path', 'text_value']
