In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [158]:
class InputEmbeddings(nn.Module):

    def __init__(self, vocab_size, embdim):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embdim)

    def forward(self, x):
        return self.embeddings(x)

In [165]:
class PositionEmbeddings(nn.Module):

    def __init__(self, seq_len, embdim):
        super().__init__()
        # self.pe is the lookup matrix where each row represents seq[i]
        self.pe = torch.zeros(seq_len, embdim, dtype=torch.float32, requires_grad=False)
        # positions is just the sequences of seq_len
        positions = torch.arange(0, seq_len, dtype=torch.float32).unsqueeze(1)
        # skipdim, skips by 2 starting 0
        emb_skip_dim = torch.arange(0, embdim, step=2, dtype=torch.float32)

        z = positions / (10000 ** (emb_skip_dim / embdim))
        self.pe[:, 0::2] = torch.sin(z)
        self.pe[:, 1::2] = torch.cos(z)
        self.pe = self.pe.unsqueeze(0)

    def forward(self, x):
        B, T, C = x.shape

        x = x + self.pe[:, :T, :]
        return x

In [166]:
B, T, C = 2, 4, 8
vocab_size = 16
x = torch.randint(0, vocab_size, (B, T))
x

tensor([[ 0,  0, 13,  2],
        [ 0,  2, 10,  5]])

In [172]:
model = InputEmbeddings(16, C)
x1 = model(x)
model2 = PositionEmbeddings(T, C)
out = model2(x1)
#

In [173]:
out = out.sum()

In [174]:
out.backward()

In [182]:
model.embeddings.weight.grad

tensor([[3., 3., 3., 3., 3., 3., 3., 3.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])