# Positional Encoding

## 요약

### 구현 코드

In [1]:
import math
import torch
import torch.nn as nn
from torch import Tensor, LongTensor

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout=0.1, max_len = 5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)    # 0부터 2칸씩
        pe[:, 1::2] = torch.cos(position * div_term)    # 1부터 2칸씩
        self.pe = pe.unsqueeze(0).transpose(0, 1)
    
    def forward(self, x: Tensor):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

### 실행 결과

In [2]:
input_embedding = nn.Embedding(4376, 512)
pos_encoding = PositionalEncoding(512)

# 임베딩 입력은 무조건 정수로 받는다
input_tensor = LongTensor([[2, 1819, 1547, 1698, 230, 3869, 2661, 3596, 3744, 1341, 3155, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
embedded_tensor = input_embedding(input_tensor)

pos_encoding(embedded_tensor)

tensor([[[ 0.2425,  2.3717, -0.7872,  ...,  1.5883, -1.1851,  3.2972],
         [ 0.7976,  0.9716, -1.0545,  ...,  2.5640, -1.1092,  1.8013],
         [ 1.3725,  0.0000,  0.8368,  ...,  0.0000, -0.6474,  0.8655],
         ...,
         [-0.6386,  0.2735, -2.2655,  ...,  2.1753, -1.2832,  0.0000],
         [-0.6386,  0.2735, -2.2655,  ...,  2.1753, -1.2832,  0.0000],
         [-0.6386,  0.2735, -2.2655,  ...,  0.0000, -1.2832,  0.0000]]],
       grad_fn=<MulBackward0>)

# 수행 과정

In [6]:
max_len = 5
d_model = 8


In [7]:
pe = torch.zeros(max_len, d_model)
print(pe)

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])


In [8]:
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
print(position)

tensor([[0.],
        [1.],
        [2.],
        [3.],
        [4.]])


In [9]:
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
print(div_term)

tensor([1.0000e+00, 1.0000e-01, 1.0000e-02, 1.0000e-03])


In [10]:
pe[:, 0::2] = torch.sin(position * div_term)    # 0부터 2칸씩
pe[:, 1::2] = torch.cos(position * div_term)    # 1부터 2칸씩

print(pe)

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
          9.9995e-01,  1.0000e-03,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
          9.9980e-01,  2.0000e-03,  1.0000e+00],
        [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
          9.9955e-01,  3.0000e-03,  1.0000e+00],
        [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
          9.9920e-01,  4.0000e-03,  9.9999e-01]])


In [11]:
pe = pe.unsqueeze(0)
print(pe)
print()
pe = pe.transpose(0, 1)

print(pe)

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000e+00],
         [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
           9.9955e-01,  3.0000e-03,  1.0000e+00],
         [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
           9.9920e-01,  4.0000e-03,  9.9999e-01]]])

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
           1.0000e+00,  0.0000e+00,  1.0000e+00]],

        [[ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
           9.9995e-01,  1.0000e-03,  1.0000e+00]],

        [[ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
           9.9980e-01,  2.0000e-03,  1.0000