In [13]:
import torch

In [14]:
BATCH_SIZE = 5
SEQ_LENGTH = 10
N_HEADS = 4
D_HEAD = 20

In [15]:
query_states = torch.randn(BATCH_SIZE, N_HEADS, SEQ_LENGTH, D_HEAD)
key_states = torch.randn(BATCH_SIZE, N_HEADS, SEQ_LENGTH, D_HEAD)
value_states = torch.randn(BATCH_SIZE, N_HEADS, SEQ_LENGTH, D_HEAD)

In [16]:
key_states.shape

torch.Size([5, 4, 10, 20])

In [17]:
value_states.shape

torch.Size([5, 4, 10, 20])

In [18]:
(key_states.transpose(-2, -1) @ value_states).shape

torch.Size([5, 4, 20, 20])

In [19]:
import torch

In [20]:
n_heads, seq_length = 10, 5

In [21]:
states = torch.arange(n_heads*seq_length).view(n_heads, seq_length)

In [22]:
states

tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49]])

In [23]:
states[None, :, :]

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19],
         [20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39],
         [40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49]]])

In [24]:
import math

In [25]:
math.sqrt(1/1024)

0.03125

### GQA

In [26]:
import torch
from einops import repeat

In [27]:
n_q_heads = 4
batch_size, n_kv_heads, d_k, d_v = 1, 2, 5, 5

In [28]:
prev_memory = torch.arange(batch_size*n_kv_heads*d_k*d_v).view(batch_size, n_kv_heads, d_k, d_v)

In [29]:
prev_memory

tensor([[[[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14],
          [15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24]],

         [[25, 26, 27, 28, 29],
          [30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44],
          [45, 46, 47, 48, 49]]]])

In [35]:
prev_memory

tensor([[[[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14],
          [15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24]],

         [[25, 26, 27, 28, 29],
          [30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44],
          [45, 46, 47, 48, 49]]]])

In [37]:
prev_memory.repeat(1, 1, n_q_heads // n_kv_heads, 1)

tensor([[[[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14],
          [15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24],
          [ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14],
          [15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24]],

         [[25, 26, 27, 28, 29],
          [30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44],
          [45, 46, 47, 48, 49],
          [25, 26, 27, 28, 29],
          [30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44],
          [45, 46, 47, 48, 49]]]])

In [40]:
repeat(
    prev_memory, "batch_size n_kv_heads d_k d_v -> batch_size (n_kv_heads n) d_k d_v",
    n=n_q_heads // n_kv_heads
)

tensor([[[[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14],
          [15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24]],

         [[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14],
          [15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24]],

         [[25, 26, 27, 28, 29],
          [30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44],
          [45, 46, 47, 48, 49]],

         [[25, 26, 27, 28, 29],
          [30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44],
          [45, 46, 47, 48, 49]]]])

In [43]:
repeat(
    prev_memory, "batch_size n_kv_heads d_k d_v -> batch_size (n n_kv_heads) d_k d_v",
    n=n_q_heads // n_kv_heads
)

tensor([[[[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14],
          [15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24]],

         [[25, 26, 27, 28, 29],
          [30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44],
          [45, 46, 47, 48, 49]],

         [[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14],
          [15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24]],

         [[25, 26, 27, 28, 29],
          [30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44],
          [45, 46, 47, 48, 49]]]])

In [44]:
repeat(
    prev_memory, "batch_size n_kv_heads d_k d_v -> batch_size (n_kv_heads n) d_k d_v",
    n=n_q_heads // n_kv_heads
)

tensor([[[[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14],
          [15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24]],

         [[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14],
          [15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24]],

         [[25, 26, 27, 28, 29],
          [30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44],
          [45, 46, 47, 48, 49]],

         [[25, 26, 27, 28, 29],
          [30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44],
          [45, 46, 47, 48, 49]]]])

In [60]:
import torch

n_q_heads = 8
n_kv_heads = 2 
d_head = 10

q = torch.rand(n_q_heads, d_head)

# Create original k, v 
k = torch.arange(n_kv_heads * d_head).reshape(n_kv_heads, d_head).float()
v = torch.arange(n_kv_heads * d_head).reshape(n_kv_heads, d_head).float()

In [61]:
k

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.]])

In [62]:
# k = k.repeat(n_q_heads // n_kv_heads, 1)  
# v = v.repeat(n_q_heads // n_kv_heads, 1)

In [63]:
k

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.]])

In [64]:
q = torch.rand(n_q_heads, d_head)


In [65]:
attn = torch.matmul(q, k.transpose(1,0))

In [67]:
attn

tensor([[28.0864, 84.9314, 28.0864, 84.9314, 28.0864, 84.9314, 28.0864, 84.9314],
        [21.6458, 64.5257, 21.6458, 64.5257, 21.6458, 64.5257, 21.6458, 64.5257],
        [26.5185, 90.9190, 26.5185, 90.9190, 26.5185, 90.9190, 26.5185, 90.9190],
        [18.6815, 51.9978, 18.6815, 51.9978, 18.6815, 51.9978, 18.6815, 51.9978],
        [15.0524, 48.3450, 15.0524, 48.3450, 15.0524, 48.3450, 15.0524, 48.3450],
        [24.1377, 79.5783, 24.1377, 79.5783, 24.1377, 79.5783, 24.1377, 79.5783],
        [17.2669, 61.5019, 17.2669, 61.5019, 17.2669, 61.5019, 17.2669, 61.5019],
        [23.2112, 84.4215, 23.2112, 84.4215, 23.2112, 84.4215, 23.2112, 84.4215]])

##### Try 5

In [78]:
import torch

n_q_heads = 8
n_kv_heads = 2 
d_head = 10

q = torch.rand(n_q_heads, d_head).float()
k = torch.arange(n_kv_heads * d_head).reshape(n_kv_heads, d_head).float()
v = torch.arange(n_kv_heads * d_head).reshape(n_kv_heads, d_head).float()

In [79]:
k

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.]])

In [80]:
v

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.]])

In [84]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        # (B, Seq_Len, N_KV_Heads, 1, Head_Dim)
        x[:, None, :]
        # (B, Seq_Len, N_KV_Heads, N_Rep, Head_Dim)
        .expand(n_kv_heads, n_rep, head_dim)
        # (B, Seq_Len, N_KV_Heads * N_Rep, Head_Dim)
        .reshape(n_kv_heads * n_rep, head_dim)
    )

In [85]:
keys = repeat_kv(k, 2)

In [86]:
keys

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.]])

In [87]:
from einops import repeat

In [90]:
repeat(
    k,
    "n_kv_heads d_k -> (n n_kv_heads) d_k",
    n=2,
)

tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.]])

In [92]:
torch.full((1, 3, 1, 1), 0.0)

tensor([[[[0.]],

         [[0.]],

         [[0.]]]])