In [240]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torch.nn as nn

In [241]:
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [242]:
len(words)

32033

In [243]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [244]:
context_window = 3
X, Y = [], []

for w in words:
    context = [0] * context_window
#     print(w)
    for i in w + '.':
        ix = stoi[i]
        X.append(context)
        Y.append(ix)
#         print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]
        
X = torch.tensor(X)      # This is the input part which has all the trigrams i.e. 196113 in a matrix of 196113 x 3
Y = torch.tensor(Y)      # This contains the right index for the predicting char

In [245]:
X.size(), Y.size()

(torch.Size([228146, 3]), torch.Size([228146]))

In [246]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([228146, 3]), torch.int64, torch.Size([228146]), torch.int64)

In [247]:
# <----- DATASET LAOADED ----->
# Time to start with with embeddings and then squeeze the dimension
# Embeddings ---> Representation of inputs, just like GPT uses tokenization

In [248]:
# First embedding then one hot encoding

In [249]:
vocab_size = len(stoi)  # Size of the vocabulary
embedding_dim = 30  # Desired dimensionality of embeddings

embedding_input = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

embedding_output = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

In [250]:
embedded_X = embedding_input(X)
embedded_Y = embedding_output(Y) 

In [251]:
embedded_X.size(), embedded_Y.size(), embedded_X.dtype, embedded_Y.dtype
# 196113 - samples, 3 - context-window, 30 - each word represented by vector length 30

(torch.Size([228146, 3, 30]),
 torch.Size([228146, 30]),
 torch.float32,
 torch.float32)

In [252]:
# g = torch.Generator().manual_seed(2147483647) # for reproducibility
# C = torch.randn((27, 10), generator=g)

W1 = torch.randn((30, 200), generator=g)
b1 = torch.randn(200, generator=g)
W2 = torch.randn((200, 27), generator=g)
b2 = torch.randn(27, generator=g)

parameters = [embedded_X, W1, b1, W2, b2]

In [253]:
sum(p.nelement() for p in parameters) # number of parameters in total

20544767

In [254]:
h = torch.tanh(embedded_X @ W1 + b1)

In [255]:
h

