In [None]:
# fused_eeg_vit_co_processing.py
# -------------------------------------------------------------
# EEG (63x250) ↔ CLIP ViT-L/14 multi-stream co-processing
# - Keeps the original CLIP ViT embedding frozen/unchanged for contrast
# - Maps EEG → ViT token space (D=1024, L=256) via ATMS-style ts/sp conv + adapter
# - Preserves your Medformer (as BrainEncoder) before the adapter
# - Inserts bi-directional Cross-Attn bridges at several ViT blocks
# - Outputs a gated fused embedding comparable to the frozen ViT embedding
# -------------------------------------------------------------

from typing import Iterable, List, Optional, Tuple

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import CLIPVisionModel

# ====== Import your BrainEncoder bits (from ATMS.py) ======
# Assumes ATMS.py is importable in PYTHONPATH; otherwise use relative import or sys.path.insert.
from ATMS import Medformer, Config  # noqa

# ==========================================================
#                      LoRA helpers
# ==========================================================
class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, r: int = 8, alpha: int = 16, bias: bool = False, freeze_main: bool = True):
        super().__init__()
        self.main = nn.Linear(in_features, out_features, bias=bias)
        if freeze_main:
            for p in self.main.parameters():
                p.requires_grad = False
        self.r = r
        if r > 0:
            self.A = nn.Linear(in_features, r, bias=False)
            self.B = nn.Linear(r, out_features, bias=False)
            nn.init.kaiming_uniform_(self.A.weight, a=math.sqrt(5))
            nn.init.zeros_(self.B.weight)
            self.scaling = alpha / r
        else:
            self.A = None
            self.B = None
            self.scaling = 0.0

    def forward(self, x):
        if self.r > 0:
            return self.main(x) + self.B(self.A(x)) * self.scaling
        else:
            return self.main(x)


