In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [5]:
B = 2
T = 4
C = 8
nh = 4
dropout = 0.2
x = torch.ones(B, T, C)

# Approach 1 - mine

In [21]:
torch.manual_seed(42)
lin_key = nn.Linear(C, C * nh, bias=False)
lin_query = nn.Linear(C, C * nh, bias=False)
lin_value = nn.Linear(C, C * nh, bias=False)

In [22]:
k = lin_key(x)
q = lin_query(x)
v = lin_value(x)
k.shape, q.shape, v.shape # (B, T, C * nh)

(torch.Size([2, 4, 32]), torch.Size([2, 4, 32]), torch.Size([2, 4, 32]))

# Approach 2 - The correct one

In [24]:
torch.manual_seed(42)
lin_attn = nn.Linear(C, 3*C, bias=False)

In [25]:
q, k, v = lin_attn(x).split(C, dim=2) # (B, T, C) @ (C, 3C) -> (B, T, 3C)
k = k.view(B, T, nh, C // nh)
q = q.view(B, T, nh, C // nh)
v = v.view(B, T, nh, C // nh)

In [26]:
k.shape # (B, T, nh, hs)

torch.Size([2, 4, 4, 2])

# Visualizing

Basic Matrix Multiplication:
```
(
    a1, a2
    b1, b2
)
@
(
    c1, c2
    d1, d2
)
=
(
    a1c1 + a2d1, a1c2 + a2d2
    b1c1 + b2d1, b1c2 + b2d2
)
```

Approach 2:
```
(
    x1, x2
    x3, x4
)
@
(
    l1, l2, l3, l4, l5, l6
    l7, l8, l9, l10, l11, l12
)
=
(
    x1l1 + x2l7, x1l2 + x2l8, x1l3 + x2l9, x1l4 + x2l10, x1l5 + x2l11, x1l6 + x2l12
    x3l1 + x4l7, x3l2 + x4l8, x3l3 + x4l9, x3l4 + x4l10, x3l5 + x4l11, x3l6 + x4l12
)
view - split in last dimension into 3 ->
q = (
    x1l1 + x2l7, x1l2 + x2l8,
    x3l1 + x4l7, x3l2 + x4l8,
)
k = (
    x1l3 + x2l9, x1l4 + x2l10,
    x3l3 + x4l9, x3l4 + x4l10
)
v = (
    x1l5 + x2l11, x1l6 + x2l12,
    x3l5 + x4l11, x3l6 + x4l12
)

Not going to write it out, but can see that this is equal to:
```
q = x @ ( l1, l2
          l7, l8 )
k = x @ ( l3, l4
          l9, l10 )
v = x @ ( l5, l6
          l11, l12 )
```