In [1]:
import torch

In [2]:
x = torch.tensor([1., 1., 1., 1.])
cos = torch.tensor([0.5403, 0.6479, 0.7318, 0.7965])
sin = torch.tensor([8.4147e-01, 7.6172e-01, 6.8156e-01, 6.0469e-01])

In [3]:
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply(x):
    """Applies the rotation to the input."""
    x = x * cos + rotate_half(x) * sin
    return x

In [4]:
x[3] * cos[3] + rotate_half(x)[3] * sin[3]

tensor(1.4012)

In [5]:
rotate_half(x)

tensor([-1., -1.,  1.,  1.])

In [6]:
apply(x)

tensor([-0.3012, -0.1138,  1.4134,  1.4012])

In [7]:
def hardcoded_cos_rot_sin(x, cos, sin):
    m_cos = torch.tensor([
        [cos[0], 0, 0, 0],
        [0, cos[1], 0, 0],
        [0, 0, cos[2], 0],
        [0, 0, 0, cos[3]]        
    ])
    m_rot_sin = torch.tensor([
        [0., 0., sin[2], 0.],
        [0., 0., 0., sin[3]],
        [-sin[0], 0., 0., 0.],
        [0., -sin[1], 0., 0.]
    ])
    print(m_rot_sin)
    m_cos_plus_rot_sin = torch.tensor([
        [cos[0], 0., sin[2], 0.],
        [0., cos[1], 0., sin[3]],
        [-sin[0], 0., cos[2], 0.],
        [0., -sin[1], 0., cos[3]]
    ])
    
    return m_cos_plus_rot_sin

x @ hardcoded_cos_rot_sin(x, cos, sin)

tensor([[ 0.0000,  0.0000,  0.6816,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.6047],
        [-0.8415,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.7617,  0.0000,  0.0000]])


tensor([-0.3012, -0.1138,  1.4134,  1.4012])

In [8]:
def general_cos_rot_sin(cos, sin):
    """ Extending the above to arbitrary lengths of cos and sin """
    m_cos = torch.diag(cos)
    m_sin = torch.diag(sin)
    d = len(sin)
    m_rot_sin = torch.cat([m_sin[d // 2:], -m_sin[:d // 2]])
    return m_cos + m_rot_sin

x @ general_cos_rot_sin(cos, sin)

tensor([-0.3012, -0.1138,  1.4134,  1.4012])

In [26]:
def precompute_freqs(dim: int, end: int, theta: float = 10000.0):
    """
    Precompute the frequency tensor for sine and cosine values with given dimensions.

    Args:
        dim (int): Dimension of the frequency tensor.
        end (int): End index for precomputing frequencies.
        theta (float, optional): Scaling factor for frequency computation. Grok-1 uses 10000.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tensors containing cosine and sine values.
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end)
    freqs = torch.outer(t, freqs).float()
    cos, sin = torch.cos(freqs), torch.sin(freqs)
    return cos, sin


def freq_row_to_rotation_matrix(cos_row, sin_row):
    """
    Transform cos/sin frequency rows to a dim x dim rotation matrix
    that implements cos + rotate_half * sin
    """

    d = len(sin_row)
    m_cos = torch.diag(cos_row)
    m_sin = torch.diag(sin_row)
    d = len(sin_row)
    m_rot_sin = torch.cat([m_sin[d // 2:], -m_sin[:d // 2]])
    return m_cos + m_rot_sin


def get_rotation_mat(dhead, end):
    cos, sin = precompute_freqs(dhead, end)
    rot_mat = [freq_row_to_rotation_matrix(c, s) for c, s in zip(cos, sin)]
    return rot_mat


In [27]:
cos, sin = precompute_freqs(8, 100)

cos.shape

torch.Size([100, 4])

In [32]:
r = get_rotation_mat(128, 16384)

In [34]:
r[0].shape

torch.Size([64, 64])