In [None]:
# ① Install NumPy from the regular PyPI repository (version 1.26.4 used here as an example)
!pip install --no-cache-dir "numpy<2.0"

# ② Next, install PyTorch / vision / audio from the cu121 index
!pip install --no-cache-dir \
  torch==2.2.2+cu121 torchvision==0.17.2+cu121 torchaudio==2.2.2+cu121 \
  --index-url https://download.pytorch.org/whl/cu121

In [None]:
from google.colab import drive
drive.mount('/content/drive')   # First, mount Google Drive normally

!pip -q install pydub

import os
import shutil
import zipfile
import warnings
import pandas as pd
from datetime import datetime
from pydub import AudioSegment
from google.colab import files  # <-- Required for downloading to local machine

# ------------------------------------------------------------
# 0) Path Settings (Existing folders are read-only)
# ------------------------------------------------------------
BASE_EXTR = "/content/drive/MyDrive/DAIC-WOZ/extracted"  # Existing
PHQ_CSV   = "/content/combined_PHQ_sorted.csv"           # Existing

TMP_ROOT  = "/content/m5ov1_tmp"                         # Local temporary directory
TMP_AUD   = f"{TMP_ROOT}/audio"
TMP_VIS   = f"{TMP_ROOT}/visual"
os.makedirs(TMP_AUD, exist_ok=True)
os.makedirs(TMP_VIS, exist_ok=True)

# ------------------------------------------------------------
# 1) Parameters
# ------------------------------------------------------------
N_MERGE = 5      # Number of segments to merge
OVL     = 1      # Number of overlapping segments
STEP    = N_MERGE - OVL      # = 4
FPS     = 30.0   # Frame rate from OpenFace
SR      = 16000  # Sample rate for audio

# ------------------------------------------------------------
# 2) TRANSCRIPT Loader
# ------------------------------------------------------------
def load_transcript(path):
    rows = []
    with open(path, 'r', encoding='utf-8', errors='replace') as f:
        next(f, None)  # Skip the header row
        for ln in f:
            pts = ln.strip().split(maxsplit=3)
            if len(pts) < 4:
                continue
            s, e, spk, val = pts
            rows.append(dict(start_time=float(s),
                             stop_time=float(e),
                             speaker=spk,
                             value=val))
    return pd.DataFrame(rows)

# ------------------------------------------------------------
# 3) AUDIO Segment Generation (to local temp directory)
# ------------------------------------------------------------
def extract_audio(pid:int):
    tr  = f"{BASE_EXTR}/{pid}_P/{pid}_TRANSCRIPT.csv"
    wav = f"{BASE_EXTR}/{pid}_P/{pid}_AUDIO.wav"
    if not (os.path.exists(tr) and os.path.exists(wav)):
        warnings.warn(f"[AUDIO] missing {pid}")
        return []

    df   = load_transcript(tr)
    parts = df[df.speaker=='Participant']
    full  = AudioSegment.from_wav(wav)

    segs = [full[int(r.start_time*1000):int(r.stop_time*1000)] for _,r in parts.iterrows()]
    out  = []
    for i in range(0, len(segs)-N_MERGE+1, STEP):
        merged = sum(segs[i:i+N_MERGE])
        name   = f"{pid}_m5ov1_{i//STEP}.wav"
        path   = f"{TMP_AUD}/{name}"
        merged.export(path, format="wav")
        out.append(path)
    return out

# ------------------------------------------------------------
# 4) VISUAL Segment Generation (to local temp directory)
# ------------------------------------------------------------
def _safe_read(fp):
    try:  # Try reading as comma-separated
        return pd.read_csv(fp, sep=",", header=0)
    except Exception: # Fallback to whitespace-delimited
        return pd.read_csv(fp, delim_whitespace=True, header=0)

def _rename_dups(df, existing_cols, tag):
    new_cols = []
    for c in df.columns:
        if c in existing_cols and c != 'timestamp':
            new_cols.append(f"{c}{tag}")
        else:
            new_cols.append(c)
    df.columns = new_cols
    return df

