In [None]:
import torch
import torch.nn as nn

In [None]:
N, B, d = 10, 5, 100

pos = torch.arange(0, N).expand(B, N)
print(pos.size(), pos[0, :][:4])
pos = pos.unsqueeze(-1)                                    
print(pos.size())
i = torch.arange(d)
print(i.size())
angles = pos * (1 / torch.pow(10000, (2 * i) / d))   
print(angles.size())
print(angles[:, :, 0::2].size())   # B x N x d

sin = torch.sin(angles[:, :, 0::2])  
cos = torch.cos(angles[:, :, 1::2])

print(sin.size(), cos.size())
embed = torch.cat([sin, cos], dim=-1)
print(embed.size())

In [None]:
B, h, N = 5, 8, 10
QKT = torch.normal(mean=0, std=1, size=(B, h, N, N))
print(QKT.size())

lower_mask = torch.ones(size=(N , N)).tril()
upper_mask = torch.zeros(size=(B, h, N, N)).masked_fill_(lower_mask.logical_not(), float("-inf"))

masked_QKT = QKT * lower_mask + upper_mask
print(masked_QKT.size())
print(masked_QKT[0][0])

### Batched Self-Attention

In [None]:
B, N, d = 10, 128, 256
d_k, d_v = 300, 320
h = 8

X = torch.rand(size=(B, N, d))
W_Q = torch.ones(size=(h, d, d_k))
W_K = torch.ones(size=(h, d, d_k))
W_V = torch.ones(size=(h, d, d_v))

Q = torch.einsum('bik,hkj->bhij', X, W_Q)
print(Q.size())

K = torch.einsum('bik,hkj->bhij', X, W_K)
print(K.size())

V = torch.einsum('bik,hkj->bhij', X, W_V)
print(V.size())

In [None]:
QKT = torch.einsum('bhik,bhkj->bhij', Q, torch.transpose(K, 2, 3))
print(QKT.size())

In [None]:
sm = nn.Softmax(dim=3)
A = sm(QKT)

print(A[0,0,0,:].sum())    # they're indeed probability distributions


Q x K
Q x K 



In [None]:
V = torch.einsum('bik,hkj->bhij', X, W_V)
print(V.size())

In [None]:
AV = torch.einsum('bhik,bhkj->bhij', A, V)
print(AV.size())

In [None]:
AV_concat = torch.reshape(AV, shape=(B, N, h*d_v))
print(AV_concat.size())

In [None]:
W_O = torch.ones(size=(h*d_v, d))
SA_out = torch.einsum('bni,id->bnd', AV_concat, W_O)
print("B x N x d =", SA_out.size())

### Batched Layer Norm

In [None]:
print("B =", B, "N =", N, "d =", d)

B N D
B N 1

mu = torch.mean(X, dim=2).unsqueeze(-1)
print("mu", mu.size())

std = torch.std(X, dim=2).unsqueeze(-1)
print("std", std.size())

X_hat = (X - mu) / std 
print("X_hat", X_hat.size())
print(torch.mean(X_hat, dim=2)[0][0])   # sanity checks
print(torch.std(X_hat, dim=2)[0][0])    # sanity checks, note if std is 0 everywhere this can yield nan

### Batched FFN

In [None]:
import torch.nn.functional as F

d_ff = 512
W1 = torch.rand(size=(d, d_ff))
W2 = torch.rand(size=(d_ff, d))

X1 = F.relu(torch.matmul(X, W1))
print(X1.size())
S2 = torch.matmul(X1, W2)
print(S2.size())

### Module Tests

In [None]:
from main import load_data
from model import ShakespearModel
from parsing.CharDataSet import CharDataSet
from torch.utils.data.dataloader import DataLoader


N_TOKENS = 128              # N
N_LAYERS = 12               # L
N_HEADS = 8                 # h
N_WORKERS = 2
BATCH_SIZE = 5              # B
D_MODEL = 768               # d
D_K = 64
D_V = D_K
D_FF = 2048
RAW_DATA_PATH = './datasets/shakespear_corpus.txt'

raw_data = load_data(RAW_DATA_PATH)

tokenized_data = CharDataSet(N_TOKENS, raw_data)
data_loader = DataLoader(
    tokenized_data,
    shuffle=False,
    batch_size=BATCH_SIZE,
    num_workers=N_WORKERS,
)

In [None]:
# get first batch of sentences
tokenized_sentence, _ = next(iter(data_loader))  # (128,128) = (B,N)

print(tokenized_sentence.size())
print(tokenized_data.decode(tokenized_sentence[0, :]))

In [None]:
# default value given by teacher assist. We should play with it when it's working
model = ShakespearModel(N_LAYERS, N_HEADS,
                        D_MODEL, D_FF, D_K, D_V, BATCH_SIZE, N_TOKENS, tokenized_data.get_vocab_size())

In [None]:
out = model(tokenized_sentence)

In [None]:
print(out.size())
print(tokenized_data.get_vocab_size())

In [None]:
seed = "O God, O God!"
idx = tokenized_data.encode(seed)
print(idx.size())

In [None]:
new_tokens = model.generate(idx, n_new_tokens=500)

In [None]:
import torch

