In [None]:
# %%
# Environment & quick sanity
import os, sys, platform, torch, numpy as np
from pathlib import Path

print("CWD:", os.getcwd())
print("Python:", platform.python_version())
print("Torch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())

# Nếu notebook ở project root: thêm "src" vào sys.path
if not (Path.cwd() / "configs").exists():
    sys.path.append(str(Path.cwd() / "src"))


In [None]:
# %%
# Reload updated modules
import importlib

import configs.base as base_cfg
importlib.reload(base_cfg)
from configs.base import Config

import loading.dataloader as data_loader
importlib.reload(data_loader)
from loading.dataloader import build_train_test_dataset

import model.networks as networks
importlib.reload(networks)
from model.networks import MER  # (MemoCMT = MER alias có trong networks.py)

import model.losses as losses_mod
importlib.reload(losses_mod)
from model.losses import get_loss

import training.trainer as trainer_mod
importlib.reload(trainer_mod)
from training.trainer import TorchTrainer

import training.callbacks as cbs_mod
importlib.reload(cbs_mod)
from training.callbacks import CheckpointsCallback

import training.optimizers as opt_mod
importlib.reload(opt_mod)
from training.optimizers import split_param_groups, build_optimizer

from transformers import get_cosine_schedule_with_warmup, BatchEncoding


In [None]:
# %%
# Config — đã đồng bộ với code mới (num_classes=5, momentum, bucketing/sampler options)
cfg = Config(
    name="MER_VNEMOS_maskfix_stable",
    checkpoint_dir="../checkpoints/mer_vnemos_maskfix_stable",

    # Train runtime
    num_epochs=30,
    batch_size=8,
    num_workers=2,

    # LR / optim (head lr set tại param_groups)
    learning_rate=2e-5,
    optimizer_type="AdamW",
    adam_weight_decay=0.01,

    # Scheduler
    scheduler_type="cosine_warmup",
    warmup_ratio=0.05,
    scheduler_step_unit="step",

    # Loss (chọn 1: "FocalLoss" hoặc "LabelSmoothingCE")
    loss_type="LabelSmoothingCE",
    label_smoothing=0.05,

    # Data
    data_root="../output",
    jsonl_dir="",
    sample_rate=16000,
    max_audio_sec=None,           # không crop cứng
    text_max_length=96,

    # Bucketing & Sampler chống length-bias
    use_length_bucket=True,
    length_bucket_size=64,
    bucketing_text_alpha=0.03,
    use_weighted_sampler=True,    # nếu bật length bucket, sampler sẽ dùng cho nhánh 'else'
    lenfreq_alpha=0.5,

    # Model
    model_type="MemoCMT",
    text_encoder_type="phobert",
    text_encoder_ckpt="vinai/phobert-base",
    text_encoder_dim=768,
    text_unfreeze=False,

    audio_encoder_type="wav2vec2_xlsr",
    audio_encoder_ckpt="facebook/wav2vec2-large-xlsr-53",
    audio_encoder_dim=1024,
    audio_unfreeze=False,

    fusion_dim=768,
    fusion_head_output_type="cls",     # khuyến nghị: "cls" hoặc "mean"
    linear_layer_output=[256, 128],
    dropout=0.10,

    # Tricks
    use_amp=True,
    max_grad_norm=1.0,

    # Gradual Unfreeze (có thể bật sau)
    gradual_unfreeze_epoch=3,
    text_unfreeze_last_k=4,
    audio_unfreeze_last_k=4,

    # Checkpoints
    save_best_val=True,
    max_to_keep=2,
)

print("Checkpoint dir:", Path(cfg.checkpoint_dir).resolve())
base_dir = (Path(cfg.data_root) / (cfg.jsonl_dir or "")).resolve()
print("base_dir:", base_dir)
print("audio_root (auto):", Path(getattr(cfg, "audio_root", base_dir.parent)).resolve())
for fn in ["train.jsonl","valid.jsonl","test.jsonl"]:
    print(fn, (base_dir/fn).exists())


In [None]:
# %%
# Build loaders bằng builder mới (đã có bucketing mix và audio_attn_mask)
train_loader, eval_loader = build_train_test_dataset(cfg)

# Lấy id2label từ dataset train (builder nội bộ đã cố định label2id trên train)
# Quick peek để suy ra nhãn:
from loading.dataloader import VNEMOSDataset
tmp_ds = VNEMOSDataset(cfg, "train")
label2id = tmp_ds.label2id
id2label = [k for k, v in sorted(label2id.items(), key=lambda x: x[1])]
cfg.num_classes = len(id2label)

print("Classes:", cfg.num_classes, id2label)
print(f"Train batches: {len(train_loader)} | Eval batches: {len(eval_loader) if eval_loader else 0}")


In [None]:
# %%
# Sanity: kiểm tra mask & encoded lengths từ 1 batch (dùng meta["audio_attn_mask"])
import torch
from model.modules import build_audio_encoder

batch = next(iter(train_loader))
((tok_sc, audio_sc, labels_sc), meta_sc) = batch
device_tmp = "cuda" if torch.cuda.is_available() else "cpu"

attn_mask_audio_input = meta_sc.get("audio_attn_mask", None)
if attn_mask_audio_input is None:
    # fallback (không cần nếu collator mới): dựng từ lengths
    lengths = torch.tensor(meta_sc["audio_lengths"], device=device_tmp, dtype=torch.long)
    B, T = audio_sc.shape
    attn_mask_audio_input = (torch.arange(T, device=device_tmp).unsqueeze(0) < lengths.unsqueeze(1)).long()
else:
    attn_mask_audio_input = attn_mask_audio_input.to(device_tmp)

audio_enc = build_audio_encoder(cfg).to(device_tmp).eval()
with torch.no_grad():
    a_feat = audio_enc(audio_sc.to(device_tmp), attention_mask=attn_mask_audio_input)  # (B, L_a, D)
    if "audio_lengths" in meta_sc:
        lengths = torch.tensor(meta_sc["audio_lengths"], device=device_tmp, dtype=torch.long)
        L_valid = audio_enc.get_feat_lengths(lengths).clamp(min=1, max=a_feat.size(1))
        kpm_audio = (torch.arange(a_feat.size(1), device=device_tmp).unsqueeze(0) >= L_valid.unsqueeze(1))
        any_all_masked = bool(kpm_audio.all(dim=1).any().item())
    else:
        any_all_masked = False

print("Sanity | audio input:", tuple(audio_sc.shape), "| encoded len L_a:", a_feat.size(1))
if "audio_lengths" in meta_sc:
    print("Sanity | first 8 L_valid:", L_valid[:8].tolist())
print("Sanity | Any sample all-masked? ->", any_all_masked)

del audio_enc
if torch.cuda.is_available():
    torch.cuda.empty_cache()


In [None]:
# %%
# Trainer “nhẹ” cho notebook — giữ đúng cách chuyển BatchEncoding thành **kwargs
import torch
from transformers import BatchEncoding

device = "cuda" if torch.cuda.is_available() else "cpu"
network = MER(cfg, device=device).to(device)
criterion = get_loss(cfg)

class MERNotebookTrainer(TorchTrainer):
    def __init__(self, cfg, network, criterion, **kwargs):
        super().__init__(**kwargs)
        self.cfg = cfg
        self.network = network
        self.criterion = criterion
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.network.to(self.device)
        self.use_amp = bool(getattr(cfg, "use_amp", torch.cuda.is_available()))
        self.max_grad_norm = float(getattr(cfg, "max_grad_norm", 0.0))
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp)

    def _to_device_text(self, x):
        if isinstance(x, BatchEncoding):
            x = x.to(self.device)
            return {k: v for k, v in x.items()}
        if isinstance(x, dict):
            return {k: v.to(self.device) for k, v in x.items()}
        return x.to(self.device)

    def _standardize_batch(self, batch):
        # batch = ((tok, audio, labels), meta)
        if isinstance(batch, (tuple, list)) and len(batch) == 2:
            a, meta = batch
            if isinstance(a, (tuple, list)) and len(a) == 3:
                tok, audio, labels = a
                return tok, audio, labels, meta
        raise ValueError(f"Batch structure không hỗ trợ: preview={str(batch)[:200]}")

    def train_step(self, batch):
        self.network.train()
        self.optimizer.zero_grad(set_to_none=True)

        (tok, audio, labels, meta) = self._standardize_batch(batch)
        audio  = audio.to(self.device, non_blocking=True)
        labels = labels.to(self.device, non_blocking=True)
        tok    = self._to_device_text(tok)

        if self.use_amp:
            with torch.cuda.amp.autocast():
                out = self.network(tok, audio, meta=meta)
                logits = out[0] if isinstance(out, (tuple, list)) else out
                loss = self.criterion(out, labels) if isinstance(out, (tuple, list)) else self.criterion(logits, labels)
            self.scaler.scale(loss).backward()
            if self.max_grad_norm and self.max_grad_norm > 0:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            out = self.network(tok, audio, meta=meta)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            loss = self.criterion(out, labels) if isinstance(out, (tuple, list)) else self.criterion(logits, labels)
            loss.backward()
            if self.max_grad_norm and self.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
            self.optimizer.step()

        preds = torch.argmax(logits, dim=1)
        acc = (preds == labels).float().mean()
        return {"loss": float(loss.detach().cpu()), "acc": float(acc.detach().cpu())}

    def test_step(self, batch):
        self.network.eval()
        (tok, audio, labels, meta) = self._standardize_batch(batch)
        audio  = audio.to(self.device, non_blocking=True)
        labels = labels.to(self.device, non_blocking=True)
        tok    = self._to_device_text(tok)
        with torch.no_grad():
            out = self.network(tok, audio, meta=meta)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            loss = self.criterion(out, labels) if isinstance(out, (tuple, list)) else self.criterion(logits, labels)
            preds = torch.argmax(logits, dim=1)
            acc = (preds == labels).float().mean()
        return {"val_loss": float(loss.detach().cpu()),
                "val_acc": float(acc.detach().cpu()),
                "preds": preds.detach().cpu(),
                "targets": labels.detach().cpu()}

