In [1]:
import numpy as np
import torch
import torch.nn.functional as F

In [2]:
with open("../Makemore/names.txt") as f:
    names = f.readlines()

In [3]:
names = [name.strip() for name in names]

In [4]:
char_set = sorted(list({c for name in names for c in name}))
char_set.insert(0, '.')
len(char_set)

27

In [5]:
stoi = {char: i for i, char in enumerate(char_set)}
itos = {i: char for char, i in stoi.items()}

In [6]:
block_size = 3
X, Y = [], []
for name in names:
    name = '.' * block_size + name + '.'
    for i in range(len(name) - block_size):
        context = name[i: i+block_size]
        X.append([stoi[char] for char in context])
        Y.append(stoi[name[i+block_size]])

X = torch.tensor(X)
Y = torch.tensor(Y)

In [7]:
X.shape, Y.shape

(torch.Size([228146, 3]), torch.Size([228146]))

In [8]:
train_idx, val_idx, test_idx = torch.utils.data.random_split(range(X.shape[0]), [0.8, 0.1, 0.1])
X_train, Y_train = X[train_idx], Y[train_idx]
X_val, Y_val = X[val_idx], Y[val_idx]
X_test, Y_test = X[test_idx], Y[test_idx]
print(f"Number of train samples: {X_train.shape[0]}")
print(f"Number of validation samples: {X_val.shape[0]}")
print(f"Number of test samples: {X_test.shape[0]}")

Number of train samples: 182517
Number of validation samples: 22815
Number of test samples: 22814


In [99]:
gen = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 20), generator=gen)
W1 = torch.randn((60, 100), generator=gen)
B1 = torch.randn(100, generator=gen)
W2 = torch.randn((100, 100), generator=gen)
B2 = torch.randn(100, generator=gen)
W3 = torch.randn((100, 27), generator=gen)
B3 = torch.randn(27, generator=gen)
parameters = [C, W1, B1, W2, B2, W3, B3]

In [100]:
sum(p.nelement() for p in parameters)

19467

In [101]:
for p in parameters:
    p.requires_grad = True

In [102]:
@torch.no_grad()
def overall_train_loss(epoch_num):
    out = torch.tanh(torch.matmul(C[X_train].view(-1, W1.shape[0]), W1) + B1)
    out = torch.tanh(torch.matmul(out, W2) + B2)
    logits = torch.matmul(out, W3) + B3
    train_loss = F.cross_entropy(logits, Y_train)
    print(f"Epoch {epoch_num} \t Loss: {train_loss}")

In [103]:
W1.shape

torch.Size([60, 100])

In [104]:
for i in range(300001):
    ix = torch.randint(0, X_train.shape[0], (64, ))
    emb = C[X_train[ix]]
    out = torch.tanh(torch.matmul(emb.view(-1, W1.shape[0]), W1) + B1)
    out = torch.tanh(torch.matmul(out, W2) + B2)
    logits = torch.matmul(out, W3) + B3
    loss = F.cross_entropy(logits, Y_train[ix])
    if i%5000 == 0:
        overall_train_loss(i)
    loss.backward()
    lr_schedule = {0: 0.1, 150000: 0.05, 225000: 0.01}
    for threshold, item in lr_schedule.items():
        if i > threshold:
            lr = item
    for p in parameters:
        p.data += -lr * p.grad
        p.grad = None

Epoch 0 	 Loss: 20.473392486572266
Epoch 5000 	 Loss: 2.4926857948303223
Epoch 10000 	 Loss: 2.328019380569458
Epoch 15000 	 Loss: 2.266166925430298
Epoch 20000 	 Loss: 2.2342920303344727
Epoch 25000 	 Loss: 2.2100491523742676
Epoch 30000 	 Loss: 2.2153806686401367
Epoch 35000 	 Loss: 2.1834499835968018
Epoch 40000 	 Loss: 2.172321319580078
Epoch 45000 	 Loss: 2.1695892810821533
Epoch 50000 	 Loss: 2.164910316467285
Epoch 55000 	 Loss: 2.1615078449249268
Epoch 60000 	 Loss: 2.154019355773926
Epoch 65000 	 Loss: 2.1530113220214844
Epoch 70000 	 Loss: 2.1521811485290527
Epoch 75000 	 Loss: 2.1407599449157715
Epoch 80000 	 Loss: 2.1357221603393555
Epoch 85000 	 Loss: 2.12994122505188
Epoch 90000 	 Loss: 2.1308140754699707
Epoch 95000 	 Loss: 2.128620147705078
Epoch 100000 	 Loss: 2.126880407333374
Epoch 105000 	 Loss: 2.119631052017212
Epoch 110000 	 Loss: 2.1165781021118164
Epoch 115000 	 Loss: 2.1115503311157227
Epoch 120000 	 Loss: 2.1067399978637695
Epoch 125000 	 Loss: 2.125293016433