new_tokens = torch.tensor([ 1., 19., 53., 42.,  6.,  1., 27.,  1., 19., 53., 42.,  2., 33., 55.,
        34., 35., 33., 33., 24., 42.,  1.,  0., 52., 63., 11., 27., 59., 51.,
        28., 49., 35., 38., 62., 56., 33., 48., 60.,  7., 49.,  5., 14., 34.,
        60., 45.,  3.,  4.,  9., 44., 38., 28., 24., 36., 22., 51., 18., 47.,
         0.,  4., 44., 26., 20., 28., 24., 14., 52., 39., 27., 61., 31., 25.,
        34., 56.,  9., 39., 43.,  3., 55., 28., 61., 14., 34., 30.,  9., 43.,
        57.,  5., 64., 55., 62., 12., 40., 30.,  9.,  4., 44., 34., 63., 10.,
        12., 38., 17., 50., 26., 53., 51., 39., 12., 11.,  6., 28., 26., 26.,
        30., 30., 12., 22., 29., 42., 29., 36., 38., 57., 16., 39.,  1.,  2.,
        34., 34., 25., 15., 24., 59., 38., 35., 23., 12., 39., 30., 55., 55.,
        13., 18., 59., 32., 40., 13., 13., 47., 57., 25., 34., 17., 11., 32.,
        17., 15., 34., 26., 34., 58., 38., 27., 21., 47., 30.,  1., 48., 62.,
        29., 27., 15., 47., 47.,  6., 40.,  8., 54., 16., 14.,  0., 63.,  5.,
        41., 54., 31., 54., 22., 35., 27., 33., 59., 38., 10., 55.,  4.,  4.,
        25., 43., 56., 51., 34.,  2.,  6., 31., 55., 30., 14., 51., 64., 28.,
        51., 16., 16., 16., 17., 63.,  9., 56., 54., 51., 51., 33., 24., 47.,
        40., 61.,  0., 38.,  0., 42., 31., 17.,  0., 16., 55., 60.,  9., 22.,
        18., 61., 53., 59., 22.,  7., 52., 13., 41.,  7.,  2., 13., 61., 56.,
        17., 29., 40., 28., 43., 28., 53.,  8., 41.,  2., 61., 32., 15., 13.,
         3.,  0., 55., 29., 27.,  9., 44., 17.,  1., 38., 53., 36., 45.,  8.,
        62., 64., 53., 14., 32., 17., 47., 34.,  5.,  5., 49., 49., 58., 28.,
         9., 13., 40., 26., 33., 26.,  7.,  7.,  7.,  8.,  2., 37., 54., 48.,
        35., 52., 28., 59., 59., 59., 23., 10., 50., 62., 56., 43.,  9.,  5.,
        33., 13., 55.,  8., 39., 20., 18.,  5., 26., 56., 19., 33., 20., 40.,
        62., 47., 48., 45.,  3., 59., 12., 28., 45., 24., 27.,  8., 26., 54.,
        19., 12., 23., 29., 44.,  2., 37., 40., 20., 12.,  4., 55.,  5., 43.,
         1., 43., 19., 57., 61., 47., 62., 36., 56., 47., 63., 50., 37., 34.,
        62., 31.,  8., 15.,  0., 34., 52.,  9., 64., 44., 42., 43., 14., 37.,
        13.,  1., 33., 12.,  9.,  5., 34.,  5.,  0.,  7., 43., 45., 46., 38.,
        25., 18., 35.,  5.,  3., 41., 28., 25., 24., 62., 11., 14., 23., 42.,
        58., 47.,  0., 32., 40.,  7., 12.,  3., 28., 55.,  4.,  4., 16., 59.,
         1.,  3.,  5., 37., 50., 31., 53.,  7., 22., 13., 34., 50., 58.,  3.,
        13., 51., 28.,  9., 64.,  0., 12.,  6., 28., 27., 14., 39., 22., 13.,
        60., 33., 35., 54., 62., 20.,  7.,  0., 39., 40., 31., 34.,  0., 49.,
        33., 53., 11., 12., 50., 42., 13., 28., 58., 45.,  4., 25., 25., 23.,
        51., 33., 50.,  4.,  9., 26., 31., 55.,  0., 33.])

In [None]:
print(tokenized_data.decode(new_tokens))

In [None]:
from train_model import train_model

N_EPOCHS = 2
N_TOKENS = 128  # N
N_LAYERS = 12  # L
N_HEADS = 8  # h
N_WORKERS = 2
BATCH_SIZE = N_TOKENS  # B
D_MODEL = 768  # d
D_K = 64
D_V = D_K
D_FF = 2048
RAW_DATA_PATH = './datasets/shakespear_corpus.txt'

trained_mod = train_model(N_EPOCHS, N_TOKENS, N_LAYERS, N_HEADS, BATCH_SIZE, D_MODEL, D_K, D_V, D_FF)

In [None]:
from train_model import load_data
from parsing.CharDataSet import CharDataSet
from torch.utils.data.dataloader import DataLoader

raw_data = load_data(RAW_DATA_PATH)

tokenized_data = CharDataSet(N_TOKENS, raw_data)
data_loader = DataLoader(
    tokenized_data,
    shuffle=False,
    batch_size=BATCH_SIZE,
    num_workers=N_WORKERS,
)

seed = "O God, O God!"
idx = tokenized_data.encode(seed)

In [None]:
new_tokens = trained_mod.generate(idx, n_new_tokens=50)

In [None]:
print(tokenized_data.decode(new_tokens))

In [None]:
for param in trained_mod.transformer.blocks[0].FFN.L1.parameters():
    print(param)

In [None]:
for name, param in trained_mod.transformer.blocks[0].CausalSelfAttn.named_parameters():
    print(name, param)