trainer = MERNotebookTrainer(cfg, network, criterion, log_dir="logs")
print("Trainer ready. Device:", trainer.device)


In [None]:
# %%
# Optimizer (tách LR: encoder thấp, head cao) + Scheduler cosine warmup
enc_lr  = 5e-6
head_lr = 2e-5

param_groups = split_param_groups(trainer, lr_enc=enc_lr, lr_head=head_lr, weight_decay=0.01)
optimizer = build_optimizer("adamw", param_groups, lr=head_lr, weight_decay=0.01)

total_steps = len(train_loader) * cfg.num_epochs
warmup_steps = max(1, int(cfg.warmup_ratio * total_steps))
print("Total steps:", total_steps, "| Warmup steps:", warmup_steps)

scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps,
)

trainer.compile(
    optimizer=optimizer,
    scheduler=scheduler,
    lr=head_lr,
    param_groups=None,
    scheduler_step_unit=cfg.scheduler_step_unit,  
)
print("Param groups:", len(optimizer.param_groups))


In [None]:

# %%
# Callbacks — monitor macro-F1 (TorchTrainer của bạn cần tính dựa trên preds/targets)
ckpt_cb = CheckpointsCallback(
    cfg.checkpoint_dir,
    save_freq=200,
    max_to_keep=cfg.max_to_keep,
    save_best_val=True,
    monitor="val_macro_f1",  # đảm bảo TorchTrainer tổng hợp macro-F1 từ preds/targets
    mode="max",
)
callbacks = [ckpt_cb]


