In [1]:
import torch

In [30]:
torch.manual_seed(0)

d_k = 3
B = 2
N = 5
T = 6
n_heads = 4

Q = torch.randn((B, n_heads, N, d_k))
K = torch.randn((B, n_heads, N, d_k))
V = torch.randn((B, n_heads, N, d_k))

In [31]:
def attn(Q, K, V, mask):
	QKT = (Q@K.transpose(-2,-1)) * (d_k**-.5)
	masked_QKT = QKT.masked_fill(mask, float("-inf"))
	weights = torch.softmax(masked_QKT, dim=-1)
	return weights@V

attn(Q, K, V, mask=torch.triu(torch.ones((N, N), dtype=torch.bool), diagonal=1)).shape

torch.Size([2, 4, 5, 3])

Just to check that the @ operator does a batched matrix multiply how I like let's unroll and use BMM

In [32]:
torch.manual_seed(0)
a = torch.randn((B, n_heads, N, d_k))
b = torch.randn((B, n_heads, N, d_k))

In [33]:
a_unrolled = a.view((-1, N, d_k))
b_unrolled = b.view((-1, N, d_k))
bmm = torch.bmm(a_unrolled, b_unrolled.transpose(-2, -1)).view((B, n_heads, N, N)).contiguous()
regular = a@b.transpose(-2, -1)
torch.allclose(bmm, regular)

True

But how do i encode the Q weights to which we get Q from?

It was super simple with one headed attention.

We just do `torch.nn.Linear()` and project the $(B, N, T)$ to $(B, N, d_k)$ which was as simple as 
`torch.nn.Linear(T, d_k)` then later calling that function.

Given multiple heads it gets a bit more complicated. But in general, I'm just trying to apply this transformation n_heads number of times.

The input stays the same as $(B, N, T)$ but now I want to project that into $(B, \text{n\_heads}, N, d_k)$ which is just the same but applied n_heads number of times.

So what I need is first weights that represent the projection. Will be shaped as $(T, d_k)$

In the original case, $(B, N, T)$ is matrix multiplied with $(T, d_k)$ to get $(B, N, d_k)$. So I want to do this n_heads times so I'll create n_heads weight matrices as (n_heads, T, d_k).

So essentially I want 

```python
w = torch.zeros((n_heads, T, d_k))
for i in range(n_heads):
	Q[i, ...] = x@w[i, ...]
```

Is there an operation to do that? What I could do is unroll the dimension leading up to it, then batch matmul

For example make the Q matrix (T, n_heads x d_k) and the x matrix (B x N, T)

So when I matmul x and Q I get (B x N, n_heads x d_k) and just reshape into (B, n_heads, N, d_k) which is what I want. Let's try that!

In [114]:
n_heads = 2
B = 2
N = 5
d_k = 2
T = 3

torch.manual_seed(0)
x = torch.randn((B, N, T))
torch.manual_seed(0)
w_Q = torch.empty((T, n_heads, d_k))
r = (1/T)**.5
torch.nn.init.uniform_(w_Q, a=-r, b=r)

w_Q

tensor([[[-0.0043,  0.3097],
         [-0.4752, -0.4249]],

        [[-0.2224,  0.1548],
         [-0.0114,  0.4578]],

        [[-0.0512,  0.1528],
         [-0.1745, -0.1135]]])

In [115]:
xT = x.view((-1, T)) # (B * N, T)
xT

tensor([[-1.1258, -1.1524, -0.2506],
        [-0.4339,  0.8487,  0.6920],
        [-0.3160, -2.1152,  0.3223],
        [-1.2633,  0.3500,  0.3081],
        [ 0.1198,  1.2377, -0.1435],
        [-0.1116, -0.6136,  0.0316],
        [-0.4927,  0.2484,  0.4397],
        [ 0.1124, -0.8411, -2.3160],
        [-0.1023,  0.7924, -0.2897],
        [ 0.0525,  0.5229,  2.3022]])

In [116]:
w_QT = w_Q.view(T, -1) # (T, n_heads * d_k)
w_QT

