In [10]:
import torch
import torch.nn as nn
import math

In [18]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        base = 10000.0 ** (-1.0 / d_model)
        div_term = torch.pow(base, torch.arange(0, d_model, 2).float())

        pe = torch.zeros(seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x): # x(embeded sequence): [batch_size, seq_len, d_model]
        x = x + self.pe.requires_grad_(False)
        return self.dropout(x) # [batch_size, seq_len, d_model]

In [19]:
# Suppose we have a sequence length of 50 and we want our embeddings to be of size 300
seq_len = 50
d_model = 300
dropout = 0.1

# Create an instance of our PositionalEncoding class
pos_encoding = PositionalEncoding(d_model, seq_len, dropout)

# Suppose we have the following embeded batch of 2 sequences (mini-batch size of 2)
# Each sequence has 50 words (sequence length of 50)
# Each word is represented by a 300-dimensional vector (d_model = 300)
x = torch.rand(2, 50, 300)

# Pass our sequences through the positional encoding
encoder_input = pos_encoding(x)

print(encoder_input.shape)  # Should output: torch.Size([2, 50, 300])

torch.Size([2, 50, 300])


In [20]:
encoder_input

tensor([[[ 7.0031e-01,  1.3413e+00,  9.5727e-01,  ...,  1.8673e+00,
           1.3082e-01,  1.1909e+00],
         [ 1.3241e+00,  1.4957e+00,  1.7736e+00,  ...,  1.8759e+00,
           2.5413e-01,  1.9031e+00],
         [ 1.8533e+00,  8.7519e-02,  1.4987e+00,  ...,  1.9938e+00,
           1.9857e-01,  1.6533e+00],
         ...,
         [ 0.0000e+00, -9.0130e-01,  8.5059e-01,  ...,  1.5707e+00,
           3.4051e-01,  2.1623e+00],
         [ 2.4486e-01,  1.1713e-01,  1.4777e+00,  ...,  1.6192e+00,
           3.8397e-01,  2.1147e+00],
         [-0.0000e+00,  0.0000e+00,  1.7454e+00,  ...,  2.0527e+00,
           2.8044e-01,  2.2023e+00]],

        [[ 1.7076e-01,  1.3026e+00,  1.9648e-01,  ...,  1.8652e+00,
           2.0502e-03,  0.0000e+00],
         [ 1.6273e+00,  1.5889e+00,  1.8088e+00,  ...,  1.7651e+00,
           8.3192e-01,  1.4153e+00],
         [ 1.6959e+00, -1.0794e-01,  1.8971e+00,  ...,  1.2293e+00,
           2.2635e-01,  1.9268e+00],
         ...,
         [ 9.9376e-01, -0