In [3]:
import os, sys, platform, torch, numpy as np
from pathlib import Path

print("CWD:", os.getcwd())
print("Python:", platform.python_version())
print("Torch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())

# Ensure src is on path (if running from project root)
if (Path.cwd() / "src").exists():
    sys.path.insert(0, str(Path.cwd() / "src"))
elif (Path.cwd().name == "src"):
    sys.path.insert(0, str(Path.cwd()))
print("Sys.path[0:3] =", sys.path[:3])

CWD: /mnt/d/MER/src
Python: 3.12.3
Torch: 2.7.1+cu126 | CUDA available: True
Sys.path[0:3] = ['/mnt/d/MER/src', '/mnt/d/MER/src', '/usr/lib/python312.zip']


In [4]:
from configs.base import Config

# reload custom modules if needed
import importlib
import loading.dataloader as data_loader
importlib.reload(data_loader)

from model.networks import MER   # sử dụng mạng MER
from model.losses import get_loss

from training.trainer import TorchTrainer
from training.callbacks import CheckpointsCallback, GradualUnfreezeCallback
from training.optimizers import split_param_groups, build_optimizer

from transformers import BatchEncoding

  """


In [5]:
cfg = Config(
    name="MER_VNEMOS_optimal",
    checkpoint_dir="../checkpoints/mer_vnemos_optimal",
    num_epochs=45,
    batch_size=8,
    learning_rate=1.5e-5,        # lr cho head (nhóm head); encoder sẽ nhỏ hơn qua param groups
    optimizer_type="AdamW",
    save_best_val=True,
    max_to_keep=2,
    num_workers=2,

    # Data roots
    data_root="../output",
    jsonl_dir="",

    # Audio/Text lengths
    sample_rate=16000,
    max_audio_sec=None,          # KHÔNG cắt cứng (theo Phương án A)
    text_max_length=64,

    # Model
    model_type="MER",
    text_encoder_type="phobert",
    text_encoder_ckpt="vinai/phobert-base",
    text_encoder_dim=768,
    text_unfreeze=False,         # ban đầu đóng băng, sẽ mở dần bằng callback

    audio_encoder_type="wav2vec2_xlsr",
    audio_encoder_ckpt="facebook/wav2vec2-large-xlsr-53",
    audio_encoder_dim=1024,
    audio_unfreeze=False,

    fusion_dim=768,
    fusion_head_output_type="mean",
    linear_layer_output=[128],
    dropout=0.05,

    # Train tricks
    use_amp=True,
    max_grad_norm=1.0,

    # Scheduler (dùng LambdaLR warmup+cosine ở notebook, step-based)
    scheduler_type="cosine_warmup",
    scheduler_step_unit="step",
    warmup_ratio=0.05,

    # Loss: Sample-Weighted CE theo độ dài mẫu
    loss_type="SampleWeightedCE",
    label_smoothing=0.02,            # 0.0 - 0.02
    sample_weight_by_duration=True,
    sample_weight_beta=1.0,
    sample_weight_eps=0.1,

    # Random crop (train-only) để chuẩn hoá độ dài hiệu dụng
    random_crop_sec_train=8.0,

    # Duration-balanced sampler (train-only)
    enable_duration_weight_sampler=True,
    duration_weight_beta=1.0,
    duration_weight_eps=0.1,

    # Augmentation (train-only)
    augment_audio=True,
    aug_speed_prob=0.5, aug_speed_range=[0.92, 1.08],
    aug_timeshift_prob=0.3, aug_timeshift_sec=0.2,
    aug_noise_prob=0.3, aug_noise_snr_db=[20.0, 25.0],
    aug_gain_prob=0.3, aug_gain_db=[-3.0, 3.0],

    # Gradual unfreeze schedule
    gradual_unfreeze_epoch1=5,
    text_unfreeze_last_k1=2,
    audio_unfreeze_last_k1=2,
    gradual_unfreeze_epoch2=12,
    text_unfreeze_last_k2=4,
    audio_unfreeze_last_k2=4,
)

print("Checkpoint dir:", Path(cfg.checkpoint_dir).resolve())
base_dir = (Path(cfg.data_root) / (cfg.jsonl_dir or "")).resolve()
print("base_dir:", base_dir)
print("audio_root (auto):", Path(getattr(cfg, "audio_root", base_dir.parent)).resolve())
for fn in ["train.jsonl","valid.jsonl","test.jsonl"]:
    print(fn, (base_dir/fn).exists())


Checkpoint dir: /mnt/d/MER/checkpoints/mer_vnemos_optimal
base_dir: /mnt/d/MER/output
audio_root (auto): /mnt/d/MER
train.jsonl True
valid.jsonl True
test.jsonl True


In [6]:
train_loader, eval_loader = data_loader.build_train_test_dataset(cfg)

train_ds = train_loader.dataset
if hasattr(train_ds, "label2id"):
    cfg.num_classes = len(train_ds.label2id)
    id2label = [k for k, v in sorted(train_ds.label2id.items(), key=lambda x: x[1])]
else:
    cfg.num_classes = 4
    id2label = list(range(cfg.num_classes))

print("Classes:", cfg.num_classes, id2label)
print(f"Train batches: {len(train_loader)} | Eval batches: {len(eval_loader) if eval_loader else 0}")

Classes: 5 ['angry', 'fear', 'happiness', 'neutral', 'sadness']
Train batches: 25 | Eval batches: 4


In [7]:

device = "cuda" if torch.cuda.is_available() else "cpu"
network = MER(cfg, device=device)


criterion = get_loss(cfg)

class MERTrainer(TorchTrainer):
    def __init__(self, cfg, network, criterion, **kwargs):
        super().__init__(**kwargs)
        self.cfg = cfg
        self.network = network
        self.criterion = criterion
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.network.to(self.device)
        self.use_amp = bool(getattr(cfg, "use_amp", torch.cuda.is_available()))
        self.max_grad_norm = float(getattr(cfg, "max_grad_norm", 0.0))
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)

    def _to_device_text(self, x):
        if isinstance(x, BatchEncoding):
            x = x.to(self.device)
            return {k: v for k, v in x.items()}
        if isinstance(x, dict):
            return {k: v.to(self.device) for k, v in x.items()}
        return x.to(self.device)

    def _standardize_batch(self, batch):
        """
        Chuẩn hoá về ((tok, audio, labels), meta) → tok, audio, labels, meta
        """
        if isinstance(batch, (tuple, list)) and len(batch) == 2:
            a, meta = batch
            if isinstance(a, (tuple, list)) and len(a) == 3:
                tok, audio, labels = a
                return tok, audio, labels, meta
        raise ValueError(f"Batch structure không hỗ trợ: preview={str(batch)[:200]}")

    def _compute_loss(self, logits, labels, meta):
        # Nếu criterion có hỗ trợ sample_weight (SampleWeightedCE)
        try:
            name = self.criterion.__class__.__name__.lower()
            if "sampleweighted" in name or "sample_weighted" in name:
                sw = meta.get("sample_weight", None)
                return self.criterion(logits, labels, sample_weight=sw)
        except Exception:
            pass
        # fallback CE thường
        if labels.dtype != torch.long:
            labels = labels.long()
        return self.criterion(logits, labels)

    def train_step(self, batch):
        self.network.train()
        self.optimizer.zero_grad(set_to_none=True)

        tok, audio, labels, meta = self._standardize_batch(batch)
        audio  = audio.to(self.device, non_blocking=True)
        labels = labels.to(self.device, non_blocking=True)
        tok    = self._to_device_text(tok)

        if self.use_amp:
            with torch.cuda.amp.autocast():
                out = self.network(tok, audio)
                logits = out[0] if isinstance(out, (tuple, list)) else out
                loss = self._compute_loss(logits, labels, meta)
            self.scaler.scale(loss).backward()
            if self.max_grad_norm and self.max_grad_norm > 0:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            out = self.network(tok, audio)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            loss = self._compute_loss(logits, labels, meta)
            loss.backward()
            if self.max_grad_norm and self.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
            self.optimizer.step()

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        return {"loss": float(loss.detach().cpu()), "acc": float(acc.detach().cpu())}

    def test_step(self, batch):
        self.network.eval()
        tok, audio, labels, meta = self._standardize_batch(batch)
        audio  = audio.to(self.device, non_blocking=True)
        labels = labels.to(self.device, non_blocking=True)
        tok    = self._to_device_text(tok)
        with torch.no_grad():
            out = self.network(tok, audio)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            loss = self._compute_loss(logits, labels, meta)
            preds = torch.argmax(logits, dim=1)
            acc = (preds == labels).float().mean()
        return {"val_loss": float(loss.detach().cpu()), "val_acc": float(acc.detach().cpu())}

trainer = MERTrainer(cfg, network, criterion, log_dir="logs")
print("Trainer ready. Device:", trainer.device)


Trainer ready. Device: cuda


  self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)


In [8]:
from math import pi, cos

# Param groups: encoder LR nhỏ hơn head LR
param_groups = split_param_groups(trainer, lr_enc=3e-6, lr_head=cfg.learning_rate, weight_decay=0.01)
optimizer = build_optimizer("adamw", param_groups, lr=cfg.learning_rate, weight_decay=0.01)

# LambdaLR: warmup+cosine theo STEP
total_steps = max(1, len(train_loader) * cfg.num_epochs)
warmup_steps = int(getattr(cfg, "warmup_ratio", 0.05) * total_steps)

def lr_lambda(step):
    if step < warmup_steps:
        return float(step) / max(1, warmup_steps)
    progress = float(step - warmup_steps) / max(1, total_steps - warmup_steps)
    return 0.5 * (1.0 + cos(pi * progress))

from torch.optim.lr_scheduler import LambdaLR
scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

trainer.compile(
    optimizer=optimizer,
    scheduler=scheduler,
    lr=cfg.learning_rate,
    param_groups=None,
    scheduler_step_unit="step",     # step-based scheduler
)
print("Param groups:", len(optimizer.param_groups))

Param groups: 1


In [9]:
ckpt_cb = CheckpointsCallback(
    cfg.checkpoint_dir, save_freq=200, max_to_keep=2,
    save_best_val=True, monitor="val_loss", mode="min"
)

unfreeze_cb = GradualUnfreezeCallback(
    epoch_trigger1=cfg.gradual_unfreeze_epoch1,
    text_last_k1=cfg.text_unfreeze_last_k1, audio_last_k1=cfg.audio_unfreeze_last_k1,
    epoch_trigger2=cfg.gradual_unfreeze_epoch2,
    text_last_k2=cfg.text_unfreeze_last_k2, audio_last_k2=cfg.audio_unfreeze_last_k2,
)




In [10]:
trainer.fit(train_loader, epochs=cfg.num_epochs, eval_data=eval_loader, callbacks=[ckpt_cb, unfreeze_cb])
print("Best checkpoint:", getattr(ckpt_cb, "best_path", ""))

Epoch 1/45
2025-08-29 15:11:41,824 - Training - INFO - Epoch 1/45
Epoch 1:   0%|          | 0/25 [00:00<?, ?it/s]Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_sp

Best checkpoint: ../checkpoints/mer_vnemos_optimal/best_val_loss/checkpoint_0.pth


In [11]:
import torch
import numpy as np

USE_TTMC = True           
TTMC_WIN_SEC = 8.0
TTMC_MIN_TOTAL_SEC = 12.0

def _ttmc_logits_single(trainer, tok_dict, audio_1d, sr=16000, win_sec=8.0):
    """Chia audio_1d thành các cửa sổ 8s không chồng lấn, rồi average logits."""
    trainer.network.eval()
    T = audio_1d.shape[-1]
    win = int(round(win_sec * sr))
    if T <= win:
        with torch.no_grad():
            out = trainer.network(tok_dict, audio_1d.unsqueeze(0).to(trainer.device))
            logits = out[0] if isinstance(out, (tuple, list)) else out
        return logits.squeeze(0).detach().cpu()

    # số crop không chồng lấn
    num = T // win
    logits_list = []
    with torch.no_grad():
        for i in range(num):
            seg = audio_1d[i*win:(i+1)*win].unsqueeze(0).to(trainer.device)
            out = trainer.network(tok_dict, seg)
            lg = out[0] if isinstance(out, (tuple, list)) else out
            logits_list.append(lg.detach().cpu())
    return torch.stack(logits_list, dim=0).mean(dim=0).squeeze(0)

def collect_preds(trainer, loader, use_ttmc=False, win_sec=8.0, min_total_sec=12.0, sr=16000):
    all_preds, all_labels = [], []
    id2dur = []
    trainer.network.eval()
    with torch.no_grad():
        for batch in loader:
            # Chuẩn hoá về ((tok,audio,labels), meta)
            if isinstance(batch, (tuple, list)) and len(batch) == 2:
                (tok, audio, labels), meta = batch
            else:
                raise ValueError("Unexpected batch format for evaluation")

            # to device
            audio = audio.to(trainer.device)
            labels = labels.to(trainer.device)
            if isinstance(tok, BatchEncoding):
                tok = tok.to(trainer.device)
                tok = {k: v for k, v in tok.items()}
            elif isinstance(tok, dict):
                tok = {k: v.to(trainer.device) for k, v in tok.items()}

            if use_ttmc:
                lengths = meta.get("audio_lengths", [audio.shape[-1]] * audio.shape[0])
                for i in range(audio.shape[0]):
                    a1 = audio[i, :lengths[i]].cpu()  # cắt về đúng length
                    if lengths[i] >= int(round(min_total_sec * sr)):
                        lg = _ttmc_logits_single(trainer, {k: v[i:i+1] for k,v in tok.items()}, a1, sr=sr, win_sec=win_sec)
                    else:
                        out = trainer.network({k: v[i:i+1] for k,v in tok.items()}, a1.unsqueeze(0).to(trainer.device))
                        lg = (out[0] if isinstance(out, (tuple, list)) else out).detach().cpu().squeeze(0)
                    pred = torch.argmax(lg, dim=-1)
                    all_preds.append(pred.unsqueeze(0))
                all_labels.append(labels.cpu())
            else:
                out = trainer.network(tok, audio)
                logits = out[0] if isinstance(out, (tuple, list)) else out
                preds = torch.argmax(logits, dim=1)
                all_preds.append(preds.cpu())
                all_labels.append(labels.cpu())
    return torch.cat(all_preds), torch.cat(all_labels)

if eval_loader is not None:
    preds, labels = collect_preds(
        trainer, eval_loader,
        use_ttmc=USE_TTMC, win_sec=TTMC_WIN_SEC, min_total_sec=TTMC_MIN_TOTAL_SEC, sr=cfg.sample_rate
    )
    num_classes = cfg.num_classes

    # Confusion matrix + metrics
    cm = torch.zeros((num_classes, num_classes), dtype=torch.int32)
    for t, p in zip(labels, preds):
        cm[t.long(), p.long()] += 1

    per_class_acc = (cm.diag().float() / cm.sum(dim=1).clamp(min=1).float()).numpy()
    overall_acc = (preds == labels).float().mean().item()

    print("Overall Acc: %.4f" % overall_acc)
    print("Per-class Acc:")
    for i, acc in enumerate(per_class_acc):
        name = str(i)
        try:
            name = id2label[i]
        except Exception:
            pass
        print(f"  {i} ({name}): {acc:.4f}")
    print("Confusion Matrix:\n", cm.numpy())

    # Classification report
    try:
        from sklearn.metrics import classification_report
        target_names = [str(i) for i in range(num_classes)]
        try:
            target_names = [id2label[i] for i in range(num_classes)]
        except Exception:
            pass
        print("\nClassification Report:")
        print(classification_report(labels.numpy(), preds.numpy(), target_names=target_names, digits=4))
    except Exception:
        pass
else:
    print("No eval_loader available.")

# %%
# === SANITY CHECK: 1 batch để xem meta (optional) ===
batch = next(iter(train_loader))
((tok_sc, audio_sc, labels_sc), meta_sc) = batch
print("Batch audio shape:", audio_sc.shape)
print("Sample weights present:", isinstance(meta_sc.get("sample_weight", None), torch.Tensor))
print("Lengths:", meta_sc.get("audio_lengths")[:4], "... (showing first 4)")
print("Durations:", meta_sc.get("duration_sec")[:4], "... (first 4)")

Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_

Overall Acc: 0.2800
Per-class Acc:
  0 (angry): 0.2500
  1 (fear): 0.2000
  2 (happiness): 0.0000
  3 (neutral): 0.4000
  4 (sadness): 0.5000
Confusion Matrix:
 [[1 2 0 1 0]
 [1 1 0 2 1]
 [0 3 0 1 1]
 [1 2 0 2 0]
 [1 1 0 1 3]]

Classification Report:
              precision    recall  f1-score   support

       angry     0.2500    0.2500    0.2500         4
        fear     0.1111    0.2000    0.1429         5
   happiness     0.0000    0.0000    0.0000         5
     neutral     0.2857    0.4000    0.3333         5
     sadness     0.6000    0.5000    0.5455         6

    accuracy                         0.2800        25
   macro avg     0.2494    0.2700    0.2543        25
weighted avg     0.2634    0.2800    0.2661        25



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space'

Batch audio shape: torch.Size([8, 125269])
Sample weights present: True
Lengths: [118514, 101796, 40868, 53871] ... (showing first 4)
Durations: [7.4071875, 6.3623125, 2.55425, 3.3669375] ... (first 4)
