In [None]:
# import pandas as pd
# import numpy as np
# from collections import defaultdict
# import regex as re
# from multiprocessing import Pool
# from support.find_chunk_boundaries import find_chunk_boundaries
# from memory_profiler import profile
# import time, tracemalloc
# from dataclasses import dataclass
# from typing import BinaryIO, Iterable, Iterator
# import random

import torch
import torch.nn as nn
from einops import rearrange, einsum, reduce, repeat

## Q2.1


In [442]:
images = torch.randn(64, 128, 128, 3) # (batch, height, width, channel)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)

In [449]:
dim_value = rearrange(dim_by, "dim_value -> 1 dim_value 1 1 1")
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel")
dimmed_images = images_rearr * dim_value
dimmed_images = einsum(
    images, dim_by,
    "batch height width channel, dim_value -> batch dim_value height width channel"
)

In [450]:
dimmed_images.shape

torch.Size([64, 10, 128, 128, 3])

# Q 3.4.2

In [454]:
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        super().__init__()
        ## Construct a linear transformation module. This function should accept the following parameters:
        self.in_features = in_features ## final dimension of the input
        self.out_features = out_features ## final dimension of the output
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters

        w = torch.empty(in_features, out_features)
        std = torch.sqrt(torch.tensor(2.0/(in_features+out_features)))
        self.weight = nn.Parameter(nn.init.trunc_normal_(w, mean=0.0, std=std.item(),a=-3*std.item(),b=3*std.item()))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## Apply the linear transformation to the input
        output = einsum(
            self.weight, x,
            "in_dim out_dim, in_dim -> out_dim"
        )
        return output

# Q 3.4.3

In [None]:
class Embedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        ## Construct an embedding module
        super().__init__()
        self.num_embeddings = num_embeddings ## Size of the vocabulary
        self.embedding_dim = embedding_dim ## Dimension of the embedding vectors
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters
        
        w = torch.empty(num_embeddings, embedding_dim)
        std = 1.0
        self.weight = nn.Parameter(nn.init.trunc_normal_(w, mean=0.0, std=std,a=-3,b=3))
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        ## Lookup the embedding vectors for the given token IDs.
        return self.weight[token_ids]


# Q 3.5.1

In [527]:
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
        ## Construct the RMSNorm module.
        super().__init__()
        self.d_model = d_model ## Hidden dimension of the model
        self.eps = eps ## Epsilon value for numerical stability
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters

        self.weights = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## Process an input tensor of shape
        
        in_dtype = x.dtype
        x = x.to(torch.float32)
        x_squaremean = reduce(
            x**2, "... d_model -> ... 1", 'mean'
        )
        x_RMS = (x_squaremean+self.eps).sqrt()
        result = x / x_RMS * self.weights
        return result.to(in_dtype)

# Q 3.5.2

In [None]:
class SwiGLU(nn.Module):
    def __init__(self, d_model: int, d_ff: int | None = None):
        super().__init__()
        self.d_model = d_model ## Hidden dimension of the model
        if d_ff is None:
            q = round(d_model*8/3/64)
            self.d_ff = q*64
        else:
            self.d_ff = d_ff
        
        self.w1_weight = nn.Parameter(torch.randn(self.d_ff, self.d_model))
        self.w2_weight = nn.Parameter(torch.randn(self.d_model, self.d_ff))
        self.w3_weight = nn.Parameter(torch.randn(self.d_ff, self.d_model))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w1x = einsum(
            self.w1_weight, x,
            "d_ff d_model, ... d_model -> ... d_ff"
        )
        w3x = einsum(
            self.w3_weight, x,
            "d_ff d_model, ... d_model -> ... d_ff"
        )
        SiLUw1x = w1x*torch.sigmoid(w1x)
        part2 = SiLUw1x * w3x
        result = einsum(
            self.w2_weight, part2,
            "d_model d_ff, ... d_ff -> ... d_model"
        )
        return result

In [533]:
d_model = 2
d_ff = 3
x = torch.randn(2, 3, d_model) 
w1_weight = torch.ones(d_ff, d_model)
w2_weight = torch.ones(d_model, d_ff)
w3_weight = torch.ones(d_ff, d_model)
x

tensor([[[-0.8280, -1.0512],
         [-0.9995,  0.6266],
         [-0.2763,  1.1610]],

        [[ 1.8855,  1.7226],
         [-0.2060, -0.7651],
         [ 0.1694,  1.8273]]])

In [536]:
x[0,0,:].sum()

tensor(-1.8792)

