# ENV & IMPORT

In [5]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""       

import glob, json, random, logging, gc, re
from pathlib import Path
from time import perf_counter
import inspect

import numpy as np
import soundfile as sf
import librosa
from tqdm import tqdm

from faster_whisper import WhisperModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

# CONFIG & SEED

In [6]:
RAW_ROOT      = "data/VNEMOS"      
PROC_WAV_DIR  = "wavs16k"
OUTPUT_DIR    = "output"

LOCAL_MODEL   = "vinai/PhoWhisper-base"
PUNC_MODEL    = "vinai/bartpho-word-base"

VALID_RATIO   = 0.10                
TEST_RATIO    = 0.10               
RANDOM_SEED   = 42


logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    datefmt="%H:%M:%S",
)
Path(PROC_WAV_DIR).mkdir(parents=True, exist_ok=True)
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# LOAD ASR

In [7]:
logging.info("Loading faster-whisper (PhoWhisper-base, INT8-CPU)…")
t0 = perf_counter()
asr_model = WhisperModel(
    "pho_base_ct2",        
    device="cpu",
    compute_type="int8",
    cpu_threads=os.cpu_count() // 2,
)
logging.info(f"faster-whisper ready in {perf_counter() - t0:.1f}s")

def _build_transcribe_kwargs(model):
    sig = inspect.signature(model.transcribe)
    params = sig.parameters
    kw = dict(beam_size=5, vad_filter=True)

    if "chunk_length"     in params: kw["chunk_length"]     = 20
    elif "chunk_length_s" in params: kw["chunk_length_s"]   = 20

    if   "chunk_overlap"   in params: kw["chunk_overlap"]   = 5
    elif "chunk_overlap_s" in params: kw["chunk_overlap_s"] = 5
    elif "stride_length"   in params: kw["stride_length"]   = 5
    elif "stride"          in params: kw["stride"]          = (5, 5)

    return kw

TRANS_KWARGS = _build_transcribe_kwargs(asr_model)
logging.info(f"Transcribe kwargs: {TRANS_KWARGS}")


14:02:12 | INFO | Loading faster-whisper (PhoWhisper-base, INT8-CPU)…
14:02:13 | INFO | faster-whisper ready in 1.3s
14:02:13 | INFO | Transcribe kwargs: {'beam_size': 5, 'vad_filter': True, 'chunk_length': 20}


# PUNCTUATION

In [8]:
logging.info("Loading punctuation model (BARTpho-word) …")
t0 = perf_counter()
punc_tok   = AutoTokenizer.from_pretrained(PUNC_MODEL, use_fast=True)
punc_model = AutoModelForSeq2SeqLM.from_pretrained(PUNC_MODEL).eval().to("cpu")
logging.info(f"Punctuation ready in {perf_counter() - t0:.1f}s")

@torch.inference_mode()
def restore_punctuation(text: str) -> str:
    enc = punc_tok(text, return_tensors="pt")
    enc.pop("token_type_ids", None)        # Bỏ khóa thừa
    out = punc_model.generate(
        **enc,
        max_length=enc["input_ids"].shape[1] + 16,
        do_sample=False,
    )
    return punc_tok.decode(out[0], skip_special_tokens=True).capitalize()


def transcribe(path: str) -> str:
    segments, _ = asr_model.transcribe(path, **TRANS_KWARGS)
    return " ".join(seg.text.strip() for seg in segments)


logging.info("Scanning WAV files …")
items = []
for wav in glob.glob(os.path.join(RAW_ROOT, "**", "*.wav"), recursive=True):
    rel = Path(wav).relative_to(RAW_ROOT)
    emotion, speaker = rel.parts[0].lower(), rel.parts[1]
    utt_id = re.sub(r"\s+", "_", rel.stem)
    items.append({"wav_original": wav, "wav_id": utt_id,
                  "speaker_id": speaker, "emotion": emotion})

logging.info(f"Found {len(items)} WAV files")

