In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""        \

import glob, json, random, logging, gc, re
from pathlib import Path
from time import perf_counter
import inspect

import numpy as np
import soundfile as sf
import librosa
from tqdm import tqdm

from faster_whisper import WhisperModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

In [None]:
RAW_ROOT      = "data/VNEMOS"
PROC_WAV_DIR  = "wavs16k"
OUTPUT_DIR    = "output"

LOCAL_MODEL   = "vinai/PhoWhisper-base"
PUNC_MODEL    = "vinai/bartpho-word-base"

TEST_RUN      = True     # đặt False khi chạy thật
TEST_COUNT    = 5
VALID_RATIO   = 0.10
TEST_RATIO    = 0.10
RANDOM_SEED   = 42

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    datefmt="%H:%M:%S",
)
Path(PROC_WAV_DIR).mkdir(parents=True, exist_ok=True)
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [None]:
logging.info("Loading faster-whisper (PhoWhisper-base, CPU-INT8) …")
t0 = perf_counter()
cpu_threads  = max(1, os.cpu_count() // 2)   # chia nửa core để còn luồng hệ thống

asr_model = WhisperModel(
    "pho_base_ct2",        
    device="cpu",
    compute_type="int8",
    cpu_threads=os.cpu_count() // 2,
)
logging.info(
    f"faster-whisper ready | threads={cpu_threads} | "
    f"load_time={perf_counter() - t0:.1f}s"
)

In [None]:
def _build_transcribe_kwargs(model):
    sig = inspect.signature(model.transcribe)
    params = sig.parameters
    kw = dict(beam_size=5, vad_filter=True)

    # tên độ dài đoạn
    if "chunk_length" in params:
        kw["chunk_length"] = 20
    elif "chunk_length_s" in params:
        kw["chunk_length_s"] = 20

    # tên độ chồng
    if "chunk_overlap" in params:
        kw["chunk_overlap"] = 5
    elif "chunk_overlap_s" in params:
        kw["chunk_overlap_s"] = 5
    elif "stride_length" in params:
        kw["stride_length"] = 5
    elif "stride" in params:           # bản 1.1.x
        kw["stride"] = (5, 5)

    return kw

TRANS_KWARGS = _build_transcribe_kwargs(asr_model)
logging.info(f"Transcribe kwargs: {TRANS_KWARGS}")

# ──────────────────────── PUNCTUATION (PyTorch) ─────────────────────────────
logging.info("Loading punctuation model (BARTpho-word) …")
t0 = perf_counter()
punc_tok   = AutoTokenizer.from_pretrained(PUNC_MODEL, use_fast=True)
punc_model = AutoModelForSeq2SeqLM.from_pretrained(PUNC_MODEL).eval().to("cpu")
logging.info(f"Punctuation ready in {perf_counter()-t0:.1f}s")

@torch.inference_mode()
def restore_punctuation(text: str) -> str:
    enc = punc_tok(text, return_tensors="pt")
    enc.pop("token_type_ids", None)         
    out = punc_model.generate(
        **enc,
        max_length=enc["input_ids"].shape[1] + 16,
        do_sample=False,
    )
    return punc_tok.decode(out[0], skip_special_tokens=True).capitalize()

# ───────────────────────────── ASR WRAPPER ──────────────────────────────────
def transcribe(path: str) -> str:
    segments, _ = asr_model.transcribe(path, **TRANS_KWARGS)
    return " ".join(seg.text.strip() for seg in segments)

# ───────────────────────────── LOAD & RESAMPLE ──────────────────────────────
logging.info("Scanning raw WAV files …")
items = []
for wav in glob.glob(os.path.join(RAW_ROOT, "**", "*.wav"), recursive=True):
    rel = Path(wav).relative_to(RAW_ROOT)
    emotion, speaker = rel.parts[0].lower(), rel.parts[1]
    utt_id = re.sub(r"\s+", "_", rel.stem)
    items.append({"wav_original": wav, "wav_id": utt_id,
                  "speaker_id": speaker, "emotion": emotion})

if TEST_RUN:
    items = items[:TEST_COUNT]
    logging.info(f"TEST_RUN=True → {len(items)} files only")

logging.info("Resampling to 16 kHz mono …")
for itm in tqdm(items, desc="Resample"):
    y, sr = sf.read(itm["wav_original"], dtype="float32", always_2d=True)
    y = y.mean(axis=1)
    if sr != 16000:
        y = librosa.resample(y, orig_sr=sr, target_sr=16000)
    if (p := np.max(np.abs(y)) + 1e-9) > 1.0:
        y = y / p
    dst = Path(PROC_WAV_DIR)/f"{itm['wav_id']}.wav"
    sf.write(dst, y, 16000, subtype="PCM_16")
    itm.update({"wav_path": str(dst), "duration": len(y)/16000})

# ────────────────────────── STT + PUNCTUATION ───────────────────────────────
logging.info("Running ASR + punctuation …")
metadata = []
for itm in tqdm(items, desc="STT"):
    raw  = transcribe(itm["wav_path"])
    text = restore_punctuation(raw)
    metadata.append({
        "utterance_id": itm["wav_id"],
        "speaker_id" : itm["speaker_id"],
        "wav_path"   : itm["wav_path"],
        "start"      : 0.0,
        "end"        : itm["duration"],
        "transcript" : text,
        "emotion"    : itm["emotion"],
    })
    gc.collect()


In [None]:
logging.info("Saving JSONL & splitting …")
Path(OUTPUT_DIR).mkdir(exist_ok=True)
with open(Path(OUTPUT_DIR)/"vnemos_all.jsonl","w",encoding="utf-8") as f:
    for row in metadata:
        f.write(json.dumps(row, ensure_ascii=False)+"\n")

spk = list({m["speaker_id"] for m in metadata}); random.shuffle(spk)
val_n = max(1,int(len(spk)*VALID_RATIO)); test_n=max(1,int(len(spk)*TEST_RATIO))
val_spk=set(spk[:val_n]); test_spk=set(spk[val_n:val_n+test_n])
train_spk=set(spk[val_n+test_n:])

for name,grp in {"train":train_spk,"valid":val_spk,"test":test_spk}.items():
    out=Path(OUTPUT_DIR)/f"{name}.jsonl"; cnt=0
    with open(out,"w",encoding="utf-8") as f:
        for rec in metadata:
            if rec["speaker_id"] in grp:
                f.write(json.dumps(rec, ensure_ascii=False)+"\n"); cnt+=1
    logging.info(f"{name:5s}: {cnt:4d} utt | {len(grp)} spk")

logging.info("Done ✔  All outputs in ./output/")