tensor([[-0.0043,  0.3097, -0.4752, -0.4249],
        [-0.2224,  0.1548, -0.0114,  0.4578],
        [-0.0512,  0.1528, -0.1745, -0.1135]])

In [117]:
out = xT@w_QT # (B*N, n_heads*d_k)
out

tensor([[ 0.2740, -0.5654,  0.5919, -0.0207],
        [-0.2223,  0.1027,  0.0757,  0.4943],
        [ 0.4552, -0.3761,  0.1181, -0.8706],
        [-0.0882, -0.2900,  0.5426,  0.6620],
        [-0.2684,  0.2068, -0.0461,  0.5319],
        [ 0.1353, -0.1247,  0.0545, -0.2370],
        [-0.0756, -0.0470,  0.1546,  0.2732],
        [ 0.3052, -0.4492,  0.3603, -0.1699],
        [-0.1609,  0.0467,  0.0901,  0.4391],
        [-0.2345,  0.4489, -0.4326, -0.0442]])

In [118]:
out_fmt = out.view((B, N, n_heads, d_k)).transpose(1,2).contiguous() # (B, n_heads, N, d_k)
out_fmt

tensor([[[[ 0.2740, -0.5654],
          [-0.2223,  0.1027],
          [ 0.4552, -0.3761],
          [-0.0882, -0.2900],
          [-0.2684,  0.2068]],

         [[ 0.5919, -0.0207],
          [ 0.0757,  0.4943],
          [ 0.1181, -0.8706],
          [ 0.5426,  0.6620],
          [-0.0461,  0.5319]]],


        [[[ 0.1353, -0.1247],
          [-0.0756, -0.0470],
          [ 0.3052, -0.4492],
          [-0.1609,  0.0467],
          [-0.2345,  0.4489]],

         [[ 0.0545, -0.2370],
          [ 0.1546,  0.2732],
          [ 0.3603, -0.1699],
          [ 0.0901,  0.4391],
          [-0.4326, -0.0442]]]])

In [119]:
out_fmt_2 = torch.empty((B, n_heads, N, d_k))
for i in range(n_heads):
	out_fmt_2[:, i, :, :] = x@w_Q[:, i, :]
torch.allclose(out_fmt_2, out_fmt)

True

In [125]:
def uniform_parameter(size, a, b, requires_grad=True):
	t = torch.empty(size, requires_grad=requires_grad)
	torch.nn.init.uniform_(t, a=a, b=b)
	return torch.nn.Parameter(t)

class MHLinear(torch.nn.Module):
	def __init__(self, n_heads, T, d_k):
		"""Linear transformation (no Bias) but applied to multiple n_heads hence MHLinaer"""
		super().__init__()
		b = T**-.5
		a = -b
		self.weights = uniform_parameter((T, n_heads, d_k), a, b)

	def __call__(self, X):
		B, N, T = X.shape
		X_unrolled = X.view(-1, T) # (B*N, T)
		W_unrolled = self.weights.view(T, -1) # (T, n_heads*d_k)
		projected = X_unrolled@W_unrolled # (B*N, n_heads*d_k)
		return projected.view((B, N, n_heads, d_k)).transpose(1, 2).contiguous() # (B, n_heads, N, d_k)

class MHA(torch.nn.Module):
	def __init__(self, n_heads, N, T, d_k, d_out):
		super().__init__()
		self.Q, self.K, self.V = MHLinear(n_heads, T, d_k), MHLinear(n_heads, T, d_k), MHLinear(n_heads, T, d_k)
		self.out = torch.nn.Linear(n_heads*d_k, d_out)
		self.mask = torch.triu(torch.ones((N, N), requires_grad=False, dtype=torch.bool), diagonal=1)

	def __call__(self, X):
		B, N, T = X.shape
		mha = attn(Q=self.Q(X), K=self.K(X), V=self.V(X), mask=self.mask) # (B, n_heads, N, d_k)
		mha = mha.transpose(1,2).contiguous().view((B, N, -1)) # *(B, N, n_heads*d_k)
		return self.out(mha) # project to d_out  

torch.manual_seed(0)
m = MHA(n_heads=n_heads, N=N, T=T, d_k=d_k, d_out=T)	
m(x).shape

torch.Size([2, 5, 3])