## Q-Former example ##

In [None]:
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
@dataclass
class QFormerConfig:
    d_model: int              # скрытый размер query-токенов и выхода
    d_kv: int                 # размер каналов памяти A' (key/value до проекций)
    n_heads: int              # число голов внимания
    K: int                    # число query-токенов (фиксированная длина Q-Prompt)
    L: int                    # число блоков (слоёв) Q-Former
    dropout: float = 0.1
    use_mem_posenc: bool = False  # добавить синусоидальную позиционку к памяти A'



In [None]:
def _split_heads(x: torch.Tensor, n_heads: int) -> torch.Tensor:
    # [B,L,D] -> [B,H,L,Dh]
    B, L, D = x.shape
    Dh = D // n_heads
    x = x.view(B, L, n_heads, Dh).transpose(1, 2)  # [B,H,L,Dh]
    return x


def _merge_heads(x: torch.Tensor) -> torch.Tensor:
    # [B,H,L,Dh] -> [B,L,H*Dh]
    B, H, L, Dh = x.shape
    return x.transpose(1, 2).contiguous().view(B, L, H * Dh)

In [None]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B,T,D] -> добавляет PE по T к последней размерности D
        B, T, D = x.shape
        device = x.device
        pe = torch.zeros(T, D, device=device)
        position = torch.arange(0, T, device=device, dtype=torch.float32).unsqueeze(1)  # [T,1]
        div_term = torch.exp(torch.arange(0, D, 2, device=device, dtype=torch.float32) * (-math.log(10000.0) / D))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return x + pe.unsqueeze(0)  # [1,T,D] broadcast на B

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_model = d_model
        self.n_heads = n_heads
        self.W_q = nn.Linear(d_model, d_model, bias=True)
        self.W_k = nn.Linear(d_model, d_model, bias=True)
        self.W_v = nn.Linear(d_model, d_model, bias=True)
        self.W_o = nn.Linear(d_model, d_model, bias=True)
        self.drop_attn = nn.Dropout(dropout)
        self.drop_res = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # x: [B,L,D]
        B, L, D = x.shape
        q = _split_heads(self.W_q(x), self.n_heads)  # [B,H,L,Dh]
        k = _split_heads(self.W_k(x), self.n_heads)  # [B,H,L,Dh]
        v = _split_heads(self.W_v(x), self.n_heads)  # [B,H,L,Dh]

        # scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(D // self.n_heads)  # [B,H,L,L]
        if attn_mask is not None:
            # attn_mask: [B,1,1,L] с 0 для маскируемых позиций (или -inf добавкой)
            if attn_mask.dtype == torch.bool or attn_mask.dtype == torch.uint8:
                scores = scores.masked_fill(~attn_mask, float("-inf"))
            else:
                scores = scores + attn_mask
        attn = F.softmax(scores, dim=-1)
        attn = self.drop_attn(attn)
        out = torch.matmul(attn, v)  # [B,H,L,Dh]
        out = _merge_heads(out)      # [B,L,D]
        out = self.W_o(out)
        out = self.drop_res(out)
        return out


