In [1]:
import torch

In [30]:
torch.manual_seed(0)

d_k = 3
B = 2
N = 5
T = 6
n_heads = 4

Q = torch.randn((B, n_heads, N, d_k))
K = torch.randn((B, n_heads, N, d_k))
V = torch.randn((B, n_heads, N, d_k))

In [31]:
def attn(Q, K, V, mask):
	QKT = (Q@K.transpose(-2,-1)) * (d_k**-.5)
	masked_QKT = QKT.masked_fill(mask, float("-inf"))
	weights = torch.softmax(masked_QKT, dim=-1)
	return weights@V

attn(Q, K, V, mask=torch.triu(torch.ones((N, N), dtype=torch.bool), diagonal=1)).shape

torch.Size([2, 4, 5, 3])

Just to check that the @ operator does a batched matrix multiply how I like let's unroll and use BMM

In [32]:
torch.manual_seed(0)
a = torch.randn((B, n_heads, N, d_k))
b = torch.randn((B, n_heads, N, d_k))

In [33]:
a_unrolled = a.view((-1, N, d_k))
b_unrolled = b.view((-1, N, d_k))
bmm = torch.bmm(a_unrolled, b_unrolled.transpose(-2, -1)).view((B, n_heads, N, N)).contiguous()
regular = a@b.transpose(-2, -1)
torch.allclose(bmm, regular)

True

But how do i encode the Q weights to which we get Q from?

It was super simple with one headed attention.

We just do `torch.nn.Linear()` and project the $(B, N, T)$ to $(B, N, d_k)$ which was as simple as 
`torch.nn.Linear(T, d_k)` then later calling that function.

Given multiple heads it gets a bit more complicated. But in general, I'm just trying to apply this transformation n_heads number of times.

The input stays the same as $(B, N, T)$ but now I want to project that into $(B, \text{n\_heads}, N, d_k)$ which is just the same but applied n_heads number of times.

So what I need is first weights that represent the projection. Will be shaped as $(T, d_k)$

In the original case, $(B, N, T)$ is matrix multiplied with $(T, d_k)$ to get $(B, N, d_k)$. So I want to do this n_heads times so I'll create n_heads weight matrices as (n_heads, T, d_k).

So essentially I want 

```python
w = torch.zeros((n_heads, T, d_k))
for i in range(n_heads):
	Q[i, ...] = x@w[i, ...]
```

Is there an operation to do that? What I could do is unroll the dimension leading up to it, then batch matmul

For example make the Q matrix (T, n_heads x d_k) and the x matrix (B x N, T)

So when I matmul x and Q I get (B x N, n_heads x d_k) and just reshape into (B, n_heads, N, d_k) which is what I want. Let's try that!

In [67]:
torch.manual_seed(0)
n_heads = 2
B = 1
N = 5
d_k = 2
T = 3

x = torch.randn((B, N, T))
w_Q = torch.empty((n_heads, T, d_k))
r = (1/T)**.5
torch.nn.init.uniform_(w_Q, a=-r, b=r)

w_Q

tensor([[[-0.3742, -0.2658],
         [-0.4034, -0.5407],
         [-0.3370,  0.4963]],

        [[ 0.2576,  0.2798],
         [ 0.0304, -0.2960],
         [ 0.0977, -0.5391]]])

In [68]:
xT = x.view((-1, T)) # (B * N, T)
xT

tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986],
        [ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820],
        [-0.8567,  1.1006, -1.0712]])

In [73]:
w_QT = w_Q.permute([1,0,2]).contiguous().view(T, -1) # (T, n_heads * d_k)
w_QT

tensor([[-0.3742, -0.2658,  0.2576,  0.2798],
        [-0.4034, -0.5407,  0.0304, -0.2960],
        [-0.3370,  0.4963,  0.0977, -0.5391]])

In [74]:
out = xT@w_QT # (B*N, n_heads*d_k)
out

tensor([[ 0.2760, -1.3322,  0.1753,  1.6926],
        [ 0.6961, -0.2588, -0.0231,  1.2340],
        [-0.2466, -0.9173,  0.0591,  0.2525],
        [ 0.3303,  0.5202, -0.1042, -0.0344],
        [ 0.2377, -0.8991, -0.2919,  0.0119]])

In [75]:
out_fmt = out.view((B, N, n_heads, d_k)).permute([0,2,1,3]).contiguous() # (B, n_heads, N, d_k)
out_fmt

tensor([[[[ 0.2760, -1.3322],
          [ 0.6961, -0.2588],
          [-0.2466, -0.9173],
          [ 0.3303,  0.5202],
          [ 0.2377, -0.8991]],

         [[ 0.1753,  1.6926],
          [-0.0231,  1.2340],
          [ 0.0591,  0.2525],
          [-0.1042, -0.0344],
          [-0.2919,  0.0119]]]])

In [81]:
out_fmt_2 = torch.empty((B, n_heads, N, d_k))
for i in range(n_heads):
	out_fmt_2[:, i, :, :] = x@w_Q[i, ...]
torch.allclose(out_fmt_2, out_fmt)

True