## Enviorment Import

In [48]:
import torch
import math
from torch import nn

## 전체적인 구조
- 트랜스포머
    - 인코더
        - 멀티헤드 어텐션
        - Add & Norm
        - FPN
        - Add & Norm
    - 디코더
        - 마스크 멀티헤드 어텐션
        - Add & Norm
        - 멀티헤드 어텐션
        - Add & Norm
        - FPN
        - Add & Norm
    - 리니어 레이어


### Multi-head attention 

In [57]:
class MultiHeadAttention(nn.Module):

    def __init__(self, dim_num=512, head_num=8) -> None:

        super().__init__()
        self.head_num = head_num
        self.dim_num = dim_num

        self.query_embed = nn.Linear(dim_num, dim_num)
        self.key_embed = nn.Linear(dim_num, dim_num)
        self.value_embed = nn.Linear(dim_num, dim_num)
        self.output_embed = nn.Linear(dim_num, dim_num)

    def scaled_dot_product_attention(self, query, key, value, mask:bool=False, scale:int=None) -> torch.Tensor:

        scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
        L, S = query.size(-2), key.size(-2)
        attn_bias = torch.zeros(L, S, dtype=query.dtype)
        if mask:
            temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias.to(query.dtype)

        attn_weight = query @ key.transpose(-2,-1) / scale_factor
        attn_weight += attn_bias # mask
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_weight = attn_weight @ value
  
        return attn_weight

    def forward(self, query, key, value, mask:bool=False) -> torch.Tensor:
        batch_size = query.size(0)
        # 순서 유지 때문에 view 후 transpose 사용

        query = self.query_embed(query).view(batch_size, -1, self.head_num, self.dim_num // self.head_num).transpose(1, 2)
        key = self.key_embed(key).view(batch_size, -1, self.head_num, self.dim_num // self.head_num).transpose(1, 2)
        value = self.value_embed(value).view(batch_size, -1, self.head_num, self.dim_num // self.head_num).transpose(1, 2)

        output = self.scaled_dot_product_attention(query, key, value, mask)
        batch_num, head_num, seq_num, hidden_num = output.size()
        output = torch.transpose(output, 1, 2).contiguous().view((batch_size, -1, hidden_num * self.head_num))
        output = self.output_embed(output)
  
        return output

In [58]:
# test scaled_dot_product_attention 
q = torch.rand((4,10,64))
k = torch.rand((4,10,64))
v = torch.rand((4,10,64))


mha = MultiHeadAttention(64, 8)

mha(q,k,v).shape

torch.Size([4, 10, 64])