tensor([[[-0.9918, -1.0000,  1.0000,  ...,  1.0000,  0.9965,  0.9134],
         [-0.9918, -1.0000,  1.0000,  ...,  1.0000,  0.9965,  0.9134],
         [-0.9918, -1.0000,  1.0000,  ...,  1.0000,  0.9965,  0.9134]],

        [[-0.9918, -1.0000,  1.0000,  ...,  1.0000,  0.9965,  0.9134],
         [-0.9918, -1.0000,  1.0000,  ...,  1.0000,  0.9965,  0.9134],
         [-0.5329, -0.9998,  1.0000,  ...,  1.0000, -0.7199, -0.9965]],

        [[-0.9918, -1.0000,  1.0000,  ...,  1.0000,  0.9965,  0.9134],
         [-0.5329, -0.9998,  1.0000,  ...,  1.0000, -0.7199, -0.9965],
         [ 0.9936,  1.0000,  0.9987,  ...,  0.8696,  0.9998, -0.9799]],

        ...,

        [[ 0.9859, -0.9902,  0.9998,  ...,  1.0000,  0.7161, -0.9960],
         [ 0.9859, -0.9902,  0.9998,  ...,  1.0000,  0.7161, -0.9960],
         [ 0.9936, -1.0000,  0.9712,  ...,  0.8225,  1.0000,  0.9940]],

        [[ 0.9859, -0.9902,  0.9998,  ...,  1.0000,  0.7161, -0.9960],
         [ 0.9936, -1.0000,  0.9712,  ...,  0.8225,  1.

In [256]:
logits = (h @ W2 + b2)

In [257]:
counts = logits.exp()

In [258]:
probs = counts/counts.sum(1, keepdims=True)

In [259]:
probs.shape

torch.Size([228146, 3, 27])

In [260]:
# < ----------------------------- >

In [261]:
C = torch.randn(27,2)

In [262]:
C[X]

tensor([[[ 0.4501,  0.3802],
         [ 0.4501,  0.3802],
         [ 0.4501,  0.3802]],

        [[ 0.4501,  0.3802],
         [ 0.4501,  0.3802],
         [ 0.9664,  0.7379]],

        [[ 0.4501,  0.3802],
         [ 0.9664,  0.7379],
         [ 1.8702,  0.0416]],

        ...,

        [[ 0.5569, -0.1607],
         [ 0.5569, -0.1607],
         [ 1.8627, -0.6295]],

        [[ 0.5569, -0.1607],
         [ 1.8627, -0.6295],
         [ 0.5569, -0.1607]],

        [[ 1.8627, -0.6295],
         [ 0.5569, -0.1607],
         [ 0.8294,  1.3327]]])

In [263]:
C[X].shape 
# (27, 2) (32, 3) ---> (32, 3, 2)

torch.Size([228146, 3, 2])

In [264]:
emb = C[X]
emb.shape
# 2 dimensional embeddings times 3 
# 32 tells about the input

torch.Size([228146, 3, 2])

In [265]:
w1 = torch.randn(6,100) # 3 x 2 and 100 weights
b1 = torch.randn(100)

In [266]:
h = torch.tanh(emb.view(-1, 6) @ w1 + b1)

In [267]:
h.shape

torch.Size([228146, 100])

In [268]:
w2 = torch.randn(100, 27) 
b2 = torch.randn(27)

In [269]:
logits = h @ w2 + b2
logits

tensor([[  4.5976,  -2.8673,  -8.9924,  ...,   5.3669,  -2.7572,  -6.9778],
        [  3.5194,  -0.4457,  -6.0251,  ...,   8.2231,   2.5794,  -3.8818],
        [  2.8355,  -5.6445,  -3.0304,  ...,   6.3688,   5.2462,  -4.9900],
        ...,
        [ -0.6152,  -2.2953,   3.9285,  ...,   1.0909,   6.1042,  -2.7090],
        [  8.3908,  -3.9000, -12.1616,  ...,   0.4957,  -5.3705,   4.3236],
        [  5.3551,   2.3366,  -4.3220,  ...,   4.4389,   0.8210,  -2.2247]])

In [270]:
counts = logits.exp()

In [278]:
prob = counts / counts.sum(1, keepdims=True)

In [279]:
prob.shape

torch.Size([228146, 27])

In [280]:
Y

tensor([ 5, 13, 13,  ..., 26, 24,  0])

In [282]:
loss = -prob[torch.arange(32), Y].log().mean()

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [32], [228146]

In [None]:
# --------------------------------------------------------------------------------------------------------------- #

In [283]:
X.shape, Y.shape

(torch.Size([228146, 3]), torch.Size([228146]))

In [284]:
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = [C, W1, b1, W2, b2]

In [285]:
sum(p.nelement() for p in parameters) # total parameters

3481

In [286]:
lre = torch.linspace(-3, 0, 1000)
lrs = 10**lre

In [287]:
for p in parameters:
    p.requires_grad = True

In [288]:
lri = []
losses = []

In [289]:
for i in range(100):
    # Forward pass
    emb = C[X]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
    logits = h @ W2 + b2
    # One liner to calc prob
    loss = F.cross_entropy(logits, Y)
    print(loss.item())
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    # update
    lr = lrs[i]
    for p in parameters:
        p.data += -lr * p.grad

#     # track stats
#     lri.append(lr[i])
# #     stepi.append(i)
#     losses.append(loss.log10().item())


19.505229949951172
19.467477798461914
19.429882049560547
19.392467498779297
19.355257034301758
19.318273544311523
19.281545639038086
19.245088577270508
19.208911895751953
19.17304229736328
19.137487411499023
19.10225486755371
19.067346572875977
19.032777786254883
18.99853515625
18.96462631225586
18.93104362487793
18.897775650024414
18.864818572998047
18.832162857055664
18.799793243408203
18.767704010009766
18.735877990722656
18.704299926757812
18.672956466674805
18.641841888427734
18.610933303833008
18.58022689819336
18.54970359802246
18.51935386657715
18.48916244506836
18.45912742614746
18.429231643676758
18.399465560913086
18.369821548461914
18.340293884277344
18.31087303161621
18.281551361083984
18.252321243286133
18.223176956176758
18.194114685058594
18.165128707885742
18.136211395263672
18.107364654541016
18.078582763671875
18.04985809326172
18.021190643310547
17.992578506469727
17.96401596069336
17.935504913330078
17.907041549682617
17.878623962402344
17.850248336791992
17.821920

In [294]:
# sample from the model
g = torch.Generator().manual_seed(2147483647 + 10)
block_size = 3
for _ in range(5):
    
    out = []
    context = [0] * block_size # initialize with all ...
    while True:
      emb = C[torch.tensor([context])] # (1,block_size,d)
      h = torch.tanh(emb.view(1, -1) @ W1 + b1)
      logits = h @ W2 + b2
      probs = F.softmax(logits, dim=1)
      ix = torch.multinomial(probs, num_samples=1, generator=g).item()
      context = context[1:] + [ix]
      out.append(ix)
      if ix == 0:
        break
        
    print(''.join(itos[i] for i in out))

njjnji.
ijjjjjjiibjiinjwtvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvpyjkqjvjjnjinjwvgywlywrfwtvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv

ghjrjwtlywvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvpyjkyovvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv