In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class InputEmbeddings(nn.Module):
  def __init__(self, d_model:int, vocab_size:int) -> None:
    super().__init__()
    self.d_model = d_model
    self.vocab_size = vocab_size
    self.embedding = nn.Embedding(vocab_size, d_model)

  def forward(self, x):
    return self.embedding(x) * math.sqrt(self.d_model)

class PositionalEmbedding(nn.Module):
  def __init__(self, maxlen:int, d_model:int, dropout:float) -> None:
    super().__init__()
    self.maxlen = maxlen
    self.dropout = nn.Dropout(p=dropout)
    self.d_model = d_model

    # Position embedding matrix of shape (maxlen, d_model)
    self.pos_emb = torch.zeros(maxlen, d_model)
    # Position for each token in the sequence of shape (maxlen, 1)
    position = torch.arange(0, maxlen, dtype=torch.float32).unsqueeze(1)
    div = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * -(math.log(10000.0)/d_model))
    # Positional embedding for each token in the sequence of shape (maxlen, d_model)
    self.pos_emb[:, 0::2] = torch.sin(position*div)
    self.pos_emb[:, 1::2] = torch.cos(position*div)
    self.pos_emb = self.pos_emb.unsqueeze(0) # (1, maxlen, d_model)
    # We don't want to update the positional embedding matrix during training, so we register it as a buffer.
    self.register_buffer('pos_emb', self.pos_emb)

  def forward(self, x):
    # Add positional embedding to the input of shape (bsz, maxlen, d_model)
    # Since position embedding is not trainable, we don't need to compute gradients for this tensor.
    x = x + self.pos_emb[:, :x.shape(1), :].requires_grad_(False) # (bsz, maxlen, d_model)
    return self.dropout(x)