In [28]:
import torch.nn as nn
import torch
from jaxtyping import Float, Int
from torch import Tensor
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device = None):
        super().__init__()
        self.theta = theta
        self.d_k = d_k
        self.max_seq_len = max_seq_len
        pair_idx = torch.arange(d_k//2, device=device)
        freq = theta ** (2 * pair_idx / d_k)
        positions = torch.arange(max_seq_len, device=device)
        angles = positions[:, None] / freq[None, :]
        cos = torch.cos(angles)
        sin = torch.sin(angles)

        self.register_buffer("cos", cos, persistent=False)
        self.register_buffer("sin", sin, persistent=False)
    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        assert x.shape[-1] % 2 == 0
        cos = self.cos[token_positions]
        sin = self.sin[token_positions]
        *batch_dims, seq_len, d_k = x.shape
        x = x.reshape(*batch_dims, seq_len, d_k // 2, 2)
        x1 = x[..., 0]
        x2 = x[..., 1]
        rotated_x1 = x1 * cos - x2 * sin
        rotated_x2 = x1 * sin + x2 * cos
        rotated = torch.stack((rotated_x1, rotated_x2), dim=-1)
        rotated = rotated.reshape(*batch_dims, seq_len, d_k)
        return rotated



In [29]:


seq_len = 256
d_k = 32
input_token = torch.randn(8, seq_len, d_k)
token_positions = torch.randint(0,100, size=(8, seq_len))
ff = RotaryPositionalEmbedding(10000.00, d_k, seq_len*2)
ff(input_token, token_positions)

tensor([[[ 8.5809e-01,  6.4918e-01,  6.6892e-01,  ..., -1.6839e-01,
           3.9509e-01,  1.2025e+00],
         [-9.6578e-01, -9.2661e-01, -6.1186e-02,  ...,  8.5194e-01,
          -1.5859e+00,  2.6573e-01],
         [ 4.4932e-01, -2.8551e-01, -1.3897e-01,  ...,  3.5243e-02,
           5.2622e-01,  2.4745e-01],
         ...,
         [-7.9324e-01,  1.9731e+00, -7.0383e-01,  ...,  2.9834e+00,
           2.0680e-01,  4.3375e-02],
         [-1.8401e+00, -8.6092e-01,  2.9436e-02,  ..., -2.1575e-01,
           1.3786e+00,  1.9359e+00],
         [-1.5590e+00,  6.0942e-02,  1.8346e+00,  ..., -1.4881e+00,
          -7.7315e-01, -5.1698e-01]],

        [[-7.4668e-01, -3.5314e-01,  1.5594e+00,  ..., -1.2057e-01,
          -7.8813e-01,  2.2185e+00],
         [ 4.0950e-01, -4.9746e-01,  3.7888e-01,  ...,  8.3834e-01,
          -1.7501e-01, -2.9479e-02],
         [ 3.2082e-03,  7.0197e-01,  1.0999e+00,  ...,  6.0478e-02,
          -6.2677e-02,  8.8634e-02],
         ...,
         [ 1.0650e-01,  8