In [3]:
import math
import torch

class PositionalEncoding(torch.nn.Module):
  def __init__(self, d_model, max_len = 5000):
    super(PositionalEncoding, self).__init__()

    # Create a matrix of shape (max_len, d_model) with zero values
    pe = torch.zeros(max_len, d_model)

    # Create a vector of shape (max_len) representing position indices
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

    # Division term of for the sine and cosine function
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000)/d_model))

    # Apply sine for even indices
    pe[:, 0::2]  =torch.sin(position*div_term)

    # Apply cosine for odd indices
    pe[:, 1::2] = torch.cos(position * div_term)

    pe = pe.unsqueeze(0).transpose(0, 1)

    self.register_buffer('pe', pe)

  def forward(self, x):
    return x + self.pe[: x.size(0), :] # positional encoding to input embedding


In [12]:
def test_positional_encoding():
  d_model = 512
  max_len = 60
  batch_size = 2

  pos_enc = PositionalEncoding(d_model, max_len)

  dummy_input = torch.zeros(max_len, batch_size, d_model)

  output = pos_enc(dummy_input)

  print("Positional Encodings: ")
  print(output)
  print("shape: ",output.size())
test_positional_encoding()


Positional Encodings: 
tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00]],

        [[ 8.4147e-01,  5.4030e-01,  8.2186e-01,  ...,  1.0000e+00,
           1.0366e-04,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  8.2186e-01,  ...,  1.0000e+00,
           1.0366e-04,  1.0000e+00]],

        [[ 9.0930e-01, -4.1615e-01,  9.3641e-01,  ...,  1.0000e+00,
           2.0733e-04,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.3641e-01,  ...,  1.0000e+00,
           2.0733e-04,  1.0000e+00]],

        ...,

        [[ 4.3616e-01,  8.9987e-01, -9.9997e-01,  ...,  9.9998e-01,
           5.9088e-03,  9.9998e-01],
         [ 4.3616e-01,  8.9987e-01, -9.9997e-01,  ...,  9.9998e-01,
           5.9088e-03,  9.9998e-01]],

        [[ 9.9287e-01,  1.1918e-01, -5.6324e-01,  ...,  9.9998e-01,
           6.0124e-03,  9.9998e-01],
         [