In [2]:
# %% [markdown]
# ### Dataset Audit for MER (VNEMOS JSONL)
# Kiểm tra: cân bằng nhãn/speaker, độ dài audio & text, clipping/silence, căn chỉnh start/end,
# rò rỉ giữa splits, và các chỉ số IR / Entropy / Gini / Cramér's V.

# %%
import os, json, math, re
from pathlib import Path
from collections import Counter

import numpy as np
import pandas as pd
import soundfile as sf
import torch
from IPython.display import display

from transformers import AutoTokenizer
from configs.base import Config
from loading.dataloader import VNEMOSDataset, _clean_text

# ---------- helper metrics ----------
def imbalance_ratio(counts: dict) -> float:
    if not counts:
        return float("nan")
    vals = np.array(list(counts.values()), dtype=float)
    if vals.min() <= 0:
        vals = vals[vals > 0]
    if vals.size == 0:
        return float("nan")
    return float(vals.max() / max(1.0, vals.min()))

def norm_entropy(counts: dict) -> float:
    vals = np.array(list(counts.values()), dtype=float)
    if vals.sum() == 0 or len(vals) == 0:
        return float("nan")
    p = vals / vals.sum()
    H = -(p * np.log(p + 1e-12)).sum()
    Hmax = math.log(len(vals)) if len(vals) > 0 else 1.0
    return float(H / max(1e-12, Hmax))

def gini_impurity(counts: dict) -> float:
    vals = np.array(list(counts.values()), dtype=float)
    if vals.sum() == 0:
        return float("nan")
    p = vals / vals.sum()
    return float(1.0 - (p**2).sum())

def percentiles(arr, qs=(0, 25, 50, 75, 90, 95, 99, 100)):
    if len(arr) == 0:
        return {}
    return {f"p{q}": float(np.percentile(arr, q)) for q in qs}

def detect_clipping(x, thr=0.999):
    if x.size == 0:
        return 0.0
    return float((np.abs(x) >= thr).mean())

def silence_ratio(x, thr=1e-4):
    if x.size == 0:
        return 0.0
    return float((np.abs(x) <= thr).mean())

def cramers_v(conf_mat: np.ndarray) -> float:
    if conf_mat.size == 0:
        return 0.0
    R, C = conf_mat.shape
    total = conf_mat.sum()
    if R < 2 or C < 2 or total == 0:
        return 0.0
    row_sums = conf_mat.sum(axis=1)
    col_sums = conf_mat.sum(axis=0)
    expected = np.outer(row_sums, col_sums) / max(1, total)
    with np.errstate(divide='ignore', invalid='ignore'):
        term = (conf_mat - expected)**2 / np.where(expected==0, 1, expected)
    chi2 = np.nansum(term)
    k = min(R, C)
    return float(np.sqrt(chi2 / (total * (k - 1) + 1e-12)))

# ---------- config ----------
cfg = Config(
    data_root="../output",
    jsonl_dir="",
    sample_rate=16000,
    max_audio_sec=None,       # kiểm đúng “không cắt cứng”
    text_max_length=64,
)

# Tự động phát hiện các split có mặt
splits = []
for sp in ["train", "valid", "test"]:
    p = (Path(cfg.data_root) / (cfg.jsonl_dir or "") / f"{sp}.jsonl").resolve()
    if p.exists():
        splits.append(sp)
print("Found splits:", splits)

# ---------- load datasets ----------
sets = {sp: VNEMOSDataset(cfg, sp) for sp in splits}
tokenizer = AutoTokenizer.from_pretrained(getattr(cfg, "text_encoder_ckpt", "vinai/phobert-base"), use_fast=True)

# ---------- collect rows (metadata) ----------
rows = []
for sp, ds in sets.items():
    for it in ds.items:
        rows.append({
            "split": sp,
            "utterance_id": it["utterance_id"],
            "speaker_id": it["speaker_id"],
            "wav_path": it["wav_path"],
            "start": float(it.get("start", 0.0) or 0.0),
            "end": float(it.get("end", 0.0) or 0.0),
            "emotion": it["emotion"],
            "transcript": _clean_text(it.get("transcript", "")),
        })
df = pd.DataFrame(rows)
print("Total rows:", len(df))

# ---------- existence (resolve) ----------
missing = []
abs_paths = []
for sp, ds in sets.items():
    for it in ds.items:
        try:
            wav_abs = str(ds._resolve_wav(it["wav_path"]))
            abs_paths.append((sp, wav_abs))
        except Exception as e:
            missing.append((sp, it["wav_path"], str(e)))
