In [1]:

import os, sys, platform, torch, numpy as np

print("CWD:", os.getcwd())
print("Python:", platform.python_version())
print("Torch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())


CWD: /mnt/d/MER/src
Python: 3.12.3
Torch: 2.7.1+cu126 | CUDA available: True


In [2]:
from configs.base import Config

import importlib
import loading.dataloader as data_loader
importlib.reload(data_loader) 

from model.networks import MER 
from model.losses import CrossEntropyLoss

from training.trainer import TorchTrainer
from training.callbacks import CheckpointsCallback
from training.optimizers import split_param_groups, build_optimizer

2025-08-22 14:57:18.320741: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755874639.637987  158299 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755874639.987095  158299 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755874643.107186  158299 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755874643.107215  158299 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755874643.107217  158299 computation_placer.cc:177] computation placer alr

In [3]:
from pathlib import Path

cfg = Config(
    name="MER_VNEMOS_PhoBERT_W2V2_A_lengthbucket",
    checkpoint_dir="../checkpoints/mer_vnemos_A",
    num_epochs=2,
    batch_size=8,
    learning_rate=2e-5,
    optimizer_type="AdamW",
    save_best_val=True,
    max_to_keep=2,
    num_workers=2,

    # Data
    data_root="../output",
    jsonl_dir="",
    sample_rate=16000,
    max_audio_sec=None,      
    text_max_length=64,

    # Sampler theo độ dài (Phương án A)
    use_length_bucket=True,   
    length_bucket_size=64,

    # Model
    model_type="MemoCMT",
    text_encoder_type="phobert",
    text_encoder_ckpt="vinai/phobert-base",
    text_encoder_dim=768,
    text_unfreeze=False,

    audio_encoder_type="wav2vec2_xlsr",
    audio_encoder_ckpt="facebook/wav2vec2-large-xlsr-53",
    audio_encoder_dim=1024,
    audio_unfreeze=False,

    fusion_dim=768,
    fusion_head_output_type="mean",
    linear_layer_output=[128],

    use_amp=True,
    max_grad_norm=1.0,
)

base_dir = (Path(cfg.data_root) / (cfg.jsonl_dir or "")).resolve()
audio_root_print = Path(getattr(cfg, "audio_root", "") or base_dir.parent).resolve()
print("base_dir:", base_dir)
print("audio_root (auto):", audio_root_print)

base_dir: /mnt/d/MER/output
audio_root (auto): /mnt/d/MER


In [4]:
train_loader, eval_loader = data_loader.build_train_test_dataset(cfg)

train_ds = train_loader.dataset
if hasattr(train_ds, "label2id"):
    cfg.num_classes = len(train_ds.label2id)
    id2label = [k for k, v in sorted(train_ds.label2id.items(), key=lambda x: x[1])]
else:
    cfg.num_classes = 4
    id2label = list(range(cfg.num_classes))

print("Classes:", cfg.num_classes, id2label)
print(f"Train batches: {len(train_loader)} | Eval batches: {len(eval_loader) if eval_loader else 0}")


Classes: 5 ['angry', 'fear', 'happiness', 'neutral', 'sadness']
Train batches: 25 | Eval batches: 4


In [None]:
from transformers import BatchEncoding

device = "cuda" if torch.cuda.is_available() else "cpu"
network = MER(cfg, device=device)
criterion = CrossEntropyLoss(cfg)

