# Multiple headers

In [87]:
# version 1: for and concat
import torch

torch.manual_seed(123)

B, T, C = 2, 4, 8 # batch, time, channel
q = torch.randn((B, T, C)) # B, T, C

k = torch.randn((B, T, C))

v = torch.randn((B, T, C))

def head(q, k, v):
    att = (q @ k.transpose(-2, -1)) / C**0.5 # B, T, T
    tril = torch.tril(torch.ones(T, T)) # T, T
    att = att.masked_fill(tril == 0, float('-inf')) # B, T, T
    att = torch.softmax(att, dim=-1) # B, T, T

    out = att @ v # B, T, C
    return out

out = head(q, k, v) # B, T, C

In [77]:
v[0]

tensor([[-0.2582, -2.0407, -0.8016, -0.8183, -1.1820, -0.2877, -0.6043,  0.6002],
        [-1.4053, -0.5922, -0.2548,  1.1517, -0.0179,  0.4264, -0.7657, -0.0545],
        [-1.2743,  0.4513, -0.2280,  0.9224,  0.2056, -0.4970,  0.5821,  0.2053],
        [-0.3018, -0.6703, -0.6171, -0.8334,  0.4839, -0.1349,  0.2119, -0.8714]])

In [78]:
out[0]

tensor([[-0.2582, -2.0407, -0.8016, -0.8183, -1.1820, -0.2877, -0.6043,  0.6002],
        [-0.5085, -1.7247, -0.6823, -0.3885, -0.9280, -0.1319, -0.6395,  0.4574],
        [-1.2056, -0.2033, -0.3026,  0.8066, -0.0315, -0.1442, -0.0328,  0.1576],
        [-0.8482, -0.1931, -0.4107,  0.1548,  0.2657, -0.2460,  0.2601, -0.2675]])

In [79]:
att[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.7818, 0.2182, 0.0000, 0.0000],
        [0.1135, 0.3563, 0.5302, 0.0000],
        [0.0258, 0.0995, 0.4501, 0.4246]])

In [80]:
torch.allclose(v[0][0], out[0][0])

True

In [81]:
torch.allclose(0.7818 * v[0][0] + 0.2182 * v[0][1], out[0][1], atol=1e-02)

True

In [94]:
q1, q2 = torch.randn((B, T, C)), torch.randn((B, T, C)) # B, T, C

k1, k2 = torch.randn((B, T, C)), torch.randn((B, T, C))

v1, v2 = torch.randn((B, T, C)), torch.randn((B, T, C))

head1 = head(q1, k1, v1) # B, T, C
head2 = head(q2, k2, v2) # B, T, C

heads = torch.cat([head1, head2], -1) # B, T, 2*C

In [101]:
torch.allclose(head1, heads[:, :, :C]), torch.allclose(head2, heads[:, :, C:])

(True, True)

In [109]:
# version 2 multiple
a = torch.arange(B*T*C*2).view(2, B, T, C).float()
b = torch.ones(2, B, T, C)

In [111]:
c = a @ b.transpose(-1, -2) # 2, B, T, T


torch.Size([2, 2, 4, 4])

### Convert this file to md

In [2]:
from IPython.core.display import Javascript

In [3]:
%%js
IPython.notebook.kernel.execute('this_notebook = "' + IPython.notebook.notebook_name + '"')

<IPython.core.display.Javascript object>

In [3]:
this_notebook

'2022-09-22-blogging-with-jupyter-notebooks.ipynb'

In [4]:
!jupyter nbconvert --to markdown {this_notebook} --output-dir=../_posts

[NbConvertApp] Converting notebook 2022-09-22-blogging-with-jupyter-notebooks.ipynb to markdown
[NbConvertApp] Writing 725 bytes to ../_posts/2022-09-22-blogging-with-jupyter-notebooks.md
