## LayerNorm

1. **평균($\mu$)과 분산($\sigma^2$) 구하기**
   - 입력 벡터($x$)의 차원($d$) 내에서 통계량을 계산
   $$\mu = \frac{1}{d} \sum_{i=1}^{d} x_i$$
   $$\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2$$

2. **정규화 (Normalization)**
   - 평균이 0, 분산이 1이 되도록 값을 조정 ($\epsilon$은 분모가 0이 되는 것을 막기 위한 작은 수)
   $$\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}$$

3. **스케일링 & 이동 (Scale & Shift)**
   - 단순히 정규화만 하면 데이터가 가진 고유한 표현력이 사라질 수 있음
   - 학습 가능한 파라미터인 감마($\gamma$)와 베타($\beta$)를 도입해서, 모델이 알아서 적절한 범위로 조절
   $$y_i = \gamma \cdot \hat{x}_i + \beta$$

In [12]:
import os
import torch
import torch.nn as nn

In [4]:
x = torch.tensor([1.0, 10.0, 100.0])
print(x)

mean = x.mean()
var = x.var(unbiased=False) # 모분산
std = torch.sqrt(var + 1e-6)

print(mean)
print(var)
print(std)

x_norm = (x - mean) / std
print(x_norm)

tensor([  1.,  10., 100.])
tensor(37.)
tensor(1998.)
tensor(44.6990)
tensor([-0.8054, -0.6040,  1.4094])


In [5]:
# L, dim
L, dim = 10, 256
x = torch.randn(L, dim)
print(x.shape)

mean = torch.mean(x, dim=-1, keepdim=True)
var = torch.var(x, dim=-1, unbiased=False ,keepdim=True)
std = torch.sqrt(var + 1e-6)

x_norm = (x-mean) / std
print(x_norm.shape)

gamma = torch.ones(dim)
beta = torch.zeros(dim)

y = x_norm * gamma + beta

print(y.shape)

torch.Size([10, 256])
torch.Size([10, 256])
torch.Size([10, 256])


In [6]:
# Layer Normalization
# B, L, dim
B, L, dim = 4, 10, 512
x = torch.randn(B, L, dim)
print(x.shape)

mean = torch.mean(x, dim=-1, keepdim=True)
var = torch.var(x, dim=-1, unbiased=False ,keepdim=True)
std = torch.sqrt(var + 1e-6)

x_norm = (x-mean) / std
print(x_norm.shape)

gamma = torch.ones(dim)
beta = torch.zeros(dim)

y = x_norm * gamma + beta

print(y.shape)


torch.Size([4, 10, 512])
torch.Size([4, 10, 512])
torch.Size([4, 10, 512])


In [8]:
class LayerNorm:
    def __init__(
            self,
            normalized_shape: int,
            eps: float = 1e-5,
            elementwise_affine: bool = True,
            bias: bool = True,
    ):
        self.normalized_shape = normalized_shape
        self.eps = eps
        self.training = True

        if elementwise_affine:
            self.gamma = torch.ones(self.normalized_shape, requires_grad=True)
            self.beta = None
            if bias:
                self.beta = torch.zeros(self.normalized_shape, requires_grad=True)
        else:
            self.gamma = None
            self.beta = None    

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward(x)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.shape[-1] == self.normalized_shape
        # 실제 입력 [B, L, dim]
        # [B, L, 1]
        mean = torch.mean(x, dim=-1, keepdim=True) 
        # [B, L, 1]
        var = torch.var(x, dim=-1, unbiased=False, keepdim=True)
        #var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
        # [B, L, 1]
        std = torch.sqrt(var + self.eps)

        # [B, L, dim]
        x_norm = (x - mean) / std

        if self.gamma is None:
            y = x_norm
        elif self.beta is None:
            y = self.gamma * x_norm
        else:
            y = self.gamma * x_norm + self.beta
        
        return y
    
    def parameters(self):
        if self.gamma is None:
            return []
        elif self.beta is None:
            return [self.gamma]
        return [self.gamma, self.beta]

    def zero_grad(self) -> None:
        for param in self.parameters():
            if param.grad is not None:
                param.grad.zero_()

    def train(self, mode: bool = True):
        self.training = mode
        return self
    
    def eval(self):
        return self.train(False)