def load_clnf(pid:int):
    folder = f"{BASE_EXTR}/{pid}_P"
    files  = [
        f"{pid}_CLNF_pose.txt",
        f"{pid}_CLNF_gaze.txt",
        f"{pid}_CLNF_features3D.txt",
        f"{pid}_CLNF_features.txt",
        f"{pid}_CLNF_AUs.txt"
    ]
    merged = None
    for fn in files:
        fp = f"{folder}/{fn}"
        if not os.path.exists(fp):
            continue
        df = _safe_read(fp)
        df.columns = [c.strip() for c in df.columns]
        if "timestamp" not in df.columns and "frame" in df.columns:
            df["timestamp"] = df["frame"] / FPS
        if merged is None:
            merged = df
        else:
            tag = f"_{fn.split('_')[-1].split('.')[0]}"
            df = _rename_dups(df, set(merged.columns), tag)
            merged = merged.merge(df, on="timestamp", how="outer")
    return pd.DataFrame() if merged is None else merged.sort_values("timestamp").reset_index(drop=True)

def extract_visual(pid:int):
    tr = f"{BASE_EXTR}/{pid}_P/{pid}_TRANSCRIPT.csv"
    if not os.path.exists(tr):
        return []
    df_t = load_transcript(tr)
    parts = df_t[df_t.speaker=='Participant']
    df_c = load_clnf(pid)
    if df_c.empty:
        return []

    segs = []
    for _,r in parts.iterrows():
        seg = df_c[(df_c.timestamp >= r.start_time) & (df_c.timestamp <= r.stop_time)]
        segs.append(seg)

    out = []
    for i in range(0, len(segs)-N_MERGE+1, STEP):
        merged = pd.concat(segs[i:i+N_MERGE], ignore_index=True)
        name   = f"{pid}_m5ov1_{i//STEP}.csv"
        path   = f"{TMP_VIS}/{name}"
        merged.to_csv(path, index=False)
        out.append(path)
    return out

# ------------------------------------------------------------
# 5) Load PHQ Labels
# ------------------------------------------------------------
df_phq = pd.read_csv(PHQ_CSV)
label_map = dict(zip(df_phq.Participant_ID, df_phq.PHQ8_Binary))

# ------------------------------------------------------------
# 6) Generation Loop (writing to local disk)
# ------------------------------------------------------------
audio_all, vis_all, labels = [], [], []
for pid, lbl in label_map.items():
    a_list = extract_audio(pid)
    v_list = extract_visual(pid)
    ln = min(len(a_list), len(v_list))
    for i in range(ln):
        audio_all.append(a_list[i])
        vis_all.append(v_list[i])
        labels.append(lbl)

# Save metadata CSV to the local temp directory
meta_csv = f"{TMP_ROOT}/dataset_info_all_ov1.csv"
pd.DataFrame(dict(audio_path=audio_all,
                  visual_path=vis_all,
                  label=labels)).to_csv(meta_csv, index=False)

print(f"✅ Local generation done: {len(audio_all)} segments")

# ------------------------------------------------------------
# 7) Zip -> Copy to Drive -> Download to Local
# ------------------------------------------------------------
ts       = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
zip_path = f"/content/m5ov1_segments_{ts}.zip"

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
    for root, _, files_in_dir in os.walk(TMP_ROOT):
        for fn in files_in_dir:
            abs_path = os.path.join(root, fn)
            rel_path = os.path.relpath(abs_path, TMP_ROOT)
            zf.write(abs_path, arcname=rel_path)

print(f"🎁 Zipped → {zip_path}")

drive_dest = f"/content/drive/MyDrive/DAIC-WOZ/m5ov1_segments_{ts}.zip"
shutil.copy(zip_path, drive_dest)
print(f"🚀 Copied to Drive → {drive_dest}")

# Download directly in Colab (be mindful of the file size)
files.download(zip_path)

In [None]:
# ============================================================
# One-shot script to merge o1pro prediction
# into the 5-merge / 1-overlap dataset.
# ============================================================
!pip -q install pandas

from google.colab import drive
drive.mount('/content/drive')

import os, zipfile, re
import pandas as pd