class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model_q: int, d_model_mem: int, n_heads: int, dropout: float):
        super().__init__()
        assert d_model_q % n_heads == 0, "d_model_q must be divisible by n_heads"
        self.d_model_q = d_model_q
        self.d_model_mem = d_model_mem
        self.n_heads = n_heads

        # Проекции в общее пространство внимания размером d_model_q
        self.W_q = nn.Linear(d_model_q, d_model_q, bias=True)
        self.W_k = nn.Linear(d_model_mem, d_model_q, bias=True)
        self.W_v = nn.Linear(d_model_mem, d_model_q, bias=True)
        self.W_o = nn.Linear(d_model_q, d_model_q, bias=True)
        self.drop_attn = nn.Dropout(dropout)
        self.drop_res = nn.Dropout(dropout)

    def forward(
        self,
        q_inp: torch.Tensor,                 # [B,K,d_model_q] — запросы (Q-токены)
        mem: torch.Tensor,                   # [B,T,d_model_mem] — память A'
        mem_mask: Optional[torch.Tensor] = None  # [B,T] bool/long -> будет преобразована
    ) -> torch.Tensor:
        B, K, Dq = q_inp.shape
        _, T, _ = mem.shape

        q = _split_heads(self.W_q(q_inp), self.n_heads)   # [B,H,K,Dh]
        k = _split_heads(self.W_k(mem), self.n_heads)     # [B,H,T,Dh]
        v = _split_heads(self.W_v(mem), self.n_heads)     # [B,H,T,Dh]

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(Dq // self.n_heads)  # [B,H,K,T]
        if mem_mask is not None:
            # mem_mask: [B,T] -> [B,1,1,T], True=keep, False=mask
            if mem_mask.dtype != torch.bool:
                mem_mask = mem_mask != 0
            mask = mem_mask.view(B, 1, 1, T)
            scores = scores.masked_fill(~mask, float("-inf"))

        attn = F.softmax(scores, dim=-1)  # по T
        attn = self.drop_attn(attn)
        out = torch.matmul(attn, v)       # [B,H,K,Dh]
        out = _merge_heads(out)           # [B,K,Dq]
        out = self.W_o(out)
        out = self.drop_res(out)
        return out


class FFN(nn.Module):
    def __init__(self, d_model: int, dropout: float):
        super().__init__()
        self.lin1 = nn.Linear(d_model, 4 * d_model)
        self.lin2 = nn.Linear(4 * d_model, d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.drop(self.lin2(F.gelu(self.lin1(x))))

In [None]:
class QFormerBlock(nn.Module):
    """Один блок: Pre-LN → SelfAttn(Q) → Pre-LN → CrossAttn(Q←A′) → Pre-LN → FFN"""

    def __init__(self, d_model: int, d_kv: int, n_heads: int, dropout: float):
        super().__init__()
        self.ln_q1 = nn.LayerNorm(d_model)
        self.self_attn = MultiHeadSelfAttention(d_model, n_heads, dropout)
        self.ln_q2 = nn.LayerNorm(d_model)
        self.cross_attn = MultiHeadCrossAttention(d_model, d_kv, n_heads, dropout)
        self.ln_q3 = nn.LayerNorm(d_model)
        self.ffn = FFN(d_model, dropout)
        self.drop = nn.Dropout(dropout)

    def forward(self, q_tokens: torch.Tensor, mem: torch.Tensor, mem_mask: Optional[torch.Tensor]) -> torch.Tensor:
        # Self-Attn
        x = q_tokens + self.self_attn(self.ln_q1(q_tokens))
        # Cross-Attn (Q <- A')
        x = x + self.cross_attn(self.ln_q2(x), mem, mem_mask)
        # FFN
        x = x + self.ffn(self.ln_q3(x))
        return x

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, d_kv: int, n_heads: int, dropout: float):
        super().__init__()
        self.ln_q1 = nn.LayerNorm(d_model)
        self.self_attn = MultiHeadSelfAttention(d_model, n_heads, dropout)
        self.ln_q2 = nn.LayerNorm(d_model)
        self.ffn = FFN(d_model, dropout)

    def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor]) -> torch.Tensor:
        # Self-Attn
        x = x + self.self_attn(self.ln_q1(x), x_mask)

        # FFN
        x = x + self.ffn(self.ln_q2(x))
        return x

In [None]:
class QFormer(nn.Module):
    """
    Минимальный Q-Former:
    - обучаемые query-токены: [1,K,d_model]
    - L блоков: SelfAttn(Q) → CrossAttn(Q←A′) → FFN
    - вход памяти: A′ [B,T,d_kv], опциональная маска [B,T]
    - выход: Q-Prompt [B,K,d_model]
    """

    def __init__(self, cfg: QFormerConfig):
        super().__init__()
        self.cfg = cfg
        self.query_tokens = nn.Parameter(torch.randn(1, cfg.K, cfg.d_model) * 0.02)
        self.layers = nn.ModuleList(
            [QFormerBlock(cfg.d_model, cfg.d_kv, cfg.n_heads, cfg.dropout) for _ in range(cfg.L)]
        )
        self.mem_pos = SinusoidalPositionalEncoding(cfg.d_kv) if cfg.use_mem_posenc else None

    @torch.no_grad()
    def num_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters())

    def forward(
        self,
        A_prime: torch.Tensor,                    # [B,T,d_kv] — память (аудио-признаки после проекции)
        mem_mask: Optional[torch.Tensor] = None   # [B,T] (1/True = keep)
    ) -> torch.Tensor:
        assert A_prime.dim() == 3 and A_prime.size(-1) == self.cfg.d_kv, \
            f"A' must be [B,T,{self.cfg.d_kv}]"
        B = A_prime.size(0)
        mem = self.mem_pos(A_prime) if self.mem_pos is not None else A_prime

        q = self.query_tokens.expand(B, -1, -1)  # [B,K,d_model]
        for layer in self.layers:
            q = layer(q, mem, mem_mask)          # [B,K,d_model]
        return q  # это и есть Q-Prompt

In [None]:
class MLP(nn.Module):
    """Проектор: CLIP(512) -> prefix_length * hidden_size_LM"""
    def __init__(self, in_dim: int, out_dim: int, hid_dim: int = 2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.Tanh(),
            nn.Linear(hid_dim, out_dim),
        )

    def forward(self, x):
        return self.net(x)

class ClipCapPrefix(nn.Module):
    """
    Frozen LM (GPT-2) + обучаемый projector.
    Учим только projector, который превращает CLIP-эмбеддинг в prefix-эмбеддинги для LM.
    """
    def __init__(self,
                 gpt_model,
                 prefix_length: int,
                 cfg: QFormerConfig,
                 prefix_size: int = 512,

                 ):
        super().__init__()
        self.gpt = gpt_model
        self.prefix_length = prefix_length

        # берём слой входных эмбеддингов правильным общим способом
        self.wte = self.gpt.get_input_embeddings()
        self.hidden = self.wte.embedding_dim  # hidden_size GPT-2

        #TODO заменить проекцию на другую реализацию self.project = MLP(prefix_size, prefix_length * self.hidden)
        self.project = QFormer(cfg)
        # можно использовать трансформер блоки
        # self.project = nn.ModuleList([TrasformerBlock() for _ in cfg.L])
        # можно использовать свертки


        # замораживаем LM: учим только project
        for p in self.gpt.parameters():
            p.requires_grad_(False)

    def forward(self, input_ids: torch.Tensor, prefix: torch.Tensor, attention_mask: torch.Tensor):
        """
        input_ids: [B, T]
        prefix:    [B, T2, Emb]
        attention_mask: [B, T]
        """
        B, T = input_ids.shape

        # эмбеддинги текста: [B, T, H]
        text_emb = self.wte(input_ids)

        # prefix-эмбеддинги: [B, P, H]
        pref = self.project(prefix)

        # fusion: concat(prefix, text) -> [B, P+T, H]
        inputs_embeds = torch.cat([pref, text_emb], dim=1)

        # attention mask тоже расширяем на prefix (там все 1)
        prefix_attn = torch.ones((B, self.prefix_length), device=input_ids.device, dtype=attention_mask.dtype)
        # [B, P+T]
        full_attn = torch.cat([prefix_attn, attention_mask], dim=1)  # [B, P+T]

        # labels: prefix игнорируем (-100), текст оставляем как есть.
        # transformers labels сдвигаются внутри модели, а -100 маскируется из loss.
        labels = torch.cat(
            [torch.full((B, self.prefix_length), -100, device=input_ids.device), input_ids],
            dim=1
        )

        out = self.gpt(inputs_embeds=inputs_embeds, attention_mask=full_attn, labels=labels)
        return out

prefix_length = 10
clipcap = ClipCapPrefix(gpt_model=gpt, prefix_length=prefix_length).to(device)

print("Trainable params (должен быть только projector):",
      sum(p.numel() for p in clipcap.parameters() if p.requires_grad))