In [1]:
# import pandas as pd
# import numpy as np
# from collections import defaultdict
# import regex as re
# from multiprocessing import Pool
# from support.find_chunk_boundaries import find_chunk_boundaries
# from memory_profiler import profile
# import time, tracemalloc
# from dataclasses import dataclass
# from typing import BinaryIO, Iterable, Iterator
# import random

import torch
import torch.nn as nn
from einops import rearrange, einsum, reduce, repeat

# Q 3.4.2

In [2]:
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        super().__init__()
        ## Construct a linear transformation module. This function should accept the following parameters:
        self.in_features = in_features ## final dimension of the input
        self.out_features = out_features ## final dimension of the output
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters

        w = torch.empty(in_features, out_features)
        std = torch.sqrt(torch.tensor(2.0/(in_features+out_features)))
        self.weight = nn.Parameter(nn.init.trunc_normal_(w, mean=0.0, std=std.item(),a=-3*std.item(),b=3*std.item()))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## Apply the linear transformation to the input
        output = einsum(
            self.weight, x,
            "in_dim out_dim, in_dim -> out_dim"
        )
        return output

# Q 3.4.3

In [3]:
class Embedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        ## Construct an embedding module
        super().__init__()
        self.num_embeddings = num_embeddings ## Size of the vocabulary
        self.embedding_dim = embedding_dim ## Dimension of the embedding vectors
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters
        
        w = torch.empty(num_embeddings, embedding_dim)
        std = 1.0
        self.weight = nn.Parameter(nn.init.trunc_normal_(w, mean=0.0, std=std,a=-3,b=3))
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        ## Lookup the embedding vectors for the given token IDs.
        return self.weight[token_ids]


# Q 3.5.1

In [4]:
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
        ## Construct the RMSNorm module.
        super().__init__()
        self.d_model = d_model ## Hidden dimension of the model
        self.eps = eps ## Epsilon value for numerical stability
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters

        self.weights = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## Process an input tensor of shape
        
        in_dtype = x.dtype
        x = x.to(torch.float32)
        x_squaremean = reduce(
            x**2, "... d_model -> ... 1", 'mean'
        )
        x_RMS = (x_squaremean+self.eps).sqrt()
        result = x / x_RMS * self.weights
        return result.to(in_dtype)

# Q 3.5.2

In [5]:
class SwiGLU(nn.Module):
    def __init__(self, d_model: int, d_ff: int | None = None):
        super().__init__()
        self.d_model = d_model ## Hidden dimension of the model
        if d_ff is None:
            q = round(d_model*8/3/64)
            self.d_ff = q*64
        else:
            self.d_ff = d_ff
        
        self.w1_weight = nn.Parameter(torch.randn(self.d_ff, self.d_model))
        self.w2_weight = nn.Parameter(torch.randn(self.d_model, self.d_ff))
        self.w3_weight = nn.Parameter(torch.randn(self.d_ff, self.d_model))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w1x = einsum(
            self.w1_weight, x,
            "d_ff d_model, ... d_model -> ... d_ff"
        )
        w3x = einsum(
            self.w3_weight, x,
            "d_ff d_model, ... d_model -> ... d_ff"
        )
        SiLUw1x = w1x*torch.sigmoid(w1x)
        part2 = SiLUw1x * w3x
        result = einsum(
            self.w2_weight, part2,
            "d_model d_ff, ... d_ff -> ... d_model"
        )
        return result

# Q 3.5.3

