In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [4]:
words = open('./names.txt', 'r').read().splitlines()
print(len(words))

32033


In [5]:
# model hyperparameters
embedding_size = 2
size_of_tanh = 100

In [6]:
chars = sorted(list(set(''.join(words))))
stoi = {ch: i+1 for i, ch in enumerate(chars)}
stoi['.'] = 0
itos = {i: ch for ch, i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [7]:
block_size = 3 # given 3 chars, predict the 4th

X, Y = [], []

for word in words:
    context = [0] * block_size
    for ch in word + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        context = context[1:] + [ix]

X = torch.tensor(X)
Y = torch.tensor(Y)

print(X.shape, Y.shape)


torch.Size([228146, 3]) torch.Size([228146])


In [8]:
for x in X[0]:
    print(itos[x.item()], end='')
print()
print (Y[0])

...
tensor(5)


In [9]:
C = torch.randn(len(chars)+1, embedding_size)
print(C.shape)

torch.Size([27, 2])


In [10]:
emb = C[X]
emb.shape

torch.Size([228146, 3, 2])

In [11]:
w1 = torch.randn((emb.shape[1]*emb.shape[2], size_of_tanh))
b1 = torch.randn(size_of_tanh)

In [12]:
h = torch.tanh(emb.view(emb.shape[0], emb.shape[1]*emb.shape[2]) @ w1 + b1)

In [13]:
h.shape

torch.Size([228146, 100])

In [14]:
w2 = torch.randn((size_of_tanh, len(chars) + 1))
b2 = torch.randn(len(chars) + 1)

In [15]:
logits = h@w2 + b2

In [16]:
logits.shape

torch.Size([228146, 27])

In [17]:
loss = F.cross_entropy(logits, Y)

In [18]:
print(C.shape)
print(w1.shape)
print(b1.shape)
print(w2.shape)
print(b2.shape)

torch.Size([27, 2])
torch.Size([6, 100])
torch.Size([100])
torch.Size([100, 27])
torch.Size([27])


In [19]:
paramaters = [C, w1, b1, w2, b2]

In [20]:
sum(p.nelement() for p in paramaters)

3481

In [21]:
lr = 10**-1

for p in paramaters:
    p.requires_grad = True

In [22]:
for i in range(101000):
    ix = torch.randint(0, X.shape[0], (32, ))
    emb = C[X[ix]]
    h = torch.tanh(emb.view(emb.shape[0], emb.shape[1]*emb.shape[2]) @ w1 + b1)
    logits = h@w2 + b2
    loss = F.cross_entropy(logits, Y[ix])

    print(loss.item())

    for p in paramaters:
        p.grad = None

    loss.backward()

    if (i == 100000):
        lr = 10**-2

    for p in paramaters:
        p.data += -lr*p.grad

16.894054412841797
14.731669425964355
17.34026527404785
13.779404640197754
17.4647159576416
14.89730453491211
12.566324234008789
12.740727424621582
13.510965347290039
11.62369155883789
10.003072738647461
13.95674991607666
8.695720672607422
14.909249305725098
9.326931953430176
7.638533592224121
7.693171977996826
9.082111358642578
8.119017601013184
5.947079181671143
9.478679656982422
6.255782604217529
9.005656242370605
5.887474536895752
8.168242454528809
6.190257549285889
5.770103454589844
5.510972023010254
7.601250171661377
6.949916839599609
7.144799709320068
6.566713809967041
7.14988899230957
9.215188026428223
5.487864971160889
7.1377387046813965
4.9913177490234375
4.892759799957275
5.390848636627197
5.442654132843018
4.732006549835205
6.629144668579102
6.1087141036987305
4.398899555206299
4.616567134857178
4.447115421295166
5.621018886566162
5.0348405838012695
4.805856227874756
3.9587204456329346
5.4046406745910645
4.463801383972168
4.963608741760254
4.686829090118408
5.87789487838745

In [23]:
for _ in range(100):
# print a test name
    context = [0] * block_size
    name = ''
    while True:
        emb = C[torch.tensor(context).unsqueeze(0)]
        h = torch.tanh(emb.view(1, emb.shape[1]*emb.shape[2]) @ w1 + b1)
        logits = h@w2 + b2
        p = F.softmax(logits, dim=1).squeeze(0)
        p = p.detach().numpy()
        ix = np.random.choice(len(chars)+1, p=p)
        if ix == 0:
            break
        name += itos[ix]
        context = context[1:] + [ix]

    print(name)

xood
jannile
ilahehille
xha
amistin
saya
bris
kei
creycariya
vison
marisssynie
kan
emahnah
smyn
samliwi
raya
cal
keith
deogh
alynn
jori
siyakuya
naakrera
corior
hed
bera
mirdsy
cyef
abryn
aab
krey
lae
dan
paesighriel
tes
lac
raidryan
danah
aul
zazith
ean
dolany
aleye
mogbynn
biorettan
milehnel
then
athelma
rimi
remion
katpin
duz
milezit
jez
asa
pry
kor
nelliah
lekelone
mavylay
kilen
kawionalyn
ema
aabeli
zaylonathi
hccine
sailer
fianna
karden
khanie
kali
jurta
reh
ifa
jadynully
kaylengahla
cire
ker
rina
kaer
kalisanna
eiloj
cola
azadyni
narimannavifphamaleie
kona
trry
con
nuh
khyna
shikar
mae
kimee
mokie
moce
serne
mive
grahledsy
arma
sarna
