<a href="https://colab.research.google.com/github/sohini4roy/MELD/blob/master/merc_lstm_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
import math
import random
import argparse
from typing import List, Dict, Any, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score, accuracy_score


# ---------------------------
# Utilities
# ---------------------------

def set_seed(seed: int = 1337):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def lengths_to_mask(lengths: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor:
    """
    lengths: (B,) long
    returns mask: (B, T) with True for valid timesteps, False for padding
    """
    B = lengths.size(0)
    T = max_len if max_len is not None else int(lengths.max().item())
    range_row = torch.arange(T, device=lengths.device).unsqueeze(0).expand(B, T)
    mask = range_row < lengths.unsqueeze(1)
    return mask


# ---------------------------
# Dataset
# ---------------------------

class JSONMERC(Dataset):
    """
    Expects a JSON with:
    [
      {
        "utterances": [
          {
            "text_feat": [float,...],   # dim_txt
            "audio_feat": [float,...],  # dim_aud
            "vision_feat": [float,...], # dim_vis
            "speaker": "A",             # or int id
            "label": 0                  # int class
          },
          ...
        ]
      },
      ...
    ]
    """
    def __init__(self, path: str, speaker2id: Optional[Dict[str, int]] = None):
        super().__init__()
        with open(path, "r", encoding="utf-8") as f:
            self.data = json.load(f)

        self.speaker2id = speaker2id if speaker2id is not None else {}
        self._build_speaker_map()

        # Infer dims
        u0 = self.data[0]["utterances"][0]
        self.dim_txt = len(u0["text_feat"])
        self.dim_aud = len(u0["audio_feat"])
        self.dim_vis = len(u0["vision_feat"])

        # Infer num classes
        labels = []
        for conv in self.data:
            for u in conv["utterances"]:
                labels.append(int(u["label"]))
        self.num_classes = int(max(labels)) + 1

    def _build_speaker_map(self):
        for conv in self.data:
            for u in conv["utterances"]:
                spk = u["speaker"]
                if isinstance(spk, int):
                    key = str(spk)
                else:
                    key = str(spk)
                if key not in self.speaker2id:
                    self.speaker2id[key] = len(self.speaker2id)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        conv = self.data[idx]
        txt, aud, vis, spk, lab = [], [], [], [], []
        for u in conv["utterances"]:
            txt.append(np.array(u["text_feat"], dtype=np.float32))
            aud.append(np.array(u["audio_feat"], dtype=np.float32))
            vis.append(np.array(u["vision_feat"], dtype=np.float32))
            sid = self.speaker2id[str(u["speaker"])]
            spk.append(sid)
            lab.append(int(u["label"]))
        txt = np.stack(txt, axis=0)            # (T, Dt)
        aud = np.stack(aud, axis=0)            # (T, Da)
        vis = np.stack(vis, axis=0)            # (T, Dv)
        spk = np.array(spk, dtype=np.int64)    # (T,)
        lab = np.array(lab, dtype=np.int64)    # (T,)
        return {
            "text": txt,
            "audio": aud,
            "vision": vis,
            "speaker": spk,
            "labels": lab,
            "length": txt.shape[0]
        }


class SyntheticMERC(Dataset):
    """
    Generates synthetic multimodal conversations for quick testing.
    """
    def __init__(self,
                 num_convs: int = 200,
                 max_len: int = 12,
                 min_len: int = 4,
                 dim_txt: int = 300,
                 dim_aud: int = 50,
                 dim_vis: int = 64,
                 num_speakers: int = 4,
                 num_classes: int = 7,
                 seed: int = 123):
        super().__init__()
        rng = np.random.RandomState(seed)
        self.samples = []
        for _ in range(num_convs):
            T = rng.randint(min_len, max_len + 1)
            txt = rng.normal(size=(T, dim_txt)).astype(np.float32)
            aud = rng.normal(size=(T, dim_aud)).astype(np.float32)
            vis = rng.normal(size=(T, dim_vis)).astype(np.float32)
            spk = rng.randint(0, num_speakers, size=(T,), dtype=np.int64)
            # Make labels weakly depend on a linear combo of modalities + speaker
            logits = (txt[:, :8].sum(axis=1) + 0.5 * aud[:, :8].sum(axis=1) + 0.3 * vis[:, :8].sum(axis=1)
                      + (spk.astype(np.float32) - num_speakers / 2.0))
            # Bucketize into classes
            q = np.quantile(logits, np.linspace(0, 1, num_classes + 1))
            labels = np.digitize(logits, q[1:-1]).astype(np.int64)
            self.samples.append({
                "text": txt, "audio": aud, "vision": vis,
                "speaker": spk, "labels": labels, "length": T
            })
        self.dim_txt, self.dim_aud, self.dim_vis = dim_txt, dim_aud, dim_vis
        self.num_classes = num_classes
        self.num_speakers = num_speakers

    def __len__(self): return len(self.samples)

    def __getitem__(self, idx: int): return self.samples[idx]


def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
    B = len(batch)
    lengths = torch.tensor([b["length"] for b in batch], dtype=torch.long)
    T = int(lengths.max().item())
    # infer dims
    Dt = batch[0]["text"].shape[1]
    Da = batch[0]["audio"].shape[1]
    Dv = batch[0]["vision"].shape[1]

    txt = torch.zeros(B, T, Dt, dtype=torch.float32)
    aud = torch.zeros(B, T, Da, dtype=torch.float32)
    vis = torch.zeros(B, T, Dv, dtype=torch.float32)
    spk = torch.zeros(B, T, dtype=torch.long)
    lab = torch.full((B, T), fill_value=-100, dtype=torch.long)  # ignore index for padding

    for i, b in enumerate(batch):
        t = b["length"]
        txt[i, :t] = torch.from_numpy(b["text"]) if isinstance(b["text"], np.ndarray) else torch.tensor(b["text"])
        aud[i, :t] = torch.from_numpy(b["audio"]) if isinstance(b["audio"], np.ndarray) else torch.tensor(b["audio"])
        vis[i, :t] = torch.from_numpy(b["vision"]) if isinstance(b["vision"], np.ndarray) else torch.tensor(b["vision"])
        spk[i, :t] = torch.from_numpy(b["speaker"]) if isinstance(b["speaker"], np.ndarray) else torch.tensor(b["speaker"])
        lab[i, :t] = torch.from_numpy(b["labels"]) if isinstance(b["labels"], np.ndarray) else torch.tensor(b["labels"])

    mask = lengths_to_mask(lengths, T)  # (B, T)
    return {
        "text": txt, "audio": aud, "vision": vis,
        "speaker": spk, "labels": lab, "lengths": lengths, "mask": mask
    }


# ---------------------------
# Model
# ---------------------------

class MLPEncoder(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, p_drop: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p_drop),
            nn.Linear(out_dim, out_dim),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)