# ---------- 0) Set paths ----------
ZIP_PATH   = "/content/drive/MyDrive/DAIC-WOZ/m5ov1_segments_20250422_152451.zip"  # <- Path to the saved ZIP file
TEXT_PATH  = "/content/text-modality-result.csv"    # <- Path to the text prediction results (ID, o1pro-prediction)
WORK_DIR   = "/content/m5ov1_data"                  # Local extraction destination
OUT_CSV    = "/content/drive/MyDrive/DAIC-WOZ/dataset_info_all_text_ov1.csv"

# ---------- 1) Unzip the archive (only once) ----------
if not os.path.exists(f"{WORK_DIR}/dataset_info_all_ov1.csv"):
    os.makedirs(WORK_DIR, exist_ok=True)
    with zipfile.ZipFile(ZIP_PATH, 'r') as zf:
        zf.extractall(WORK_DIR)
    print(f"✔️  Unzipped into {WORK_DIR}")

# ---------- 2) Load the base CSV ----------
base_csv = f"{WORK_DIR}/dataset_info_all_ov1.csv"
df_base  = pd.read_csv(base_csv, low_memory=False)

# If 'participant_id' column doesn't exist, extract it from the audio filename
if "participant_id" not in df_base.columns:
    def get_pid(path):
        fn = os.path.basename(path)
        m  = re.match(r"(\d+)_", fn)
        return int(m.group(1)) if m else -1
    df_base["participant_id"] = df_base["audio_path"].apply(get_pid)

# ---------- 3) Load the text results ----------
df_text = pd.read_csv(TEXT_PATH)
df_text = df_text.rename(columns={"ID": "participant_id"})        # Align column names for merging

# ---------- 4) Merge (left join) ----------
df_merged = pd.merge(
    df_base,
    df_text,                # <- This brings in the 'o1pro-prediction' column
    on="participant_id",
    how="left"
)

# Fill NaN (no prediction) with "not depressed"
df_merged["o1pro-prediction"] = df_merged["o1pro-prediction"].fillna("not depressed")

# ---------- 5) Save the result ----------
df_merged.to_csv(OUT_CSV, index=False)
print(f"✅ Saved {len(df_merged)} rows → {OUT_CSV}")

Mounted at /content/drive
✔️  Unzipped into /content/m5ov1_data
✅ Saved 7972 rows → /content/drive/MyDrive/DAIC-WOZ/dataset_info_all_text_ov1.csv


In [None]:
# ============================================================
# multimodal_daicwoz_v3_attnpool_dropout.py
# ------------------------------------------------------------
# Tri‑modal depression detector (audio + visual + text)
# This version introduces two changes that recent papers have
# proven effective for unbalanced multimodal settings where one
# modality dominates (DeepMLF, 2025; ECA‑MMDD, 2024; AVTF‑TBN, 2024):
#   1) AttentivePooling instead of plain average‑pool ⇒ lets the
#      network focus on depressive salient frames (turn‑level).
#   2) ModalityDropout ⇒ randomly masks each modality w.p.=p at
#      training time so that the network cannot over‑rely on text
#      alone and must learn complementary audio‑visual cues.
# The rest of the pipeline, hyper‑parameters, and CLI behaviour stay
# identical to v2 so you can drop‑in replace the script.
# ============================================================

"""
Prerequisites (same as v2):
  !pip install torch torchaudio transformers librosa scikit-learn pydub
  Upload DAIC‑WOZ pre‑processed files +
  dataset_info_all_text.csv(audio_path, visual_path, participant_id,
                            label, o1pro-prediction)
  to Google Drive.
Run:
  python multimodal_daicwoz_v3_attnpool_dropout.py
"""

import os, gc, random, warnings
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from transformers import (
    Wav2Vec2Processor,
    Wav2Vec2Model,
    Wav2Vec2Config,
)
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import train_test_split

