In [1]:
import torch
import torch.nn as nn

### Batched Self-Attention

In [55]:
B, N, d = 10, 128, 256
d_k, d_v = 300, 320
h = 8

X = torch.rand(size=(B, N, d))
W_Q = torch.ones(size=(h, d, d_k))
W_K = torch.ones(size=(h, d, d_k))
W_V = torch.ones(size=(h, d, d_v))

Q = torch.einsum('bik,hkj->bhij', X, W_Q)
print(Q.size())

K = torch.einsum('bik,hkj->bhij', X, W_K)
print(K.size())

V = torch.einsum('bik,hkj->bhij', X, W_V)
print(V.size())

torch.Size([10, 8, 128, 300])
torch.Size([10, 8, 128, 300])
torch.Size([10, 8, 128, 320])


In [56]:
QKT = torch.einsum('bhik,bhkj->bhij', Q, torch.transpose(K, 2, 3))
print(QKT.size())

torch.Size([10, 8, 128, 128])


In [62]:
sm = nn.Softmax(dim=3)
A = sm(QKT)

print(A[0,0,0,:].sum())    # they're indeed probability distributions


Q x K
Q x K 



tensor(1.)


In [39]:
V = torch.einsum('bik,hkj->bhij', X, W_V)
print(V.size())

torch.Size([10, 8, 128, 320])


In [40]:
AV = torch.einsum('bhik,bhkj->bhij', A, V)
print(AV.size())

torch.Size([10, 8, 128, 320])


In [41]:
AV_concat = torch.reshape(AV, shape=(B, N, h*d_v))
print(AV_concat.size())

torch.Size([10, 128, 2560])


In [42]:
W_O = torch.ones(size=(h*d_v, d))
SA_out = torch.einsum('bni,id->bnd', AV_concat, W_O)
print("B x N x d =", SA_out.size())

B x N x d = torch.Size([10, 128, 256])


### Batched Layer Norm

In [43]:
print("B =", B, "N =", N, "d =", d)

B N D
B N 1

mu = torch.mean(X, dim=2).unsqueeze(-1)
print("mu", mu.size())

std = torch.std(X, dim=2).unsqueeze(-1)
print("std", std.size())

X_hat = (X - mu) / std 
print("X_hat", X_hat.size())
print(torch.mean(X_hat, dim=2)[0][0])   # sanity checks
print(torch.std(X_hat, dim=2)[0][0])    # sanity checks, note if std is 0 everywhere this can yield nan

N = 128 B = 10 d = 256
mu torch.Size([10, 128, 1])
std torch.Size([10, 128, 1])
X_hat torch.Size([10, 128, 256])
tensor(5.7742e-08)
tensor(1.)


### Batched FFN

In [50]:
import torch.nn.functional as F

d_ff = 512
W1 = torch.rand(size=(d, d_ff))
W2 = torch.rand(size=(d_ff, d))

X1 = F.relu(torch.matmul(X, W1))
print(X1.size())
S2 = torch.matmul(X1, W2)
print(S2.size())

torch.Size([10, 128, 512])
torch.Size([10, 128, 256])


### Module Tests

In [9]:
from main import load_data
from model import ShakespearModel
from parsing.CharDataSet import CharDataSet
from torch.utils.data.dataloader import DataLoader


N_TOKENS = 128              # N
N_LAYERS = 12               # L
N_HEADS = 8                 # h
N_WORKERS = 2
BATCH_SIZE = 20             # B
D_MODEL = 768               # d
D_K = 64
D_V = D_K
D_FF = 2048
RAW_DATA_PATH = './datasets/shakespear_corpus.txt'

raw_data = load_data(RAW_DATA_PATH)

tokenized_data = CharDataSet(N_TOKENS, raw_data)
data_loader = DataLoader(
    tokenized_data,
    shuffle=False,
    batch_size=BATCH_SIZE,
    num_workers=N_WORKERS,
)

In [10]:
# get first batch of sentences
tokenized_sentence, _ = next(iter(data_loader))  # (128,128) = (B,N)

print(tokenized_sentence.size())
print(tokenized_data.decode(tokenized_sentence[0, :]))

torch.Size([20, 128])
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to 


In [13]:
# default value given by teacher assist. We should play with it when it's working
model = ShakespearModel(N_LAYERS, N_HEADS,
                        D_MODEL, D_FF, D_K, D_V, BATCH_SIZE, N_TOKENS, tokenized_data.get_vocab_size())

In [14]:
out = model(tokenized_sentence)

In [16]:
print(out.size())
print(tokenized_data.get_vocab_size())

torch.Size([20, 128, 65])
65
