In [12]:
import torch.nn as nn
import torch

### Batched Self-Attention

In [36]:
B, N, d = 10, 128, 256
d_k, d_v = 300, 320
h = 8

X = torch.rand(size=(B, N, d))
W_Q = torch.ones(size=(h, d, d_k))
W_K = torch.ones(size=(h, d, d_k))
W_V = torch.ones(size=(h, d, d_v))

Q = torch.einsum('bik,hkj->bhij', X, W_Q)
print(Q.size())

K = torch.einsum('bik,hkj->bhij', X, W_K)
print(K.size())

V = torch.einsum('bik,hkj->bhij', X, W_V)
print(V.size())

torch.Size([10, 8, 128, 300])
torch.Size([10, 8, 128, 300])
torch.Size([10, 8, 128, 320])


In [37]:
QKT = torch.einsum('bhik,bhkj->bhij', Q, torch.transpose(K, 2, 3))
print(QKT.size())

torch.Size([10, 8, 128, 128])


In [38]:
sm = nn.Softmax(dim=2)
A = sm(QKT)

print(A[0,0,:,:].sum(dim=1))    # they're indeed probability distributions

tensor([  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0., 128.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.])


In [39]:
V = torch.einsum('bik,hkj->bhij', X, W_V)
print(V.size())

torch.Size([10, 8, 128, 320])


In [40]:
AV = torch.einsum('bhik,bhkj->bhij', A, V)
print(AV.size())

torch.Size([10, 8, 128, 320])


In [41]:
AV_concat = torch.reshape(AV, shape=(B, N, h*d_v))
print(AV_concat.size())

torch.Size([10, 128, 2560])


In [42]:
W_O = torch.ones(size=(h*d_v, d))
SA_out = torch.einsum('bni,id->bnd', AV_concat, W_O)
print("B x N x d =", SA_out.size())

B x N x d = torch.Size([10, 128, 256])


### Batched Layer Norm

In [43]:
print("N =", N, "B =", B, "d =", d)

mu = torch.mean(X, dim=2).unsqueeze(-1)
print("mu", mu.size())

std = torch.std(X, dim=2).unsqueeze(-1)
print("std", std.size())

X_hat = (X - mu) / std 
print("X_hat", X_hat.size())
print(torch.mean(X_hat, dim=2)[0][0])   # sanity checks
print(torch.std(X_hat, dim=2)[0][0])    # sanity checks, note if std is 0 everywhere this can yield nan

N = 128 B = 10 d = 256
mu torch.Size([10, 128, 1])
std torch.Size([10, 128, 1])
X_hat torch.Size([10, 128, 256])
tensor(5.7742e-08)
tensor(1.)


### Batched FFN

In [50]:
import torch.nn.functional as F

d_ff = 512
W1 = torch.rand(size=(d, d_ff))
W2 = torch.rand(size=(d_ff, d))

X1 = F.relu(torch.matmul(X, W1))
print(X1.size())
S2 = torch.matmul(X1, W2)
print(S2.size())

torch.Size([10, 128, 512])
torch.Size([10, 128, 256])


## Todo : 

- Define batch einsum operations on layer norm
- Same on feedforward
- Implement a Transformer block using that.