In [16]:

import os, sys, platform, torch, numpy as np

print("CWD:", os.getcwd())
print("Python:", platform.python_version())
print("Torch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())


CWD: /mnt/d/MER/src
Python: 3.12.3
Torch: 2.7.1+cu126 | CUDA available: True


In [17]:
import os, sys, platform, numpy as np, torch, importlib
from pathlib import Path

print("CWD:", os.getcwd())
print("Python:", platform.python_version())
print("Torch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())

# Nếu notebook ở project root: thêm "src" vào sys.path
if not (Path.cwd() / "configs").exists():
    sys.path.append(str(Path.cwd() / "src"))

# Reload các module để chắc chắn dùng bản mới
import configs.base as base_cfg
importlib.reload(base_cfg)
from configs.base import Config

import loading.dataloader as data_loader
importlib.reload(data_loader)

import model.networks as networks
importlib.reload(networks)
from model.networks import MER

import model.losses as losses_mod
importlib.reload(losses_mod)
from model.losses import get_loss

import training.trainer as trainer_mod
importlib.reload(trainer_mod)
from training.trainer import TorchTrainer

import training.callbacks as cbs_mod
importlib.reload(cbs_mod)
from training.callbacks import CheckpointsCallback 

import training.optimizers as opt_mod
importlib.reload(opt_mod)
from training.optimizers import split_param_groups, build_optimizer

from transformers import get_cosine_schedule_with_warmup, BatchEncoding

CWD: /mnt/d/MER/src
Python: 3.12.3
Torch: 2.7.1+cu126 | CUDA available: True


In [18]:
cfg = Config(
    name="MER_VNEMOS_maskfix_stable",
    checkpoint_dir="../checkpoints/mer_vnemos_maskfix_stable",

    # Train runtime
    num_epochs=30,
    batch_size=8,
    num_workers=2,

    # LR / optim (head lr set tại param_groups)
    learning_rate=2e-5,          
    optimizer_type="AdamW",
    adam_weight_decay=0.01,

    # Scheduler
    scheduler_type="cosine_warmup",
    warmup_ratio=0.05,
    scheduler_step_unit="step",

    # Loss
    loss_type="LabelSmoothingCE",
    label_smoothing=0.05,

    # Data
    data_root="../output",
    jsonl_dir="",
    sample_rate=16000,
    max_audio_sec=None,           
    text_max_length=96,

    # Model
    model_type="MemoCMT",
    text_encoder_type="phobert",
    text_encoder_ckpt="vinai/phobert-base",
    text_encoder_dim=768,
    text_unfreeze=False,

    audio_encoder_type="wav2vec2_xlsr",
    audio_encoder_ckpt="facebook/wav2vec2-large-xlsr-53",
    audio_encoder_dim=1024,
    audio_unfreeze=False,

    fusion_dim=768,
    fusion_head_output_type="mean",    # ổn định hơn "cls" lúc chưa chín
    linear_layer_output=[256, 128],
    dropout=0.10,

    # Tricks
    use_amp=True,
    max_grad_norm=1.0,

    # Gradual Unfreeze (TẮT ở run này; có thể bật lại sau)
    gradual_unfreeze_epoch=3,
    text_unfreeze_last_k=4,
    audio_unfreeze_last_k=4,

    # Checkpoints
    save_best_val=True,
    max_to_keep=2,
)

print("Checkpoint dir:", Path(cfg.checkpoint_dir).resolve())
base_dir = (Path(cfg.data_root) / (cfg.jsonl_dir or "")).resolve()
print("base_dir:", base_dir)
print("audio_root (auto):", Path(getattr(cfg, "audio_root", base_dir.parent)).resolve())
for fn in ["train.jsonl","valid.jsonl","test.jsonl"]:
    print(fn, (base_dir/fn).exists())


Checkpoint dir: /mnt/d/MER/checkpoints/mer_vnemos_maskfix_stable
base_dir: /mnt/d/MER/output
audio_root (auto): /mnt/d/MER
train.jsonl True
valid.jsonl True
test.jsonl True


In [19]:
from typing import List, Dict
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from loading.dataloader import VNEMOSDataset, _clean_text

class CollatorA:
    def __init__(self, tokenizer: AutoTokenizer, text_max_length: int = 96):
        self.tok = tokenizer
        self.text_max_length = text_max_length

    def __call__(self, batch: List[Dict]):
        # lengths trước pad
        audios_raw = [b["audio"] for b in batch]
        lengths = [int(a.numel()) for a in audios_raw]
        T = max(lengths) if lengths else 0

        # pad audio lên T
        aud = [a if a.numel() == T else torch.nn.functional.pad(a, (0, T - a.numel())) for a in audios_raw]
        audio_tensor = torch.stack(aud, dim=0)  # (B, T)

        # tokenize text
        texts = [b["text"] for b in batch]
        tok = self.tok(
            texts,
            padding=True,
            truncation=True,
            max_length=self.text_max_length,
            return_tensors="pt",
            add_prefix_space=True,
        )
        tok = {k: v for k, v in tok.items()}
        labels = torch.tensor([b["label"] for b in batch], dtype=torch.long)

        out = (tok, audio_tensor, labels)
        meta = {
            "utterance_id": [b.get("utterance_id") for b in batch],
            "speaker_id":   [b.get("speaker_id") for b in batch],
            "audio_lengths": lengths,  # QUAN TRỌNG cho mask audio
        }
        return out, meta


def build_loaders(cfg: Config):
    base_dir = (Path(cfg.data_root) / (getattr(cfg, "jsonl_dir", "") or "")).resolve()
    has_valid = (base_dir / "valid.jsonl").exists()
    eval_split = "valid" if has_valid else "test"

    train_set = VNEMOSDataset(cfg, "train")
    lbl2id = train_set.label2id
    eval_set  = VNEMOSDataset(cfg, eval_split, label2id=lbl2id)

    collate = CollatorA(train_set.tokenizer, text_max_length=getattr(cfg, "text_max_length", 96))

    train_loader = DataLoader(
        train_set,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        collate_fn=collate,
        pin_memory=True,
    )
    eval_loader = DataLoader(
        eval_set,
        batch_size=max(1, cfg.batch_size),
        shuffle=False,
        num_workers=cfg.num_workers,
        collate_fn=collate,
        pin_memory=True,
    )
    return train_loader, eval_loader, lbl2id

train_loader, eval_loader, label2id = build_loaders(cfg)

id2label = [k for k, v in sorted(label2id.items(), key=lambda x: x[1])]
cfg.num_classes = len(id2label)
print("Classes:", cfg.num_classes, id2label)
print(f"Train batches: {len(train_loader)} | Eval batches: {len(eval_loader) if eval_loader else 0}")



Classes: 5 ['angry', 'fear', 'happiness', 'neutral', 'sadness']
Train batches: 25 | Eval batches: 4


In [20]:
from model.modules import build_audio_encoder
import torch

# Lấy 1 batch
batch = next(iter(train_loader))
# ĐÚNG: batch là ((tok, audio, labels), meta)
((tok_sc, audio_sc, labels_sc), meta_sc) = batch

device_tmp = "cuda" if torch.cuda.is_available() else "cpu"

# lengths (số mẫu audio trước pad)
lengths = torch.tensor(meta_sc["audio_lengths"], device=device_tmp, dtype=torch.long)  # (B,)
B, T = audio_sc.shape

# attention mask cho đầu vào Wav2Vec2: 1=keep, 0=pad
attn_mask_audio_input = (torch.arange(T, device=device_tmp).unsqueeze(0) < lengths.unsqueeze(1)).long()

# Chỉ khởi tạo audio encoder để kiểm tra (nhanh & nhẹ hơn)
audio_enc = build_audio_encoder(cfg).to(device_tmp)
audio_enc.eval()

with torch.no_grad():
    # Mã hoá audio (sau fix trong mã nguồn, get_feat_lengths đã clamp min=1)
    a_feat = audio_enc(audio_sc.to(device_tmp), attention_mask=attn_mask_audio_input)  # (B, L_a, D)
    L_valid = audio_enc.get_feat_lengths(lengths)  # (B,)
    L_a = a_feat.size(1)
    L_valid = L_valid.clamp(min=1, max=L_a)
    # key_padding_mask cho MHA: True=PAD
    kpm_audio = (torch.arange(L_a, device=device_tmp).unsqueeze(0) >= L_valid.unsqueeze(1))

any_all_masked = bool(kpm_audio.all(dim=1).any().item())

print("Sanity | audio input:", tuple(audio_sc.shape), "| encoded len L_a:", L_a)
print("Sanity | first 8 L_valid:", L_valid[:8].tolist())
print("Sanity | Any sample all-masked? ->", any_all_masked)

# Dọn tài nguyên tạm
del audio_enc
if torch.cuda.is_available():
    torch.cuda.empty_cache()

Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_

Sanity | audio input: (8, 248918) | encoded len L_a: 777
Sanity | first 8 L_valid: [777, 250, 703, 429, 659, 444, 391, 180]
Sanity | Any sample all-masked? -> False


In [21]:
device = "cuda" if torch.cuda.is_available() else "cpu"
network = MER(cfg, device=device).to(device)
criterion = get_loss(cfg)

class MERNotebookTrainer(TorchTrainer):
    def __init__(self, cfg, network, criterion, **kwargs):
        super().__init__(**kwargs)
        self.cfg = cfg
        self.network = network
        self.criterion = criterion
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.network.to(self.device)
        self.use_amp = bool(getattr(cfg, "use_amp", torch.cuda.is_available()))
        self.max_grad_norm = float(getattr(cfg, "max_grad_norm", 0.0))
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)

    def _to_device_text(self, x):
        if isinstance(x, BatchEncoding):
            x = x.to(self.device)
            return {k: v for k, v in x.items()}
        if isinstance(x, dict):
            return {k: v.to(self.device) for k, v in x.items()}
        return x.to(self.device)

    def _standardize_batch(self, batch):
        # batch = ((tok, audio, labels), meta)
        if isinstance(batch, (tuple, list)) and len(batch) == 2:
            a, meta = batch
            if isinstance(a, (tuple, list)) and len(a) == 3:
                tok, audio, labels = a
                return tok, audio, labels, meta
        raise ValueError(f"Batch structure không hỗ trợ: preview={str(batch)[:200]}")

    def train_step(self, batch):
        self.network.train()
        self.optimizer.zero_grad(set_to_none=True)

        (tok, audio, labels, meta) = self._standardize_batch(batch)
        audio  = audio.to(self.device, non_blocking=True)
        labels = labels.to(self.device, non_blocking=True)
        tok    = self._to_device_text(tok)

        if self.use_amp:
            with torch.cuda.amp.autocast():
                out = self.network(tok, audio, meta=meta)
                logits = out[0] if isinstance(out, (tuple, list)) else out
                loss = self.criterion(out, labels) if isinstance(out, (tuple, list)) else self.criterion(logits, labels)
            self.scaler.scale(loss).backward()
            if self.max_grad_norm and self.max_grad_norm > 0:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            out = self.network(tok, audio, meta=meta)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            loss = self.criterion(out, labels) if isinstance(out, (tuple, list)) else self.criterion(logits, labels)
            loss.backward()
            if self.max_grad_norm and self.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
            self.optimizer.step()

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        return {"loss": float(loss.detach().cpu()), "acc": float(acc.detach().cpu())}

    def test_step(self, batch):
        self.network.eval()
        (tok, audio, labels, meta) = self._standardize_batch(batch)
        audio  = audio.to(self.device, non_blocking=True)
        labels = labels.to(self.device, non_blocking=True)
        tok    = self._to_device_text(tok)
        with torch.no_grad():
            out = self.network(tok, audio, meta=meta)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            loss = self.criterion(out, labels) if isinstance(out, (tuple, list)) else self.criterion(logits, labels)
            preds = torch.argmax(logits, dim=1)
            acc = (preds == labels).float().mean()
        return {"val_loss": float(loss.detach().cpu()),
                "val_acc": float(acc.detach().cpu()),
                "preds": preds.detach().cpu(),
                "targets": labels.detach().cpu()}

trainer = MERNotebookTrainer(cfg, network, criterion, log_dir="logs")
print("Trainer ready. Device:", trainer.device)

Trainer ready. Device: cuda


  self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)


