In [133]:
# convert stoi and itos
NUM_CLASSES = 27

stoi = {}
itos = [None] * NUM_CLASSES

for i, s in enumerate('abcdefghijklmnopqrstuvwxyz'):
    itos[i+1] = s
    stoi[s] = i+1

itos[0] = '.'
stoi['.'] = 0

In [134]:
# read the words
words = open('../data/names.txt', 'r').read().splitlines()

In [135]:
len(words)

32033

In [136]:
import torch
import torch.nn.functional as F
# Check if MPS is available
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS device")
else:
    device = torch.device("cpu")
    print("MPS device not found, using CPU")

Using MPS device


In [137]:
# split into the training examples
xs, ys = [], []

for word in words:
    chs =  ['.', '.'] + list(word) + ['.', '.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]

        combined_idx = ix1 * NUM_CLASSES + ix2
        xs.append(combined_idx)
        ys.append(ix3)
    
xs = torch.tensor(xs, device=device)
ys = torch.tensor(ys, device=device)
n = len(ys)
n

260179

In [138]:
# initialise our weights and generator
g = torch.Generator(device=device).manual_seed(1)
SIZE = NUM_CLASSES**2
W = torch.randn(SIZE, NUM_CLASSES, generator=g, requires_grad=True, device=device)

In [139]:
# train our model on the training examples
for k in range(100):
    # forward pass
    xenc = F.one_hot(xs, num_classes=SIZE).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)

    loss = -probs[torch.arange(n), ys].log().mean() + 0.01 * (W**2).mean()

    W.grad = None
    loss.backward()
    print('loss', loss.item())

    W.data += -50 * W.grad


loss 3.9376444816589355
loss 3.7200145721435547
loss 3.5653035640716553
loss 3.4320993423461914
loss 3.325392484664917
loss 3.24168062210083
loss 3.172454595565796
loss 3.11285400390625
loss 3.0606515407562256
loss 3.0143988132476807
loss 2.9730350971221924
loss 2.9357802867889404
loss 2.902039051055908
loss 2.871333360671997
loss 2.8432631492614746
loss 2.817483901977539
loss 2.793699026107788
loss 2.771648645401001
loss 2.7511119842529297
loss 2.731902599334717
loss 2.7138636112213135
loss 2.6968650817871094
loss 2.680799722671509
loss 2.6655774116516113
loss 2.651122808456421
loss 2.637371063232422
loss 2.624267339706421
loss 2.611762285232544
loss 2.599813222885132
loss 2.588380813598633
loss 2.5774307250976562
loss 2.566930055618286
loss 2.556849479675293
loss 2.5471620559692383
loss 2.537842273712158
loss 2.5288665294647217
loss 2.5202136039733887
loss 2.5118627548217773
loss 2.503796339035034
loss 2.495997190475464
loss 2.4884493350982666
loss 2.48113751411438
loss 2.47404980659

In [165]:
# prediction on inputs
for i in range(5):
    out = []
    # start off with the special character
    ix1 = 0
    ix2 = 0 
    while True:
        # continuously sample from our neural network
        combined_idx = NUM_CLASSES * ix1 + ix2

        xenc = F.one_hot(torch.tensor([combined_idx], device=device), num_classes=SIZE).float()
        logits = xenc @ W
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdims=True)

        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        if ix == 0:
            break

        out.append(itos[ix])
        ix1, ix2 = ix2, ix
    
    print("".join(out))


ren
ymarvomcqnhf
mellvqdan
joriana
aly
