## This note will focus on embedding

In transformer, the final embedding will be token-embedding and position-embedding.

For transformer, the encoding use specific process like:

<img src="./assert/position-embedding.png" width="50%" alt="Position Embedding">



In [3]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len
        
        # initialize position encoding, shape: (max_len, d_model), all zeros
        self.position_encoding = torch.zeros(max_len, d_model)
        
        # position: (max_len, 1), this will be position for each token
        self.position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # i will be step_count we used to compute div_term
        i = torch.arange(0, d_model, step=2, dtype=torch.float)
        
        # the calculation of div_term is based on the formula in the paper, check image above
        # we apply a boardcast to the position matrix to get a matrix with shape (max_len, d_model / 2)
        
        # here is the example:
        # position: [0, 1, 2, 3]
        # i: [0, 2]
        # div_term: [1, 100, 10000]
        
        # in order to calculate the div, we need transform position for boardcast.
        # we will use the boardcast to fill the position_encoding matrix
        # [0, 1, 2, 3] -> [0, 0, 0, 0]
        # [0, 1, 2, 3] -> [1, 100, 10000, 1000000]
        # [0, 1, 2, 3] -> [2, 200, 20000, 2000000]
        # [0, 1, 2, 3] -> [3, 300, 30000, 3000000]
        self.div_term = self.position / torch.pow(10000.0, 2 * i / d_model)
        
        # We will leverage boardcast to fill the position_encoding matrix
        # the self.div_term is a matrix with shape (max_len, d_model/2)
        # fill odd index with sin, even index with cos
        self.position_encoding[:, 0::2] = torch.sin(self.div_term)
        self.position_encoding[:, 1::2] = torch.cos(self.div_term)
   
    def forward(self, x):
        _batch_size, seq_len = x.size()
        seq_len = min(self.max_len, seq_len)
        x = x + self.position_encoding[:seq_len, :]
        return x
        