In [22]:
enc_lr  = 5e-6
head_lr = 2e-5

param_groups = split_param_groups(trainer, lr_enc=enc_lr, lr_head=head_lr, weight_decay=0.01)
optimizer = build_optimizer("adamw", param_groups, lr=head_lr, weight_decay=0.01)

total_steps = len(train_loader) * cfg.num_epochs
warmup_steps = max(1, int(cfg.warmup_ratio * total_steps))
print("Total steps:", total_steps, "| Warmup steps:", warmup_steps)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps,
)

trainer.compile(
    optimizer=optimizer,
    scheduler=scheduler,
    lr=head_lr,
    param_groups=None,
    scheduler_step_unit=cfg.scheduler_step_unit,  
)
print("Param groups:", len(optimizer.param_groups))



Total steps: 750 | Warmup steps: 37
Param groups: 1


In [23]:
ckpt_cb = CheckpointsCallback(
    cfg.checkpoint_dir,
    save_freq=200,
    max_to_keep=cfg.max_to_keep,
    save_best_val=True,
    monitor="val_macro_f1",
    mode="max",
)
callbacks = [ckpt_cb]





In [24]:
trainer.fit(train_loader, epochs=cfg.num_epochs, eval_data=eval_loader, callbacks=callbacks)
print("Best checkpoint:", getattr(ckpt_cb, "best_path", ""))