# ------------------------------
# 0. Common utilities
# ------------------------------
SEED = 103
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ------------------------------
# 1. Dataset & collate_fn
# ------------------------------
class AudioVisualTextDataset(Dataset):
    """Loads audio, OpenFace CSV, and text label."""

    TEXT_MAP = {"depressed": 1.0, "not depressed": 0.0}

    def __init__(
        self,
        df: pd.DataFrame,
        target_sr: int = 16000,
        expected_vis_dim: int = 393,
        verbose: bool = False,
    ):
        self.df = df.reset_index(drop=True)
        self.target_sr = target_sr
        self.expected_vis_dim = expected_vis_dim
        self.verbose = verbose

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        pid = row["participant_id"]
        label = int(row["label"])
        audio_path = row["audio_path"]
        visual_path = row["visual_path"]
        text_str = str(row.get("o1pro-prediction", "not depressed")).strip().lower()
        text_label = float(self.TEXT_MAP.get(text_str, 0.0))

        try:
            # --- audio ---
            wav, sr = torchaudio.load(audio_path)
            if wav.size(0) > 1:
                wav = wav[:1]  # mono
            if sr != self.target_sr:
                wav = torchaudio.functional.resample(wav, sr, self.target_sr)
                sr = self.target_sr

            # --- visual ---
            df_v = pd.read_csv(visual_path)
            if "timestamp" in df_v.columns:
                df_v.drop(columns=["timestamp"], inplace=True)
            df_v = df_v.select_dtypes(include=[np.number])
            if df_v.shape[1] != self.expected_vis_dim:
                raise ValueError("visual dim mismatch")
            vis = torch.tensor(df_v.values, dtype=torch.float32)

            return idx, pid, wav, sr, vis, text_label, label

        except Exception as e:
            if self.verbose:
                warnings.warn(f"skip {pid}: {e}")
            return None


def collate_fn(batch):
    batch = [x for x in batch if x is not None]
    if len(batch) == 0:
        return None

    idxs, pids, waves, srs, vis_list, txt_labels, labels = zip(*batch)

    # --- audio ---
    wave_1d = [w.squeeze(0) for w in waves]
    wave_pad = pad_sequence(wave_1d, batch_first=True)  # (B, T)
    wave_pad = wave_pad.unsqueeze(1)  # (B, 1, T)
    sr_tensor = torch.tensor(srs)

    # --- visual ---
    vis_pad = pad_sequence(vis_list, batch_first=True)  # (B, T_v, D)
    vis_len = torch.tensor([v.size(0) for v in vis_list])

    # --- text label ---
    txt_tensor = torch.tensor(txt_labels, dtype=torch.float32)  # (B,)

    labels_t = torch.tensor(labels, dtype=torch.long)
    return pids, wave_pad, sr_tensor, vis_pad, vis_len, txt_tensor, labels_t

# ------------------------------
# 2. Optional Focal Loss
# ------------------------------
class FocalLoss(nn.Module):
    def __init__(self, gamma: float = 2.0, weight: torch.Tensor | None = None):
        super().__init__()
        self.gamma = gamma
        self.weight = weight

    def forward(self, logits, targets):
        ce = F.cross_entropy(logits, targets, weight=self.weight, reduction="none")
        pt = torch.exp(-ce)
        focal = ((1 - pt) ** self.gamma) * ce
        return focal.mean()

# ------------------------------
# 3. Model blocks: AttentivePool & ModalityDropout
# ------------------------------
class AttentivePool(nn.Module):
    """Attention‑weighted mean pooling over the time dimension."""

    def __init__(self, dim: int):
        super().__init__()
        self.query = nn.Parameter(torch.randn(dim))

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None):
        # x: (B, T, D)
        scores = (x * self.query).sum(-1)  # (B, T)
        if mask is not None:
            scores = scores.masked_fill(mask, -1e9)
        weights = torch.softmax(scores, dim=-1).unsqueeze(-1)  # (B, T, 1)
        return (weights * x).sum(1)  # (B, D)

class ModalityDropout(nn.Module):
    """Randomly zeros out the given modality embedding with prob p."""

    def __init__(self, p: float = 0.3):
        super().__init__()
        self.p = p

    def forward(self, x: torch.Tensor):
        if not self.training:
            return x
        if torch.rand(1).item() < self.p:
            return torch.zeros_like(x)
        return x

