In [1]:
# import pandas as pd
# import numpy as np
# from collections import defaultdict
# import regex as re
# from multiprocessing import Pool
# from support.find_chunk_boundaries import find_chunk_boundaries
# from memory_profiler import profile
# import time, tracemalloc
# from dataclasses import dataclass
# from typing import BinaryIO, Iterable, Iterator
# import random

import torch
import torch.nn as nn
from einops import rearrange, einsum, reduce, repeat

# Q 3.4.2

In [2]:
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        super().__init__()
        ## Construct a linear transformation module. This function should accept the following parameters:
        self.in_features = in_features ## final dimension of the input
        self.out_features = out_features ## final dimension of the output
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters

        self.weights = nn.Parameter(
            torch.empty(out_features, in_features, device=self.device, dtype=self.dtype)
        )
        std = torch.sqrt(torch.tensor(2.0/(in_features+out_features)))
        nn.init.trunc_normal_(self.weights, mean=0.0, std=std.item(), a=-3*std.item(), b=3*std.item())
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## Apply the linear transformation to the input
        output = einsum(
            x, self.weights,
            "... in_dim, out_dim in_dim -> ... out_dim"
        )
        return output

# Q 3.4.3

In [3]:
class Embedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        ## Construct an embedding module
        super().__init__()
        self.num_embeddings = num_embeddings ## Size of the vocabulary
        self.embedding_dim = embedding_dim ## Dimension of the embedding vectors
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters
        
        self.weights = nn.Parameter(
            torch.empty(num_embeddings, embedding_dim, device=self.device, dtype=self.dtype)
        )
        std = 1.0
        nn.init.trunc_normal_(self.weights, mean=0.0, std=std,a=-3,b=3)
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        ## Lookup the embedding vectors for the given token IDs.
        return self.weights[token_ids]

# Q 3.5.1

In [4]:
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device: torch.device | None = None, dtype: torch.dtype | None = None):
        ## Construct the RMSNorm module.
        super().__init__()
        self.d_model = d_model ## Hidden dimension of the model
        self.eps = eps ## Epsilon value for numerical stability
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters

        self.weights = nn.Parameter(torch.ones(d_model, device=self.device, dtype=self.dtype))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## Process an input tensor of shape
        in_dtype = x.dtype
        x = x.to(torch.float32)
        x_squaremean = reduce(
            x**2, "... d_model -> ... 1", 'mean'
        )
        x_RMS = (x_squaremean+self.eps).sqrt()
        result = x / x_RMS * self.weights
        return result.to(in_dtype)

# Q 3.5.2

In [5]:
class SwiGLU(nn.Module):
    def __init__(self, d_model: int, d_ff: int | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None):
        super().__init__()
        self.d_model = d_model ## Hidden dimension of the model
        self.device = device
        self.dtype = dtype
        if d_ff is None:
            q = round(d_model*8/3/64)
            self.d_ff = q*64
        else:
            self.d_ff = d_ff
        
        self.w1_weight = nn.Parameter(torch.randn(self.d_ff, self.d_model, device=self.device, dtype=self.dtype))
        self.w2_weight = nn.Parameter(torch.randn(self.d_model, self.d_ff, device=self.device, dtype=self.dtype))
        self.w3_weight = nn.Parameter(torch.randn(self.d_ff, self.d_model, device=self.device, dtype=self.dtype))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w1x = einsum(
            self.w1_weight, x,
            "d_ff d_model, ... d_model -> ... d_ff"
        )
        w3x = einsum(
            self.w3_weight, x,
            "d_ff d_model, ... d_model -> ... d_ff"
        )
        SiLUw1x = w1x*torch.sigmoid(w1x)
        part2 = SiLUw1x * w3x
        result = einsum(
            self.w2_weight, part2,
            "d_model d_ff, ... d_ff -> ... d_model"
        )
        return result

# Q 3.5.3

