## Position-wise Feed-Forward Networks
- 각 위치(Position)의 단어마다 개별적으로 적용되는 Fully Connected Network
- 모든 단어가 동일한 $W_1$, $W_2$를 공유
- 연산 자체는 독립적으로 적용

- $$FFN(x) = \max(0, xW_1 + b_1)W_2 + b_2$$


    - $x$: 입력 벡터 (Shape: [Batch, Seq_len, d_model])
    - $W_1$: 첫 번째 Linear 가중치 (Shape: [d_model, d_ff]) $\rightarrow$ 확장
    - $b_1$: 첫 번째 편향 (Bias)$W_2$: 두 번째 Linear 가중치 (Shape: [d_ff, d_model]) $\rightarrow$ 축소
    - $b_2$: 두 번째 편향 (Bias)

- 학습 파라미터
    - $W_1$, $W_2$, $b_1$, $b_2$

In [16]:
import math
from typing import List

import torch
import torch.nn as nn

In [7]:
B, L, d_model = 4, 10, 512
d_ff = 2048

# W_1: [d_model, d_ff]
W_1 = torch.randn(d_model, d_ff, requires_grad=True)
print(W_1.shape)
# b_1: [d_ff]
b_1 = torch.zeros(d_ff, requires_grad=True)
print(b_1.shape)

# W_2: [d_ff, d_model]
W_2 = torch.randn(d_ff, d_model, requires_grad=True)
print(W_2.shape)
b_2 = torch.zeros(d_model, requires_grad=True)
print(b_2.shape)


torch.Size([512, 2048])
torch.Size([2048])
torch.Size([2048, 512])
torch.Size([512])


### ReLU
- $$f(x) = \max(0, x)$$
    - $x > 0$ 이면: $x$ (그대로)
    - $x\le 0$ 이면: $0$ (차단)


In [10]:
x = torch.tensor([0,2,5,2,3,-12])
x.clamp(min=0)

tensor([0, 2, 5, 2, 3, 0])

In [11]:
class ReLU:
    def __init__(self, inplace: bool = False):
        self.inplace = inplace

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward(x)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.inplace:
            x.clamp_(min=0)
            return x
        return x.clamp(min=0)

    def parameters(self):
        return []

In [18]:
class Dropout:
    def __init__(
            self,
            p: float
    ):  
        if not (0.0 <= p < 1.0):
            raise ValueError("p must be in [0,1)")
        self.p = p
        self.training = True

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward(x)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.training or self.p == 0.0:
            return x
        
        keep_prob = 1.0 - self.p
        mask = torch.rand_like(x) < keep_prob
        
        return (x * mask) / keep_prob

    def train(self, mode: bool=True):
        self.training = mode
        return self
    
    def eval(self):
        self.training = False
        return self.train(False)

In [19]:
class PositionWiseFeedForward:
    def __init__(
            self,
            d_model: int,
            d_ff: int,
            dropout_p: float = 0.1,
            bias: bool = True,
    ) -> None:
        if d_model <= 0 or d_ff <= 0:
            raise ValueError("d_model and d_ff must be positive.")
        if not (0.0 <= dropout_p < 1.0):
            raise ValueError("dropout_p must be in [0,1)")
        
        self.d_model = d_model
        self.d_ff = d_ff
        self.training = True
        self.bias = bias

        self.w1 = torch.randn(d_model, d_ff) /  math.sqrt(d_model)
        self.w1.requires_grad_()
        self.b1 = None
        if bias:
            self.b1 = torch.zeros(d_ff, requires_grad=True)

        self.w2 = torch.randn(d_ff, d_model) /  math.sqrt(d_ff)
        self.w2.requires_grad_()
        self.b2 = None
        if bias:
            self.b2 = torch.zeros(d_model, requires_grad=True)

        self.activation = ReLU()
        self.dropout = Dropout(p=dropout_p)

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward(x)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.shape[-1] == self.d_model, "Last dim must be d_model."

        # [B, L, d_model] -> [B, L, d_ff]
        x = torch.matmul(x, self.w1)
        if self.b1 is not None:
            x = x + self.b1
        
        x = self.activation(x)
        x = self.dropout(x)

        # [B, L, d_ff] -> [B, L, d_model]
        x = torch.matmul(x, self.w2)
        if self.b2 is not None:
            x = x + self.b2
        
        return x

    def parameters(self) -> List[torch.Tensor]:
        params = [self.w1, self.w2]
        if self.b1 is not None:
            params.append(self.b1)
        if self.b2 is not None:
            params.append(self.b2)
        return params
    
    def zero_grad(self) -> None:
        for param in self.parameters():
            if param.grad is not None:
                param.grad.zero_()

    def train(self, mode: bool = True):
        self.training = mode
        self.dropout.train(mode)
        return self
    
    def eval(self):
        return self.train(False)
