In [48]:

import os, sys, platform, torch, numpy as np

print("CWD:", os.getcwd())
print("Python:", platform.python_version())
print("Torch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())


CWD: /mnt/d/MER/src
Python: 3.12.3
Torch: 2.7.1+cu126 | CUDA available: True


In [49]:

from configs.base import Config


import importlib
import loading.dataloader as data_loader
importlib.reload(data_loader)  

from model.networks import MER   
from model.losses import CrossEntropyLoss

from training.trainer import TorchTrainer
from training.callbacks import CheckpointsCallback
from training.optimizers import split_param_groups, build_optimizer


In [None]:
# %%
from pathlib import Path

cfg = Config(
    name="MER_VNEMOS_PhoBERT_W2V2",
    checkpoint_dir="../checkpoints/mer_vnemos",
    num_epochs=2,
    batch_size=8,
    learning_rate=2e-5,
    optimizer_type="AdamW",
    save_best_val=True,
    max_to_keep=2,
    num_workers=2,


    data_root="../output",
    jsonl_dir="",


    sample_rate=16000,
    max_audio_sec=6.0,
    text_max_length=64,

    # Model
    model_type="MemoCMT",
    text_encoder_type="phobert",
    text_encoder_ckpt="vinai/phobert-base",
    text_encoder_dim=768,
    text_unfreeze=False,

    audio_encoder_type="wav2vec2_xlsr",
    audio_encoder_ckpt="facebook/wav2vec2-large-xlsr-53",
    audio_encoder_dim=1024,
    audio_unfreeze=False,

    fusion_dim=768,
    fusion_head_output_type="mean",
    linear_layer_output=[128],

    use_amp=True,
    max_grad_norm=1.0,
)

FRACTION = 0.1
SEED = 42
STRATIFIED = True

print("Checkpoint dir:", Path(cfg.checkpoint_dir).resolve())
base_dir = (Path(cfg.data_root) / (cfg.jsonl_dir or "")).resolve()
print("base_dir:", base_dir)
print("audio_root (auto):", Path(getattr(cfg, "audio_root", base_dir.parent)).resolve())
for fn in ["train.jsonl","valid.jsonl","test.jsonl"]:
    print(fn, (base_dir/fn).exists())


Checkpoint dir: /mnt/d/MER/checkpoints/mer_vnemos
base_dir: /mnt/d/MER/output
audio_root (auto): /mnt/d/MER
train.jsonl True
valid.jsonl True
test.jsonl True


In [51]:

from torch.utils.data import Subset, RandomSampler

def _subset_indices_random(n, fraction, seed=42):
    k = max(1, int(round(n * fraction)))
    rng = np.random.RandomState(seed)
    idx = np.arange(n)
    rng.shuffle(idx)
    return idx[:k].tolist()

def _subset_indices_stratified_vnemos(ds, fraction, seed=42):

    label2inds = {}
    for i, ex in enumerate(ds.items):
        y = ds.label2id[ex["emotion"]]
        label2inds.setdefault(y, []).append(i)
    rng = np.random.RandomState(seed)
    out = []
    for y, inds in label2inds.items():
        k = max(1, int(round(len(inds) * fraction)))
        rng.shuffle(inds)
        out.extend(inds[:k])
    rng.shuffle(out)
    return out

def _subset_loader(loader, fraction=1.0, seed=42, stratified=False):
    if fraction >= 0.999:
        return loader
    ds = loader.dataset
    if stratified and hasattr(ds, "items") and hasattr(ds, "label2id"):
        indices = _subset_indices_stratified_vnemos(ds, fraction, seed)
    else:
        indices = _subset_indices_random(len(ds), fraction, seed)
    subset = Subset(ds, indices)

    shuffle = isinstance(loader.sampler, RandomSampler)
    return type(loader)(
        subset,
        batch_size=loader.batch_size,
        shuffle=shuffle,
        num_workers=loader.num_workers,
        collate_fn=loader.collate_fn,
        pin_memory=loader.pin_memory,
        drop_last=loader.drop_last,
    )


train_loader, eval_loader = data_loader.build_train_test_dataset(cfg)


train_ds = train_loader.dataset
if hasattr(train_ds, "label2id"):
    cfg.num_classes = len(train_ds.label2id)
    id2label = [k for k, v in sorted(train_ds.label2id.items(), key=lambda x: x[1])]
else:
    cfg.num_classes = 4
    id2label = list(range(cfg.num_classes))


train_loader = _subset_loader(train_loader, fraction=FRACTION, seed=SEED, stratified=STRATIFIED)
if eval_loader is not None:
    eval_loader = _subset_loader(eval_loader, fraction=FRACTION, seed=SEED, stratified=STRATIFIED)

print("Classes:", cfg.num_classes, id2label)
print(f"Train batches: {len(train_loader)} | Eval batches: {len(eval_loader) if eval_loader else 0}")


Classes: 5 ['angry', 'fear', 'happiness', 'neutral', 'sadness']
Train batches: 3 | Eval batches: 1


In [None]:

from transformers import BatchEncoding

device = "cuda" if torch.cuda.is_available() else "cpu"
network = MER(cfg, device=device)
criterion = CrossEntropyLoss(cfg)

class MERTrainer(TorchTrainer):
    def __init__(self, cfg, network, criterion, **kwargs):
        super().__init__(**kwargs)
        self.cfg = cfg
        self.network = network
        self.criterion = criterion
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.network.to(self.device)
        self.use_amp = bool(getattr(cfg, "use_amp", torch.cuda.is_available()))
        self.max_grad_norm = float(getattr(cfg, "max_grad_norm", 0.0))
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)

    def _to_device_text(self, x):

        if isinstance(x, BatchEncoding):
            x = x.to(self.device)
            return {k: v for k, v in x.items()}
        if isinstance(x, dict):
            return {k: v.to(self.device) for k, v in x.items()}

        return x.to(self.device)

    def _standardize_batch(self, batch):
        """
        Chuẩn hoá về (tok_dict, audio_tensor, labels_tensor).
        Hỗ trợ:
          - (tok, audio, labels)
          - ((tok, audio, labels), meta)
          - {"text": tok, "audio": audio, "label": labels}
          - (tok_dict, (audio, labels))
        """
        if isinstance(batch, (tuple, list)) and len(batch) == 2:
            a, b = batch
            if isinstance(a, (tuple, list)) and len(a) == 3:
                return a
            if isinstance(b, (tuple, list)) and len(b) == 3:
                return b
            if isinstance(a, (dict, BatchEncoding)) and isinstance(b, (tuple, list)) and len(b) == 2:
                tok = a
                audio, labels = b
                return tok, audio, labels

        if isinstance(batch, (tuple, list)) and len(batch) == 3:
            return batch

        if isinstance(batch, dict):
            tok = batch.get("text"); audio = batch.get("audio"); labels = batch.get("label")
            if tok is not None and audio is not None and labels is not None:
                return tok, audio, labels

        raise ValueError(f"Batch structure không hỗ trợ: preview={str(batch)[:200]}")

    def train_step(self, batch):
        self.network.train()
        self.optimizer.zero_grad(set_to_none=True)

        input_text, input_audio, labels = self._standardize_batch(batch)
        input_audio = input_audio.to(self.device)
        labels = labels.to(self.device)
        input_text = self._to_device_text(input_text)

        if self.use_amp:
            with torch.cuda.amp.autocast():
                out = self.network(input_text, input_audio)
                logits = out[0] if isinstance(out, (tuple, list)) else out
                loss = self.criterion(out, labels) if isinstance(out, (tuple, list)) else self.criterion(logits, labels)
            self.scaler.scale(loss).backward()
            if self.max_grad_norm and self.max_grad_norm > 0:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            out = self.network(input_text, input_audio)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            loss = self.criterion(out, labels) if isinstance(out, (tuple, list)) else self.criterion(logits, labels)
            loss.backward()
            if self.max_grad_norm and self.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
            self.optimizer.step()

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        return {"loss": float(loss.detach().cpu()), "acc": float(acc.detach().cpu())}

    def test_step(self, batch):
        self.network.eval()
        input_text, input_audio, labels = self._standardize_batch(batch)
        input_audio = input_audio.to(self.device)
        labels = labels.to(self.device)
        input_text = self._to_device_text(input_text)
        with torch.no_grad():
            out = self.network(input_text, input_audio)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            loss = self.criterion(out, labels) if isinstance(out, (tuple, list)) else self.criterion(logits, labels)
            preds = torch.argmax(logits, dim=1)
            acc = (preds == labels).float().mean()
        return {"val_loss": float(loss.detach().cpu()), "val_acc": float(acc.detach().cpu())}