class GatedFusion(nn.Module):
    """Early fusion with learnable gates per modality."""
    def __init__(self, d_txt, d_aud, d_vis, d_out):
        super().__init__()
        self.txt_enc = MLPEncoder(d_txt, d_out)
        self.aud_enc = MLPEncoder(d_aud, d_out)
        self.vis_enc = MLPEncoder(d_vis, d_out)

        self.gate = nn.Sequential(
            nn.Linear(d_out * 3, d_out * 3),
            nn.ReLU(inplace=True),
            nn.Linear(d_out * 3, 3),
            nn.Sigmoid()
        )

        self.proj = nn.Linear(d_out * 3, d_out)

    def forward(self, txt, aud, vis):
        # (B, T, D) -> encode
        h_t = self.txt_enc(txt)
        h_a = self.aud_enc(aud)
        h_v = self.vis_enc(vis)
        H = torch.cat([h_t, h_a, h_v], dim=-1)
        g = self.gate(H)  # (B, T, 3)
        g_t, g_a, g_v = g[..., 0:1], g[..., 1:2], g[..., 2:3]
        fused = torch.cat([g_t * h_t, g_a * h_a, g_v * h_v], dim=-1)
        return self.proj(fused)  # (B, T, d_out)


class AdditiveAttention(nn.Module):
    """Bahdanau-style attention over temporal context."""
    def __init__(self, d_ctx: int, d_attn: int):
        super().__init__()
        self.fc_h = nn.Linear(d_ctx, d_attn, bias=False)
        self.fc_q = nn.Linear(d_ctx, d_attn, bias=False)
        self.v = nn.Linear(d_attn, 1, bias=False)

    def forward(self, H: torch.Tensor, q: torch.Tensor, mask: torch.Tensor):
        """
        H: (B, T, d_ctx)  context sequence
        q: (B, T, d_ctx)  per-timestep query (e.g., same as H or projected)
        mask: (B, T)      valid positions
        returns: attended context per timestep: (B, T, d_ctx)
        """
        Wh = self.fc_h(H)                # (B, T, d_attn)
        Wq = self.fc_q(q)                # (B, T, d_attn)
        scores = self.v(torch.tanh(Wh.unsqueeze(1) + Wq.unsqueeze(2))).squeeze(-1)  # (B, T, T)
        # mask invalid keys
        key_mask = mask.unsqueeze(1).expand_as(scores)  # (B, T, T)
        scores = scores.masked_fill(~key_mask, -1e9)
        attn = torch.softmax(scores, dim=-1)            # (B, T, T)
        ctx = torch.bmm(attn, H)                        # (B, T, d_ctx)
        return ctx