miss_df = pd.DataFrame(missing, columns=["split","wav_path","error"])
print("Missing/Unreadable files:", len(miss_df))

# ---------- duplicates (robust) ----------
# Absolute path duplicates across all splits
path_counts = Counter([p for _, p in abs_paths])
dup_list = [{"path": p, "count": c} for p, c in path_counts.items() if c > 1]
dup_df = pd.DataFrame(dup_list, columns=["path","count"])
if not dup_df.empty:
    dup_df = dup_df.sort_values("count", ascending=False).reset_index(drop=True)
print("Duplicate absolute wav paths (any split):", 0 if dup_df.empty else len(dup_df))

# Duplicate utterance_id across splits
uid_counts = df.groupby(["utterance_id"]).size().sort_values(ascending=False)
uid_dup_df = uid_counts[uid_counts > 1].rename("count").reset_index()
print("Duplicate utterance_id across splits:", 0 if uid_dup_df.empty else len(uid_dup_df))

# Duplicate JSONL wav_path across splits (relative path duplicates)
rel_counts = df.groupby(["wav_path"]).size().sort_values(ascending=False)
rel_dup_df = rel_counts[rel_counts > 1].rename("count").reset_index()
print("Duplicate JSONL wav_path across splits:", 0 if rel_dup_df.empty else len(rel_dup_df))

# ---------- leakage: speaker overlap ----------
spk_by_split = {sp: set(df[df["split"]==sp]["speaker_id"]) for sp in splits}
overlap = {}
for i, a in enumerate(splits):
    for b in splits[i+1:]:
        overlap[(a,b)] = len(spk_by_split[a].intersection(spk_by_split[b]))

# ---------- label balance ----------
label_counts = {sp: Counter(df[df["split"]==sp]["emotion"]) for sp in splits}
balance_table = []
for sp in splits:
    counts = dict(label_counts[sp])
    row = {
        "split": sp,
        "IR_max/min": imbalance_ratio(counts),
        "H_norm": norm_entropy(counts),
        "Gini": gini_impurity(counts),
        **{f"cnt_{k}": v for k, v in sorted(counts.items())},
        "total": int(sum(counts.values())),
    }
    balance_table.append(row)
balance_df = pd.DataFrame(balance_table).fillna(0)

# ---------- token length & truncation ----------
tok_stats = []
for sp, ds in sets.items():
    texts = df[df["split"]==sp]["transcript"].tolist()
    lens = []
    over = 0
    for txt in texts:
        ids = tokenizer(txt, add_prefix_space=True)["input_ids"]
        L = len(ids); lens.append(L)
        if L > cfg.text_max_length:
            over += 1
    arr = np.array(lens, dtype=int)
    sts = percentiles(arr, qs=(0,25,50,75,90,95,99,100))
    tok_stats.append({
        "split": sp,
        **{f"tok_{k}": v for k, v in sts.items()},
        "tok_over_maxlen": int(over),
        "tok_over_rate": float(over / max(1, len(arr))) if len(arr) else 0.0,
    })
tok_df = pd.DataFrame(tok_stats)

# ---------- audio quality & alignment (read audio) ----------
def expected_len_after_resample(wav_abs, start, end, target_sr):
    info = sf.info(wav_abs)
    orig_sr = info.samplerate
    orig_frames = info.frames
    if end and end > 0:
        dur_s = max(0.0, end - start)
    else:
        dur_s = orig_frames / float(orig_sr)
    return int(round(dur_s * target_sr))

aq_rows = []
for sp, ds in sets.items():
    for it in ds.items:
        try:
            wav_abs = str(ds._resolve_wav(it["wav_path"]))
            data, sr = sf.read(wav_abs, always_2d=False)
            if data.ndim == 2: 
                data = data.mean(axis=1)
            # crop theo start/end
            start = float(it.get("start", 0.0) or 0.0)
            end   = float(it.get("end", 0.0) or 0.0)
            if end and end > 0:
                s = int(max(0.0, start) * sr)
                e = min(int(end * sr), len(data))
                data = data[s:e]
            # metrics
            clip = detect_clipping(data)
            sil  = silence_ratio(data)
            exp_len = expected_len_after_resample(wav_abs, start, end, cfg.sample_rate)
            act_len = int(round(len(data) * (cfg.sample_rate / sr)))  # quy đổi về 16k
            dur_sec = len(data) / float(sr) if sr > 0 else 0.0
            aq_rows.append({
                "split": sp,
                "wav_abs": wav_abs,
                "sr": sr,
                "len_samples": int(len(data)),
                "dur_sec": float(dur_sec),
                "clip_rate": clip,
                "silence_rate": sil,
                "len_expected_16k": int(exp_len),
                "len_actual_16k": int(act_len),
                "len_diff": int(act_len - exp_len),
            })
        except Exception as e:
            aq_rows.append({"split": sp, "wav_abs": it["wav_path"], "error": str(e)})

