In [None]:
# cmu_cmt.py

import torch
import torch.nn as nn
import torch.nn.functional as F


class CrossModalTransformerBlock(nn.Module):
    def __init__(self, dim_q, dim_kv, n_heads, dropout):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim=dim_q, num_heads=n_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(dim_q)
        self.ffn = nn.Sequential(
            nn.Linear(dim_q, dim_q * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_q * 4, dim_q)
        )
        self.norm2 = nn.LayerNorm(dim_q)

    def forward(self, q, kv, kv_mask=None):
        attn_output, attn_weights = self.attn(q, kv, kv, key_padding_mask=kv_mask, need_weights=True)
        x = self.norm1(q + attn_output)
        ffn_output = self.ffn(x)
        out = self.norm2(x + ffn_output)
        return out, attn_weights


class CrossModalTransformer(nn.Module):
    def __init__(self, dim_text=300, dim_audio=74, dim_visual=35, dim_model=128, n_heads=4, dropout=0.1):
        super().__init__()

        # ✅ 각 modality projection: Linear → 2-layer MLP
        self.text_proj = nn.Sequential(
            nn.Linear(dim_text, 256),
            nn.ReLU(),
            nn.Linear(256, dim_model)
        )
        self.audio_proj = nn.Sequential(
            nn.Linear(dim_audio, 128),
            nn.ReLU(),
            nn.Linear(128, dim_model)
        )
        self.visual_proj = nn.Sequential(
            nn.Linear(dim_visual, 64),
            nn.ReLU(),
            nn.Linear(64, dim_model)
        )

        self.cross_blocks = nn.ModuleDict({
            'text_audio': CrossModalTransformerBlock(dim_model, dim_model, n_heads, dropout),
            'text_visual': CrossModalTransformerBlock(dim_model, dim_model, n_heads, dropout),
            'audio_text': CrossModalTransformerBlock(dim_model, dim_model, n_heads, dropout),
            'audio_visual': CrossModalTransformerBlock(dim_model, dim_model, n_heads, dropout),
            'visual_text': CrossModalTransformerBlock(dim_model, dim_model, n_heads, dropout),
            'visual_audio': CrossModalTransformerBlock(dim_model, dim_model, n_heads, dropout),
        })

        self.fusion_norm = nn.LayerNorm(dim_model * 6)
        self.classifier = nn.Sequential(
            nn.Linear(dim_model * 6, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 7)  # MOSEI: 7-way emotion soft label (regression → classification 가능)
        )

    def forward(self, text, audio, visual):
        log = {}

        text = self.text_proj(text)
        audio = self.audio_proj(audio)
        visual = self.visual_proj(visual)

        ta, log['text->audio'] = self.cross_blocks['text_audio'](text, audio)
        tv, log['text->visual'] = self.cross_blocks['text_visual'](text, visual)
        at, log['audio->text'] = self.cross_blocks['audio_text'](audio, text)
        av, log['audio->visual'] = self.cross_blocks['audio_visual'](audio, visual)
        vt, log['visual->text'] = self.cross_blocks['visual_text'](visual, text)
        va, log['visual->audio'] = self.cross_blocks['visual_audio'](visual, audio)

        pooled = [ta.mean(1), tv.mean(1), at.mean(1), av.mean(1), vt.mean(1), va.mean(1)]
        fusion = torch.cat(pooled, dim=-1)
        fusion = self.fusion_norm(fusion)
        logits = self.classifier(fusion)

        return logits, {k: v.mean().item() for k, v in log.items()}