class MERCLSTM(nn.Module):
    def __init__(self,
                 dim_txt: int,
                 dim_aud: int,
                 dim_vis: int,
                 num_speakers: int,
                 num_classes: int,
                 d_model: int = 256,
                 d_attn: int = 128,
                 d_spk: int = 32,
                 lstm_layers: int = 1,
                 dropout: float = 0.2,
                 bidirectional: bool = True):
        super().__init__()

        self.fusion = GatedFusion(dim_txt, dim_aud, dim_vis, d_model)

        self.spk_emb = nn.Embedding(num_speakers, d_spk)

        self.lstm = nn.LSTM(
            input_size=d_model + d_spk,
            hidden_size=d_model // 2 if bidirectional else d_model,
            num_layers=lstm_layers,
            batch_first=True,
            dropout=dropout if lstm_layers > 1 else 0.0,
            bidirectional=bidirectional
        )
        d_ctx = (d_model // 2) * 2 if bidirectional else d_model

        self.attn = AdditiveAttention(d_ctx, d_attn)
        self.classifier = nn.Sequential(
            nn.Linear(d_ctx * 2, d_ctx),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_ctx, num_classes)
        )

    def forward(self, text, audio, vision, speaker, mask):
        """
        text/audio/vision: (B, T, D*)
        speaker: (B, T) long
        mask: (B, T) bool
        returns logits: (B, T, C)
        """
        fused = self.fusion(text, audio, vision)  # (B, T, d_model)
        spk = self.spk_emb(speaker)               # (B, T, d_spk)
        x = torch.cat([fused, spk], dim=-1)       # (B, T, d_model+d_spk)

        lengths = mask.sum(dim=1).long().cpu()
        packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        packed_out, _ = self.lstm(packed)
        H, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)  # (B, T, d_ctx)

        # Attention with H as both keys/values and query
        ctx = self.attn(H, H, mask)               # (B, T, d_ctx)
        out = torch.cat([H, ctx], dim=-1)         # (B, T, 2*d_ctx)
        logits = self.classifier(out)             # (B, T, C)

        # mask out padding logits for safety (not strictly necessary if using ignore_index)
        return logits


# ---------------------------
# Training / Evaluation
# ---------------------------