In [39]:
with torch.no_grad():
    out = torch.tanh(torch.matmul(C[X_val].view(-1, W1.shape[0]), W1) + B1)
    out = torch.tanh(torch.matmul(out, W2) + B2)
    logits = torch.matmul(out, W3) + B3
    val_loss = F.cross_entropy(logits, Y_val)
    print(f"Val loss: {val_loss}")

Val loss: 5.98593807220459


In [108]:
with torch.no_grad():
    out = torch.tanh(torch.matmul(C[X_test].view(-1, W1.shape[0]), W1) + B1)
    out = torch.tanh(torch.matmul(out, W2) + B2)
    logits = torch.matmul(out, W3) + B3
    test_loss = F.cross_entropy(logits, Y_test)
    print(f"test loss: {test_loss}")

test loss: 2.144402265548706


### E01: I did not get around to seeing what happens when you initialize all weights and biases to zero. Try this and train the neural net. You might think either that 1) the network trains just fine or 2) the network doesn't train at all, but actually it is 3) the network trains but only partially, and achieves a pretty bad final performance. Inspect the gradients and activations to figure out what is happening and why the network is only partially training, and what part is being trained exactly.


In [18]:
c_zero = torch.randn((27, 20))
w1_zero = torch.zeros((60, 300))
b1_zero = torch.zeros(300)
w2_zero = torch.zeros((300, 27))
b2_zero = torch.zeros(27)
params = [c_zero, w1_zero, b1_zero, w2_zero, b2_zero]

In [19]:
for p in params:
    p.requires_grad = True

In [20]:
for i in range(100000):
    ix = torch.randint(0, X_train.shape[0], (64, ))
    emb = c_zero[X_train[ix]]
#     print(f"emb: {emb.shape}")
    out = torch.tanh(torch.matmul(emb.view(-1, w1_zero.shape[0]), w1_zero) + b1_zero)
    logits = torch.matmul(out, w2_zero) + b2_zero
    loss = F.cross_entropy(logits, Y_train[ix])
    if i%5000 == 0:
        with torch.no_grad():
            out = torch.tanh(torch.matmul(c_zero[X_train].view(-1, w1_zero.shape[0]), w1_zero) + b1_zero)
            logits = torch.matmul(out, w2_zero) + b2_zero
            train_loss = F.cross_entropy(logits, Y_train)
            print(f"Epoch {i+1} \t Loss: {train_loss}")
    for p in params:
        p.grad = None
    loss.backward()
    lr_schedule = {0: 0.1, 100000: 0.05, 200000: 0.01, 300000: 0.001}
    for threshold, item in lr_schedule.items():
        if i >= threshold:
            lr = item
    for p in params:
#         print(f"{lr=} {p.grad=}")
        p.data += -lr * p.grad
#         p.grad = None

Epoch 1 	 Loss: 3.295837163925171
Epoch 5001 	 Loss: 2.823450803756714
Epoch 10001 	 Loss: 2.823431968688965
Epoch 15001 	 Loss: 2.823232650756836
Epoch 20001 	 Loss: 2.8230626583099365
Epoch 25001 	 Loss: 2.8233225345611572
Epoch 30001 	 Loss: 2.823059320449829
Epoch 35001 	 Loss: 2.8232638835906982
Epoch 40001 	 Loss: 2.823038339614868
Epoch 45001 	 Loss: 2.823099136352539
Epoch 50001 	 Loss: 2.8232247829437256
Epoch 55001 	 Loss: 2.823425769805908
Epoch 60001 	 Loss: 2.823354721069336
Epoch 65001 	 Loss: 2.8230814933776855
Epoch 70001 	 Loss: 2.823064088821411
Epoch 75001 	 Loss: 2.823176622390747
Epoch 80001 	 Loss: 2.8231592178344727
Epoch 85001 	 Loss: 2.823040723800659
Epoch 90001 	 Loss: 2.822995662689209
Epoch 95001 	 Loss: 2.823046922683716