In [6]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        ## Construct the RoPE module and create buffers if needed.
        super().__init__()
        assert d_k % 2 == 0, "RoPE requires even head dimension (pairs of features)"
        self.theta = theta ## $\\Theta$ value for the RoPE
        self.d_k = d_k ## dimension of query and key vectors
        self.max_seq_len = max_seq_len ## Maximum sequence length that will be inputted
        self.device = device ## Device to store the buffer on

        dim_index = torch.arange(self.d_k // 2, device=self.device, dtype=torch.float32)
        position_index = torch.arange(self.max_seq_len, device=self.device, dtype=torch.float32)
        theta_inv_index = self.theta**(-2*dim_index/d_k)
        theta_ik = einsum(
            position_index, theta_inv_index,
            "s, d -> s d"
        )
        sin = torch.sin(theta_ik)
        cos = torch.cos(theta_ik)
        
        self.register_buffer("sin", sin, persistent=False)
        self.register_buffer("cos", cos, persistent=False)
        
    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        
        assert x.shape[-1] == self.d_k, "The last dim of input should be equal to dim of embedding."
        assert x.shape[-2] == token_positions.shape[-1], "token_positions length must match sequence length"
        sin_expend = self.sin[token_positions]
        cos_expend = self.cos[token_positions]

        x_even = x[...,::2]
        x_odd = x[...,1::2]

        y_even = x_even*cos_expend-x_odd*sin_expend
        y_odd = x_even*sin_expend+x_odd*cos_expend
        y = rearrange(torch.stack([y_even, y_odd], dim=-1), '... s d two -> ... s (d two)')
        return y

# Q 3.5.4

In [7]:
def softmax(x: torch.Tensor, dim: int) -> torch.Tensor:
    x_max = torch.max(x, dim=dim, keepdim=True).values
    x_subtract_max = x-x_max
    x_subtract_max_exp = torch.exp(x_subtract_max)
    x_subtract_max_exp_sum = torch.sum(x_subtract_max_exp, dim=dim, keepdim=True)
    y = x_subtract_max_exp/x_subtract_max_exp_sum
    return y

In [8]:
def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    d_k = Q.shape[-1]
    QK = einsum(
        Q, K, "... n d_k, ... m d_k -> ... n m"
    )
    QK_scaled = QK/torch.tensor(d_k).sqrt()
    if mask is not None:
        M = torch.where(mask, torch.tensor(0.0), torch.tensor(float('-inf')))
        QK_scaled += M
    QK_scaled_softmax = softmax(QK_scaled, Q.dim()-1)
    y = einsum(
        QK_scaled_softmax, V, "... n m, ... m d_v -> ... n d_v"
    )
    return y


# 3.5.5

### Step1: 让我们从没有head 和 batch的简单情况开始

In [11]:
# 设置维度参数
batch = 1
seq_len = 3
d_in = 4
d_k = 6
d_v = 8
num_heads = 1
torch.manual_seed(4)
# 输入张量
in_features = torch.randn(batch, seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)

In [12]:
d_k = torch.tensor(q_proj_weight.shape[-2])

Q = einsum(
    q_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
K = einsum(
    k_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
V = einsum(
    v_proj_weight, in_features,
    "d_v d_in, ... d_in -> ... d_v"
)
head = scaled_dot_product_attention(Q,K,V)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

tensor([[[ 6.3022,  3.5587,  0.0493,  2.1251],
         [ 6.2906,  3.5453,  0.0229,  2.2065],
         [ 6.2706,  3.5209, -0.0067,  2.2971]]])

### Step2: 加入batch

In [13]:
# 设置维度参数
batch = 2
seq_len = 3
d_in = 4
d_k = 6
d_v = 8
num_heads = 1
torch.manual_seed(4)
# 输入张量
in_features = torch.randn(batch, seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)

In [14]:
d_k = torch.tensor(q_proj_weight.shape[-2])

Q = einsum(
    q_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
K = einsum(
    k_proj_weight, in_features,
    "d_k d_in, ... d_in -> ... d_k"
)
V = einsum(
    v_proj_weight, in_features,
    "d_v d_in, ... d_in -> ... d_v"
)
head = scaled_dot_product_attention(Q,K,V)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

tensor([[[ -2.4333,  -2.2128,   1.1368,  -2.7840],
         [ -4.1671,  -0.1571,   1.0785, -19.5492],
         [ -3.1011,  -2.1886,   1.3718,  -5.7030]],

        [[  3.7037,   2.6312,  -1.1190,  13.7237],
         [  0.0890,  -2.3049,  -0.1087,  -2.6351],
         [  0.1030,  -2.2935,  -0.1119,  -2.5340]]])

### Step3: 加入head

In [15]:
# 设置维度参数
batch = 2
seq_len = 3
d_in = 4
d_k = 6
d_v = 8
num_heads = 2
torch.manual_seed(4)
# 输入张量
in_features = torch.randn(batch, seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)

In [16]:
d_k = torch.tensor(q_proj_weight.shape[-2])

Q = einsum(
    q_proj_weight, in_features,
    "nd_k d_in, ... d_in -> ... nd_k"
)
Q_head = rearrange(
    Q, "... seq_len (n d_k) -> ... n seq_len d_k", n = num_heads
)
K = einsum(
    k_proj_weight, in_features,
    "nd_k d_in, ... d_in -> ... nd_k"
)
K_head = rearrange(
    K, "... seq_len (n d_k) -> ... n seq_len d_k", n = num_heads
)
V = einsum(
    v_proj_weight, in_features,
    "nd_v d_in, ... d_in -> ... nd_v"
)
V_head = rearrange(
    V, "... seq_len (n d_v) -> ... n seq_len d_v", n = num_heads
)
head = scaled_dot_product_attention(Q_head,K_head,V_head)
head = rearrange(
    head, "... n seq_len d_v -> ... seq_len (n d_v)"
)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

tensor([[[-0.9310,  6.0378, -0.0974, -0.8215],
         [10.2525,  5.0840, -8.1525, -7.2442],
         [-3.5506,  2.4867,  1.0241,  1.9592]],

        [[-3.3017, -6.0040,  7.7340,  6.5704],
         [ 3.6543, 11.7540,  0.8076, -8.0421],
         [ 4.2321, -2.0617, -0.4722, -4.8180]]])

In [17]:
[attention.shape, in_features.shape,head.shape]

[torch.Size([2, 3, 4]), torch.Size([2, 3, 4]), torch.Size([2, 3, 16])]

### Step4: 加入mask

In [18]:
# 设置维度参数
batch = 2
seq_len = 3
d_in = 4
d_k = 6
d_v = 8
num_heads = 2
torch.manual_seed(4)
# 输入张量
in_features = torch.randn(batch, seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)

In [19]:
seq_len = in_features.shape[-2]

mask = torch.tril(torch.ones(seq_len,seq_len,dtype=torch.bool))
expend_shape = (*Q_head.shape[:-1], seq_len)
mask_boardcasted = mask.expand(expend_shape)
Q = einsum(
    q_proj_weight, in_features,
    "nd_k d_in, ... d_in -> ... nd_k"
)
Q_head = rearrange(
    Q, "... seq_len (n d_k) -> ... n seq_len d_k", n = num_heads
)
K = einsum(
    k_proj_weight, in_features,
    "nd_k d_in, ... d_in -> ... nd_k"
)
K_head = rearrange(
    K, "... seq_len (n d_k) -> ... n seq_len d_k", n = num_heads
)
V = einsum(
    v_proj_weight, in_features,
    "nd_v d_in, ... d_in -> ... nd_v"
)
V_head = rearrange(
    V, "... seq_len (n d_v) -> ... n seq_len d_v", n = num_heads
)
head = scaled_dot_product_attention(Q_head,K_head,V_head,mask_boardcasted)
head = rearrange(
    head, "... n seq_len d_v -> ... seq_len (n d_v)"
)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

tensor([[[ -0.6790,  -2.4998,  -8.9383,   4.1447],
         [ 10.3599,   5.2157,  -8.1922,  -7.3385],
         [ -3.5506,   2.4867,   1.0241,   1.9592]],

        [[-12.1002,  -3.7207,   2.2474,   7.5827],
         [ -8.2128,   6.3416,  -2.3230,  -5.1408],
         [  4.2321,  -2.0617,  -0.4722,  -4.8180]]])

### Step5: 加入RoPE

In [20]:
# 设置维度参数
batch = 2
seq_len = 3
d_in = 4
d_k = 6
d_v = 8
num_heads = 2
torch.manual_seed(4)
# 输入张量
in_features = torch.randn(batch, seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)

In [21]:
[in_features.shape,Q_head.shape]

[torch.Size([2, 3, 4]), torch.Size([2, 2, 3, 6])]

In [None]:
theta = 10000
max_seq_len = 10
rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len)
position = torch.arange(seq_len)


seq_len = in_features.shape[-2]

mask = torch.tril(torch.ones(seq_len,seq_len,dtype=torch.bool))
expend_shape = (*Q_head.shape[:-1], seq_len)
mask_boardcasted = mask.expand(expend_shape)
Q = einsum(
    q_proj_weight, in_features,
    "nd_k d_in, ... d_in -> ... nd_k"
)
Q_head = rearrange(
    Q, "... seq_len (n d_k) -> ... n seq_len d_k", n = num_heads
)
position_expend_shape = (Q_head.shape[:-1])
position = position.expand(position_expend_shape)
Q_head_rope = rope(Q_head, position)
K = einsum(
    k_proj_weight, in_features,
    "nd_k d_in, ... d_in -> ... nd_k"
)
K_head = rearrange(
    K, "... seq_len (n d_k) -> ... n seq_len d_k", n = num_heads
)
K_head_rope = rope(K_head, position)
V = einsum(
    v_proj_weight, in_features,
    "nd_v d_in, ... d_in -> ... nd_v"
)
V_head = rearrange(
    V, "... seq_len (n d_v) -> ... n seq_len d_v", n = num_heads
)
head = scaled_dot_product_attention(Q_head_rope,K_head_rope,V_head,mask_boardcasted)
head = rearrange(
    head, "... n seq_len d_v -> ... seq_len (n d_v)"
)
attention = einsum(
    head, o_proj_weight,
    "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
)
attention

tensor([[[ -0.6790,  -2.4998,  -8.9383,   4.1447],
         [  7.7346,   1.0026,  -6.7823,  -3.6310],
         [ -7.4218,  -2.0985,   2.3776,   5.1274]],

        [[-12.1002,  -3.7207,   2.2474,   7.5827],
         [ -8.0947,   6.5038,  -2.3712,  -5.3683],
         [  9.2953,   0.3954,   7.2983,  -6.3640]]])

In [29]:
class MultiheadSelfAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, 
                theta: float|None=None, max_seq_len: int|None=None,
                device: torch.device | None = None,dtype: torch.dtype | None = None,):
        super().__init__()
        ## mplement causal multi-head self-attention
        self.d_model = d_model ## final dimension of the input
        self.num_heads = num_heads ## number of heads
        self.device = device ## Device to store the parameters on
        self.theta = theta
        self.max_seq_len = max_seq_len
        
        assert d_model%num_heads == 0, "d_model/num_heads need to be int"
        self.d_k = d_model//num_heads
        self.d_v = d_model//num_heads
        
        # self.W_q = Linear(d_model,self.d_k * num_heads,device,dtype)
        # self.W_k = Linear(d_model,self.d_k * num_heads,device,dtype)
        # self.W_v = Linear(d_model,self.d_v * num_heads,device,dtype)
        # self.W_o = Linear(self.d_k * num_heads,d_model,device,dtype)
        
        self.W_q = nn.Parameter(torch.randn(self.d_k * num_heads, d_model))
        self.W_k = nn.Parameter(torch.randn(self.d_k * num_heads, d_model))
        self.W_v = nn.Parameter(torch.randn(self.d_v * num_heads, d_model))
        self.W_o = nn.Parameter(torch.randn(d_model, self.d_v * num_heads))
        
        self.rope = None
        if (theta is not None) and (max_seq_len is not None):
            self.rope = RotaryPositionalEmbedding(theta, self.d_k, max_seq_len,device,dtype)
    
    def forward(self, in_features: torch.Tensor, token_positions: torch.Tensor|None=None) -> torch.Tensor:
        seq_len = in_features.shape[-2]
        mask = torch.tril(torch.ones(seq_len,seq_len,dtype=torch.bool))
        # Q = self.W_q(in_features)
        Q = einsum(
            self.W_q, in_features,
            "nd_k d_in, ... d_in -> ... nd_k"
        )
        Q_head = rearrange(
            Q, "... seq_len (n d_k) -> ... n seq_len d_k", n = self.num_heads
        )
        # K = self.W_k(in_features)
        K = einsum(
            self.W_k, in_features,
            "nd_k d_in, ... d_in -> ... nd_k"
        )
        K_head = rearrange(
            K, "... seq_len (n d_k) -> ... n seq_len d_k", n = self.num_heads
        )
        if (self.rope is not None) and (token_positions is not None):
            position = repeat(
                token_positions, " ... seq_len -> ... n seq_len", n = self.num_heads
            )
            Q_head = self.rope(Q_head, position)
            K_head = self.rope(K_head, position)
        # V = self.W_v(in_features)
        V = einsum(
            self.W_v, in_features,
            "nd_v d_in, ... d_in -> ... nd_v"
        )
        V_head = rearrange(
            V, "... seq_len (n d_v) -> ... n seq_len d_v", n = self.num_heads
        )
        expend_shape = (*Q_head.shape[:-1], seq_len)
        mask_boardcasted = mask.expand(expend_shape)
        head = scaled_dot_product_attention(Q_head,K_head,V_head,mask_boardcasted)
        head = rearrange(
            head, "... n seq_len d_v -> ... seq_len (n d_v)"
        )
        # attention = self.W_o(head)
        attention = einsum(
            head, self.W_o,
            "... seq_len d_v,  d_in d_v -> ... seq_len d_in"
        )
        return attention

# Q 3.6

In [24]:
d_model = 4
num_heads = 1
d_ff = 6
device = None
dtype = None


batch = 1
seq_len = 3

# 输入张量
x = torch.randn(batch, seq_len, d_model)

In [25]:
rms_norm1 = RMSNorm(d_model, device = device, dtype = dtype)
rms_norm2 = RMSNorm(d_model, device = device, dtype = dtype)
multi_head_attention = MultiheadSelfAttention(d_model, num_heads , device = device, dtype = dtype)
ffn = SwiGLU(d_model, d_ff ,device = device, dtype = dtype)
x += multi_head_attention(rms_norm1(x))
x += ffn(rms_norm2(x))


In [26]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, d_ff: int, 
                theta: float|None=None, max_seq_len: int|None=None,
                device: torch.device | None = None, dtype: torch.dtype | None = None):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.theta = theta
        self.max_seq_len = max_seq_len
        self.device = device
        self.dtype = dtype
        
        self.rms_norm1 = RMSNorm(d_model, device = device, dtype = dtype)
        self.rms_norm2 = RMSNorm(d_model, device = device, dtype = dtype)
        self.mha = MultiheadSelfAttention(d_model, num_heads, theta=theta,max_seq_len=max_seq_len, device = device, dtype = dtype)
        self.ffn = SwiGLU(d_model=d_model, d_ff=d_ff ,device = device, dtype = dtype)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x += self.mha(self.rms_norm1(x))
        x += self.ffn(self.rms_norm2(x))
        return x

In [27]:
model = TransformerBlock(d_model,num_heads,d_ff,theta,max_seq_len)

In [28]:
model.mha.W_q

Parameter containing:
tensor([[-1.3757, -0.5687,  0.0147,  0.1320],
        [ 1.6708, -0.2414,  1.5715, -2.7058],
        [ 1.3684, -0.2810,  0.6613,  0.6166],
        [-2.2904,  1.2070,  0.0746,  0.7860]], requires_grad=True)