class MERTrainer(TorchTrainer):
    def __init__(self, cfg, network, criterion, **kwargs):
        super().__init__(**kwargs)
        self.cfg = cfg
        self.network = network
        self.criterion = criterion
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.network.to(self.device)
        self.use_amp = bool(getattr(cfg, "use_amp", torch.cuda.is_available()))
        self.max_grad_norm = float(getattr(cfg, "max_grad_norm", 0.0))
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)

    def _to_device_text(self, x):
        if isinstance(x, BatchEncoding):
            x = x.to(self.device)
            return {k: v for k, v in x.items()}
        if isinstance(x, dict):
            return {k: v.to(self.device) for k, v in x.items()}
        return x.to(self.device)

    def _standardize_batch(self, batch):
        """
        Chuẩn hoá về (tok_dict, audio_tensor, labels_tensor, meta or None).
        Hỗ trợ:
          - (tok, audio, labels)
          - ((tok, audio, labels), meta)
          - {"text": tok, "audio": audio, "label": labels}
          - (tok_dict, (audio, labels))
        """
        # ((tok,a,l), meta)
        if isinstance(batch, (tuple, list)) and len(batch) == 2:
            a, meta = batch
            if isinstance(a, (tuple, list)) and len(a) == 3:
                tok, audio, labels = a
                return tok, audio, labels, meta
            if isinstance(a, (dict, BatchEncoding)) and isinstance(meta, (tuple, list)) and len(meta) == 2:
                # dạng ít gặp: (tok, (audio, labels))
                tok = a
                audio, labels = meta
                return tok, audio, labels, None

        # (tok,a,l)
        if isinstance(batch, (tuple, list)) and len(batch) == 3:
            tok, audio, labels = batch
            return tok, audio, labels, None

        # dict
        if isinstance(batch, dict):
            tok = batch.get("text"); audio = batch.get("audio"); labels = batch.get("label")
            if tok is not None and audio is not None and labels is not None:
                return tok, audio, labels, None

        raise ValueError(f"Batch structure không hỗ trợ: preview={str(batch)[:200]}")

    def train_step(self, batch):
        self.network.train()
        self.optimizer.zero_grad(set_to_none=True)

        input_text, input_audio, labels, meta = self._standardize_batch(batch)
        input_audio = input_audio.to(self.device, non_blocking=True)
        labels = labels.to(self.device, non_blocking=True)
        input_text = self._to_device_text(input_text)

        if self.use_amp:
            with torch.cuda.amp.autocast():
                # >>> TRUYỀN meta để bật audio mask theo Phương án A
                out = self.network(input_text, input_audio, meta=meta)
                logits = out[0] if isinstance(out, (tuple, list)) else out
                loss = self.criterion(out, labels) if isinstance(out, (tuple, list)) else self.criterion(logits, labels)
            self.scaler.scale(loss).backward()
            if self.max_grad_norm and self.max_grad_norm > 0:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            out = self.network(input_text, input_audio, meta=meta)  # <<< pass meta
            logits = out[0] if isinstance(out, (tuple, list)) else out
            loss = self.criterion(out, labels) if isinstance(out, (tuple, list)) else self.criterion(logits, labels)
            loss.backward()
            if self.max_grad_norm and self.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
            self.optimizer.step()

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        return {"loss": float(loss.detach().cpu()), "acc": float(acc.detach().cpu())}

    def test_step(self, batch):
        self.network.eval()
        input_text, input_audio, labels, meta = self._standardize_batch(batch)
        input_audio = input_audio.to(self.device, non_blocking=True)
        labels = labels.to(self.device, non_blocking=True)
        input_text = self._to_device_text(input_text)
        with torch.no_grad():
            out = self.network(input_text, input_audio, meta=meta)  
            logits = out[0] if isinstance(out, (tuple, list)) else out
            loss = self.criterion(out, labels) if isinstance(out, (tuple, list)) else self.criterion(logits, labels)
            preds = torch.argmax(logits, dim=1)
            acc = (preds == labels).float().mean()
        return {"val_loss": float(loss.detach().cpu()), "val_acc": float(acc.detach().cpu())}

trainer = MERTrainer(cfg, network, criterion, log_dir="logs")
print("Trainer ready. Device:", trainer.device)

Trainer ready. Device: cuda


  self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)


In [6]:
from torch.optim.lr_scheduler import CosineAnnealingLR

param_groups = split_param_groups(trainer, lr_enc=cfg.learning_rate * 0.25,
                                  lr_head=cfg.learning_rate, weight_decay=0.01)
optimizer = build_optimizer("adamw", param_groups, lr=cfg.learning_rate, weight_decay=0.01)

scheduler = CosineAnnealingLR(optimizer, T_max=max(1, cfg.num_epochs))
trainer.compile(optimizer=optimizer, scheduler=scheduler, lr=cfg.learning_rate, param_groups=None)
print("Param groups:", len(optimizer.param_groups))


Param groups: 1


In [7]:
ckpt_cb = CheckpointsCallback(cfg.checkpoint_dir, save_freq=200, max_to_keep=2,
                              save_best_val=True, monitor="val_loss", mode="min")
trainer.fit(train_loader, epochs=cfg.num_epochs, eval_data=eval_loader, callbacks=[ckpt_cb])
print("Best checkpoint:", getattr(ckpt_cb, "best_path", ""))

Epoch 1/2
2025-08-22 14:58:46,424 - Training - INFO - Epoch 1/2
Epoch 1:   0%|          | 0/25 [00:00<?, ?it/s]Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_spac

Best checkpoint: ../checkpoints/mer_vnemos_A/best_val_loss/checkpoint_0.pth


Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_

In [8]:
import torch
import numpy as np

def collect_preds(trainer, loader):
    all_preds, all_labels = [], []
    trainer.network.eval()
    with torch.no_grad():
        for batch in loader:
            tok, audio, labels, meta = trainer._standardize_batch(batch)
            audio = audio.to(trainer.device, non_blocking=True)
            labels = labels.to(trainer.device, non_blocking=True)
            tok = trainer._to_device_text(tok)

            out = trainer.network(tok, audio, meta=meta)  # <<< pass meta
            logits = out[0] if isinstance(out, (tuple, list)) else out
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    return torch.cat(all_preds), torch.cat(all_labels)