In [None]:
# %%
# Train
trainer.fit(train_loader, epochs=cfg.num_epochs, eval_data=eval_loader, callbacks=callbacks)
print("Best checkpoint:", getattr(ckpt_cb, "best_path", ""))


In [None]:
# %%
# Đánh giá & báo cáo
best_path = getattr(ckpt_cb, "best_path", "")
if isinstance(best_path, str) and best_path.endswith(".pth") and Path(best_path).exists():
    print("Loading best weights for eval:", best_path)
    state = torch.load(best_path, map_location=trainer.device)
    trainer.network.load_state_dict(state)

import numpy as np
import torch

def collect_preds(trainer, loader):
    all_preds, all_labels = [], []
    trainer.network.eval()
    with torch.no_grad():
        for batch in loader:
            (tok, audio, labels, meta) = trainer._standardize_batch(batch)
            audio  = audio.to(trainer.device)
            labels = labels.to(trainer.device)
            tok    = trainer._to_device_text(tok)

            out = trainer.network(tok, audio, meta=meta)
            logits = out[0] if isinstance(out, (tuple, list)) else out
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    return torch.cat(all_preds), torch.cat(all_labels)

if eval_loader is not None:
    preds, labels = collect_preds(trainer, eval_loader)
    num_classes = cfg.num_classes

    cm = torch.zeros((num_classes, num_classes), dtype=torch.int32)
    for t, p in zip(labels, preds):
        cm[t.long(), p.long()] += 1

    per_class_acc = (cm.diag().float() / cm.sum(dim=1).clamp(min=1).float()).numpy()
    overall_acc = (preds == labels).float().mean().item()

    print("Overall Acc: %.4f" % overall_acc)
    print("Per-class Acc:")
    for i, acc in enumerate(per_class_acc):
        name = id2label[i] if i < len(id2label) else str(i)
        print(f"  {i} ({name}): {acc:.4f}")
    print("Confusion Matrix:\n", cm.numpy())

    try:
        from sklearn.metrics import classification_report
        target_names = [id2label[i] for i in range(num_classes)]
        print("\nClassification Report:")
        print(classification_report(labels.numpy(), preds.numpy(), target_names=target_names, digits=4))
    except Exception as e:
        print("sklearn not available:", e)