logging.info("Resampling to 16 kHz mono …")
for itm in tqdm(items, desc="Resample"):
    y, sr = sf.read(itm["wav_original"], dtype="float32", always_2d=True)
    y = y.mean(axis=1)
    if sr != 16000:
        y = librosa.resample(y, orig_sr=sr, target_sr=16000)
    if (p := np.max(np.abs(y)) + 1e-9) > 1.0:
        y = y / p
    dst = Path(PROC_WAV_DIR) / f"{itm['wav_id']}.wav"
    sf.write(dst, y, 16000, subtype="PCM_16")
    itm.update({"wav_path": str(dst), "duration": len(y) / 16000})

14:02:13 | INFO | Loading punctuation model (BARTpho-word) …
2025-07-31 14:02:21.914815: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1753970541.964767   33687 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753970541.980347   33687 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753970542.091742   33687 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753970542.091836   33687 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753970542.091

#  STT + PUNCTUATION PIPELINE 

In [9]:
logging.info("Running ASR + punctuation …")
metadata = []
for itm in tqdm(items, desc="STT"):
    raw  = transcribe(itm["wav_path"])
    text = restore_punctuation(raw)
    metadata.append({
        "utterance_id": itm["wav_id"],
        "speaker_id" : itm["speaker_id"],
        "wav_path"   : itm["wav_path"],
        "start"      : 0.0,
        "end"        : itm["duration"],
        "transcript" : text,
        "emotion"    : itm["emotion"],
    })
    gc.collect()


14:04:17 | INFO | Running ASR + punctuation …
STT:   0%|          | 0/250 [00:00<?, ?it/s]14:04:18 | INFO | Processing audio with duration 00:04.783
14:04:18 | INFO | VAD filter removed 00:00.272 of audio
14:04:19 | INFO | Detected language 'vi' with probability 1.00
STT:   0%|          | 1/250 [00:03<12:45,  3.07s/it]14:04:21 | INFO | Processing audio with duration 00:02.461
14:04:21 | INFO | VAD filter removed 00:00.000 of audio
14:04:21 | INFO | Detected language 'vi' with probability 1.00
STT:   1%|          | 2/250 [00:04<09:25,  2.28s/it]14:04:22 | INFO | Processing audio with duration 00:10.403
14:04:22 | INFO | VAD filter removed 00:01.552 of audio
14:04:23 | INFO | Detected language 'vi' with probability 1.00
STT:   1%|          | 3/250 [00:07<10:04,  2.45s/it]14:04:25 | INFO | Processing audio with duration 00:08.824
14:04:25 | INFO | VAD filter removed 00:00.000 of audio
14:04:25 | INFO | Detected language 'vi' with probability 1.00
STT:   2%|▏         | 4/250 [00:09<09:41, 

# SAVE JSONL & SPLIT

In [10]:
logging.info("Saving JSONL & creating splits …")
with open(Path(OUTPUT_DIR) / "vnemos_all.jsonl", "w", encoding="utf-8") as f:
    for row in metadata:
        f.write(json.dumps(row, ensure_ascii=False) + "\n")

speakers = list({m["speaker_id"] for m in metadata}); random.shuffle(speakers)
val_n  = max(1, int(len(speakers) * VALID_RATIO))
test_n = max(1, int(len(speakers) * TEST_RATIO))

val_spk   = set(speakers[:val_n])
test_spk  = set(speakers[val_n : val_n + test_n])
train_spk = set(speakers[val_n + test_n :])

for name, grp in {"train": train_spk, "valid": val_spk, "test": test_spk}.items():
    out = Path(OUTPUT_DIR) / f"{name}.jsonl"; cnt = 0
    with open(out, "w", encoding="utf-8") as f:
        for rec in metadata:
            if rec["speaker_id"] in grp:
                f.write(json.dumps(rec, ensure_ascii=False) + "\n"); cnt += 1
    logging.info(f"{name:5s}: {cnt:6d} utt | {len(grp)} speakers")

logging.info(" Pipeline finished ")

14:13:59 | INFO | Saving JSONL & creating splits …
14:13:59 | INFO | train:    200 utt | 200 speakers
14:13:59 | INFO | valid:     25 utt | 25 speakers
14:13:59 | INFO | test :     25 utt | 25 speakers
14:13:59 | INFO |  Pipeline finished 