In [22]:
for param in params:
    print(torch.sum(param.grad))

tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(-1.2340e-07)


<p> It looks like all the gradients are zero except the gradients of the last bias tensor. That is why the optimization was quick and plateaued immediately at 2.82. </p>

### E02: BatchNorm, unlike other normalization layers like LayerNorm/GroupNorm etc. has the big advantage that after training, the batchnorm gamma/beta can be "folded into" the weights of the preceeding Linear layers, effectively erasing the need to forward it at test time. Set up a small 3-layer MLP with batchnorms, train the network, then "fold" the batchnorm gamma/beta into the preceeding Linear layer's W,b by creating a new W2, b2 and erasing the batch norm. Verify that this gives the same forward pass during inference. i.e. we see that the batchnorm is there just for stabilizing the training, and can be thrown out after training is done! pretty cool.


In [84]:
gen = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 15), generator=gen)
W1 = torch.randn((45, 100), generator=gen)
B1 = torch.randn(100, generator=gen)
W2 = torch.randn((100, 100), generator=gen)
B2 = torch.randn(100, generator=gen)
W3 = torch.randn((100, 27), generator=gen)
B3 = torch.randn(27, generator=gen)
parameters = [C, W1, B1, W2, B2, W3, B3]

In [85]:
for p in parameters:
    p.requires_grad = True

In [83]:
batch_norm = torch.nn.BatchNorm1d(100, eps=1e-5, momentum=0.22)
batch_norm2 = torch.nn.BatchNorm1d(100, eps=1e-5, momentum=0.27)

In [89]:
for i in range(100001):
    idx = torch.randint(0, X_train.shape[0], (64,))
    emb = C[X_train[idx]]
    out = torch.matmul(emb.view(-1, 45), W1) + B1
    out = batch_norm(out)
    out = torch.tanh(out)
    out = torch.matmul(out, W2) + B2
    out = batch_norm2(out)
    out = torch.tanh(out)
    out = torch.matmul(out, W3) + B3
    loss = F.cross_entropy(out, Y_train[idx])
    for p in parameters:
        p.grad = None
    loss.backward()
    for p in parameters:
        p.data += -0.01 * p.grad
    if i % 5000 == 0:
        with torch.no_grad():
            out = torch.matmul(C[X_train].view(-1, W1.shape[0]), W1) + B1
            out = batch_norm(out)
            out = torch.tanh(out)
            out = torch.matmul(out, W2) + B2
            out = batch_norm2(out)
            out = torch.tanh(out)
            logits = torch.matmul(out, W3) + B3
            train_loss = F.cross_entropy(logits, Y_train)
            print(f"Loss: {train_loss}")

Loss: 2.1329731941223145
Loss: 2.1265265941619873
Loss: 2.125951051712036
Loss: 2.1251254081726074
Loss: 2.1256470680236816
Loss: 2.1250040531158447
Loss: 2.1252458095550537
Loss: 2.124931812286377
Loss: 2.1244149208068848
Loss: 2.1247334480285645
Loss: 2.124326705932617
Loss: 2.1239967346191406
Loss: 2.1232070922851562
Loss: 2.123250961303711
Loss: 2.1231024265289307
Loss: 2.123734951019287
Loss: 2.1231186389923096
Loss: 2.1229145526885986
Loss: 2.122171401977539
Loss: 2.122885227203369
Loss: 2.122119903564453


In [87]:
with torch.no_grad():
    out = torch.matmul(C[X_val].view(-1, W1.shape[0]), W1) + B1
    out = torch.tanh(batch_norm(out))
    out = torch.matmul(out, W2) + B2
    out = torch.tanh(batch_norm2(out))
    logits = torch.matmul(out, W3) + B3
    val_loss = F.cross_entropy(logits, Y_val)
    print(f"Val loss: {val_loss}")

Val loss: 2.191786766052246