else:
    print("No eval_loader available.")


In [None]:
# %%
# Save model weights to HDF5 (.h5)
import json, h5py
from pathlib import Path
import torch
import numpy as np

# Đảm bảo đã train xong và trainer.network đang giữ trọng số tốt nhất (nếu bạn đã load best trước đó).
h5_path = Path(cfg.checkpoint_dir) / "model_weights.h5"

state = trainer.network.state_dict()  # OrderedDict[str, Tensor]

with h5py.File(h5_path, "w") as f:
    # --- Metadata ---
    f.attrs["model_name"] = "MER"
    f.attrs["num_classes"] = int(cfg.num_classes)
    f.attrs["id2label"] = json.dumps(id2label, ensure_ascii=False)
    f.attrs["torch_version"] = torch.__version__

    # Lưu toàn bộ cấu hình (serialize về string để an toàn)
    cfg_group = f.create_group("config")
    for k, v in vars(cfg).items():
        try:
            cfg_group.attrs[k] = json.dumps(v, ensure_ascii=False)
        except TypeError:
            cfg_group.attrs[k] = str(v)

    # --- Weights ---
    weights_group = f.create_group("state_dict")
    for name, tensor in state.items():
        arr = tensor.detach().cpu().numpy()
        # h5py tạo dataset theo tên tham số; thay dấu chấm bằng gạch dưới nếu bạn muốn phẳng
        dset = weights_group.create_dataset(name, data=arr)
        dset.attrs["shape"] = arr.shape
        dset.attrs["dtype"] = str(arr.dtype)

print(f"Saved HDF5 weights to: {h5_path.resolve()}")
