In [1]:
import torch
import torch.nn as nn

In [15]:
N, B, d = 10, 5, 100

pos = torch.arange(0, N).expand(B, N)
print(pos.size(), pos[0, :][:4])
pos = pos.unsqueeze(-1)                                    
print(pos.size())
i = torch.arange(d)
print(i.size())
angles = pos * (1 / torch.pow(10000, (2 * i) / d))   
print(angles.size())
print(angles[:, :, 0::2].size())   # B x N x d

sin = torch.sin(angles[:, :, 0::2])  
cos = torch.cos(angles[:, :, 1::2])

print(sin.size(), cos.size())
embed = torch.cat([sin, cos], dim=-1)
print(embed.size())

torch.Size([5, 10]) tensor([0, 1, 2, 3])
torch.Size([5, 10, 1])
torch.Size([100])
torch.Size([5, 10, 100])
torch.Size([5, 10, 50])
torch.Size([5, 10, 50]) torch.Size([5, 10, 50])


In [8]:
B, h, N = 5, 8, 10
QKT = torch.normal(mean=0, std=1, size=(B, h, N, N))
print(QKT.size())

lower_mask = torch.ones(size=(N , N)).tril()
upper_mask = torch.zeros(size=(B, h, N, N)).masked_fill_(lower_mask.logical_not(), float("-inf"))

masked_QKT = QKT * lower_mask + upper_mask
print(masked_QKT.size())
print(masked_QKT[0][0])

