## Positional Encoding

This notebook will code positional encoding for Transformer neural networks with pytorch

In [17]:
import torch
import torch.nn as nn

max_sequence_length = 3
d_model = 8

$$
PE(\text{position}, 2i) = \sin\bigg( \frac{ \text{position} }{10000^\frac{2i}{d_{model}}} \bigg)
$$

$$
PE(\text{position}, 2i+1) = \cos\bigg( \frac{ \text{position} }{10000^\frac{2i}{d_{model}}} \bigg)
$$

We can rewrite these as

$$
PE(\text{position}, i) = \sin\bigg( \frac{ \text{position} }{10000^\frac{i}{d_{model}}} \bigg) \text{ when i is even}
$$

$$
PE(\text{position}, i) = \cos\bigg( \frac{ \text{position} }{10000^\frac{i-1}{d_{model}}} \bigg) \text{ when i is odd}
$$

In [18]:
even_i = torch.arange(0, d_model, 2).float()
even_i

tensor([0., 2., 4., 6.])

In [19]:
even_denominator = torch.pow(10000, even_i/d_model)
even_denominator

tensor([   1.,   10.,  100., 1000.])

In [20]:
odd_i = torch.arange(1, d_model, 2).float()
odd_i

tensor([1., 3., 5., 7.])

In [21]:
even_denominator = torch.pow(10000, (odd_i - 1)/d_model)
even_denominator

tensor([   1.,   10.,  100., 1000.])

`even_denominator` and `odd_denominator` are the same! So we can just do one of these actions and call the resulting variable `denominator`

In [22]:
denominator = even_denominator

In [23]:
position = torch.arange(max_sequence_length, dtype=torch.float)
position.shape

torch.Size([3])

In [24]:
position = torch.arange(max_sequence_length, dtype=torch.float).reshape(max_sequence_length, 1)

In [25]:
print(position.shape)
position

torch.Size([3, 1])


tensor([[0.],
        [1.],
        [2.]])

In [26]:
even_PE = torch.sin(position / denominator)
odd_PE = torch.cos(position / denominator)

In [27]:
denominator

tensor([   1.,   10.,  100., 1000.])

In [28]:
even_PE

tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [0.8415, 0.0998, 0.0100, 0.0010],
        [0.9093, 0.1987, 0.0200, 0.0020]])

In [29]:
even_PE.shape

torch.Size([3, 4])

In [30]:
odd_PE

tensor([[ 1.0000,  1.0000,  1.0000,  1.0000],
        [ 0.5403,  0.9950,  0.9999,  1.0000],
        [-0.4161,  0.9801,  0.9998,  1.0000]])

In [31]:
odd_PE.shape

torch.Size([3, 4])

In [32]:
stacked = torch.stack([even_PE, odd_PE], dim=2)
stacked.shape

torch.Size([3, 4, 2])

In [33]:
#Alternate values of sin and cos pe
PE = torch.flatten(stacked, start_dim=1, end_dim=2)
PE

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
          9.9995e-01,  1.0000e-03,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
          9.9980e-01,  2.0000e-03,  1.0000e+00]])

In [34]:
# max_sequence_length = 10
# d_model = 6
PE.shape

torch.Size([3, 8])

## Class

Let's combine all the code above into a cute class

In [31]:
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_sequence_length):
        super().__init__()
        self.max_sequence_length = max_sequence_length
        self.d_model = d_model

    def forward(self):
        even_i = torch.arange(0, self.d_model, 2).float()
        denominator = torch.pow(10000, even_i/self.d_model)
        position = torch.arange(self.max_sequence_length).reshape(self.max_sequence_length, 1)
        even_PE = torch.sin(position / denominator)
        odd_PE = torch.cos(position / denominator)
        stacked = torch.stack([even_PE, odd_PE], dim=2)
        PE = torch.flatten(stacked, start_dim=1, end_dim=2)
        return PE

In [32]:
pe = PositionalEncoding(d_model=6, max_sequence_length=10)
pe.forward()

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0464,  0.9989,  0.0022,  1.0000],
        [ 0.9093, -0.4161,  0.0927,  0.9957,  0.0043,  1.0000],
        [ 0.1411, -0.9900,  0.1388,  0.9903,  0.0065,  1.0000],
        [-0.7568, -0.6536,  0.1846,  0.9828,  0.0086,  1.0000],
        [-0.9589,  0.2837,  0.2300,  0.9732,  0.0108,  0.9999],
        [-0.2794,  0.9602,  0.2749,  0.9615,  0.0129,  0.9999],
        [ 0.6570,  0.7539,  0.3192,  0.9477,  0.0151,  0.9999],
        [ 0.9894, -0.1455,  0.3629,  0.9318,  0.0172,  0.9999],
        [ 0.4121, -0.9111,  0.4057,  0.9140,  0.0194,  0.9998]])

Happy Coding!