aq_df = pd.DataFrame(aq_rows)

# ---------- speaker × emotion coverage & association ----------
spk_em_rows = []
for sp in splits:
    sub = df[df["split"]==sp]
    if len(sub) == 0:
        continue
    grp = sub.groupby(["speaker_id","emotion"]).size().reset_index(name="cnt")
    for spk, emo, cnt in grp.values:
        spk_em_rows.append({"split": sp, "speaker_id": spk, "emotion": emo, "count": int(cnt)})
spk_em_df = pd.DataFrame(spk_em_rows)

pivot = df.pivot_table(index="speaker_id", columns="emotion", aggfunc="size", fill_value=0)
cramersV = cramers_v(pivot.values) if pivot.size > 0 else 0.0

# ---------- SUMMARY ----------
print("\n=== EXISTENCE / DUPLICATION / LEAKAGE ===")
print("Missing files:", len(miss_df))
print("Duplicate absolute wav paths:", 0 if dup_df.empty else len(dup_df))
print("Duplicate utterance_id across splits:", 0 if uid_dup_df.empty else len(uid_dup_df))
print("Duplicate JSONL wav_path across splits:", 0 if rel_dup_df.empty else len(rel_dup_df))
print("Speaker overlap across splits:", overlap)

print("\n=== LABEL BALANCE (per split) ===")
display(balance_df)

print("\n=== TOKEN LENGTH & TRUNCATION (per split) ===")
display(tok_df)

print("\n=== AUDIO QUALITY & ALIGNMENT (aggregates) ===")
if len(aq_df) > 0 and "error" not in aq_df.columns:
    agg = aq_df.groupby("split").agg({
        "sr": ["min","max"],
        "dur_sec": ["count","mean","median","max"],
        "clip_rate": ["mean","max"],
        "silence_rate": ["mean","max"],
        "len_diff": ["mean","min","max"],
    })
    display(agg)
else:
    print("No audio inspected or errors present.")
    if "error" in aq_df.columns:
        print(aq_df[["split","wav_abs","error"]].head())

print("\n=== SPEAKER × EMOTION COVERAGE ===")
print("Rows:", len(spk_em_df), "| speakers:", df['speaker_id'].nunique(), "| emotions:", df['emotion'].nunique())
print(f"Cramér's V (speaker↔emotion): {cramersV:.3f}  (≈0: độc lập, →1: phụ thuộc mạnh)")

# ---------- QUICK RULES / FLAGS ----------
flags = []

# A. Cân bằng nhãn (train)
train_counts = label_counts.get("train", Counter())
if train_counts:
    IR = imbalance_ratio(train_counts)
    Hn = norm_entropy(train_counts)
    Gi = gini_impurity(train_counts)
    if IR > 3.0: flags.append(f"[Label balance] Imbalance Ratio = {IR:.2f} > 3 (cân nhắc bổ sung/oversample/class-weight).")
    if Hn < 0.90: flags.append(f"[Label balance] Entropy chuẩn hoá = {Hn:.2f} < 0.90 (phân bố lệch).")
    if Gi > 0.40: flags.append(f"[Label balance] Gini = {Gi:.2f} > 0.40 (phân bố lệch).")

# B. Token truncation
for sp in splits:
    row = tok_df[tok_df["split"]==sp]
    if len(row):
        rate = float(row["tok_over_rate"].values[0])
        if rate > 0.10:
            flags.append(f"[Text truncation] {sp}: {rate*100:.1f}% bị cắt > max_length={cfg.text_max_length}. Cân nhắc tăng text_max_length.")