# ------------------------------
# 4. Fusion model
# ------------------------------
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base", return_attention_mask=False)

def wave_to_input(wave_batch, sr_batch, target_sr=16000):
    out_list = []
    for i in range(wave_batch.size(0)):
        w = wave_batch[i, 0].cpu().numpy()
        sr_i = sr_batch[i].item()
        if sr_i != target_sr:
            w = torchaudio.functional.resample(torch.from_numpy(w), sr_i, target_sr).numpy()
        out_list.append(processor(w, sampling_rate=target_sr, return_tensors="pt").input_values[0])
    return pad_sequence(out_list, batch_first=True)

class CrossAttentionBlock(nn.Module):
    def __init__(self, dim: int, heads: int = 8, p: float = 0.1):
        super().__init__()
        self.av = nn.MultiheadAttention(dim, heads, dropout=p, batch_first=True)
        self.va = nn.MultiheadAttention(dim, heads, dropout=p, batch_first=True)
        self.n_a = nn.LayerNorm(dim)
        self.n_v = nn.LayerNorm(dim)

    def forward(self, A, V, mask_a=None, mask_v=None):
        v2a, _ = self.av(V, A, A, key_padding_mask=mask_a)
        V = self.n_v(V + v2a)
        a2v, _ = self.va(A, V, V, key_padding_mask=mask_v)
        A = self.n_a(A + a2v)
        return A, V

