# RMSNorm(Root Mean Square Normalization)
llama에서는 정규화의 방법으로 Root Mean Square Normalization을 사용합니다.

해당 방법은 internal covariate shift를 해결하기 위해 사용합니다.

RMSNorm의 수식은 다음과 같습니다.

![equation of RMSNorm](https://miro.medium.com/v2/resize:fit:1400/1*tjjimBzPdzuWW73y444Uzg.jpeg)

(관련 논문 링크 : https://dl.acm.org/doi/pdf/10.5555/3454287.3455397)

In [2]:
import torch

In [3]:
#llama3의 RMSNorm class
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

**참고: torch.rsqrt**

torch.rsqrt(x)는 1/root(x)를 반환합니다.

공식문서 링크 : https://pytorch.org/docs/stable/generated/torch.rsqrt.html

In [10]:
#torch.rsqrt
sample_tensor = torch.rand(4)
print("sample_tensor:",sample_tensor, "\n")

#직접구현
custom_rsqrt = 1/torch.sqrt(sample_tensor)

#두 값 비교
print("rsqrt:",torch.rsqrt(sample_tensor))
print("rsqrt 직접구현:", custom_rsqrt)

sample_tensor: tensor([0.5521, 0.0182, 0.0376, 0.7472]) 

rsqrt: tensor([1.3459, 7.4214, 5.1544, 1.1569])
rsqrt 직접구현: tensor([1.3459, 7.4214, 5.1544, 1.1569])


**참고: torch.mean**

mean은 평균을 의미합니다. 아마 많은 분들이 알고 계실거라 생각하는데요,</br>
그것보다 중요한 것은 인자값입니다. 각각 dim = -1과 keepdim =True로 설정되었습니다.

In [30]:
#dim 비교

sample_tensor = torch.rand((2,3))
print(sample_tensor)
print("shape of sample_tensor:", sample_tensor.size())

#dim = None, keepdim = False => tensor의 모든 값의 평균
mean_tensor = sample_tensor.mean()
print("\n ----------<dim=None, keepdim=False>----------")
print(mean_tensor)
print("shape of sample_tensor:", mean_tensor.size())

#dim = 0, keepdim = False => axis=0을 기준으로 평균
mean_tensor = sample_tensor.mean(dim=0)
print("\n ----------<dim=0, keepdim=False>----------")
print(mean_tensor)
print("shape of sample_tensor:", mean_tensor.size())

#dim = 1, keepdim = False => axis=1을 기준으로 평균
mean_tensor = sample_tensor.mean(dim=1)
print("\n ----------<dim=1, keepdim=False>----------")
print(mean_tensor)
print("shape of sample_tensor:", mean_tensor.size())

#dim = -1 : 가장 마지막 axis를 기준으로 평균.
#어떤 shape의 tensor가 들어오더라도 가장 마지막 axis를 기준으로 하기 때문에
#유연하게 반응할 수 있다는 장점을 가지고 있습니다.
sample_tensor2 = torch.rand((2,3,4))
sample_tensor3 = torch.rand((2,3,4,5))

print("\n ----------<dim=-1, keepdim=False>----------")
print("mean of shape(2,3): ", sample_tensor.mean(dim=-1).size())
print("mean of shape(2,3,4): ", sample_tensor2.mean(dim=-1).size())
print("mean of shape(2,3,4,5): ", sample_tensor3.mean(dim=-1).size())

tensor([[0.8432, 0.4588, 0.2093],
        [0.2540, 0.1243, 0.1277]])
shape of sample_tensor: torch.Size([2, 3])

 ----------<dim=None, keepdim=False>----------
tensor(0.3362)
shape of sample_tensor: torch.Size([])

 ----------<dim=0, keepdim=False>----------
tensor([0.5486, 0.2915, 0.1685])
shape of sample_tensor: torch.Size([3])

 ----------<dim=1, keepdim=False>----------
tensor([0.5038, 0.1687])
shape of sample_tensor: torch.Size([2])

 ----------<dim=-1, keepdim=False>----------
mean of shape(2,3):  torch.Size([2])
mean of shape(2,3,4):  torch.Size([2, 3])
mean of shape(2,3,4,5):  torch.Size([2, 3, 4])


In [33]:
#keepdim 비교

sample_tensor = torch.rand((2,3))
print(sample_tensor)
print("shape of sample_tensor:", sample_tensor.size())

#dim = 0, keepdim = False => axis=0을 기준으로 평균
mean_tensor = sample_tensor.mean(dim=0, keepdim=True)
print("\n ----------<dim=0, keepdim=False>----------")
print(mean_tensor)
print("shape of sample_tensor:", mean_tensor.size())

#dim = 1, keepdim = False => axis=1을 기준으로 평균
mean_tensor = sample_tensor.mean(dim=1, keepdim=True)
print("\n ----------<dim=1, keepdim=False>----------")
print(mean_tensor)
print("shape of sample_tensor:", mean_tensor.size())

#dim = -1 : 가장 마지막 axis를 기준으로 평균.
#어떤 shape의 tensor가 들어오더라도 가장 마지막 axis를 기준으로 하기 때문에
#유연하게 반응할 수 있다는 장점을 가지고 있습니다.
sample_tensor2 = torch.rand((2,3,4))
sample_tensor3 = torch.rand((2,3,4,5))

tensor([[0.4798, 0.9732, 0.1887],
        [0.4485, 0.2879, 0.5027]])
shape of sample_tensor: torch.Size([2, 3])

 ----------<dim=0, keepdim=False>----------
tensor([[0.4642, 0.6305, 0.3457]])
shape of sample_tensor: torch.Size([1, 3])

 ----------<dim=1, keepdim=False>----------
tensor([[0.5472],
        [0.4130]])
shape of sample_tensor: torch.Size([2, 1])