Epoch 1/30
2025-08-23 05:11:41,857 - Training - INFO - Epoch 1/30
Epoch 1:   0%|          | 0/25 [00:00<?, ?it/s]Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_sp

Best checkpoint: ../checkpoints/mer_vnemos_maskfix_stable/best_val_macro_f1/checkpoint_0.pth


In [25]:
best_path = getattr(ckpt_cb, "best_path", "")
if isinstance(best_path, str) and best_path.endswith(".pth") and Path(best_path).exists():
    print("Loading best weights for eval:", best_path)
    state = torch.load(best_path, map_location=trainer.device)
    trainer.network.load_state_dict(state)

import numpy as np
import torch

def collect_preds(trainer, loader):
    all_preds, all_labels = [], []
    trainer.network.eval()
    with torch.no_grad():
        for batch in loader:
            (tok, audio, labels, meta) = trainer._standardize_batch(batch)
            audio  = audio.to(trainer.device)
            labels = labels.to(trainer.device)
            tok    = trainer._to_device_text(tok)

            out = trainer.network(tok, audio, meta=meta)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    return torch.cat(all_preds), torch.cat(all_labels)

if eval_loader is not None:
    preds, labels = collect_preds(trainer, eval_loader)
    num_classes = cfg.num_classes

    cm = torch.zeros((num_classes, num_classes), dtype=torch.int32)
    for t, p in zip(labels, preds):
        cm[t.long(), p.long()] += 1

    per_class_acc = (cm.diag().float() / cm.sum(dim=1).clamp(min=1).float()).numpy()
    overall_acc = (preds == labels).float().mean().item()

    print("Overall Acc: %.4f" % overall_acc)
    print("Per-class Acc:")
    for i, acc in enumerate(per_class_acc):
        name = id2label[i] if i < len(id2label) else str(i)
        print(f"  {i} ({name}): {acc:.4f}")
    print("Confusion Matrix:\n", cm.numpy())

    try:
        from sklearn.metrics import classification_report
        target_names = [id2label[i] for i in range(num_classes)]
        print("\nClassification Report:")
        print(classification_report(labels.numpy(), preds.numpy(), target_names=target_names, digits=4))
    except Exception as e:
        print("sklearn not available:", e)
else:
    print("No eval_loader available.")

Loading best weights for eval: ../checkpoints/mer_vnemos_maskfix_stable/best_val_macro_f1/checkpoint_0.pth


Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_prefix_space': True} not recognized.
Keyword arguments {'add_

Overall Acc: 0.4400
Per-class Acc:
  0 (angry): 0.5000
  1 (fear): 0.2000
  2 (happiness): 0.4000
  3 (neutral): 0.4000
  4 (sadness): 0.6667
Confusion Matrix:
 [[2 1 1 0 0]
 [1 1 1 2 0]
 [0 1 2 2 0]
 [3 0 0 2 0]
 [0 1 0 1 4]]

Classification Report:
              precision    recall  f1-score   support

       angry     0.3333    0.5000    0.4000         4
        fear     0.2500    0.2000    0.2222         5
   happiness     0.5000    0.4000    0.4444         5
     neutral     0.2857    0.4000    0.3333         5
     sadness     1.0000    0.6667    0.8000         6

    accuracy                         0.4400        25
   macro avg     0.4738    0.4333    0.4400        25
weighted avg     0.5005    0.4400    0.4560        25

