### Rotary Embedding

In [2]:
import torch


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    # angular frequencies, in unit of radians per timestep: https://en.wikipedia.org/wiki/Sine_wave
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    freqs_cos = torch.cos(freqs)  # real part
    freqs_sin = torch.sin(freqs)  # imaginary part
    return freqs_cis, freqs_cos, freqs_sin


cis, freqs_cos, freqs_sin = precompute_freqs_cis(10, 10) # (t, dim/2): (10, 5), (10, 5), (10, 5)

# These are the same. The difference is that freqs_cis is complex64, while cos and sin are
# the real and imaginary parts of freqs_cis, respectively.
print(cis[5])
print(freqs_cos[5])
print(freqs_sin[5])

tensor([0.2837-0.9589j, 0.7021+0.7121j, 0.9921+0.1253j, 0.9998+0.0199j,
        1.0000+0.0032j])
tensor([0.2837, 0.7021, 0.9921, 0.9998, 1.0000])
tensor([-0.9589,  0.7121,  0.1253,  0.0199,  0.0032])


In [42]:
import altair as alt
import pandas as pd

cis100, freqs_cos100, freqs_sin100 = precompute_freqs_cis(10, 100)
# data = freqs_cos100
print(torch.view_as_real(cis100).flatten(1).shape)
data = torch.view_as_real(cis100).flatten(1)
# data = torch.concat([freqs_cos100, freqs_sin100], dim=1)  # (100, 10)
data = pd.concat(
    [
        pd.DataFrame(
            {
                "embedding": data[:, dim],
                "dimension": dim,
                "position": list(range(100)),
            }
        )
        for dim in range(2, data.shape[-1])
    ]
)
alt.Chart(data) \
    .mark_line() \
    .properties(width=800) \
    .encode(x="position", y="embedding", color="dimension:N") \
    .interactive()

# The positional embeddings are waves. Each dimension is a wave with a different frequency.
# Why not just use a single wave? Why use multiple waves?
# The answer is that waves are periodic. Given a long enough sequence, a single wave will repeat itself.
# This is not what we want. We want each position to have a unique embedding. If we use multiple waves,
# then the combination of waves will be unique for each position.

torch.Size([100, 10])


In [18]:
xq = torch.randn(1, 10)
xq = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk = torch.randn(1, 10)
xk = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
xq.shape

xq1 = xq * cis[1]
xk5 = xk * cis[5]
xq1 = torch.view_as_real(xq1).flatten(1)
xk5 = torch.view_as_real(xk5).flatten(1)
print(xq1 @ xk5.T)

xq2 = xq * cis[2]
xk6 = xk * cis[6]
xq2 = torch.view_as_real(xq2).flatten(1)
xk6 = torch.view_as_real(xk6).flatten(1)
print(xq2 @ xk6.T)

# shifting tokens but keeping the relative distance between them will not change their dot product,
# this is consistent with the intuition that the meaning of a phrase is not changed by shifting it
# in the sentence

tensor([[-0.8709]])
tensor([[-0.8709]])


### Grouped-Query Attention