trainer = MERTrainer(cfg, network, criterion, log_dir="logs")
print("Trainer ready. Device:", trainer.device)



Trainer ready. Device: cuda


  self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)


In [53]:

from torch.optim.lr_scheduler import CosineAnnealingLR

param_groups = split_param_groups(trainer, lr_enc=cfg.learning_rate * 0.25,
                                  lr_head=cfg.learning_rate, weight_decay=0.01)
optimizer = build_optimizer("adamw", param_groups, lr=cfg.learning_rate, weight_decay=0.01)

scheduler = CosineAnnealingLR(optimizer, T_max=max(1, cfg.num_epochs))
trainer.compile(optimizer=optimizer, scheduler=scheduler, lr=cfg.learning_rate, param_groups=None)
print("Param groups:", len(optimizer.param_groups))


Param groups: 1


In [54]:

ckpt_cb = CheckpointsCallback(cfg.checkpoint_dir, save_freq=200, max_to_keep=2,
                              save_best_val=True, monitor="val_loss", mode="min")
trainer.fit(train_loader, epochs=cfg.num_epochs, eval_data=eval_loader, callbacks=[ckpt_cb])
print("Best checkpoint:", getattr(ckpt_cb, "best_path", ""))


Epoch 1/2
2025-08-09 13:52:15,367 - Training - INFO - Epoch 1/2
Epoch 1:   0%|          | 0/3 [00:00<?, ?it/s]Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space