def masked_cross_entropy(logits: torch.Tensor, targets: torch.Tensor, ignore_index: int = -100):
    """
    logits: (B, T, C), targets: (B, T)
    """
    B, T, C = logits.shape
    loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index)
    return loss_fn(logits.view(B * T, C), targets.view(B * T))


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Dict[str, float]:
    model.eval()
    all_preds, all_golds = [], []
    for batch in loader:
        text = batch["text"].to(device)
        audio = batch["audio"].to(device)
        vision = batch["vision"].to(device)
        speaker = batch["speaker"].to(device)
        labels = batch["labels"].to(device)
        mask = batch["mask"].to(device)

        logits = model(text, audio, vision, speaker, mask)  # (B, T, C)
        preds = logits.argmax(dim=-1)                       # (B, T)

        # Flatten valid positions
        valid = mask.view(-1).cpu().numpy().astype(bool)
        gold = labels.view(-1).cpu().numpy()
        prd = preds.view(-1).cpu().numpy()
        gold = gold[valid]
        prd = prd[valid]

        all_preds.append(prd)
        all_golds.append(gold)

    y_true = np.concatenate(all_golds, axis=0)
    y_pred = np.concatenate(all_preds, axis=0)
    acc = accuracy_score(y_true, y_pred)
    f1m = f1_score(y_true, y_pred, average="macro")
    return {"accuracy": acc, "f1_macro": f1m}


def train_loop(model: nn.Module,
               train_loader: DataLoader,
               valid_loader: Optional[DataLoader],
               device: torch.device,
               epochs: int = 20,
               lr: float = 1e-3,
               weight_decay: float = 1e-4,
               grad_clip: float = 1.0,
               ckpt_path: Optional[str] = None):
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    best_f1 = -1.0
    for ep in range(1, epochs + 1):
        model.train()
        running = 0.0
        n_steps = 0
        for batch in train_loader:
            text = batch["text"].to(device)
            audio = batch["audio"].to(device)
            vision = batch["vision"].to(device)
            speaker = batch["speaker"].to(device)
            labels = batch["labels"].to(device)
            mask = batch["mask"].to(device)

            logits = model(text, audio, vision, speaker, mask)
            loss = masked_cross_entropy(logits, labels)

            opt.zero_grad()
            loss.backward()
            if grad_clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            opt.step()

            running += loss.item()
            n_steps += 1

        msg = f"[Epoch {ep:02d}] Train Loss: {running / max(n_steps,1):.4f}"

        if valid_loader is not None:
            metrics = evaluate(model, valid_loader, device)
            msg += f" | Val Acc: {metrics['accuracy']:.4f} | Val F1m: {metrics['f1_macro']:.4f}"
            # Save best
            if metrics["f1_macro"] > best_f1 and ckpt_path:
                best_f1 = metrics["f1_macro"]
                torch.save({
                    "model_state": model.state_dict(),
                    "config": {
                        "num_classes": model.classifier[-1].out_features
                    }
                }, ckpt_path)
                msg += "  [*saved*]"
        print(msg)


# ---------------------------
# Main / CLI
# ---------------------------

def build_loaders(args) -> Tuple[DataLoader, Optional[DataLoader], Dict[str, int], Dict[str, int]]:
    if args.data_json and os.path.isfile(args.data_json):
        dataset = JSONMERC(args.data_json)
        # Split 80/20
        idxs = list(range(len(dataset)))
        random.shuffle(idxs)
        cut = int(0.8 * len(dataset))
        train_idxs, val_idxs = idxs[:cut], idxs[cut:]
        train_data = torch.utils.data.Subset(dataset, train_idxs)
        val_data = torch.utils.data.Subset(dataset, val_idxs)
        num_speakers = len(dataset.speaker2id)
        num_classes = dataset.num_classes
        dim_txt, dim_aud, dim_vis = dataset.dim_txt, dataset.dim_aud, dataset.dim_vis
    else:
        # synthetic
        syn = SyntheticMERC(
            num_convs=args.syn_num_convs,
            max_len=args.syn_max_len,
            min_len=args.syn_min_len,
            dim_txt=args.syn_dim_txt,
            dim_aud=args.syn_dim_aud,
            dim_vis=args.syn_dim_vis,
            num_speakers=args.syn_num_speakers,
            num_classes=args.syn_num_classes,
            seed=args.seed
        )
        # Split 80/20
        idxs = list(range(len(syn)))
        random.shuffle(idxs)
        cut = int(0.8 * len(syn))
        train_data = torch.utils.data.Subset(syn, idxs[:cut])
        val_data = torch.utils.data.Subset(syn, idxs[cut:])
        num_speakers = syn.num_speakers
        num_classes = syn.num_classes
        dim_txt, dim_aud, dim_vis = syn.dim_txt, syn.dim_aud, syn.dim_vis

    train_loader = DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn
    )
    val_loader = DataLoader(
        val_data, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn
    )

    dims = {"dim_txt": dim_txt, "dim_aud": dim_aud, "dim_vis": dim_vis}
    meta = {"num_speakers": num_speakers, "num_classes": num_classes}
    return train_loader, val_loader, dims, meta