# C. Audio quality
if len(aq_df) and "error" not in aq_df.columns:
    bysp = aq_df.groupby("split").agg(clip_mean=("clip_rate","mean"), clip_max=("clip_rate","max"),
                                      sil_mean=("silence_rate","mean"), sr_min=("sr","min"), sr_max=("sr","max"),
                                      len_diff_min=("len_diff","min"), len_diff_max=("len_diff","max"))
    for sp, r in bysp.iterrows():
        if r["clip_max"] > 0.01:
            flags.append(f"[Audio clipping] {sp}: clip_max={r['clip_max']:.3f} > 1%. Có thể normalize/gain staging.")
        if int(r["sr_min"]) != 16000 or int(r["sr_max"]) != 16000:
            flags.append(f"[Sample rate] {sp}: phát hiện sr khác 16k (min={int(r['sr_min'])}, max={int(r['sr_max'])}). Chuẩn hoá về 16k.")
        if abs(int(r["len_diff_min"])) > 5 or abs(int(r["len_diff_max"])) > 5:
            flags.append(f"[Alignment] {sp}: len_diff nên trong [-5,+5] mẫu @16k; thấy min={int(r['len_diff_min'])}, max={int(r['len_diff_max'])}.")

# D. Leakage
for (a,b), n in overlap.items():
    if n > 0:
        flags.append(f"[Leakage] Speaker trùng giữa {a} và {b}: {n} speaker. Nên tách disjoint.")

# E. Speaker bias
if cramersV > 0.50:
    flags.append(f"[Bias] Cramér's V(speaker↔emotion)={cramersV:.2f} cao. Nguy cơ model học speaker thay vì cảm xúc.")

print("\n=== FLAGS / RECOMMENDATIONS ===")
if flags:
    for f in flags:
        print("-", f)
else:
    print("Không phát hiện vấn đề đáng lo theo các ngưỡng mặc định. ✅")


Found splits: ['train', 'valid', 'test']
Total rows: 250


Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_

Missing/Unreadable files: 0
Duplicate absolute wav paths (any split): 0
Duplicate utterance_id across splits: 0
Duplicate JSONL wav_path across splits: 0


Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_


=== EXISTENCE / DUPLICATION / LEAKAGE ===
Missing files: 0
Duplicate absolute wav paths: 0
Duplicate utterance_id across splits: 0
Duplicate JSONL wav_path across splits: 0
Speaker overlap across splits: {('train', 'valid'): 0, ('train', 'test'): 0, ('valid', 'test'): 0}

=== LABEL BALANCE (per split) ===


Unnamed: 0,split,IR_max/min,H_norm,Gini,cnt_angry,cnt_fear,cnt_happiness,cnt_neutral,cnt_sadness,total
0,train,1.162162,0.99899,0.79935,43,38,42,37,40,200
1,valid,1.5,0.994996,0.7968,4,5,5,5,6,25
2,test,2.666667,0.946371,0.7648,3,7,3,8,4,25



=== TOKEN LENGTH & TRUNCATION (per split) ===


Unnamed: 0,split,tok_p0,tok_p25,tok_p50,tok_p75,tok_p90,tok_p95,tok_p99,tok_p100,tok_over_maxlen,tok_over_rate
0,train,5.0,14.75,22.0,33.0,44.0,48.0,60.07,72.0,2,0.01
1,valid,5.0,16.0,22.0,35.0,40.0,46.4,63.96,69.0,1,0.04
2,test,5.0,15.0,21.0,25.0,39.6,40.8,62.28,69.0,1,0.04



=== AUDIO QUALITY & ALIGNMENT (aggregates) ===


Unnamed: 0_level_0,sr,sr,dur_sec,dur_sec,dur_sec,dur_sec,clip_rate,clip_rate,silence_rate,silence_rate,len_diff,len_diff,len_diff
Unnamed: 0_level_1,min,max,count,mean,median,max,mean,max,mean,max,mean,min,max
split,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
test,16000,16000,25,7.941258,7.035687,30.882562,2.049226e-06,2.8e-05,0.027693,0.188055,0.0,0,0
train,16000,16000,200,7.865092,7.158,22.662688,1.244188e-06,4.2e-05,0.027324,0.484088,0.0,0,0
valid,16000,16000,25,9.016475,8.661062,18.436688,6.040562e-07,1.5e-05,0.035206,0.230401,0.0,0,0



=== SPEAKER × EMOTION COVERAGE ===
Rows: 250 | speakers: 250 | emotions: 5
Cramér's V (speaker↔emotion): 1.000  (≈0: độc lập, →1: phụ thuộc mạnh)

=== FLAGS / RECOMMENDATIONS ===
- [Label balance] Gini = 0.80 > 0.40 (phân bố lệch).
- [Bias] Cramér's V(speaker↔emotion)=1.00 cao. Nguy cơ model học speaker thay vì cảm xúc.
