In [None]:
# import pandas as pd
# import numpy as np
# from collections import defaultdict
# import regex as re
# from multiprocessing import Pool
# from support.find_chunk_boundaries import find_chunk_boundaries
# from memory_profiler import profile
# import time, tracemalloc
# from dataclasses import dataclass
# from typing import BinaryIO, Iterable, Iterator
# import random

import torch
import torch.nn as nn
from einops import rearrange, einsum, reduce

## Q2.1


In [442]:
images = torch.randn(64, 128, 128, 3) # (batch, height, width, channel)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)

In [449]:
dim_value = rearrange(dim_by, "dim_value -> 1 dim_value 1 1 1")
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel")
dimmed_images = images_rearr * dim_value
dimmed_images = einsum(
    images, dim_by,
    "batch height width channel, dim_value -> batch dim_value height width channel"
)

In [450]:
dimmed_images.shape

torch.Size([64, 10, 128, 128, 3])

# Q 3.4.2

In [454]:
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        super().__init__()
        ## Construct a linear transformation module. This function should accept the following parameters:
        self.in_features = in_features ## final dimension of the input
        self.out_features = out_features ## final dimension of the output
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters

        w = torch.empty(in_features, out_features)
        std = torch.sqrt(torch.tensor(2.0/(in_features+out_features)))
        self.weight = nn.Parameter(nn.init.trunc_normal_(w, mean=0.0, std=std.item(),a=-3*std.item(),b=3*std.item()))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## Apply the linear transformation to the input
        output = einsum(
            self.weight, x,
            "in_dim out_dim, in_dim -> out_dim"
        )
        return output

# Q 3.4.3

In [None]:
class Embedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        ## Construct an embedding module
        super().__init__()
        self.num_embeddings = num_embeddings ## Size of the vocabulary
        self.embedding_dim = embedding_dim ## Dimension of the embedding vectors
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters
        
        w = torch.empty(num_embeddings, embedding_dim)
        std = 1.0
        self.weight = nn.Parameter(nn.init.trunc_normal_(w, mean=0.0, std=std,a=-3,b=3))
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        ## Lookup the embedding vectors for the given token IDs.
        return self.weight[token_ids]


# Q 3.5.1

In [527]:
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
        ## Construct the RMSNorm module.
        super().__init__()
        self.d_model = d_model ## Hidden dimension of the model
        self.eps = eps ## Epsilon value for numerical stability
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters

        self.weights = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## Process an input tensor of shape
        
        in_dtype = x.dtype
        x = x.to(torch.float32)
        x_squaremean = reduce(
            x**2, "... d_model -> ... 1", 'mean'
        )
        x_RMS = (x_squaremean+self.eps).sqrt()
        result = x / x_RMS * self.weights
        return result.to(in_dtype)

# Q 3.5.2

In [None]:
class SwiGLU(nn.Module):
    def __init__(self, d_model: int, d_ff: int | None = None):
        super().__init__()
        self.d_model = d_model ## Hidden dimension of the model
        if d_ff is None:
            q = round(d_model*8/3/64)
            self.d_ff = q*64
        else:
            self.d_ff = d_ff
        
        self.w1_weight = nn.Parameter(torch.randn(self.d_ff, self.d_model))
        self.w2_weight = nn.Parameter(torch.randn(self.d_model, self.d_ff))
        self.w3_weight = nn.Parameter(torch.randn(self.d_ff, self.d_model))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        w1x = einsum(
            self.w1_weight, x,
            "d_ff d_model, ... d_model -> ... d_ff"
        )
        w3x = einsum(
            self.w3_weight, x,
            "d_ff d_model, ... d_model -> ... d_ff"
        )
        SiLUw1x = w1x*torch.sigmoid(w1x)
        part2 = SiLUw1x * w3x
        result = einsum(
            self.w2_weight, part2,
            "d_model d_ff, ... d_ff -> ... d_model"
        )
        return result
        

In [533]:
d_model = 2
d_ff = 3
x = torch.randn(2, 3, d_model) 
w1_weight = torch.ones(d_ff, d_model)
w2_weight = torch.ones(d_model, d_ff)
w3_weight = torch.ones(d_ff, d_model)
x

tensor([[[-0.8280, -1.0512],
         [-0.9995,  0.6266],
         [-0.2763,  1.1610]],

        [[ 1.8855,  1.7226],
         [-0.2060, -0.7651],
         [ 0.1694,  1.8273]]])

In [536]:
x[0,0,:].sum()

tensor(-1.8792)

In [534]:
w1x = einsum(
            w1_weight, x,
            "d_ff d_model, ... d_model -> ... d_ff"
        )
w1x

tensor([[[-1.8792, -1.8792, -1.8792],
         [-0.3729, -0.3729, -0.3729],
         [ 0.8847,  0.8847,  0.8847]],

        [[ 3.6081,  3.6081,  3.6081],
         [-0.9711, -0.9711, -0.9711],
         [ 1.9966,  1.9966,  1.9966]]])

# Q 3.5.3

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, heta: float, d_k: int, max_seq_len: int, device=None):
        ## Construct the RoPE module and create buffers if needed.
        super().__init__()
        self.heta = heta ## $\\Theta$ value for the RoPE
        self.d_k = d_k ## dimension of query and key vectors
        self.max_seq_len = max_seq_len ## Maximum sequence length that will be inputted
        self.device = device ## Device to store the buffer on

        self.weights = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        ## Process an input tensor of shape
        
        in_dtype = x.dtype
        x = x.to(torch.float32)
        x_squaremean = reduce(
            x**2, "... d_model -> ... 1", 'mean'
        )
        x_RMS = (x_squaremean+self.eps).sqrt()
        result = x / x_RMS * self.weights
        return result.to(in_dtype)

tensor([[ 0.8015,  1.2299, -0.0326,  1.3418],
        [-0.2687, -0.0065,  0.4680,  0.7164]])

In [530]:
# 初始化 RMSNorm
rmsnorm = RMSNorm(d_model, eps)
rmsnorm.weights.data.fill_(1.0)  # gamma = 1

tensor([1., 1., 1., 1.])

In [531]:
# 前向计算
y = rmsnorm(x)
y

tensor([[ 0.7347,  1.1273, -0.0299,  1.2299],
        [-0.2463, -0.0059,  0.4290,  0.6567]], grad_fn=<MulBackward0>)

In [532]:
test = x[0]
test

tensor([ 0.8015,  1.2299, -0.0326,  1.3418])

In [526]:
torch.sqrt((test**2).sum()/d_model)

tensor(1.1235)

In [524]:
test/torch.sqrt((test**2).sum()/d_model+eps)

tensor([-0.4036,  1.1004, -1.6200, -0.0424])

In [None]:







# 检查每个样本的 RMS
rms = torch.sqrt(torch.mean(y**2, dim=-1))

print("Input:", x)
print("Output:", y)
print("RMS per sample:", rms)