In [7]:
import math
import torch
import torch.nn as nn

In [14]:
class PositionalEncoding(nn.Module):
    """
    Implements the sinusoidal positional encoding for transformer models.
    """

    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
        """
        Args:
            d_model: dimension of the embeddings
            max_len: maximum length of input sequences
            dropout: dropout probability to apply after adding positional encodings
        """
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create a long enough P x D matrix
        pe = torch.zeros(max_len, d_model)            # (max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()  # (max_len, 1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        # Apply sin to even indices; cos to odd indices
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))  # (1, max_len, d_model)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, d_model)
        Returns:
            Tensor of same shape with positional encodings added.
        """
        seq_len = x.size(1)
        # Add positional encoding up to the input length
        x = x + self.pe[:, :seq_len, :].to(x.dtype)
        return self.dropout(x)

In [15]:
if __name__ == "__main__":
    batch_size, seq_len, d_model = 2, 10, 512
    pe = PositionalEncoding(d_model=d_model, max_len=100)
    sample_input = torch.zeros(batch_size, seq_len, d_model)
    output = pe(sample_input)
    print(output.shape)  # should be [2, 10, 512]

torch.Size([2, 10, 512])
