In [None]:
# requirements: torch, transformers, einops (可选)
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoConfig

# -------------------------
# 基本构件：跨流 CrossAttention
# -------------------------
class CrossAttention(nn.Module):
    def __init__(self, dim, n_heads, dropout=0.1):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim // n_heads

        self.q = nn.Linear(dim, dim, bias=False)
        self.k = nn.Linear(dim, dim, bias=False)
        self.v = nn.Linear(dim, dim, bias=False)
        self.out = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key_value, attn_mask=None):
        # query: [B, Tq, D], key_value: [B, Tk, D]
        B, Tq, D = query.shape
        Tk = key_value.shape[1]
        q = self.q(query).view(B, Tq, self.n_heads, self.head_dim).transpose(1,2)  # [B, H, Tq, Hd]
        k = self.k(key_value).view(B, Tk, self.n_heads, self.head_dim).transpose(1,2)  # [B, H, Tk, Hd]
        v = self.v(key_value).view(B, Tk, self.n_heads, self.head_dim).transpose(1,2)

        scores = torch.matmul(q, k.transpose(-2,-1)) / (self.head_dim ** 0.5)  # [B, H, Tq, Tk]
        if attn_mask is not None:
            scores = scores.masked_fill(attn_mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)  # [B, H, Tq, Hd]
        out = out.transpose(1,2).contiguous().view(B, Tq, D)
        return self.out(out), attn  # 返回注意力以便可视化

# -------------------------
# Brain Transformer Block（带 CrossFromModel）
# -------------------------
class BrainBlock(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, batch_first=True)
        self.cross_from_model = CrossAttention(dim=dim, n_heads=n_heads, dropout=dropout)
        self.ffn = nn.Sequential(
            nn.Linear(dim, int(dim*mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(dim*mlp_ratio), dim),
            nn.Dropout(dropout)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)

    def forward(self, x, model_kv=None, attn_mask=None):
        # x: [B, T_b, D_b]; model_kv: [B, T_m, D_b] (projected if needed)
        x2, _ = self.self_attn(x, x, x)
        x = x + x2
        x = self.norm1(x)
        if model_kv is not None:
            cross_out, attn = self.cross_from_model(x, model_kv, attn_mask=attn_mask)
            x = x + cross_out
        x = self.norm2(x)
        x = x + self.ffn(self.norm3(x))
        return x, attn if model_kv is not None else None

# -------------------------
# Model Stream wrapper (pretrained LM with insertion point)
# -------------------------
class LMWithCross(nn.Module):
    def __init__(self, lm_name='gpt2', cross_dim=512, insert_layers=[4,8,12]):
        super().__init__()
        # load a pretrained model and expose hidden states
        self.config = AutoConfig.from_pretrained(lm_name, output_hidden_states=True)
        self.lm = AutoModelForCausalLM.from_pretrained(lm_name, config=self.config)
        self.insert_layers = insert_layers
        # project brain dim to LM hidden dim if necessary
        self.brain_to_lm_proj = nn.Linear(cross_dim, self.config.hidden_size)
        # create CrossAttention modules to query brain stream from LM stream
        self.cross_modules = nn.ModuleDict({
            str(i): CrossAttention(self.config.hidden_size, n_heads=8) for i in insert_layers
        })

    def forward(self, input_ids, brain_repr_by_layer=None, attention_mask=None):
        # brain_repr_by_layer: dict(layer_idx -> tensor [B, T_b, D_b])
        outputs = self.lm.model(input_ids, output_hidden_states=True, attention_mask=attention_mask)
        hs = list(outputs.hidden_states)  # tuple of (embeddings, layer1, layer2,...)
        # iterate through layers and inject cross-attn where desired (simplified)
        attn_maps = {}
        for idx in self.insert_layers:
            layer_h = hs[idx]  # [B, T_m, H]
            if brain_repr_by_layer and idx in brain_repr_by_layer:
                br = brain_repr_by_layer[idx]  # [B, T_b, D_b]
                br_proj = self.brain_to_lm_proj(br)  # [B, T_b, H]
                cross_out, attn = self.cross_modules[str(idx)](layer_h, br_proj)
                # residual add (simplified — real injection needs to integrate into transformer block)
                hs[idx] = layer_h + cross_out
                attn_maps[idx] = attn
        # compute logits by passing last hidden state through LM head
        last_h = hs[-1]
        logits = self.lm.lm_head(last_h)
        return logits, attn_maps

# -------------------------
# 顶层并行模型组合
# -------------------------
class CoProcessingModel(nn.Module):
    def __init__(self, brain_dim=512, n_brain_layers=6, lm_name='gpt2', insert_layers=[4,8]):
        super().__init__()
        self.brain_embed = nn.Linear(in_features=64, out_features=brain_dim)  # 假设原始脑特征维64
        self.brain_blocks = nn.ModuleList([BrainBlock(brain_dim, n_heads=8) for _ in range(n_brain_layers)])
        self.lm_with_cross = LMWithCross(lm_name=lm_name, cross_dim=brain_dim, insert_layers=insert_layers)

    def forward(self, brain_inputs, input_ids, brain_time_to_layer_map=None):
        # brain_inputs: [B, T_b, feat_dim]
        x = self.brain_embed(brain_inputs)  # project
        brain_by_layer = {}
        # iterate brain layers, optionally cross from LM hidden states via brain_time_to_layer_map
        for i, blk in enumerate(self.brain_blocks):
            # optionally obtain model hidden to cross from (here ignored for simplicity)
            x, attn = blk(x, model_kv=None)
            brain_by_layer[i] = x
        # map brain layers to LM insert layer indices (simple mapping or learned)
        mapped = {}
        for lm_idx in self.lm_with_cross.insert_layers:
            # simple mapping: choose brain layer by index mod
            b_idx = lm_idx % len(self.brain_blocks)
            mapped[lm_idx] = brain_by_layer[b_idx]
        logits, attn_maps = self.lm_with_cross(input_ids, brain_repr_by_layer=mapped)
        return logits, attn_maps