In [6]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device: torch.device | None = None):
        ## Construct the RoPE module and create buffers if needed.
        super().__init__()
        assert d_k % 2 == 0, "RoPE requires even head dimension (pairs of features)"
        self.theta = theta ## $\\Theta$ value for the RoPE
        self.d_k = d_k ## dimension of query and key vectors
        self.max_seq_len = max_seq_len ## Maximum sequence length that will be inputted
        self.device = device ## Device to store the buffer on

        dim_index = torch.arange(self.d_k // 2, dtype=torch.float32)
        position_index = torch.arange(self.max_seq_len, dtype=torch.float32)
        theta_inv_index = self.theta**(-2*dim_index/d_k)
        theta_ik = einsum(
            position_index, theta_inv_index,
            "s, d -> s d"
        )
        sin = torch.sin(theta_ik)
        cos = torch.cos(theta_ik)
        
        self.register_buffer("sin", sin, persistent=False)
        self.register_buffer("cos", cos, persistent=False)
        
    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        
        assert x.shape[-1] == self.d_k, "The last dim of input should be equal to dim of embedding."
        assert x.shape[-2] == token_positions.shape[-1], "token_positions length must match sequence length"
        sin_expend = self.sin[token_positions]
        cos_expend = self.cos[token_positions]

        x_even = x[...,::2]
        x_odd = x[...,1::2]

        y_even = x_even*cos_expend-x_odd*sin_expend
        y_odd = x_even*sin_expend+x_odd*cos_expend
        y = rearrange(torch.stack([y_even, y_odd], dim=-1), '... s d two -> ... s (d two)')
        return y

# Q 3.5.4

In [7]:
def softmax(x: torch.Tensor, dim: int) -> torch.Tensor:
    x_max = torch.max(x, dim=dim, keepdim=True).values
    x_subtract_max = x-x_max
    x_subtract_max_exp = torch.exp(x_subtract_max)
    x_subtract_max_exp_sum = torch.sum(x_subtract_max_exp, dim=dim, keepdim=True)
    y = x_subtract_max_exp/x_subtract_max_exp_sum
    return y

In [8]:
def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    d_k = Q.shape[-1]
    QK = einsum(
        Q, K, "... n d_k, ... m d_k -> ... n m"
    )
    QK_scaled = QK/torch.tensor(d_k).sqrt()
    if mask is not None:
        M = torch.where(mask, torch.tensor(0.0), torch.tensor(float('-inf')))
        QK_scaled += M
    QK_scaled_softmax = softmax(QK_scaled, Q.dim()-1)
    y = einsum(
        QK_scaled_softmax, V, "... n m, ... m d_v -> ... n d_v"
    )
    return y


# 3.5.5

In [40]:
import numpy as np

# 加载 snapshot 文件
snapshot_path = "tests/_snapshots/test_multihead_self_attention.npz"
data = np.load(snapshot_path)

print("包含的 keys:", list(data.keys()))
correct_output = data["array"]   # pytest 存的时候默认 key="array"


包含的 keys: ['array']


In [53]:
import torch
import numpy as np
import json
from pathlib import Path
from tests.adapters import run_multihead_self_attention
from tests.common import FIXTURES_PATH  # 你项目里有这个路径定义

# --------------------------
# 1. 准备参数（和 conftest.py 一致）
# --------------------------
batch_size = 4
n_queries = 12
n_heads = 4
d_head = 16
d_model = n_heads * d_head  # = 64

# 固定随机种子，生成输入
torch.manual_seed(4)
in_embeddings = torch.randn(batch_size, n_queries, d_model)

# 加载权重和配置
state_dict = torch.load(FIXTURES_PATH / "ts_tests" / "model.pt", map_location="cpu")
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

q_proj_weight = state_dict["layers.0.attn.q_proj.weight"]
k_proj_weight = state_dict["layers.0.attn.k_proj.weight"]
v_proj_weight = state_dict["layers.0.attn.v_proj.weight"]
o_proj_weight = state_dict["layers.0.attn.output_proj.weight"]


# --------------------------
# 2. 跑你的实现
# --------------------------
output = run_multihead_self_attention(
    d_model=d_model,
    num_heads=n_heads,
    q_proj_weight=q_proj_weight,
    k_proj_weight=k_proj_weight,
    v_proj_weight=v_proj_weight,
    o_proj_weight=o_proj_weight,
    in_features=in_embeddings,
)

output_np = output.detach().cpu().numpy()

# --------------------------
# 3. 加载 snapshot 标准答案
# --------------------------
snapshot_path = Path("tests/_snapshots/test_multihead_self_attention.npz")
snapshot_output = np.load(snapshot_path)["array"]

# --------------------------
# 4. 对比
# --------------------------
diff = output_np - snapshot_output
print("✅ in_embeddings.shape:", in_embeddings.shape)
print("✅ output.shape:", output_np.shape)
print("📊 最大误差:", abs(diff).max())
print("📊 平均误差:", abs(diff).mean())


✅ in_embeddings.shape: torch.Size([4, 12, 64])
✅ output.shape: (4, 12, 64)
📊 最大误差: 1.0861819
📊 平均误差: 0.16066213


In [45]:
batch_index = 0
queries_index = 0

In [46]:
output_np[batch_index,queries_index,:]

array([-0.17197564, -0.08255363, -0.0124781 , -0.13603452,  0.00127764,
       -0.12001598,  0.10902623, -0.09237173, -0.02512559, -0.00557411,
       -0.11389705, -0.20467345, -0.01746393, -0.10730196, -0.0450832 ,
        0.00112269, -0.13262169, -0.01873078,  0.20487486,  0.05051364,
       -0.16577597,  0.01586309,  0.3289656 ,  0.01362708, -0.1481699 ,
        0.08887245,  0.3209791 ,  0.1910902 ,  0.14688256, -0.15865944,
       -0.25018042, -0.08816586, -0.0465146 , -0.06768438, -0.15954879,
        0.1471851 , -0.14069512, -0.10784942, -0.4013285 ,  0.21584041,
       -0.20935541,  0.00850761, -0.07735043,  0.19367744, -0.10390159,
       -0.05698189, -0.12960042, -0.12811019, -0.05007379,  0.03976776,
        0.1602968 , -0.26809353,  0.2720321 ,  0.01819955, -0.29095235,
       -0.04509705,  0.050052  , -0.08730943,  0.04727838, -0.03699401,
        0.07663991, -0.14175309,  0.02848481, -0.24121457], dtype=float32)

In [47]:
snapshot_output[batch_index,queries_index,:]

array([-0.24033818,  0.14196444, -0.04016975,  0.12584177,  0.2838859 ,
        0.45506647, -0.3701429 ,  0.3658208 , -0.26944515, -0.00672541,
       -0.40027574, -0.03942056, -0.47971496, -0.30313554,  0.02359407,
        0.10665806,  0.5105323 ,  0.22533073, -0.3322856 , -0.05323157,
        0.24506564, -1.0703188 ,  0.05229701, -0.5186902 ,  0.07613137,
        0.50190294,  0.02792899, -0.16633943,  0.18831418, -0.3904159 ,
       -0.5906472 ,  0.23029123, -0.13024016,  0.7513475 ,  0.03903922,
        0.39151755, -0.55778044, -0.44849786,  0.26461315, -0.25431275,
       -0.8647511 , -0.75121766,  0.37779185,  0.16481523, -0.1681257 ,
       -0.6019287 , -0.49165186,  0.25521484, -0.63498884,  0.13234214,
       -0.1674973 , -0.00461843, -0.03648568, -0.28829518, -0.3748052 ,
        0.12815759,  0.429467  ,  0.07528834,  0.22757603,  0.5270285 ,
        0.7546846 , -0.7465039 , -0.34337416, -0.19546816], dtype=float32)

### Step1: 让我们从没有head 和 batch的简单情况开始

In [54]:
# 设置维度参数
batch = 1
seq_len = 3
d_in = 4
d_k = 6
d_v = 8
num_heads = 1
torch.manual_seed(4)
# 输入张量
in_features = torch.randn(batch, seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)

In [39]:
d_k = torch.tensor(q_proj_weight.shape[-2])

Q = einsum(
    q_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
K = einsum(
    k_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
V = einsum(
    v_proj_weight, in_features,
    "d_v d_in, ... d_in -> ... d_v"
)
head = scaled_dot_product_attention(Q,K,V)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

tensor([[[ 6.3022,  3.5587,  0.0493,  2.1251],
         [ 6.2906,  3.5453,  0.0229,  2.2065],
         [ 6.2706,  3.5209, -0.0067,  2.2971]]])

### Step2: 加入batch

In [55]:
# 设置维度参数
batch = 2
seq_len = 3
d_in = 4
d_k = 6
d_v = 8
num_heads = 1
torch.manual_seed(4)
# 输入张量
in_features = torch.randn(batch, seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)

In [56]:
d_k = torch.tensor(q_proj_weight.shape[-2])

Q = einsum(
    q_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
K = einsum(
    k_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
V = einsum(
    v_proj_weight, in_features,
    "d_v d_in, ... d_in -> ... d_v"
)
head = scaled_dot_product_attention(Q,K,V)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

tensor([[[ -2.4333,  -2.2128,   1.1368,  -2.7840],
         [ -4.1671,  -0.1571,   1.0785, -19.5492],
         [ -3.1011,  -2.1886,   1.3718,  -5.7030]],

        [[  3.7037,   2.6312,  -1.1190,  13.7237],
         [  0.0890,  -2.3049,  -0.1087,  -2.6351],
         [  0.1030,  -2.2935,  -0.1119,  -2.5340]]])

### Step3: 加入head

In [57]:
# 设置维度参数
batch = 2
seq_len = 3
d_in = 4
d_k = 6
d_v = 8
num_heads = 2
torch.manual_seed(4)
# 输入张量
in_features = torch.randn(batch, seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)

In [64]:
d_k = torch.tensor(q_proj_weight.shape[-2])

Q = einsum(
    q_proj_weight, in_features,
    "nd_k d_in, ... d_in -> ... nd_k"
)
Q_head = rearrange(
    Q, "... seq_len (n d_k) -> ... n seq_len d_k", n = num_heads
)
K = einsum(
    k_proj_weight, in_features,
    "nd_k d_in, ... d_in -> ... nd_k"
)
K_head = rearrange(
    K, "... seq_len (n d_k) -> ... n seq_len d_k", n = num_heads
)
V = einsum(
    v_proj_weight, in_features,
    "nd_v d_in, ... d_in -> ... nd_v"
)
V_head = rearrange(
    V, "... seq_len (n d_v) -> ... n seq_len d_v", n = num_heads
)
head = scaled_dot_product_attention(Q_head,K_head,V_head)
head = rearrange(
    head, "... n seq_len d_v -> ... seq_len (n d_v)"
)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

tensor([[[-0.9310,  6.0378, -0.0974, -0.8215],
         [10.2525,  5.0840, -8.1525, -7.2442],
         [-3.5506,  2.4867,  1.0241,  1.9592]],

        [[-3.3017, -6.0040,  7.7340,  6.5704],
         [ 3.6543, 11.7540,  0.8076, -8.0421],
         [ 4.2321, -2.0617, -0.4722, -4.8180]]])

In [65]:
[attention.shape, in_features.shape,head.shape]

[torch.Size([2, 3, 4]), torch.Size([2, 3, 4]), torch.Size([2, 3, 16])]

### Step4: 加入mask

In [86]:
# 设置维度参数
batch = 2
seq_len = 3
d_in = 4
d_k = 6
d_v = 8
num_heads = 2
torch.manual_seed(4)
# 输入张量
in_features = torch.randn(batch, seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)

In [None]:
seq_len = in_features.shape[-2]

mask = torch.tril(torch.ones(seq_len,seq_len,dtype=torch.bool))
expend_shape = (*Q_head.shape[:-1], seq_len)
mask_boardcasted = mask.expand(expend_shape)
Q = einsum(
    q_proj_weight, in_features,
    "nd_k d_in, ... d_in -> ... nd_k"
)
Q_head = rearrange(
    Q, "... seq_len (n d_k) -> ... n seq_len d_k", n = num_heads
)
K = einsum(
    k_proj_weight, in_features,
    "nd_k d_in, ... d_in -> ... nd_k"
)
K_head = rearrange(
    K, "... seq_len (n d_k) -> ... n seq_len d_k", n = num_heads
)
V = einsum(
    v_proj_weight, in_features,
    "nd_v d_in, ... d_in -> ... nd_v"
)
V_head = rearrange(
    V, "... seq_len (n d_v) -> ... n seq_len d_v", n = num_heads
)
head = scaled_dot_product_attention(Q_head,K_head,V_head,mask_boardcasted)
head = rearrange(
    head, "... n seq_len d_v -> ... seq_len (n d_v)"
)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

tensor([[[ -0.6790,  -2.4998,  -8.9383,   4.1447],
         [ 10.3599,   5.2157,  -8.1922,  -7.3385],
         [ -3.5506,   2.4867,   1.0241,   1.9592]],

        [[-12.1002,  -3.7207,   2.2474,   7.5827],
         [ -8.2128,   6.3416,  -2.3230,  -5.1408],
         [  4.2321,  -2.0617,  -0.4722,  -4.8180]]])

### Step5: 加入RoPE

In [88]:
# 设置维度参数
batch = 2
seq_len = 3
d_in = 4
d_k = 6
d_v = 8
num_heads = 2
torch.manual_seed(4)
# 输入张量
in_features = torch.randn(batch, seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)

In [83]:
d_k = torch.tensor(q_proj_weight.shape[-2])

theta = 10000
max_seq_len = 10
rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len)
position = torch.arange(seq_len)


seq_len = in_features.shape[-2]

mask = torch.tril(torch.ones(seq_len,seq_len,dtype=torch.bool))
expend_shape = (*Q_head.shape[:-1], seq_len)
mask_boardcasted = mask.expand(expend_shape)
Q = einsum(
    q_proj_weight, in_features,
    "nd_k d_in, ... d_in -> ... nd_k"
)
Q_head = rearrange(
    Q, "... seq_len (n d_k) -> ... n seq_len d_k", n = num_heads
)
position_expend_shape = (Q_head.shape[:-1])
position_boardcasted = position.expand(position_expend_shape)
Q_head_rope = rope(Q_head, position_boardcasted)
K = einsum(
    k_proj_weight, in_features,
    "nd_k d_in, ... d_in -> ... nd_k"
)
K_head = rearrange(
    K, "... seq_len (n d_k) -> ... n seq_len d_k", n = num_heads
)
K_head_rope = rope(K_head, position_boardcasted)
V = einsum(
    v_proj_weight, in_features,
    "nd_v d_in, ... d_in -> ... nd_v"
)
V_head = rearrange(
    V, "... seq_len (n d_v) -> ... n seq_len d_v", n = num_heads
)
head = scaled_dot_product_attention(Q_head_rope,K_head_rope,V_head,mask_boardcasted)
head = rearrange(
    head, "... n seq_len d_v -> ... seq_len (n d_v)"
)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

AssertionError: The last dim of input should be equal to dim of embedding.

In [84]:
[Q_head.shape, position_boardcasted.shape]

[torch.Size([2, 2, 3, 6]), torch.Size([2, 2, 3])]

In [85]:
d_k

tensor(12)

In [82]:
position_boardcasted.shape

torch.Size([2, 2, 3])

### Step3: 加入RoPEd

In [12]:
d_k = q_proj_weight.shape[-2]
seq_len = in_features.shape[-2]

mask = torch.tril(torch.ones(seq_len,seq_len,dtype=torch.bool))
mask_expend_shape = (*in_features.shape[:-2], seq_len, seq_len)
mask_boardcasted = mask.expand(mask_expend_shape)

theta = 10000
max_seq_len = 10
rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len)
position = torch.arange(seq_len)
position_expend_shape = (in_features.shape[:-1])
position_boardcasted = position.expand(position_expend_shape)

Q = einsum(
    q_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
Q_rope = rope(Q, position_boardcasted)
K = einsum(
    k_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
K_rope = rope(K, position_boardcasted)
V = einsum(
    v_proj_weight, in_features,
    "d_v d_in, ... d_in -> ... d_v"
)
head = scaled_dot_product_attention(Q_rope,K_rope,V,mask_boardcasted)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

tensor([[-8.8476,  3.6411,  6.7984,  0.1213],
        [ 3.6279, -1.3175, -3.8935, -0.5140],
        [ 4.2604,  0.5462, -4.2576, -1.6795]])

### Step4: 加入head

In [1062]:
# 设置维度参数
batch = 1
seq_len = 3
d_in = 4
d_k = 6
d_v = 8
num_heads = 2

# 输入张量
# in_features = torch.randn(batch, seq_len, d_in)
in_features = torch.randn(seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)

In [13]:
d_k = q_proj_weight.shape[-2]//num_heads
seq_len = in_features.shape[-2]

mask = torch.tril(torch.ones(seq_len,seq_len,dtype=torch.bool))
mask_expend_shape = (*in_features.shape[:-2], seq_len, seq_len)
mask_boardcasted = mask.expand(mask_expend_shape)

theta = 10000
max_seq_len = 10
rope = RotaryPositionalEmbedding(theta, d_k*num_heads, max_seq_len)
position = torch.arange(seq_len)
position_expend_shape = (in_features.shape[:-1])
position_boardcasted = position.expand(position_expend_shape)

Q = einsum(
    q_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
Q_rope = rope(Q, position_boardcasted)
K = einsum(
    k_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
K_rope = rope(K, position_boardcasted)
V = einsum(
    v_proj_weight, in_features,
    "d_v d_in, ... d_in -> ... d_v"
)
head = scaled_dot_product_attention(Q_rope,K_rope,V,mask_boardcasted)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

tensor([[-8.8476,  3.6411,  6.7984,  0.1213],
        [ 3.6279, -1.3175, -3.8935, -0.5140],
        [ 4.2604,  0.5462, -4.2576, -1.6795]])

### Step5: 加入batch

In [15]:
# 设置维度参数
batch = 2
seq_len = 3
d_in = 4
d_k = 6
d_v = 8
num_heads = 2

# 输入张量
in_features = torch.randn(batch, seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)

In [16]:
d_k = q_proj_weight.shape[-2]//num_heads
seq_len = in_features.shape[-2]

mask = torch.tril(torch.ones(seq_len,seq_len,dtype=torch.bool))
mask_expend_shape = (*in_features.shape[:-2], seq_len, seq_len)
mask_boardcasted = mask.expand(mask_expend_shape)

theta = 10000
max_seq_len = 10
rope = RotaryPositionalEmbedding(theta, d_k*num_heads, max_seq_len)
position = torch.arange(seq_len)
position_expend_shape = (in_features.shape[:-1])
position_boardcasted = position.expand(position_expend_shape)

Q = einsum(
    q_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
Q_rope = rope(Q, position_boardcasted)
K = einsum(
    k_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
K_rope = rope(K, position_boardcasted)
V = einsum(
    v_proj_weight, in_features,
    "d_v d_in, ... d_in -> ... d_v"
)
head = scaled_dot_product_attention(Q_rope,K_rope,V,mask_boardcasted)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

tensor([[[-10.1454, -15.4852,  20.6310, -15.1790],
         [ -5.7551, -12.4943,  16.7282, -12.8813],
         [  6.5436,  -4.1160,   5.7951,  -6.4449]],

        [[ 13.7090,  -1.6341,   3.2373,  -4.6695],
         [ 13.6951,  -1.6374,   3.2414,  -4.6702],
         [ 13.2631,  -0.5656,   4.9205,  -3.4586]]])

In [None]:
class MultiheadSelfAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, device: torch.device | None = None):
        super().__init__()
        ## mplement causal multi-head self-attention
        self.d_model = d_model ## final dimension of the input
        self.num_heads = num_heads ## number of heads
        self.device = device ## Device to store the parameters on

        assert d_model%num_heads == 0, "d_model/num_heads need to be int"
        d_k = d_model//num_heads
        d_v = d_model//num_heads
        
        self.q_proj_weight = nn.Parameter(torch.randn(d_k * num_heads, d_model))
        self.k_proj_weight = nn.Parameter(torch.randn(d_k * num_heads, d_model))
        self.v_proj_weight = nn.Parameter(torch.randn(d_v * num_heads, d_model))
        self.o_proj_weight = nn.Parameter(torch.randn(d_model, d_v * num_heads))
        
    
    def forward(self, in_features: torch.Tensor) -> torch.Tensor:
        
        Q = einsum(
            self.q_proj_weight, in_features,
            "d_k d_in, ... d_in -> ... d_k"
        )
        K = einsum(
            self.k_proj_weight, in_features,
            "d_k d_in, ... d_in -> ... d_k"
        )
        V = einsum(
            self.v_proj_weight, in_features,
            "d_v d_in, ... d_in -> ... d_v"
        )
        head = scaled_dot_product_attention(Q,K,V)
        attention = einsum(
            head, self.o_proj_weight,
            "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
        )
        return attention