In [534]:
w1x = einsum(
            w1_weight, x,
            "d_ff d_model, ... d_model -> ... d_ff"
        )
w1x

tensor([[[-1.8792, -1.8792, -1.8792],
         [-0.3729, -0.3729, -0.3729],
         [ 0.8847,  0.8847,  0.8847]],

        [[ 3.6081,  3.6081,  3.6081],
         [-0.9711, -0.9711, -0.9711],
         [ 1.9966,  1.9966,  1.9966]]])

# Q 3.5.3

In [529]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device: torch.device | None = None):
        ## Construct the RoPE module and create buffers if needed.
        super().__init__()
        self.theta = theta ## $\\Theta$ value for the RoPE
        self.d_k = d_k ## dimension of query and key vectors
        self.max_seq_len = max_seq_len ## Maximum sequence length that will be inputted
        self.device = device ## Device to store the buffer on

        theta_ik = torch.Tensor([[i/(self.theta**((2*k-1)/self.d_k)) for k in range(1,self.d_k//2+2)] for i in range(self.max_seq_len)])
        sin = torch.sin(theta_ik)
        cos = torch.cos(theta_ik)
        
        self.register_buffer("sin", sin, persistent=False)
        self.register_buffer("cos", cos, persistent=False)
        
    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        
        sin_expend = self.sin[token_positions]
        cos_expend = self.cos[token_positions]

        x_even = x[...,::2]
        x_odd = x[...,1::2]

        y_even = x_even*cos_expend-x_odd*sin_expend
        y_odd = x_odd*sin_expend+x_even*cos_expend
        y = rearrange(torch.stack([y_even, y_odd], dim=-1), '... s d two -> ... s (two d)')
        return y

tensor([[ 0.8015,  1.2299, -0.0326,  1.3418],
        [-0.2687, -0.0065,  0.4680,  0.7164]])

In [660]:
theta = 10000
d_k = 4
max_seq_len = 100
seq_len = 3

In [733]:
x = [1,2,3]
x[:-1]

[1, 2]

In [732]:
dim_index = torch.arange(d_k // 2, dtype=torch.float32)
position_index = torch.arange(max_seq_len, dtype=torch.float32)
theta_inv_index = theta**(-2*dim_index/d_k)
theta_ik = einsum(
    position_index, theta_inv_index,
    "s, d -> s d"
)
theta_ik

tensor([[0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e-02],
        [2.0000e+00, 2.0000e-02],
        [3.0000e+00, 3.0000e-02],
        [4.0000e+00, 4.0000e-02],
        [5.0000e+00, 5.0000e-02],
        [6.0000e+00, 6.0000e-02],
        [7.0000e+00, 7.0000e-02],
        [8.0000e+00, 8.0000e-02],
        [9.0000e+00, 9.0000e-02],
        [1.0000e+01, 1.0000e-01],
        [1.1000e+01, 1.1000e-01],
        [1.2000e+01, 1.2000e-01],
        [1.3000e+01, 1.3000e-01],
        [1.4000e+01, 1.4000e-01],
        [1.5000e+01, 1.5000e-01],
        [1.6000e+01, 1.6000e-01],
        [1.7000e+01, 1.7000e-01],
        [1.8000e+01, 1.8000e-01],
        [1.9000e+01, 1.9000e-01],
        [2.0000e+01, 2.0000e-01],
        [2.1000e+01, 2.1000e-01],
        [2.2000e+01, 2.2000e-01],
        [2.3000e+01, 2.3000e-01],
        [2.4000e+01, 2.4000e-01],
        [2.5000e+01, 2.5000e-01],
        [2.6000e+01, 2.6000e-01],
        [2.7000e+01, 2.7000e-01],
        [2.8000e+01, 2.8000e-01],
        [2.900

In [731]:
theta_ik = torch.Tensor([[i/(theta**((2*k-2)/d_k)) for k in range(1,d_k//2+1)] for i in range(max_seq_len)])
sin = torch.sin(theta_ik)
cos = torch.cos(theta_ik)
theta_ik

tensor([[0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e-02],
        [2.0000e+00, 2.0000e-02],
        [3.0000e+00, 3.0000e-02],
        [4.0000e+00, 4.0000e-02],
        [5.0000e+00, 5.0000e-02],
        [6.0000e+00, 6.0000e-02],
        [7.0000e+00, 7.0000e-02],
        [8.0000e+00, 8.0000e-02],
        [9.0000e+00, 9.0000e-02],
        [1.0000e+01, 1.0000e-01],
        [1.1000e+01, 1.1000e-01],
        [1.2000e+01, 1.2000e-01],
        [1.3000e+01, 1.3000e-01],
        [1.4000e+01, 1.4000e-01],
        [1.5000e+01, 1.5000e-01],
        [1.6000e+01, 1.6000e-01],
        [1.7000e+01, 1.7000e-01],
        [1.8000e+01, 1.8000e-01],
        [1.9000e+01, 1.9000e-01],
        [2.0000e+01, 2.0000e-01],
        [2.1000e+01, 2.1000e-01],
        [2.2000e+01, 2.2000e-01],
        [2.3000e+01, 2.3000e-01],
        [2.4000e+01, 2.4000e-01],
        [2.5000e+01, 2.5000e-01],
        [2.6000e+01, 2.6000e-01],
        [2.7000e+01, 2.7000e-01],
        [2.8000e+01, 2.8000e-01],
        [2.900

In [719]:
dim_index = torch.arange(0, d_k, 2, dtype=torch.float32)
inv_freq = theta ** (-dim_index / d_k)
positions = torch.arange(
            max_seq_len, dtype=torch.float32
        )
freqs = einsum(
            positions, inv_freq, "max_seq_len, d_k_2 -> max_seq_len d_k_2"
        ) 
cos_cached = freqs.cos()

In [722]:
freqs

tensor([[0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e-02],
        [2.0000e+00, 2.0000e-02],
        [3.0000e+00, 3.0000e-02],
        [4.0000e+00, 4.0000e-02],
        [5.0000e+00, 5.0000e-02],
        [6.0000e+00, 6.0000e-02],
        [7.0000e+00, 7.0000e-02],
        [8.0000e+00, 8.0000e-02],
        [9.0000e+00, 9.0000e-02],
        [1.0000e+01, 1.0000e-01],
        [1.1000e+01, 1.1000e-01],
        [1.2000e+01, 1.2000e-01],
        [1.3000e+01, 1.3000e-01],
        [1.4000e+01, 1.4000e-01],
        [1.5000e+01, 1.5000e-01],
        [1.6000e+01, 1.6000e-01],
        [1.7000e+01, 1.7000e-01],
        [1.8000e+01, 1.8000e-01],
        [1.9000e+01, 1.9000e-01],
        [2.0000e+01, 2.0000e-01],
        [2.1000e+01, 2.1000e-01],
        [2.2000e+01, 2.2000e-01],
        [2.3000e+01, 2.3000e-01],
        [2.4000e+01, 2.4000e-01],
        [2.5000e+01, 2.5000e-01],
        [2.6000e+01, 2.6000e-01],
        [2.7000e+01, 2.7000e-01],
        [2.8000e+01, 2.8000e-01],
        [2.900

In [721]:
theta_ik

tensor([[0.0000e+00, 0.0000e+00],
        [1.0000e+00, 1.0000e-02],
        [2.0000e+00, 2.0000e-02],
        [3.0000e+00, 3.0000e-02],
        [4.0000e+00, 4.0000e-02],
        [5.0000e+00, 5.0000e-02],
        [6.0000e+00, 6.0000e-02],
        [7.0000e+00, 7.0000e-02],
        [8.0000e+00, 8.0000e-02],
        [9.0000e+00, 9.0000e-02],
        [1.0000e+01, 1.0000e-01],
        [1.1000e+01, 1.1000e-01],
        [1.2000e+01, 1.2000e-01],
        [1.3000e+01, 1.3000e-01],
        [1.4000e+01, 1.4000e-01],
        [1.5000e+01, 1.5000e-01],
        [1.6000e+01, 1.6000e-01],
        [1.7000e+01, 1.7000e-01],
        [1.8000e+01, 1.8000e-01],
        [1.9000e+01, 1.9000e-01],
        [2.0000e+01, 2.0000e-01],
        [2.1000e+01, 2.1000e-01],
        [2.2000e+01, 2.2000e-01],
        [2.3000e+01, 2.3000e-01],
        [2.4000e+01, 2.4000e-01],
        [2.5000e+01, 2.5000e-01],
        [2.6000e+01, 2.6000e-01],
        [2.7000e+01, 2.7000e-01],
        [2.8000e+01, 2.8000e-01],
        [2.900

In [734]:
data = [
    [1,2,3,4],
    [5,6,7,8],
    [9,10,11,12]
]
x = torch.tensor(data)
data = [2,0,1]
position = torch.tensor(data)
[x.shape, position.shape]

[torch.Size([3, 4]), torch.Size([3])]

In [739]:
[x.shape[:-1]==position.shape[:]]

[True]

In [702]:
sin_expend = sin[position]
cos_expend = cos[position]

x_even = x[...,::2]
x_odd = x[...,1::2]

y_even = x_even*cos_expend-x_odd*sin_expend
y_odd = x_even*sin_expend+x_odd*cos_expend
y = rearrange(torch.stack([y_even, y_odd], dim=-1), '... s d two -> ... s (d two)')

y

tensor([[ 0.5827,  2.1588,  2.9920,  4.0060],
        [ 5.0000,  6.0000,  7.0000,  8.0000],
        [ 7.9567, 10.8485, 10.9880, 12.0110]])

In [697]:
seq_index = 0
posi = position[seq_index]
x_sub = x[seq_index,:]
[x_sub[0]*cos[posi,0]-x_sub[1]*sin[posi,0],x_sub[0]*sin[posi,0]+x_sub[1]*cos[posi,0],x_sub[2]*cos[posi,1]-x_sub[3]*sin[posi,1],x_sub[2]*sin[posi,1]+x_sub[3]*cos[posi,1]]

[tensor(0.5827), tensor(2.1588), tensor(2.9920), tensor(4.0060)]

In [698]:
y_even

tensor([[ 0.5827,  2.9920],
        [ 5.0000,  7.0000],
        [ 7.9567, 10.9880]])

In [700]:
temp = torch.stack([y_even, y_odd], dim=-1)
temp.shape

torch.Size([3, 2, 2])

In [701]:
temp

tensor([[[ 0.5827,  2.1588],
         [ 2.9920,  4.0060]],

        [[ 5.0000,  6.0000],
         [ 7.0000,  8.0000]],

        [[ 7.9567, 10.8485],
         [10.9880, 12.0110]]])

In [680]:
sin_expend

tensor([[0.1987, 0.0020],
        [0.0000, 0.0000],
        [0.0998, 0.0010]])

In [681]:
sin[posi,0]

tensor(0.1987)

In [675]:
cos[1,0]

tensor(0.9950)

In [676]:
cos

tensor([[ 1.0000,  1.0000],
        [ 0.9950,  1.0000],
        [ 0.9801,  1.0000],
        [ 0.9553,  1.0000],
        [ 0.9211,  1.0000],
        [ 0.8776,  1.0000],
        [ 0.8253,  1.0000],
        [ 0.7648,  1.0000],
        [ 0.6967,  1.0000],
        [ 0.6216,  1.0000],
        [ 0.5403,  0.9999],
        [ 0.4536,  0.9999],
        [ 0.3624,  0.9999],
        [ 0.2675,  0.9999],
        [ 0.1700,  0.9999],
        [ 0.0707,  0.9999],
        [-0.0292,  0.9999],
        [-0.1288,  0.9999],
        [-0.2272,  0.9998],
        [-0.3233,  0.9998],
        [-0.4161,  0.9998],
        [-0.5048,  0.9998],
        [-0.5885,  0.9998],
        [-0.6663,  0.9997],
        [-0.7374,  0.9997],
        [-0.8011,  0.9997],
        [-0.8569,  0.9997],
        [-0.9041,  0.9996],
        [-0.9422,  0.9996],
        [-0.9710,  0.9996],
        [-0.9900,  0.9996],
        [-0.9991,  0.9995],
        [-0.9983,  0.9995],
        [-0.9875,  0.9995],
        [-0.9668,  0.9994],
        [-0.9365,  0

In [635]:
theta = 10000
d_k = 4
max_seq_len = 100
batch_size = 2
seq_len = 3

In [None]:
data = [
    [
        [1,2,3,4],
        [5,6,7,8],
        [9,10,11,12]
    ],
    [
        [11,12,13,14],
        [15,16,17,18],
        [19,20,21,22]
    ]
]
x = torch.tensor(data)
data = [
    [0,1,2],
    [2,0,1]
]
position = torch.tensor(data)
[x.shape, position.shape]

[torch.Size([2, 3, 4]), torch.Size([2, 3])]

In [741]:
[x.shape[:-1]==position.shape[:]]

[True]

# Q 3.5.4

In [877]:
def softmax(x: torch.Tensor, dim: int) -> torch.Tensor:
    x_max = torch.max(x, dim=dim, keepdim=True).values
    x_subtract_max = x-x_max
    x_subtract_max_exp = torch.exp(x_subtract_max)
    x_subtract_max_exp_sum = torch.sum(x_subtract_max_exp, dim=dim, keepdim=True)
    y = x_subtract_max_exp/x_subtract_max_exp_sum
    return y

def _softmax_1dim(x: torch.Tensor) -> torch.Tensor:
    x_subtract_max = x-x.max()
    x_subtract_max_exp = torch.exp(x_subtract_max)
    return x_subtract_max_exp/x_subtract_max_exp.sum()



In [869]:
data = [
    [
        [1,2,3,4],
        [5,6,7,8],
        [9,10,11,12]
    ],
    [
        [11,12,13,14],
        [15,16,17,18],
        [19,110,111,112]
    ]
]
x = torch.tensor(data)
dim = 0

In [875]:
x_max = torch.max(x, dim=dim, keepdim=True).values
x_subtract_max = x-x.max()
x_subtract_max_exp = torch.exp(x_subtract_max)
x_subtract_max_exp_sum = torch.sum(x_subtract_max_exp, dim=dim, keepdim=True)
x_subtract_max_exp/x_subtract_max_exp_sum

tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [3.4348e-05, 4.2039e-44, 3.7835e-44, 3.7835e-44]],

        [[1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00],
         [1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00],
         [9.9997e-01, 1.0000e+00, 1.0000e+00, 1.0000e+00]]])

In [876]:
x_subtract_max_exp_sum

tensor([[[1.4013e-44, 3.7835e-44, 1.0089e-43, 2.7465e-43],
         [7.4689e-43, 2.0305e-42, 5.5211e-42, 1.5008e-41],
         [4.0797e-41, 1.3534e-01, 3.6788e-01, 1.0000e+00]]])

In [None]:
def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor | None = None,
) -> torch.Tensor:
    d_k = Q.shape[-1]
    QK = einsum(
        Q, K, "... n d_k, ... m d_k -> ... n m"
    )
    QK_scaled = QK/torch.tensor(d_k).sqrt()
    if mask is not None:
        M = torch.where(mask, torch.tensor(0.0), torch.tensor(float('-inf')))
        QK_scaled += M
    QK_scaled_softmax = softmax(QK_scaled, Q.dim()-1)
    y = einsum(
        QK_scaled_softmax, V, "... n m, ... m d_v -> ... n d_v"
    )
    return y


In [908]:
batch = 1
head = 1
n = 3
m = 4
d_k = 8
d_v = 6


# 构造 q, k, v：每个维度是 (batch * head, seq, d)
q = torch.randn(batch * head, n, d_k)
k = torch.randn(batch * head, m, d_k)
v = torch.randn(batch * head, m, d_v)

# 构造 mask: (batch * head, query_len, key_len)
mask = torch.randint(0, 2, (batch * head, n, m)).float()

# reshape 到标准的 attention 输入格式
Q, K, V = (rearrange(x, "(batch head) seq d -> batch head seq d", head=head) for x in (q, k, v))
mask = rearrange(mask, "(batch head) query key -> batch head query key", head=head)


In [916]:
d_k = Q.shape[-1]
QK = einsum(
    Q, K, "... n d_k, ... m d_k -> ... n m"
)
QK_scaled = QK/np.sqrt(d_k)
QK_scaled_masked = QK_scaled # + mask*torch.tensor(float('-inf')) 
QK_scaled_masked_softmax = softmax(QK_scaled_masked, Q.dim()-1)
y = einsum(
    QK_scaled_masked_softmax, V, "... n m, ... m d_v -> ... n d_v"
)
y


tensor([[[[-0.1039,  0.6112, -0.5392, -0.1096,  0.3800, -1.3136],
          [ 0.2571,  1.3238,  0.3052,  0.6354,  0.9631, -0.6820],
          [ 0.1931,  1.1706,  0.1135,  0.5356,  0.8351, -0.8478]]]])

In [913]:
QK_scaled_masked

tensor([[[[-0.3037, -0.1392,  1.5195,  1.0188],
          [ 0.1006, -2.0256, -2.0816,  0.9052],
          [ 0.6530, -0.7269,  0.1320,  1.6333]]]])

In [914]:
QK_scaled_masked_softmax

tensor([[[[0.0825, 0.0972, 0.5107, 0.3096],
          [0.2883, 0.0344, 0.0325, 0.6447],
          [0.2217, 0.0558, 0.1317, 0.5909]]]])

In [897]:
[QK_scaled.shape,mask.shape]

[torch.Size([2, 2, 3, 4]), torch.Size([2, 2, 3, 4])]