Best checkpoint: ../checkpoints/mer_vnemos/best_val_loss/checkpoint_0.pth


In [56]:
# %%  (EVALUATION — sửa để tái dùng chuẩn hoá batch của trainer)

import torch
import numpy as np

def collect_preds(trainer, loader):
    all_preds, all_labels = [], []
    trainer.network.eval()
    with torch.no_grad():
        for batch in loader:
            # DÙ batch ở dạng gì, chuẩn hoá về (tok_dict, audio, labels)
            input_text, input_audio, labels = trainer._standardize_batch(batch)
            input_audio = input_audio.to(trainer.device)
            labels = labels.to(trainer.device)
            input_text = trainer._to_device_text(input_text)

            out = trainer.network(input_text, input_audio)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    return torch.cat(all_preds), torch.cat(all_labels)

if eval_loader is not None:
    preds, labels = collect_preds(trainer, eval_loader)
    num_classes = cfg.num_classes

    # Confusion matrix + metrics
    cm = torch.zeros((num_classes, num_classes), dtype=torch.int32)
    for t, p in zip(labels, preds):
        cm[t.long(), p.long()] += 1

    per_class_acc = (cm.diag().float() / cm.sum(dim=1).clamp(min=1).float()).numpy()
    overall_acc = (preds == labels).float().mean().item()

    print("Overall Acc: %.4f" % overall_acc)
    print("Per-class Acc:")
    for i, acc in enumerate(per_class_acc):
        name = str(i)
        try:
            name = id2label[i]
        except Exception:
            pass
        print(f"  {i} ({name}): {acc:.4f}")
    print("Confusion Matrix:\n", cm.numpy())

    # Classification report (nếu sklearn có sẵn)
    try:
        from sklearn.metrics import classification_report
        target_names = [str(i) for i in range(num_classes)]
        try:
            target_names = [id2label[i] for i in range(num_classes)]
        except Exception:
            pass
        print("\nClassification Report:")
        print(classification_report(labels.numpy(), preds.numpy(), target_names=target_names, digits=4))
    except Exception:
        pass
else:
    print("No eval_loader available.")


Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.


Overall Acc: 0.2000
Per-class Acc:
  0 (angry): 1.0000
  1 (fear): 0.0000
  2 (happiness): 0.0000
  3 (neutral): 0.0000
  4 (sadness): 0.0000
Confusion Matrix:
 [[1 0 0 0 0]
 [1 0 0 0 0]
 [0 1 0 0 0]
 [0 1 0 0 0]
 [0 0 0 1 0]]

Classification Report:
              precision    recall  f1-score   support

       angry     0.5000    1.0000    0.6667         1
        fear     0.0000    0.0000    0.0000         1
   happiness     0.0000    0.0000    0.0000         1
     neutral     0.0000    0.0000    0.0000         1
     sadness     0.0000    0.0000    0.0000         1

    accuracy                         0.2000         5
   macro avg     0.1000    0.2000    0.1333         5
weighted avg     0.1000    0.2000    0.1333         5



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