class GatingUnit(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.alpha = nn.Parameter(torch.zeros(2, dim))

    def forward(self, a_vec, v_vec):
        weights = F.softmax(self.alpha.mean(dim=1), dim=0)
        return weights[0] * a_vec + weights[1] * v_vec

class MultiModalModel(nn.Module):
    def __init__(
        self,
        audio_model_name="facebook/wav2vec2-base",
        unfreeze_last_n: int = 2,
        visual_dim: int = 393,
        hidden: int = 384,
        heads: int = 8,
        text_emb_dim: int = 128,
        drop_p: float = 0.3,
    ):
        super().__init__()
        # ---- audio encoder ----
        cfg = Wav2Vec2Config.from_pretrained(audio_model_name)
        self.wav = Wav2Vec2Model.from_pretrained(audio_model_name)
        for p in self.wav.parameters():
            p.requires_grad = False
        if unfreeze_last_n > 0:
            for p in self.wav.encoder.layers[-unfreeze_last_n:].parameters():
                p.requires_grad = True
        self.proj_a = nn.Linear(cfg.hidden_size, hidden)

        # ---- visual encoder ----
        self.lstm = nn.LSTM(visual_dim, hidden // 2, num_layers=2,
                             bidirectional=True, batch_first=True)
        self.proj_v = nn.Identity()

        # ---- cross attention ----
        self.cross = CrossAttentionBlock(hidden, heads)
        self.pool_a = AttentivePool(hidden)
        self.pool_v = AttentivePool(hidden)

        # ---- gating ----
        self.gate = GatingUnit(hidden)

        # ---- text embedding ----
        self.text_emb = nn.Sequential(
            nn.Linear(1, text_emb_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(text_emb_dim, text_emb_dim),
        )
        self.text_md = ModalityDropout(drop_p)

        # ---- classifier ----
        self.cls = nn.Sequential(
            nn.LayerNorm(hidden + text_emb_dim),
            nn.Linear(hidden + text_emb_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 2),
        )

    def forward(self, wav_inputs, vis, vis_mask, text_labels):
        # ---- audio ----
        A = self.wav(wav_inputs).last_hidden_state  # (B, T_a, H0)
        A = self.proj_a(A)

        # ---- visual ----
        V, _ = self.lstm(vis)
        V = self.proj_v(V)

        # ---- cross attention ----
        A, V = self.cross(A, V, mask_a=None, mask_v=vis_mask)
        A_vec = self.pool_a(A)
        V_vec = self.pool_v(V, mask=vis_mask)
        av_fused = self.gate(A_vec, V_vec)  # (B, hidden)

        # ---- text embedding (with ModalityDropout) ----
        t_emb = self.text_emb(text_labels.unsqueeze(-1))  # (B, text_emb_dim)
        t_emb = self.text_md(t_emb)

        fused = torch.cat([av_fused, t_emb], dim=-1)
        return self.cls(fused)

# ------------------------------
# 5. Helper functions
# ------------------------------

def make_mask(lengths: torch.Tensor, max_len: int) -> torch.Tensor:
    idx = torch.arange(max_len, device=lengths.device).expand(len(lengths), -1)
    return idx >= lengths.unsqueeze(1)

# ------------------------------
# 6. Training / Evaluation
# ------------------------------

def evaluate(model, loader, criterion, device):
    model.eval()
    tot_loss, tot_samples, correct = 0.0, 0, 0

    y_true_seg, y_pred_seg, y_text_seg, pid_list = [], [], [], []
    part_logits, part_cnt, part_label, part_text_sum = {}, {}, {}, {}

    with torch.no_grad():
        for batch in loader:
            if batch is None:
                continue
            pids, wav, sr, vis, vlen, txt, y = batch
            wav, sr = wav.to(device), sr.to(device)
            vis, vlen = vis.to(device), vlen.to(device)
            txt, y = txt.to(device), y.to(device)
            txt_int = torch.round(txt).long()

            wav_in = wave_to_input(wav, sr).to(device)
            mask_v = make_mask(vlen, vis.size(1))
            logits = model(wav_in, vis, mask_v, txt)

            loss = criterion(logits, y)
            bs = y.size(0)
            tot_loss += loss.item() * bs
            tot_samples += bs

            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()

            y_true_seg.extend(y.cpu().tolist())
            y_pred_seg.extend(preds.cpu().tolist())
            y_text_seg.extend(txt_int.cpu().tolist())
            pid_list.extend(pids)

            for i in range(bs):
                pid = pids[i]
                if pid not in part_logits:
                    part_logits[pid] = logits[i].cpu()
                    part_cnt[pid] = 1
                    part_label[pid] = y[i].item()
                    part_text_sum[pid] = txt_int[i].item()
                else:
                    part_logits[pid] += logits[i].cpu()
                    part_cnt[pid] += 1
                    part_text_sum[pid] += txt_int[i].item()

    if tot_samples == 0:
        return (0, 0, "", [[0,0],[0,0]], 0, "", [[0,0],[0,0]], "", [[0,0],[0,0]], "", [[0,0],[0,0]])

    avg_loss_seg = tot_loss / tot_samples
    acc_seg = correct / tot_samples
    rpt_seg = classification_report(y_true_seg, y_pred_seg, digits=4,
                                    target_names=["not dep","dep"])
    cm_seg  = confusion_matrix(y_true_seg, y_pred_seg).tolist()

    rpt_text_seg = classification_report(y_text_seg, y_pred_seg, digits=4,
                                         target_names=["text=0","text=1"])
    cm_text_seg  = confusion_matrix(y_text_seg, y_pred_seg).tolist()

    y_true_part, y_pred_part, y_text_part = [], [], []
    for pid, logit_sum in part_logits.items():
        avg_log = logit_sum / part_cnt[pid]
        y_pred_part.append(torch.argmax(avg_log).item())
        y_true_part.append(part_label[pid])
        txt_avg = part_text_sum[pid] / part_cnt[pid]
        y_text_part.append(int(round(txt_avg)))

    acc_part = (np.array(y_true_part) == np.array(y_pred_part)).mean()
    rpt_part = classification_report(y_true_part, y_pred_part, digits=4,
                                     target_names=["not dep","dep"])
    cm_part = confusion_matrix(y_true_part, y_pred_part).tolist()

    rpt_text_part = classification_report(y_text_part, y_pred_part, digits=4,
                                          target_names=["text=0","text=1"])
    cm_text_part  = confusion_matrix(y_text_part, y_pred_part).tolist()

    return (
        avg_loss_seg, acc_seg, rpt_seg, cm_seg,
        acc_part, rpt_part, cm_part,
        rpt_text_seg, cm_text_seg,
        rpt_text_part, cm_text_part,
    )

# ------------------------------
# 7. Main
# ------------------------------
if __name__ == "__main__":
    CSV_PATH = "/content/drive/MyDrive/DAIC-WOZ/dataset_info_all_text_ov1.csv"
    df = pd.read_csv(CSV_PATH)

    df["audio_path"]  = df["audio_path"].str.replace(
        "/content/m5ov1_tmp", "/content/m5ov1_data", regex=False)
    df["visual_path"] = df["visual_path"].str.replace(
        "/content/m5ov1_tmp", "/content/m5ov1_data", regex=False)

    VIS_DIM = 393

    ids = df["participant_id"].unique()
    train_ids, test_ids = train_test_split(ids, test_size=0.2, random_state=SEED)
    train_ids, dev_ids  = train_test_split(train_ids, test_size=0.25, random_state=SEED)

    train_df = df[df.participant_id.isin(train_ids)].copy()
    dev_df   = df[df.participant_id.isin(dev_ids)].copy()
    test_df  = df[df.participant_id.isin(test_ids)].copy()

    BS, EPOCHS, ACC_STEPS, LR = 2, 15, 16, 1e-5
    weight = torch.tensor([1.0, 2.5]); use_focal = False

    train_ds = AudioVisualTextDataset(train_df, expected_vis_dim=VIS_DIM)
    dev_ds   = AudioVisualTextDataset(dev_df,   expected_vis_dim=VIS_DIM)
    test_ds  = AudioVisualTextDataset(test_df,  expected_vis_dim=VIS_DIM)

    train_ld = DataLoader(train_ds, BS, shuffle=True,  collate_fn=collate_fn)
    dev_ld   = DataLoader(dev_ds,  BS, shuffle=False, collate_fn=collate_fn)
    test_ld  = DataLoader(test_ds, BS, shuffle=False, collate_fn=collate_fn)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model  = MultiModalModel(unfreeze_last_n=2, visual_dim=VIS_DIM).to(device)

    criterion = (FocalLoss(weight=weight.to(device)) if use_focal
                 else nn.CrossEntropyLoss(weight=weight.to(device)))
    opt = torch.optim.AdamW(model.parameters(), lr=LR)

    step = 0
    for epoch in range(EPOCHS):
        model.train(); loss_sum = n = 0.0; opt.zero_grad()

        for batch in train_ld:
            if batch is None: continue
            pids, wav, sr, vis, vlen, txt, y = batch
            wav, sr = wav.to(device), sr.to(device)
            vis, vlen = vis.to(device), vlen.to(device)
            txt, y = txt.to(device), y.to(device)

            x = wave_to_input(wav, sr).to(device)
            out = model(x, vis, make_mask(vlen, vis.size(1)), txt)
            loss = criterion(out, y) / ACC_STEPS
            loss.backward()

            loss_sum += loss.item() * y.size(0); n += y.size(0)
            if (step + 1) % ACC_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
                opt.step(); opt.zero_grad()
            step += 1

        print(f"[E{epoch+1}] train_loss = {loss_sum / n if n else 0:.4f} (samples={n})")

        for name, loader in zip(["dev", "test"], [dev_ld, test_ld]):
            (l_seg, acc_seg, rpt_seg, cm_seg,
             acc_part, rpt_part, cm_part,
             rpt_text_seg, cm_text_seg,
             rpt_text_part, cm_text_part) = evaluate(model, loader, criterion, device)

            print(f"  === {name} (segment-level) ===")
            print(f"    loss={l_seg:.4f}, acc={acc_seg:.3f}")
            print("    GroundTruth vs Pred:");  print(rpt_seg)
            print(f"    CM:\n      {cm_seg}")
            print("    TextLabel vs Pred:");   print(rpt_text_seg)
            print(f"    CM:\n      {cm_text_seg}")

            print(f"  === {name} (participant-level) ===")
            print(f"    acc={acc_part:.3f}")
            print("    GroundTruth vs Pred:"); print(rpt_part)
            print(f"    CM:\n      {cm_part}")
            print("    TextLabel vs Pred:");   print(rpt_text_part)
            print(f"    CM:\n      {cm_text_part}")

        gc.collect(); torch.cuda.empty_cache()