In [None]:
### Test
B, L, D = 4, 10, 512
x = torch.randn(B, L, D)

ln = LayerNorm(D, eps=1e-5, elementwise_affine=False, bias=False)
y = ln(x)

assert y.shape == x.shape, "Error"

mean = y.mean(dim=-1)
var = y.var(dim=-1, unbiased=False)

assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-4, rtol=0), "Mean not ~0."
assert torch.allclose(var, torch.ones_like(var), atol=1e-3, rtol=0), "Var not ~1."

x = torch.randn(B, L, D)

ln = LayerNorm(D, eps=1e-5, elementwise_affine=True, bias=True)
test = nn.LayerNorm(D, eps=1e-5, elementwise_affine=True, bias=True)

with torch.no_grad():
    test.weight.copy_(ln.gamma)
    test.bias.copy_(ln.beta)

y = ln(x)
y_test = test(x)

max_abs_err = (y - y_test).abs().max().item()

assert torch.allclose(y, y_test, atol=1e-6, rtol=0)


In [16]:
torch.manual_seed(0)

B, L, D = 4, 5, 16
x = torch.randn(B, L, D, requires_grad=True)

ln = LayerNorm(D, elementwise_affine=True, bias=True)

y = ln(x)
loss = (y ** 2).mean()
loss.backward()

grads = [p.grad for p in ln.parameters()]
assert all(g is not None for g in grads), "Some parameters have no grad."
assert all(torch.isfinite(g).all() for g in grads), "Some grads have NaN/Inf."

print("[OK] gradient flow test (gamma/beta grads exist & finite)")

[OK] gradient flow test (gamma/beta grads exist & finite)


## Batch Normalization
1. 평균($\mu_{\mathcal{B}}$)과 분산($\sigma_{\mathcal{B}}^2$) 구하기
    - 배치($m$) 전체를 보고, 같은 위치(Feature)에 있는 값들의 통계량을 계산
    - $m$: 배치 크기 (Batch Size)
    $$\mu_{\mathcal{B}} = \frac{1}{m} \sum_{i=1}^{m} x_i$$
    $$\sigma_{\mathcal{B}}^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_{\mathcal{B}})^2$$
1. 정규화 (Normalization)
    - 배치 내에서 평균이 0, 분산이 1이 되도록 값을 조정 ($\epsilon$은 안정성을 위한 작은 수)
    $$\hat{x}_i = \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}}
1. $$스케일링 & 이동 (Scale & Shift)
    - 정규화된 값이 데이터의 고유한 특징을 잃지 않도록 조정
    - 학습 가능한 파라미터 감마($\gamma$)와 베타($\beta$) 사용 (채널/특성별로 존재)
    $$y_i = \gamma \cdot \hat{x}_i + \beta$$

Note: 학습(Train) 때는 위 계산을 매번 수행하지만, 추론(Inference/Eval) 때는 학습 중 구해놓은 '이동 평균(Moving Average)'을 사용한다는 점이 LayerNorm과 가장 큰 차이입니다.

In [26]:
# Batch Normalization
# B, L, dim
B, L, dim = 4, 10, 512
x = torch.randn(B, L, dim)
print(x.shape)

mean_bn = torch.mean(x, dim=0, keepdim=True)
var_bn = torch.var(x, dim=0, unbiased=False ,keepdim=True)
std_bn = torch.sqrt(var_bn + 1e-6)

x_norm = (x-mean_bn) / std_bn
print(x_norm.shape)

gamma = torch.ones(dim)
beta = torch.zeros(dim)

y = x_norm * gamma + beta

print(y.shape)


torch.Size([4, 10, 512])
torch.Size([4, 10, 512])
torch.Size([4, 10, 512])