class CrossAttentionAdapter(nn.Module):
    """Bi-directional cross-attn building block used inside the bridge.

    We expose a single-direction module here (Q from stream A, K/V from stream B),
    and compose two of them in CoProcessingBridge.
    """

    def __init__(self, d_model: int, n_heads: int, d_kv: Optional[int] = None,
                 lora_r: int = 8, lora_alpha: int = 16, dropout: float = 0.0):
        super().__init__()
        self.n_heads = n_heads
        self.d_model = d_model
        self.d_kv = d_kv or (d_model // n_heads)
        self.scale = self.d_kv ** -0.5

        self.q = LoRALinear(d_model, n_heads * self.d_kv, r=lora_r, alpha=lora_alpha)
        self.k = LoRALinear(d_model, n_heads * self.d_kv, r=lora_r, alpha=lora_alpha)
        self.v = LoRALinear(d_model, n_heads * self.d_kv, r=lora_r, alpha=lora_alpha)
        self.o = LoRALinear(n_heads * self.d_kv, d_model, r=lora_r, alpha=lora_alpha)

        self.dropout = nn.Dropout(dropout)
        self.norm_q = nn.LayerNorm(d_model)
        self.gate = nn.Parameter(torch.zeros(1))  # learnable scalar gate

    def forward(self, q_tokens: torch.Tensor, kv_tokens: torch.Tensor) -> torch.Tensor:
        # q_tokens: [B, Lq, D], kv_tokens: [B, Lk, D]
        B, Lq, D = q_tokens.shape
        Lk = kv_tokens.shape[1]

        qn = self.norm_q(q_tokens)
        q = self.q(qn).view(B, Lq, self.n_heads, self.d_kv).transpose(1, 2)  # [B,H,Lq,d]
        k = self.k(kv_tokens).view(B, Lk, self.n_heads, self.d_kv).transpose(1, 2)
        v = self.v(kv_tokens).view(B, Lk, self.n_heads, self.d_kv).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale  # [B,H,Lq,Lk]
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
        ctx = attn @ v  # [B,H,Lq,d]
        ctx = ctx.transpose(1, 2).contiguous().view(B, Lq, self.n_heads * self.d_kv)
        out = self.o(ctx)

        # Residual with gated injection
        out = q_tokens + torch.tanh(self.gate) * out
        return out


class CoProcessingBridge(nn.Module):
    """Projects both streams to shared dim, exchanges info both ways, and unprojects back."""

    def __init__(self, d_vit: int, d_brain: int, n_heads: int = 8, d_shared: Optional[int] = None,
                 lora_r: int = 8, lora_alpha: int = 16, dropout: float = 0.0):
        super().__init__()
        self.d_shared = d_shared or d_vit
        self.v_proj = nn.Linear(d_vit, self.d_shared, bias=False)
        self.b_proj = nn.Linear(d_brain, self.d_shared, bias=False)
        self.v2b = CrossAttentionAdapter(self.d_shared, n_heads, lora_r=lora_r, lora_alpha=lora_alpha, dropout=dropout)
        self.b2v = CrossAttentionAdapter(self.d_shared, n_heads, lora_r=lora_r, lora_alpha=lora_alpha, dropout=dropout)
        self.v_unproj = nn.Linear(self.d_shared, d_vit, bias=False)
        self.b_unproj = nn.Linear(self.d_shared, d_brain, bias=False)

    def forward(self, vit_tokens: torch.Tensor, brain_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        V = self.v_proj(vit_tokens)
        B = self.b_proj(brain_tokens)
        V_new = self.b2v(V, B)
        B_new = self.v2b(B, V)
        vit_tokens = vit_tokens + self.v_unproj(V_new - V)
        brain_tokens = brain_tokens + self.b_unproj(B_new - B)
        return vit_tokens, brain_tokens


# ==========================================================
#    EEG → ViT token adapter (ATMS-style + Medformer)
# ==========================================================
class BrainEncoderAdapter(nn.Module):
    """Use your Medformer to encode EEG, then adapt to ViT token grid (L=256, D=1024).

    Input:  EEG [B, 63, 250]
    Output: brain_tokens [B, 256, 1024]  (matches ViT-L/14 patch tokens)
    """

    def __init__(self, seq_len: int = 250, enc_in: int = 63, d_model_med: int = 250,
                 vit_hidden: int = 1024, vit_token_len: int = 256, depth: int = 4):
        super().__init__()
        # Medformer as in your ATMS
        cfg = Config(depth=depth)
        cfg.seq_len = seq_len
        cfg.enc_in = enc_in
        cfg.d_model = d_model_med
        self.medformer = Medformer(cfg)

        # Project Medformer features → ViT hidden
        self.to_vit_hidden = nn.Linear(d_model_med, vit_hidden)

        # Token count adapter:  variable L → fixed 256 via 1D interpolation in token axis
        self.vit_token_len = vit_token_len
        self.norm = nn.LayerNorm(vit_hidden)

    @torch.no_grad()
    def _length_to_256(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, L, D] → [B, 256, D]
        B, L, D = x.shape
        if L == self.vit_token_len:
            return x
        x_t = x.transpose(1, 2)  # [B, D, L]
        x_t = F.interpolate(x_t, size=self.vit_token_len, mode="linear", align_corners=False)
        x = x_t.transpose(1, 2)
        return x

    def forward(self, eeg: torch.Tensor) -> torch.Tensor:
        # eeg: [B, 63, 250]
        # Medformer expects (B, C, T) → patch-embedded → encoder → features
        feats = self.medformer(eeg)              # shape ~ [B, Lm, d_model_med] per ATMS.Medformer.classification
        feats = self.to_vit_hidden(feats)        # [B, Lm, 1024]
        feats = self._length_to_256(feats)       # [B, 256, 1024]
        feats = self.norm(feats)
        return feats


# ==========================================================
#       Fused model that runs CLIP ViT with bridges
# ==========================================================
class FusedEEGViT(nn.Module):
    def __init__(self,
                 insert_layers: Iterable[int] = (6, 12, 18),
                 n_heads_bridge: int = 8,
                 lora_r: int = 8,
                 med_depth: int = 4,
                 vit_model_name: str = 'openai/clip-vit-large-patch14'):
        super().__init__()

        # 1) Load frozen CLIP ViT-L/14
        self.clip = CLIPVisionModel.from_pretrained(vit_model_name)
        for p in self.clip.parameters():
            p.requires_grad = False
        self.clip.eval()

        self.vit_hidden = self.clip.config.hidden_size  # 1024 for ViT-L/14
        self.vit_token_len = (self.clip.config.image_size // self.clip.config.patch_size) ** 2  # 16*16=256

        # 2) Brain side (Medformer + adapter to ViT token space)
        self.brain_adapter = BrainEncoderAdapter(seq_len=250, enc_in=63,
                                                 d_model_med=250, vit_hidden=self.vit_hidden,
                                                 vit_token_len=self.vit_token_len, depth=med_depth)

        # 3) Multi-spot bridges
        self.insert_layers = set(insert_layers)
        self.bridges = nn.ModuleDict()
        for idx in insert_layers:
            self.bridges[str(idx)] = CoProcessingBridge(d_vit=self.vit_hidden, d_brain=self.vit_hidden,
                                                        n_heads=n_heads_bridge, lora_r=lora_r, lora_alpha=16,
                                                        dropout=0.0)

        # 4) Gating for final fused embedding vs original CLIP embedding
        self.gate_head = nn.Sequential(
            nn.LayerNorm(self.vit_hidden),
            nn.Linear(self.vit_hidden, self.vit_hidden // 4),
            nn.GELU(),
            nn.Linear(self.vit_hidden // 4, 1),
            nn.Sigmoid(),  # per-sample scalar g in [0,1]
        )

        # A small projection to CLIP projection space (if you need to compare with pooled output)
        # Here we emulate CLIP's final pooling by taking CLS then applying the same post_layernorm
        self.final_norm = nn.LayerNorm(self.vit_hidden)

    def vit_embed_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor:
        """Run patch embedding + positional + pre layer norm to get initial tokens.
        Returns tokens before encoder blocks. Shape: [B, 1+256, 1024]
        """
        vm = self.clip.vision_model
        x = vm.embeddings.patch_embedding(pixel_values)  # [B, D, 16, 16]
        x = x.flatten(2).transpose(1, 2)                 # [B, 256, D]
        cls = vm.embeddings.class_embedding.to(x.dtype).unsqueeze(0).expand(x.size(0), -1, -1)
        x = torch.cat([cls, x], dim=1)                   # [B, 257, D]
        pos = vm.embeddings.position_embedding(vm.embeddings.position_ids[:, : x.size(1)])
        x = x + pos
        x = vm.pre_layrnorm(x)
        return x

    def run_encoder_with_bridges(self, tokens: torch.Tensor, brain_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward through ViT encoder, injecting bridges after selected blocks.
        tokens: [B, 1+256, D], brain_tokens: [B, 256, D]
        Returns updated (tokens, brain_tokens)
        """
        enc = self.clip.vision_model.encoder
        for i, blk in enumerate(enc.layers):
            # standard ViT block forward (frozen)
            tokens = blk(tokens)[0]  # Transformers blocks may return BaseModelOutputWithPast
            if (i + 1) in self.insert_layers:
                # exclude class token for cross-attn (optional): only patch tokens interact
                cls_tok, patch_tok = tokens[:, :1, :], tokens[:, 1:, :]
                bridge = self.bridges[str(i + 1)]
                patch_tok, brain_tokens = bridge(patch_tok, brain_tokens)
                tokens = torch.cat([cls_tok, patch_tok], dim=1)
        tokens = enc.post_layernorm(tokens)
        return tokens, brain_tokens

    @staticmethod
    def _cls_pool(tokens: torch.Tensor) -> torch.Tensor:
        return tokens[:, 0]  # [B, D]

    def forward(self, pixel_values: torch.Tensor, eeg: torch.Tensor) -> dict:
        """
        pixel_values: [B,3,224,224] (already normalized as CLIP expects)
        eeg: [B,63,250]
        Returns a dict with original CLIP embedding, fused embedding, and gating value.
        """
        # 1) Original (pure) CLIP embedding — frozen, used as contrast target
        with torch.no_grad():
            pure_outputs = self.clip(pixel_values)
            pure_tokens = pure_outputs.last_hidden_state  # [B, 257, D]
            pure_tokens = self.clip.vision_model.post_layernorm(pure_tokens)
            z_pure = self._cls_pool(pure_tokens)  # [B, D]

        # 2) EEG → ViT token space
        brain_tokens = self.brain_adapter(eeg)  # [B, 256, D]

        # 3) ViT tokens before encoder
        vit_tokens = self.vit_embed_tokens(pixel_values)

        # 4) Run encoder with bridges (multi-exchange)
        vit_tokens, brain_tokens = self.run_encoder_with_bridges(vit_tokens, brain_tokens)

        # 5) Fused embedding from class token
        z_fused = self._cls_pool(vit_tokens)     # [B, D]
        z_fused = self.final_norm(z_fused)

        # 6) Learn a gate g ∈ [0,1] per sample (can be conditioned on EEG-aware CLS)
        g = self.gate_head(z_fused).squeeze(-1)  # [B]
        z_gate = g.unsqueeze(-1) * z_fused + (1.0 - g).unsqueeze(-1) * z_pure

        return {
            'z_pure': z_pure.detach(),      # frozen CLIP embedding (no grad)
            'z_fused': z_fused,             # fused (with bridges) embedding
            'z_gate': z_gate,               # gated final embedding (for contrastive / retrieval)
            'gate': g,                      # [B]
        }


# ==========================================================
#                 Losses / Training helpers
# ==========================================================
class InfoNCE(nn.Module):
    def __init__(self, temperature: float = 0.07):
        super().__init__()
        self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1.0 / temperature))

    def forward(self, z_a: torch.Tensor, z_b: torch.Tensor) -> torch.Tensor:
        # z_a, z_b: [B, D]
        z_a = F.normalize(z_a, dim=-1)
        z_b = F.normalize(z_b, dim=-1)
        logits = z_a @ z_b.t() * self.logit_scale.exp()
        labels = torch.arange(z_a.size(0), device=z_a.device)
        loss = F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)
        return loss / 2


def contrastive_step(model: FusedEEGViT, batch: dict, temperature: float = 0.07,
                     lambda_sync: float = 0.1) -> Tuple[torch.Tensor, dict]:
    """One training step example.

    batch must contain:
      - pixel_values: [B,3,224,224] (CLIP preprocessed)
      - eeg: [B,63,250]
    """
    outputs = model(batch['pixel_values'], batch['eeg'])

    z_pure = outputs['z_pure']
    z_gate = outputs['z_gate']

    # Contrastive loss: bring (z_gate) close to (z_pure) to preserve CLIP semantics
    nce = InfoNCE(temperature)
    loss_contrast = nce(z_gate, z_pure)

    # Optional: simple sync regularizer aligning patch means
    # (You can craft a richer sync using intermediate tokens if you expose them.)
    # Here, for simplicity, we use cosine between fused and pure CLS
    loss_sync = 1.0 - F.cosine_similarity(z_gate, z_pure, dim=-1).mean()

    loss = loss_contrast + lambda_sync * loss_sync
    log = {
        'loss_contrast': loss_contrast.item(),
        'loss_sync': loss_sync.item(),
        'gate_mean': outputs['gate'].mean().item(),
    }
    return loss, log


# ==========================================================
#            How to unfreeze only what we need
# ==========================================================
def set_trainable_parameters(model: FusedEEGViT):
    """Freeze all CLIP params; train only bridges + gate + brain adapter (and Medformer tail if desired)."""
    # 1) Keep CLIP frozen (already requires_grad=False)

    # 2) Enable training for bridges (LoRA params) and gate
    for m in [model.bridges, model.gate_head, model.final_norm]:
        for p in m.parameters():
            p.requires_grad = True

    # 3) Brain side: either fully or partially trainable
    for name, p in model.brain_adapter.named_parameters():
        # Example: freeze early Medformer layers by name rule if needed
        p.requires_grad = True

    return model


# ==========================================================
#                       Usage sketch
# ==========================================================
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = FusedEEGViT(insert_layers=(6, 12, 18), n_heads_bridge=8, lora_r=8).to(device)
    set_trainable_parameters(model)

    B = 2
    pixel_values = torch.randn(B, 3, 224, 224, device=device)
    eeg = torch.randn(B, 63, 250, device=device)

    out = model(pixel_values, eeg)
    for k, v in out.items():
        if isinstance(v, torch.Tensor):
            print(k, tuple(v.shape))
        else:
            print(k, v)