if eval_loader is not None:
    preds, labels = collect_preds(trainer, eval_loader)
    num_classes = cfg.num_classes

    cm = torch.zeros((num_classes, num_classes), dtype=torch.int32)
    for t, p in zip(labels, preds):
        cm[t.long(), p.long()] += 1

    per_class_acc = (cm.diag().float() / cm.sum(dim=1).clamp(min=1).float()).numpy()
    overall_acc = (preds == labels).float().mean().item()

    print("Overall Acc: %.4f" % overall_acc)
    print("Per-class Acc:")
    for i, acc in enumerate(per_class_acc):
        name = str(i)
        try:
            name = id2label[i]
        except Exception:
            pass
        print(f"  {i} ({name}): {acc:.4f}")
    print("Confusion Matrix:\n", cm.numpy())

    try:
        from sklearn.metrics import classification_report
        target_names = [str(i) for i in range(num_classes)]
        try:
            target_names = [id2label[i] for i in range(num_classes)]
        except Exception:
            pass
        print("\nClassification Report:")
        print(classification_report(labels.numpy(), preds.numpy(), target_names=target_names, digits=4))
    except Exception:
        pass
else:
    print("No eval_loader available.")

Overall Acc: 0.4000
Per-class Acc:
  0 (angry): 0.0000
  1 (fear): 0.2000
  2 (happiness): 0.6000
  3 (neutral): 0.2000
  4 (sadness): 0.8333
Confusion Matrix:
 [[0 1 2 1 0]
 [1 1 1 2 0]
 [0 1 3 1 0]
 [2 1 0 1 1]
 [0 1 0 0 5]]

Classification Report:
              precision    recall  f1-score   support

       angry     0.0000    0.0000    0.0000         4
        fear     0.2000    0.2000    0.2000         5
   happiness     0.5000    0.6000    0.5455         5
     neutral     0.2000    0.2000    0.2000         5
     sadness     0.8333    0.8333    0.8333         6

    accuracy                         0.4000        25
   macro avg     0.3467    0.3667    0.3558        25
weighted avg     0.3800    0.4000    0.3891        25



In [9]:
# %%
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
from loading.dataloader import VNEMOSDataset, _clean_text
from configs.base import Config

# Dùng đúng cfg bạn huấn luyện: với Phương án A phải max_audio_sec=None
cfg_check = Config(
    data_root="../output",
    jsonl_dir="",
    sample_rate=16000,
    max_audio_sec=None,     # RẤT QUAN TRỌNG: để None thì dataset không cắt cứng
    text_max_length=64,
    use_length_bucket=True,
)

train_set = VNEMOSDataset(cfg_check, "train")  # hoặc "valid"/"test"

def expected_len_after_resample(item, target_sr: int):
    """Tính số mẫu kỳ vọng sau khi cắt start/end và RESAMPLE về target_sr."""
    wav_path = train_set._resolve_wav(item["wav_path"])
    info = sf.info(str(wav_path))
    orig_sr = info.samplerate
    orig_frames = info.frames
    start = float(item.get("start", 0.0) or 0.0)
    end   = float(item.get("end", 0.0) or 0.0)
    if end and end > 0:
        dur_s = max(0.0, end - start)
    else:
        dur_s = orig_frames / float(orig_sr)
    # sau resample, số mẫu ≈ dur_s * target_sr
    return int(round(dur_s * target_sr))

def audit_audio_full(ds, max_examples=999999, tol_samples=4):
    """
    Kiểm tra L_thực / L_kỳ vọng.
    - tol_samples: dung sai vài sample do làm tròn/resample.
    """
    bad = []
    diffs = []
    for i, it in enumerate(ds.items[:max_examples]):
        exp_len = expected_len_after_resample(it, ds.sample_rate)
        ex = ds[i]  # kích hoạt __getitem__
        act_len = int(ex["audio"].numel())
        diffs.append(act_len - exp_len)
        # nếu act_len < exp_len - tol => có vẻ bị cắt
        if exp_len > 0 and act_len + tol_samples < exp_len:
            bad.append((i, it["wav_path"], exp_len, act_len))
    diffs = np.array(diffs, dtype=np.int64)
    return bad, diffs

bad, diffs = audit_audio_full(train_set)

print(f"[Audio] Tổng mẫu kiểm: {len(train_set.items)}")
print(f"[Audio] Sai khác (act-exp) min/mean/max:", diffs.min(), diffs.mean().round(2), diffs.max())
print(f"[Audio] Số mẫu NGHI BỊ CẮT (act << exp): {len(bad)}")
if bad[:10]:
    print("Ví dụ 5 mẫu đầu NGHI BỊ CẮT:")
    for row in bad[:5]:
        i, path, exp_len, act_len = row
        print(f"  idx={i} | {path} | expected={exp_len} | actual={act_len}")


[Audio] Tổng mẫu kiểm: 200
[Audio] Sai khác (act-exp) min/mean/max: 0 0.0 0
[Audio] Số mẫu NGHI BỊ CẮT (act << exp): 0
