### Loading names

In [1]:
words = open('names.txt','r').read().splitlines()

In [3]:
len(words)

32033

In [5]:
words[:5]

['emma', 'olivia', 'ava', 'isabella', 'sophia']

### Creating stoi and itos for ys

In [8]:
chars = sorted(list(set(''.join(words))) + ['.'])

In [10]:
len(chars)

27

In [12]:
stoi_ys = {i:e for e,i in enumerate(chars)}

In [14]:
itos_ys = {v:k for k,v in stoi_ys.items()}

In [18]:
stoi_ys

{'.': 0,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26}

### Creating stoi and itos of xs

In [20]:
from itertools import product

In [22]:
stoi_xs = {}
for e,(ch1,ch2) in enumerate(product(chars,repeat=2)):
    ch = ch1+ch2
    stoi_xs[ch] = e 
    

In [24]:
len(stoi_xs)

729

In [28]:
itos_xs = {v:k for k,v in stoi_xs.items()}

In [32]:
stoi_xs

{'..': 0,
 '.a': 1,
 '.b': 2,
 '.c': 3,
 '.d': 4,
 '.e': 5,
 '.f': 6,
 '.g': 7,
 '.h': 8,
 '.i': 9,
 '.j': 10,
 '.k': 11,
 '.l': 12,
 '.m': 13,
 '.n': 14,
 '.o': 15,
 '.p': 16,
 '.q': 17,
 '.r': 18,
 '.s': 19,
 '.t': 20,
 '.u': 21,
 '.v': 22,
 '.w': 23,
 '.x': 24,
 '.y': 25,
 '.z': 26,
 'a.': 27,
 'aa': 28,
 'ab': 29,
 'ac': 30,
 'ad': 31,
 'ae': 32,
 'af': 33,
 'ag': 34,
 'ah': 35,
 'ai': 36,
 'aj': 37,
 'ak': 38,
 'al': 39,
 'am': 40,
 'an': 41,
 'ao': 42,
 'ap': 43,
 'aq': 44,
 'ar': 45,
 'as': 46,
 'at': 47,
 'au': 48,
 'av': 49,
 'aw': 50,
 'ax': 51,
 'ay': 52,
 'az': 53,
 'b.': 54,
 'ba': 55,
 'bb': 56,
 'bc': 57,
 'bd': 58,
 'be': 59,
 'bf': 60,
 'bg': 61,
 'bh': 62,
 'bi': 63,
 'bj': 64,
 'bk': 65,
 'bl': 66,
 'bm': 67,
 'bn': 68,
 'bo': 69,
 'bp': 70,
 'bq': 71,
 'br': 72,
 'bs': 73,
 'bt': 74,
 'bu': 75,
 'bv': 76,
 'bw': 77,
 'bx': 78,
 'by': 79,
 'bz': 80,
 'c.': 81,
 'ca': 82,
 'cb': 83,
 'cc': 84,
 'cd': 85,
 'ce': 86,
 'cf': 87,
 'cg': 88,
 'ch': 89,
 'ci': 90,
 'cj': 91

### Creating training set of trigrams

In [55]:
import torch

In [98]:
xs = []
ys = []
printed=False
for w in words:
    chs = ['.'] + list(w) + ['.']
    for i in range(len(chs)-2):
        ch1 = chs[i] + chs[i+1]
        ch2 = chs[i+2]
            
        ix1 = stoi_xs[ch1]
        ix2 = stoi_ys[ch2]
        if not printed:
            print('first trigram :')
            print(ch1,ch2)
            print(ix1,ix2)
            printed=True
        xs.append(ix1)
        ys.append(ix2)
                               
xs = torch.tensor(xs, dtype=torch.long)
ys = torch.tensor(ys, dtype=torch.long)

first trigram :
.e m
5 13


In [104]:
min(xs),max(xs)

(tensor(1), tensor(728))

In [106]:
min(ys),max(ys)

(tensor(0), tensor(26))

In [108]:
len(xs), len(ys)

(196113, 196113)

### Doing a single trigram pass

In [365]:
import torch.nn.functional as F

In [367]:
xenc = F.one_hot(xs, num_classes=729).float()

In [369]:
xenc.shape

torch.Size([196113, 729])

In [371]:
g = torch.Generator().manual_seed(42)
W = torch.randn((729,27), generator=g)

In [373]:
logits = xenc @ W
count = logits.exp()
probs = count / count.sum(dim=1, keepdims=True)

In [375]:
probs.shape

torch.Size([196113, 27])

In [377]:
xs[0].item()

5

In [379]:
itos_xs[xs[0].item()],itos_ys[ys[0].item()]

('.e', 'm')

In [381]:
probs[0,ys[0].item()]

tensor(0.0292)

### Training the trigram neural net

In [467]:
xenc = F.one_hot(xs, num_classes=729).float()
g = torch.Generator().manual_seed(42)
W = torch.randn((729,27), generator=g, requires_grad=True)

In [471]:
for iterations in range(50):
    logits = xenc @ W
    count = logits.exp()
    probs = count / count.sum(dim=1, keepdims=True)
    loss = -probs[torch.arange(len(xs)), ys].log().mean()
    print(f'Loss : {loss.item():.2f}')

    W.grad = None
    loss.backward()
    W.data += -50 * W.grad

Loss : 3.57
Loss : 3.51
Loss : 3.45
Loss : 3.39
Loss : 3.34
Loss : 3.29
Loss : 3.25
Loss : 3.21
Loss : 3.17
Loss : 3.14
Loss : 3.10
Loss : 3.07
Loss : 3.05
Loss : 3.02
Loss : 3.00
Loss : 2.98
Loss : 2.95
Loss : 2.93
Loss : 2.92
Loss : 2.90
Loss : 2.88
Loss : 2.86
Loss : 2.85
Loss : 2.83
Loss : 2.82
Loss : 2.81
Loss : 2.79
Loss : 2.78
Loss : 2.77
Loss : 2.76
Loss : 2.75
Loss : 2.74
Loss : 2.73
Loss : 2.72
Loss : 2.71
Loss : 2.70
Loss : 2.69
Loss : 2.68
Loss : 2.67
Loss : 2.66
Loss : 2.65
Loss : 2.65
Loss : 2.64
Loss : 2.63
Loss : 2.63
Loss : 2.62
Loss : 2.61
Loss : 2.60
Loss : 2.60
Loss : 2.59


### Predictions

In [473]:
{k:v for k,v in stoi_xs.items() if k[0] == '.' if k[1] != '.'} # valid starting bigrams

{'.a': 1,
 '.b': 2,
 '.c': 3,
 '.d': 4,
 '.e': 5,
 '.f': 6,
 '.g': 7,
 '.h': 8,
 '.i': 9,
 '.j': 10,
 '.k': 11,
 '.l': 12,
 '.m': 13,
 '.n': 14,
 '.o': 15,
 '.p': 16,
 '.q': 17,
 '.r': 18,
 '.s': 19,
 '.t': 20,
 '.u': 21,
 '.v': 22,
 '.w': 23,
 '.x': 24,
 '.y': 25,
 '.z': 26}

In [475]:
torch.tensor(2).shape

torch.Size([])

In [485]:
g = torch.Generator().manual_seed(99)
N = 5
for xs_input in torch.randint(low=1, high=27, size=(N,),generator=g): # get a valid starting bigram
    xs_input = xs_input.item()
    out = [itos_xs[xs_input]]
    while True:
        xenc = F.one_hot(torch.tensor([xs_input]), num_classes=729).float()
        logits = xenc @ W
        count = logits.exp()
        probs = count / count.sum(dim=1, keepdims=True)
        xs_input = torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item()
        # if itos_xs[xs_input][0] == '.':
        #     continue
        out.append(itos_xs[xs_input])
        if itos_xs[xs_input][1] == '.':
            break
    print(''.join(out))

.b.r.a.r.o.c.h.o.y.a.b.r.a.l.u.y.i.k.a.g.i.y.r.a.r.a.p.o..
.x.l..
.d.a.x.j.a.m.a.n.a.r.a.n.u.b.r.o.t.r.o.f.i.y.r.u.q.j.o.u.u.p.d.a.r.a.r.c.a.b.l.m.a.f.d.e.l.o.a.a.l.u.t.a.e.h.c.a.v.t.x..
.y.t.a.a.r.i.s.h.e.l.e.l.u.z.a.v.x.x.x.l.u.s.a.r.o.g.r.o.c.a.l.y.z.u.z.e.a.n.l.a.r.e.l.e.r.e.m.a.d.a.n.o.s.a.z.l.i.m.y.b.r.e.k.g.e.t.x.x.x.c.z.a.j.a.y.i.k.w.p.t.a.d.a.l.e.f.a.f.a.d.e.z.s.a.b.r.a.r.a.z.x.t.l..
.x.x.d.x.e.t.r.a.n.a.i.j.e.r.o.n.o.z.v.w.h.i.k.a.l.e.m.i.i.k.a.n.a.w..