def main():
    parser = argparse.ArgumentParser(description="Multimodal LSTM for Emotion Recognition in Conversation")
    parser.add_argument("--data_json", type=str, default="", help="Path to JSON dataset. If empty, uses synthetic.")
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--weight_decay", type=float, default=1e-4)
    parser.add_argument("--grad_clip", type=float, default=1.0)
    parser.add_argument("--seed", type=int, default=1337)
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--ckpt", type=str, default="best_merc_lstm.pt")

    # Synthetic data settings (used if --data_json not provided)
    parser.add_argument("--syn_num_convs", type=int, default=200)
    parser.add_argument("--syn_min_len", type=int, default=4)
    parser.add_argument("--syn_max_len", type=int, default=12)
    parser.add_argument("--syn_dim_txt", type=int, default=300)
    parser.add_argument("--syn_dim_aud", type=int, default=50)
    parser.add_argument("--syn_dim_vis", type=int, default=64)
    parser.add_argument("--syn_num_speakers", type=int, default=4)
    parser.add_argument("--syn_num_classes", type=int, default=7)

    # Model hyperparams
    parser.add_argument("--d_model", type=int, default=256)
    parser.add_argument("--d_attn", type=int, default=128)
    parser.add_argument("--d_spk", type=int, default=32)
    parser.add_argument("--lstm_layers", type=int, default=1)
    parser.add_argument("--dropout", type=float, default=0.2)
    parser.add_argument("--bidirectional", action="store_true", help="Use BiLSTM")
    parser.add_argument("--no-bidirectional", dest="bidirectional", action="store_false")
    parser.set_defaults(bidirectional=True)

    args = parser.parse_args()
    set_seed(args.seed)

    train_loader, val_loader, dims, meta = build_loaders(args)
    device = torch.device(args.device)
    print(f"Using device: {device}")

    model = MERCLSTM(
        dim_txt=dims["dim_txt"], dim_aud=dims["dim_aud"], dim_vis=dims["dim_vis"],
        num_speakers=meta["num_speakers"], num_classes=meta["num_classes"],
        d_model=args.d_model, d_attn=args.d_attn, d_spk=args.d_spk,
        lstm_layers=args.lstm_layers, dropout=args.dropout,
        bidirectional=args.bidirectional
    ).to(device)

    print(model)
    train_loop(model, train_loader, val_loader, device,
               epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay,
               grad_clip=args.grad_clip, ckpt_path=args.ckpt)

    # Final evaluation on validation set with best checkpoint if available
    if os.path.isfile(args.ckpt):
        print(f"Loading best checkpoint: {args.ckpt}")
        ckpt = torch.load(args.ckpt, map_location=device)
        model.load_state_dict(ckpt["model_state"])
    metrics = evaluate(model, val_loader, device)
    print(f"[Final] Val Acc: {metrics['accuracy']:.4f} | Val F1m: {metrics['f1_macro']:.4f}")


if __name__ == "__main__":
    main()
