In [None]:
# import pandas as pd
# import numpy as np
# from collections import defaultdict
# import regex as re
# from multiprocessing import Pool
# from support.find_chunk_boundaries import find_chunk_boundaries
# from memory_profiler import profile
# import time, tracemalloc
# from dataclasses import dataclass
# from typing import BinaryIO, Iterable, Iterator
# import random

import torch
import torch.nn as nn
from einops import rearrange, einsum, reduce

## Q2.1


In [442]:
images = torch.randn(64, 128, 128, 3) # (batch, height, width, channel)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)

In [449]:
dim_value = rearrange(dim_by, "dim_value -> 1 dim_value 1 1 1")
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel")
dimmed_images = images_rearr * dim_value
dimmed_images = einsum(
    images, dim_by,
    "batch height width channel, dim_value -> batch dim_value height width channel"
)

In [450]:
dimmed_images.shape

torch.Size([64, 10, 128, 128, 3])

# Q 3.4.2

In [454]:
class Linear(nn.Module):
    def __init__(self, in_features: int, out_features: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        super().__init__()
        ## Construct a linear transformation module. This function should accept the following parameters:
        self.in_features = in_features ## final dimension of the input
        self.out_features = out_features ## final dimension of the output
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters

        w = torch.empty(in_features, out_features)
        std = torch.sqrt(torch.tensor(2.0/(in_features+out_features)))
        self.weight = nn.Parameter(nn.init.trunc_normal_(w, mean=0.0, std=std.item(),a=-3*std.item(),b=3*std.item()))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## Apply the linear transformation to the input
        output = einsum(
            self.weight, x,
            "in_dim out_dim, in_dim -> out_dim"
        )
        return output

# Q 3.4.3

In [455]:
class Embedding(nn.Module):
    def __init__(self, num_embeddings: int, embedding_dim: int, device: torch.device | None = None, dtype: torch.dtype | None = None):
        ## Construct an embedding module
        super().__init__()
        self.num_embeddings = num_embeddings ## Size of the vocabulary
        self.embedding_dim = embedding_dim ## Dimension of the embedding vectors
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters
        
        w = torch.empty(num_embeddings, embedding_dim)
        std = 1.0
        self.weight = nn.Parameter(nn.init.trunc_normal_(w, mean=0.0, std=std,a=-3,b=3))
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        ## Lookup the embedding vectors for the given token IDs.
        return self.weight[token_ids]


# Q 3.5.1

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
        ## Construct the RMSNorm module.
        super().__init__()
        self.d_model = d_model ## Hidden dimension of the model
        self.eps = eps ## Epsilon value for numerical stability
        self.device = device ## Device to store the parameters on
        self.dtype = dtype ## Data type of the parameters

        self.weights = nn.Parameter(torch.ones(d_model))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ## Process an input tensor of shape
        
        in_dtype = x.dtype
        x = x.to(torch.float32)
        x_squaremean = reduce(
            x^2, "... d_model -> ... 1", 'mean'
        )
        x_RMS = (x_squaremean+self.eps).sqrt()
        result = x / x_RMS * self.weights
        return result.to(in_dtype)

In [491]:
import torch
from support.transformer import RMSNorm

# 超参数
d_model = 4
eps = 1e-5


In [493]:
# 构造输入
x = torch.randn(2, d_model)
x

tensor([[-0.3636, -0.3152,  0.8715, -1.3661],
        [-0.8370, -0.5851, -0.4513, -1.0230]])

In [494]:
# 初始化 RMSNorm
rmsnorm = RMSNorm(d_model, eps)
rmsnorm.weights.data.fill_(1.0)  # gamma = 1

tensor([1., 1., 1., 1.])

In [495]:
# 前向计算
y = rmsnorm(x)

In [496]:
y

tensor([[-0.3204, -0.2778,  0.7680, -1.2039],
        [-0.7376, -0.5157, -0.3977, -0.9015]], grad_fn=<MulBackward0>)

In [497]:
x[1]

tensor([-0.8370, -0.5851, -0.4513, -1.0230])

In [502]:
x[1]/torch.sqrt((x[1]**2).sum()/d_model+eps)

tensor([-1.1055, -0.7728, -0.5960, -1.3511])

In [None]:







# 检查每个样本的 RMS
rms = torch.sqrt(torch.mean(y**2, dim=-1))

print("Input:", x)
print("Output:", y)
print("RMS per sample:", rms)