In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline
from time import time

### Building dataset

In [3]:
words = open('../2_makemore_bigrams/names.txt','r').read().splitlines()

In [4]:
len(words)

32033

In [5]:
words[:5]

['emma', 'olivia', 'ava', 'isabella', 'sophia']

In [7]:
chars = sorted(list(set(''.join(words))))
stoi = {i:enum+1 for enum, i in enumerate(chars)}
stoi['.'] = 0
itos = {v:k for k,v in stoi.items()}

In [12]:
# build the dataset
block_size = 3

def build_dataset(words):
    X=[]
    Y=[]
    for w in words:
        context = [0] * block_size
        for ch in w+'.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X,Y

In [14]:
import random
random.seed(42)
random.shuffle(words)

In [16]:
n1 = int(len(words)*0.8)
n2 = int(len(words)*0.9)

In [18]:
Xtr,Ytr = build_dataset(words[:n1])
Xdev,Ydev = build_dataset(words[n1:n2])
Xte,Yte = build_dataset(words[n2:])

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


In [20]:
device = torch.device('mps')
Xtr = Xtr.to(device)
Ytr = Ytr.to(device)
Xdev = Xdev.to(device)
Ydev = Ydev.to(device)
Xte = Xte.to(device)
Yte = Yte.to(device)

In [28]:
torch.tensor(10).log10()

tensor(1.)

In [47]:
### Loss if all probs are equally likely
probs = torch.tensor(1/len(stoi))
uniform_loss = -probs.log()
uniform_loss

tensor(3.2958)

### Initializin nn so that we get the same loss

In [60]:
n_emb = 100
n_hidden = 100
g = torch.Generator(device=device).manual_seed(99)
C = torch.ones((27, n_emb), device=device) # ones initialization
W1 = torch.ones((n_emb*3, n_hidden), device=device)
b1 = torch.ones(n_hidden, device=device)
W2 = torch.ones((n_hidden,27),device=device)
b2 = torch.ones(27, device=device)
parameters = [C, W1,W2,b1,b2]
print(f"Total parameters : {sum(p.nelement() for p in parameters)}")
minibatch = torch.randint(0,Xtr.shape[0],(32,))
emb = C[Xtr[minibatch]].view(-1, n_emb*3)
l1 = torch.tanh(emb @ W1 + b1)
logits = l1 @ W2 + b2
loss_train = F.cross_entropy(logits, Ytr[minibatch])
print(f"Loss at first pass : {round(loss_train.item(),3)}")

Total parameters : 35527
Loss at first pass : 3.296
