In [None]:
# import pandas as pd
# import numpy as np
# from collections import defaultdict
# import regex as re
# from multiprocessing import Pool
# from support.find_chunk_boundaries import find_chunk_boundaries
# from memory_profiler import profile
# import time, tracemalloc
# from dataclasses import dataclass
# from typing import BinaryIO, Iterable, Iterator
# import random

import torch
import torch.nn as nn
from einops import rearrange, einsum, reduce, repeat

# Q 3.4.2

In [1014]:
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        super().__init__()
        ## Construct a linear transformation module. This function should accept the following parameters:
        self.in_features = in_features ## final dimension of the input
        self.out_features = out_features ## final dimension of the output
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters

        w = torch.empty(in_features, out_features)
        std = torch.sqrt(torch.tensor(2.0/(in_features+out_features)))
        self.weight = nn.Parameter(nn.init.trunc_normal_(w, mean=0.0, std=std.item(),a=-3*std.item(),b=3*std.item()))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## Apply the linear transformation to the input
        output = einsum(
            self.weight, x,
            "in_dim out_dim, in_dim -> out_dim"
        )
        return output

# Q 3.4.3

In [1015]:
class Embedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        ## Construct an embedding module
        super().__init__()
        self.num_embeddings = num_embeddings ## Size of the vocabulary
        self.embedding_dim = embedding_dim ## Dimension of the embedding vectors
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters
        
        w = torch.empty(num_embeddings, embedding_dim)
        std = 1.0
        self.weight = nn.Parameter(nn.init.trunc_normal_(w, mean=0.0, std=std,a=-3,b=3))
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        ## Lookup the embedding vectors for the given token IDs.
        return self.weight[token_ids]


# Q 3.5.1

In [1016]:
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
        ## Construct the RMSNorm module.
        super().__init__()
        self.d_model = d_model ## Hidden dimension of the model
        self.eps = eps ## Epsilon value for numerical stability
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters

        self.weights = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## Process an input tensor of shape
        
        in_dtype = x.dtype
        x = x.to(torch.float32)
        x_squaremean = reduce(
            x**2, "... d_model -> ... 1", 'mean'
        )
        x_RMS = (x_squaremean+self.eps).sqrt()
        result = x / x_RMS * self.weights
        return result.to(in_dtype)

# Q 3.5.2

In [1017]:
class SwiGLU(nn.Module):
    def __init__(self, d_model: int, d_ff: int | None = None):
        super().__init__()
        self.d_model = d_model ## Hidden dimension of the model
        if d_ff is None:
            q = round(d_model*8/3/64)
            self.d_ff = q*64
        else:
            self.d_ff = d_ff
        
        self.w1_weight = nn.Parameter(torch.randn(self.d_ff, self.d_model))
        self.w2_weight = nn.Parameter(torch.randn(self.d_model, self.d_ff))
        self.w3_weight = nn.Parameter(torch.randn(self.d_ff, self.d_model))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w1x = einsum(
            self.w1_weight, x,
            "d_ff d_model, ... d_model -> ... d_ff"
        )
        w3x = einsum(
            self.w3_weight, x,
            "d_ff d_model, ... d_model -> ... d_ff"
        )
        SiLUw1x = w1x*torch.sigmoid(w1x)
        part2 = SiLUw1x * w3x
        result = einsum(
            self.w2_weight, part2,
            "d_model d_ff, ... d_ff -> ... d_model"
        )
        return result

# Q 3.5.3

In [1018]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device: torch.device | None = None):
        ## Construct the RoPE module and create buffers if needed.
        super().__init__()
        assert d_k % 2 == 0, "RoPE requires even head dimension (pairs of features)"
        self.theta = theta ## $\\Theta$ value for the RoPE
        self.d_k = d_k ## dimension of query and key vectors
        self.max_seq_len = max_seq_len ## Maximum sequence length that will be inputted
        self.device = device ## Device to store the buffer on

        dim_index = torch.arange(self.d_k // 2, dtype=torch.float32)
        position_index = torch.arange(self.max_seq_len, dtype=torch.float32)
        theta_inv_index = self.theta**(-2*dim_index/d_k)
        theta_ik = einsum(
            position_index, theta_inv_index,
            "s, d -> s d"
        )
        sin = torch.sin(theta_ik)
        cos = torch.cos(theta_ik)
        
        self.register_buffer("sin", sin, persistent=False)
        self.register_buffer("cos", cos, persistent=False)
        
    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        
        assert x.shape[-1] == self.d_k, "The last dim of input should be equal to dim of embedding."
        assert x.shape[-2] == token_positions.shape[-1], "token_positions length must match sequence length"
        sin_expend = self.sin[token_positions]
        cos_expend = self.cos[token_positions]

        x_even = x[...,::2]
        x_odd = x[...,1::2]

        y_even = x_even*cos_expend-x_odd*sin_expend
        y_odd = x_even*sin_expend+x_odd*cos_expend
        y = rearrange(torch.stack([y_even, y_odd], dim=-1), '... s d two -> ... s (d two)')
        return y

# Q 3.5.4

In [1019]:
def softmax(x: torch.Tensor, dim: int) -> torch.Tensor:
    x_max = torch.max(x, dim=dim, keepdim=True).values
    x_subtract_max = x-x_max
    x_subtract_max_exp = torch.exp(x_subtract_max)
    x_subtract_max_exp_sum = torch.sum(x_subtract_max_exp, dim=dim, keepdim=True)
    y = x_subtract_max_exp/x_subtract_max_exp_sum
    return y

In [1020]:
def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    d_k = Q.shape[-1]
    QK = einsum(
        Q, K, "... n d_k, ... m d_k -> ... n m"
    )
    QK_scaled = QK/torch.tensor(d_k).sqrt()
    if mask is not None:
        M = torch.where(mask, torch.tensor(0.0), torch.tensor(float('-inf')))
        QK_scaled += M
    QK_scaled_softmax = softmax(QK_scaled, Q.dim()-1)
    y = einsum(
        QK_scaled_softmax, V, "... n m, ... m d_v -> ... n d_v"
    )
    return y


# 3.5.5

In [1023]:
def run_multihead_self_attention(
    d_model: int,
    num_heads: int,
    q_proj_weight: Float[Tensor, " d_k d_in"],
    k_proj_weight: Float[Tensor, " d_k d_in"],
    v_proj_weight: Float[Tensor, " d_v d_in"],
    o_proj_weight: Float[Tensor, " d_model d_v"],
    in_features: Float[Tensor, " ... sequence_length d_in"],
) -> Float[Tensor, " ... sequence_length d_out"]:
    pass

NameError: name 'Float' is not defined

In [1029]:
# 设置维度参数
batch = 2
seq_len = 3
d_in = 4
d_k = 6
d_v = 5
num_heads = 2

# 输入张量
in_features = torch.randn(batch, seq_len, d_in)
in_features = torch.randn(seq_len, d_in)
# Projection weights
q_proj_weight = torch.randn(d_k * num_heads, d_in)   # (d_k * h, d_in)
k_proj_weight = torch.randn(d_k * num_heads, d_in)
v_proj_weight = torch.randn(d_v * num_heads, d_in)
o_proj_weight = torch.randn(d_in, d_v * num_heads)


In [None]:
class multihead_self_attention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, max_seq_len: int, device: torch.device | None = None):
        ## mplement causal multi-head self-attention
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.max_seq_len = max_seq_len

In [1056]:
num_heads = 2
d_in = 16

seq_len = in_features.shape[-2]


def _generate_QKV(in_features, proj_weight,num_heads):
    A = einsum(
        in_features, proj_weight,
        "... d_in, nd_k d_in -> ... nd_k"
    )
    A = rearrange(
        A, "... (n d_k) -> ... n d_k", n = num_heads
    )
    return A

theta = 10000
max_seq_len = 10
rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len)

def _rope_QKV(A, rope, seq_len):
    A_rearrange = rearrange(
        A, "... seq n d_k -> ... n seq d_k"
    )
    posi_shape = A_rearrange.shape[:-1]
    posi_basic = torch.arange(seq_len)
    token_positions = posi_basic.expand(posi_shape)
    A_rope = rope(A_rearrange, token_positions)
    A_final = rearrange(
        A_rope, "... n seq d_k -> ... seq n d_k"
    )
    return A_final

Qi = _generate_QKV(in_features, q_proj_weight, num_heads)
Qi = _rope_QKV(Qi, rope, seq_len)
Ki = _generate_QKV(in_features, k_proj_weight, num_heads)
Ki = _rope_QKV(Ki, rope, seq_len)
Vi = _generate_QKV(in_features, v_proj_weight, num_heads)

mask_shape = Qi.shape[:-2] + (num_heads,num_heads)
upper_tri = torch.triu(torch.ones((num_heads, num_heads), dtype=torch.bool), diagonal=0)
mask = upper_tri.expand(mask_shape)
MHi = scaled_dot_product_attention(Qi,Ki,Vi,mask)
MH = rearrange(
    MHi, "... n d_v -> ... (n d_v)", n = num_heads
)
y = einsum(
    o_proj_weight, MH,
    "d_model nd_v, ... nd_v -> ... d_model"
)

In [None]:

Qi_t = rearrange(
    Qi, "... seq n d_k -> ... n seq d_k"
)
posi_shape = Qi_t.shape[:-1]
posi_basic = torch.arange(seq_len)
token_positions = posi_basic.expand(posi_shape)
Qi_t_rope = rope(Qi_t, token_positions)
Qi_final = rearrange(
    Qi_t_rope, "... n seq d_k -> ... seq n d_k"
)

In [1054]:
Qi_t_rope.shape

torch.Size([2, 3, 6])

In [None]:
output = rope(in_query_or_key, token_positions)

In [1035]:
[in_features.shape,Qi.shape,Q.shape]

[torch.Size([3, 4]), torch.Size([3, 2, 6]), torch.Size([3, 12])]

tensor([0, 1, 2])

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device: torch.device | None = None):
        ## Construct the RoPE module and create buffers if needed.
        super().__init__()
        assert d_k % 2 == 0, "RoPE requires even head dimension (pairs of features)"
        self.theta = theta ## $\\Theta$ value for the RoPE
        self.d_k = d_k ## dimension of query and key vectors
        self.max_seq_len = max_seq_len ## Maximum sequence length that will be inputted
        self.device = device ## Device to store the buffer on

        dim_index = torch.arange(self.d_k // 2, dtype=torch.float32)
        position_index = torch.arange(self.max_seq_len, dtype=torch.float32)
        theta_inv_index = self.theta**(-2*dim_index/d_k)
        theta_ik = einsum(
            position_index, theta_inv_index,
            "s, d -> s d"
        )
        sin = torch.sin(theta_ik)
        cos = torch.cos(theta_ik)
        
        self.register_buffer("sin", sin, persistent=False)
        self.register_buffer("cos", cos, persistent=False)
        
    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        
        assert x.shape[-1] == self.d_k, "The last dim of input should be equal to dim of embedding."
        assert x.shape[-2] == token_positions.shape[-1], "token_positions length must match sequence length"
        sin_expend = self.sin[token_positions]
        cos_expend = self.cos[token_positions]

        x_even = x[...,::2]
        x_odd = x[...,1::2]

        y_even = x_even*cos_expend-x_odd*sin_expend
        y_odd = x_even*sin_expend+x_odd*cos_expend
        y = rearrange(torch.stack([y_even, y_odd], dim=-1), '... s d two -> ... s (d two)')
        return y