In [None]:
import torch
import torch.nn as nn

In [None]:
class LearnablePositionalEmbedding(nn.Module):
    def __init__(self, seq_len, embedding_dim):
        """
        max_seq_len: Maximum sequence length the model can handle.
        embedding_dim: Dimension of the token embeddings.
        """
        super(LearnablePositionalEmbedding, self).__init__()

        # Initialize the positional embeddings matrix as a learnable parameter
        # Shape: (max_seq_len, embedding_dim)
        self.positional_embeddings = nn.Parameter(torch.randn(seq_len, embedding_dim))

    def forward(self, x):
        """
        x: Tensor of token embeddings of shape (batch_size, seq_len, embedding_dim)
        """
        batch_size, seq_len, embedding_dim = x.shape
        
        # Ensure positional embeddings are only added to the sequence length dimension
        # Slice the positional embeddings matrix to match the current sequence length
        positional_embedding = self.positional_embeddings[:seq_len, :].unsqueeze(0)  # Shape: (1, seq_len, embedding_dim)
        
        # Add the positional embeddings to the input token embeddings
        return x + positional_embedding  # Shape: (batch_size, seq_len, embedding_dim)