torch.Size([5, 8, 10, 10])
torch.Size([5, 8, 10, 10])
tensor([[-2.0072,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 0.5019,  0.1154,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 0.9875,  1.4404, -0.9922,    -inf,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 0.0352,  0.8409,  0.3405,  0.1072,    -inf,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [ 1.2102, -0.9062, -0.7747,  0.0468,  0.4849,    -inf,    -inf,    -inf,
            -inf,    -inf],
        [-0.2564,  1.6343,  0.7520, -0.4366, -0.4817,  1.4208,    -inf,    -inf,
            -inf,    -inf],
        [-0.5564, -0.3971,  1.1115, -0.2522,  1.5255,  0.6649, -0.5406,    -inf,
            -inf,    -inf],
        [ 0.4964, -0.6005,  0.2156, -1.8611,  1.4616,  1.3107,  2.4214,  0.5683,
            -inf,    -inf],
        [-0.8565, -0.6299,  2.0903, -0.3663,  0.7363, -0.9004, -1.2593, -0

### Batched Self-Attention

In [None]:
B, N, d = 10, 128, 256
d_k, d_v = 300, 320
h = 8

X = torch.rand(size=(B, N, d))
W_Q = torch.ones(size=(h, d, d_k))
W_K = torch.ones(size=(h, d, d_k))
W_V = torch.ones(size=(h, d, d_v))

Q = torch.einsum('bik,hkj->bhij', X, W_Q)
print(Q.size())

K = torch.einsum('bik,hkj->bhij', X, W_K)
print(K.size())

V = torch.einsum('bik,hkj->bhij', X, W_V)
print(V.size())

In [None]:
QKT = torch.einsum('bhik,bhkj->bhij', Q, torch.transpose(K, 2, 3))
print(QKT.size())

In [None]:
sm = nn.Softmax(dim=3)
A = sm(QKT)

print(A[0,0,0,:].sum())    # they're indeed probability distributions


Q x K
Q x K 



In [None]:
V = torch.einsum('bik,hkj->bhij', X, W_V)
print(V.size())

In [None]:
AV = torch.einsum('bhik,bhkj->bhij', A, V)
print(AV.size())

In [None]:
AV_concat = torch.reshape(AV, shape=(B, N, h*d_v))
print(AV_concat.size())

In [None]:
W_O = torch.ones(size=(h*d_v, d))
SA_out = torch.einsum('bni,id->bnd', AV_concat, W_O)
print("B x N x d =", SA_out.size())

### Batched Layer Norm

In [None]:
print("B =", B, "N =", N, "d =", d)

B N D
B N 1

mu = torch.mean(X, dim=2).unsqueeze(-1)
print("mu", mu.size())

std = torch.std(X, dim=2).unsqueeze(-1)
print("std", std.size())

X_hat = (X - mu) / std 
print("X_hat", X_hat.size())
print(torch.mean(X_hat, dim=2)[0][0])   # sanity checks
print(torch.std(X_hat, dim=2)[0][0])    # sanity checks, note if std is 0 everywhere this can yield nan

### Batched FFN

In [None]:
import torch.nn.functional as F

d_ff = 512
W1 = torch.rand(size=(d, d_ff))
W2 = torch.rand(size=(d_ff, d))

X1 = F.relu(torch.matmul(X, W1))
print(X1.size())
S2 = torch.matmul(X1, W2)
print(S2.size())

### Module Tests

In [1]:
from main import load_data
from model import ShakespearModel
from parsing.CharDataSet import CharDataSet
from torch.utils.data.dataloader import DataLoader


N_TOKENS = 128              # N
N_LAYERS = 12               # L
N_HEADS = 8                 # h
N_WORKERS = 2
BATCH_SIZE = 5              # B
D_MODEL = 768               # d
D_K = 64
D_V = D_K
D_FF = 2048
RAW_DATA_PATH = './datasets/shakespear_corpus.txt'

raw_data = load_data(RAW_DATA_PATH)

tokenized_data = CharDataSet(N_TOKENS, raw_data)
data_loader = DataLoader(
    tokenized_data,
    shuffle=False,
    batch_size=BATCH_SIZE,
    num_workers=N_WORKERS,
)

In [6]:
# get first batch of sentences
tokenized_sentence, _ = next(iter(data_loader))  # (128,128) = (B,N)

print(tokenized_sentence.size())
print(tokenized_data.decode(tokenized_sentence[0, :]))

torch.Size([5, 128])
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to 


In [7]:
# default value given by teacher assist. We should play with it when it's working
model = ShakespearModel(N_LAYERS, N_HEADS,
                        D_MODEL, D_FF, D_K, D_V, BATCH_SIZE, N_TOKENS, tokenized_data.get_vocab_size())

In [8]:
out = model(tokenized_sentence)

In [9]:
print(out.size())
print(tokenized_data.get_vocab_size())

torch.Size([5, 128, 65])
65


In [10]:
seed = "O God, O God!"
idx = tokenized_data.encode(seed)
print(idx.size())

torch.Size([1, 13])


In [11]:
new_tokens = model.generate(idx, n_new_tokens=500)

torch.Size([5, 128])
torch.Size([5, 128])
torch.Size([640]) 13
torch.Size([5, 128])
torch.Size([640]) 14
torch.Size([5, 128])
torch.Size([640]) 15
torch.Size([5, 128])
torch.Size([640]) 16
torch.Size([5, 128])
torch.Size([640]) 17
torch.Size([5, 128])
torch.Size([640]) 18
torch.Size([5, 128])
torch.Size([640]) 19
torch.Size([5, 128])
torch.Size([640]) 20
torch.Size([5, 128])
torch.Size([640]) 21
torch.Size([5, 128])
torch.Size([640]) 22
torch.Size([5, 128])
torch.Size([640]) 23
torch.Size([5, 128])
torch.Size([640]) 24
torch.Size([5, 128])
torch.Size([640]) 25
torch.Size([5, 128])
torch.Size([640]) 26
torch.Size([5, 128])
torch.Size([640]) 27
torch.Size([5, 128])
torch.Size([640]) 28
torch.Size([5, 128])
torch.Size([640]) 29
torch.Size([5, 128])
torch.Size([640]) 30
torch.Size([5, 128])
torch.Size([640]) 31
torch.Size([5, 128])
torch.Size([640]) 32
torch.Size([5, 128])
torch.Size([640]) 33
torch.Size([5, 128])
torch.Size([640]) 34
torch.Size([5, 128])
torch.Size([640]) 35
torch.Size([5

KeyboardInterrupt: 

In [12]:
import torch

new_tokens = torch.tensor([ 1., 19., 53., 42.,  6.,  1., 27.,  1., 19., 53., 42.,  2., 33., 55.,
        34., 35., 33., 33., 24., 42.,  1.,  0., 52., 63., 11., 27., 59., 51.,
        28., 49., 35., 38., 62., 56., 33., 48., 60.,  7., 49.,  5., 14., 34.,
        60., 45.,  3.,  4.,  9., 44., 38., 28., 24., 36., 22., 51., 18., 47.,
         0.,  4., 44., 26., 20., 28., 24., 14., 52., 39., 27., 61., 31., 25.,
        34., 56.,  9., 39., 43.,  3., 55., 28., 61., 14., 34., 30.,  9., 43.,
        57.,  5., 64., 55., 62., 12., 40., 30.,  9.,  4., 44., 34., 63., 10.,
        12., 38., 17., 50., 26., 53., 51., 39., 12., 11.,  6., 28., 26., 26.,
        30., 30., 12., 22., 29., 42., 29., 36., 38., 57., 16., 39.,  1.,  2.,
        34., 34., 25., 15., 24., 59., 38., 35., 23., 12., 39., 30., 55., 55.,
        13., 18., 59., 32., 40., 13., 13., 47., 57., 25., 34., 17., 11., 32.,
        17., 15., 34., 26., 34., 58., 38., 27., 21., 47., 30.,  1., 48., 62.,
        29., 27., 15., 47., 47.,  6., 40.,  8., 54., 16., 14.,  0., 63.,  5.,
        41., 54., 31., 54., 22., 35., 27., 33., 59., 38., 10., 55.,  4.,  4.,
        25., 43., 56., 51., 34.,  2.,  6., 31., 55., 30., 14., 51., 64., 28.,
        51., 16., 16., 16., 17., 63.,  9., 56., 54., 51., 51., 33., 24., 47.,
        40., 61.,  0., 38.,  0., 42., 31., 17.,  0., 16., 55., 60.,  9., 22.,
        18., 61., 53., 59., 22.,  7., 52., 13., 41.,  7.,  2., 13., 61., 56.,
        17., 29., 40., 28., 43., 28., 53.,  8., 41.,  2., 61., 32., 15., 13.,
         3.,  0., 55., 29., 27.,  9., 44., 17.,  1., 38., 53., 36., 45.,  8.,
        62., 64., 53., 14., 32., 17., 47., 34.,  5.,  5., 49., 49., 58., 28.,
         9., 13., 40., 26., 33., 26.,  7.,  7.,  7.,  8.,  2., 37., 54., 48.,
        35., 52., 28., 59., 59., 59., 23., 10., 50., 62., 56., 43.,  9.,  5.,
        33., 13., 55.,  8., 39., 20., 18.,  5., 26., 56., 19., 33., 20., 40.,
        62., 47., 48., 45.,  3., 59., 12., 28., 45., 24., 27.,  8., 26., 54.,
        19., 12., 23., 29., 44.,  2., 37., 40., 20., 12.,  4., 55.,  5., 43.,
         1., 43., 19., 57., 61., 47., 62., 36., 56., 47., 63., 50., 37., 34.,
        62., 31.,  8., 15.,  0., 34., 52.,  9., 64., 44., 42., 43., 14., 37.,
        13.,  1., 33., 12.,  9.,  5., 34.,  5.,  0.,  7., 43., 45., 46., 38.,
        25., 18., 35.,  5.,  3., 41., 28., 25., 24., 62., 11., 14., 23., 42.,
        58., 47.,  0., 32., 40.,  7., 12.,  3., 28., 55.,  4.,  4., 16., 59.,
         1.,  3.,  5., 37., 50., 31., 53.,  7., 22., 13., 34., 50., 58.,  3.,
        13., 51., 28.,  9., 64.,  0., 12.,  6., 28., 27., 14., 39., 22., 13.,
        60., 33., 35., 54., 62., 20.,  7.,  0., 39., 40., 31., 34.,  0., 49.,
        33., 53., 11., 12., 50., 42., 13., 28., 58., 45.,  4., 25., 25., 23.,
        51., 33., 50.,  4.,  9., 26., 31., 55.,  0., 33.])

In [13]:
print(tokenized_data.decode(new_tokens))

 God, O God!UqVWUULd 
ny;OumPkWZxrUjv-k'BVvg$&3fZPLXJmFi
&fNHPLBnaOwSMVr3ae$qPwBVR3es'zqx?bR3&fVy:?ZElNoma?;,PNNRR?JQdQXZsDa !VVMCLuZWK?aRqqAFuTbAAisMVE;TECVNVtZOIiR jxQOCii,b.pDB
y'cpSpJWOUuZ:q&&MermV!,SqRBmzPmDDDEy3rpmmULibw
Z
dSE
Dqv3JFwouJ-nAc-!AwrEQbPePo.c!wTCA$
qQO3fE ZoXg.xzoBTEiV''kktP3AbNUN---.!YpjWnPuuuK:lxre3'UAq.aHF'NrGUHbxijg$u?PgLO.NpG?KQf!YbH?&q'e eGswixXriylYVxS.C
Vn3zfdeBYA U?3'V'
-eghZMFW'$cPMLx;BKdti
Tb-?$Pq&&Du $'YlSo-JAVlt$AmP3z
?,POBaJAvUWpxH-
abSV
kUo;?ldAPtg&MMKmUl&3NSq
U


In [2]:
from train_model import train_model

trained_mod = train_model()

batch 0 - Loss: 4.183255195617676
batch 1 - Loss: 4.171545505523682
batch 2 - Loss: 4.198949813842773
batch 3 - Loss: 4.199731349945068
batch 4 - Loss: 4.200317859649658
batch 5 - Loss: 4.2002973556518555
batch 6 - Loss: 4.20034122467041
batch 7 - Loss: 4.2002153396606445
batch 8 - Loss: 4.199623107910156
batch 9 - Loss: 4.198882579803467
batch 10 - Loss: 4.199898719787598
batch 11 - Loss: 4.200346946716309
batch 12 - Loss: 4.199868679046631
batch 13 - Loss: 4.19860315322876
batch 14 - Loss: 4.196768283843994
batch 15 - Loss: 4.196898460388184
batch 16 - Loss: 4.185666084289551
batch 17 - Loss: 4.184398174285889
batch 18 - Loss: 4.17842960357666
batch 19 - Loss: 4.1671223640441895
batch 20 - Loss: 4.168938636779785
batch 21 - Loss: 4.181916236877441
batch 22 - Loss: 4.1858625411987305
batch 23 - Loss: 4.171082496643066
batch 24 - Loss: 4.176156044006348
batch 25 - Loss: 4.173542022705078
batch 26 - Loss: 4.190957069396973
batch 27 - Loss: 4.197818279266357
batch 28 - Loss: 4.1960515975

In [3]:
seed = "O God, O God!"
idx = tokenized_data.encode(seed)

In [5]:
new_tokens = trained_mod.generate(idx, n_new_tokens=500)

torch.Size([20, 64])
torch.Size([20, 64])
torch.Size([1280]) 13
torch.Size([20, 64])
torch.Size([1280]) 14
torch.Size([20, 64])
torch.Size([1280]) 15
torch.Size([20, 64])
torch.Size([1280]) 16
torch.Size([20, 64])
torch.Size([1280]) 17
torch.Size([20, 64])
torch.Size([1280]) 18
torch.Size([20, 64])
torch.Size([1280]) 19
torch.Size([20, 64])
torch.Size([1280]) 20
torch.Size([20, 64])
torch.Size([1280]) 21
torch.Size([20, 64])
torch.Size([1280]) 22
torch.Size([20, 64])
torch.Size([1280]) 23
torch.Size([20, 64])
torch.Size([1280]) 24
torch.Size([20, 64])
torch.Size([1280]) 25
torch.Size([20, 64])
torch.Size([1280]) 26
torch.Size([20, 64])
torch.Size([1280]) 27
torch.Size([20, 64])
torch.Size([1280]) 28
torch.Size([20, 64])
torch.Size([1280]) 29
torch.Size([20, 64])
torch.Size([1280]) 30
torch.Size([20, 64])
torch.Size([1280]) 31
torch.Size([20, 64])
torch.Size([1280]) 32
torch.Size([20, 64])
torch.Size([1280]) 33
torch.Size([20, 64])
torch.Size([1280]) 34
torch.Size([20, 64])
torch.Size([

In [6]:
print(tokenized_data.decode(new_tokens))

 God, O God!gggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg


In [12]:
for param in trained_mod.transformer.blocks[0].FFN.L1.parameters():
    print(param)

Parameter containing:
tensor([[-4.9045e-44, -1.1771e-43,  7.0065e-45,  ..., -1.7236e-43,
          6.0256e-44, -3.0829e-44],
        [-3.4722e-34, -1.9619e-35, -8.0870e-35,  ..., -2.4657e-35,
         -1.2622e-34, -2.8339e-35],
        [ 1.9707e-37,  1.8528e-37, -3.7294e-36,  ..., -2.0965e-37,
          4.7421e-37, -3.3137e-36],
        ...,
        [ 1.1345e-38,  1.3107e-37, -2.0920e-36,  ...,  8.0493e-37,
          8.9756e-37,  6.1864e-37],
        [ 8.4078e-45,  4.3440e-44, -3.5032e-44,  ...,  5.4651e-44,
          4.7644e-44, -3.3631e-44],
        [ 1.6867e-34,  1.4720e-34, -9.9768e-35,  ...,  2.9880e-34,
         -9.0866e-35,  4.1784e-34]], requires_grad=True)
Parameter containing:
tensor([-3.6177e-30, -1.5254e-11,  9.7871e-10,  ..., -6.6818e-20,
        -3.7555e-43,  2.8368e-16], requires_grad=True)


In [15]:
for name, param in trained_mod.transformer.blocks[0].CausalSelfAttn.named_parameters():
    print(name, param)

to_Q.weight Parameter containing:
tensor([[-9.8091e-45, -1.1210e-44,  4.2039e-44,  ..., -1.2612e-44,
          0.0000e+00, -5.6052e-45],
        [-7.5670e-44, -3.2230e-44, -1.9618e-44,  ...,  0.0000e+00,
          1.4013e-44, -7.0065e-45],
        [-7.0065e-45, -2.9427e-44, -4.2039e-45,  ...,  2.2421e-44,
         -5.6052e-45, -1.5414e-44],
        ...,
        [-6.3058e-44,  1.8217e-44, -6.0256e-44,  ...,  1.4013e-45,
          2.8026e-45,  3.5032e-44],
        [-5.6052e-44, -3.0829e-44, -1.6816e-44,  ...,  4.3440e-44,
          4.0638e-44, -1.1210e-44],
        [-1.4013e-45,  5.6052e-45,  1.1210e-44,  ..., -5.0447e-44,
         -1.2612e-44,  4.3440e-44]], requires_grad=True)
to_Q.bias Parameter containing:
tensor([ 4.2319e-43, -1.8217e-44,  4.7644e-44,  4.4842e-44,  2.5966e-42,
        -1.9352e-42,  3.9236e-44,  3.7835e-44,  2.6625e-44,  3.3911e-43,
         5.1083e-10, -1.0370e-43,  2.3402e-43,  3.7555e-43,  2.9427e-44,
         1.8217e-44,  9.5288e-44, -9.5288e-44, -6.3213e-42, -8.