In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

axis = 1
num_embeddings, embeddings_dim = 500, 1024
embeddings = torch.randn((num_embeddings, embeddings_dim))
indices = torch.randint(0, embeddings.shape[0], size=(14, 89, 25))

embeddings[indices].shape

torch.Size([14, 89, 25, 1024])

In [3]:
from einops import rearrange

rearrange(torch.arange(6), "(a b) -> a b", b=2)

tensor([[0, 1],
        [2, 3],
        [4, 5]])

In [4]:
import torch
from einops import einsum, rearrange, repeat
from jaxtyping import Float, Int
from torch import Tensor

max_seq_len = 32
# theta = 1 / (torch.pi / 2**15)
# theta = 10_000
theta = 2 / (2 * torch.pi / max_seq_len)
d_k = 64

batch_size = 7
sequence_length = 16


sequence_position, feature_position = torch.meshgrid(
    torch.arange(1, max_seq_len + 1),
    torch.arange(1, d_k // 2 + 1),
    indexing="ij",
)

angle = sequence_position / (theta ** ((2 * feature_position) / d_k))

cos = torch.cos(angle)
sin = torch.sin(angle)

in_query_or_key: Float[Tensor, "... sequence_length d_k"] = torch.randn((batch_size, sequence_length, d_k))
token_positions: Int[Tensor, "... sequence_length"] = repeat(
    torch.arange(0, sequence_length), "position -> batch_size position", batch_size=batch_size
)

all_rotations = torch.stack((torch.stack((cos, -sin), dim=-1), torch.stack((sin, cos), dim=-1)), dim=-2)
rotations = all_rotations[token_positions]
in_query_or_key_grouped = rearrange(
    in_query_or_key,
    "... (rotation_groups rotation_group_size) -> ... rotation_groups rotation_group_size",
    rotation_group_size=2,
)
in_query_or_key_grouped_rotated = einsum(
    rotations, in_query_or_key_grouped, "... group_out group_in, ... group_in -> ... group_out"
)
in_query_or_key_rotated = rearrange(
    in_query_or_key_grouped_rotated,
    "... rotation_groups rotation_group_size -> ... (rotation_groups rotation_group_size)",
)

In [8]:
import torch
from jaxtyping import Float
from torch import Tensor

from cs336_basics.models.transformer_lm import scaled_dot_product_attention

torch.random.manual_seed(0)

seq_len = 5
d_k = 8
d_v = 8

Q: Float[Tensor, "... queries d_k"] = torch.randn((seq_len, d_k))
K: Float[Tensor, "... keys d_k"] = torch.randn((seq_len, d_k))
V: Float[Tensor, "... values d_v"] = torch.randn((seq_len, d_v))
mask: Float[Tensor, "... queries keys"] = torch.tril(torch.ones((seq_len, seq_len), dtype=bool))

scaled_dot_product_attention(Q, K, V, mask)

tensor([[ 3.3334,    -inf,    -inf,    -inf,    -inf],
        [ 2.9558, -1.0914,    -inf,    -inf,    -inf],
        [-3.5900,  2.7503, -2.6823,    -inf,    -inf],
        [-1.1395,  1.4364, -0.9490,  2.5124,    -inf],
        [-1.6919,  0.8984,  2.1943,  4.1622, -2.8925]])
tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.8283e-01, 1.7172e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.7529e-03, 9.9390e-01, 4.3450e-03, 0.0000e+00, 0.0000e+00],
        [1.8552e-02, 2.4384e-01, 2.2446e-02, 7.1516e-01, 0.0000e+00],
        [2.4270e-03, 3.2362e-02, 1.1826e-01, 8.4622e-01, 7.3060e-04]])


tensor([[-0.2188, -2.4351, -0.0729, -0.0340,  0.9625,  0.3492, -0.9215, -0.0562],
        [-0.2258, -2.4012, -0.0387, -0.0403,  0.9481,  0.3632, -0.8898, -0.0314],
        [-0.6231, -0.4670,  1.9064, -0.3977,  0.1275,  1.1603,  0.9246,  1.3741],
        [-0.7689, -1.1741,  1.6126,  0.0371, -1.0907, -0.0419,  0.1493,  0.0865],
        [-0.8266, -1.2606,  1.3439,  0.1996, -1.2845, -0.3120,  0.1377, -0.3500]])

In [27]:
import torch
from einops import parse_shape, rearrange
from jaxtyping import Float
from torch import Tensor

from cs336_basics.models.transformer_lm import scaled_dot_product_attention

num_heads = 8

d_k = 16
seq_len = 7

Q: Float[Tensor, "... queries d_k"] = torch.randn((seq_len, d_k))
K: Float[Tensor, "... keys d_k"] = torch.randn((seq_len, d_k))
V: Float[Tensor, "... values d_v"] = torch.randn((seq_len, d_k))


torch.Size([8, 7, 2]) torch.Size([7, 7])


torch.Size([7, 16])

In [None]:
sequence_length = parse_shape(Q, "... sequence_length d_in")["sequence_length"]


def rearrange_to_heads(X: Float[Tensor, "... sequence_length d_k"]):
    return rearrange(
        X, "... sequence_length (num_heads d_head) -> ... num_heads sequence_length d_head", num_heads=num_heads
    )


def rearrange_from_heads(X: Float[Tensor, "... num_heads d_head"]):
    return rearrange(
        X, "... num_heads sequence_length d_head -> ... sequence_length (num_heads d_head)", num_heads=num_heads
    )


Q_heads = rearrange_to_heads(Q)
K_heads = rearrange_to_heads(K)
V_heads = rearrange_to_heads(V)

mask = torch.triu(torch.ones((sequence_length, sequence_length), dtype=torch.bool))

print(Q_heads.shape, mask.shape)

attention_output_heads = scaled_dot_product_attention(Q_heads, K_heads, V_heads, mask)

attention_output = rearrange_from_heads(attention_output_heads)

attention_output.shape

In [30]:
~torch.triu(torch.ones((sequence_length, sequence_length), dtype=torch.bool))

tensor([[False, False, False, False, False, False, False],
        [ True, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True, False]])