In [1]:
import numpy as np
import torch

In [2]:
# Read names from file
with open("names.txt", "r") as f:
    names = f.readlines()

In [3]:
names = [name.strip() for name in names]
names[:3]

['emma', 'olivia', 'ava']

In [4]:
# Clean the names
import re
names = [re.sub('[-,.]', '', name) for name in names]
names = [re.sub(r'\(.*\)', '', name) for name in names]
names = [name.lower() for name in names]

In [5]:
len(names)

32033

In [6]:
letter_set = {'.'}

for name in names:
    for l in list(name):
        letter_set.add(l)
        
letter_set = sorted(list(letter_set))
len(letter_set)

27

In [7]:
stoi = {letter: pos for pos, letter in enumerate(letter_set)}
itos = {pos: letter for letter, pos in stoi.items()}

In [8]:
lookup_table = torch.ones((27, 27, 27), dtype=torch.int32)

In [9]:
for name in names:
    name = ['.', '.'] + list(name) + ['.']
    for char1, char2, char3 in zip(name, name[1:], name[2:]):
        p1, p2, p3 = stoi[char1], stoi[char2], stoi[char3]
        lookup_table[p1, p2, p3] += 1

In [10]:
def get_occurences(char1, char2, char3):
    return lookup_table[stoi[char1], stoi[char2], stoi[char3]]

In [11]:
get_occurences('.', '.', 'a')

tensor(4411, dtype=torch.int32)

In [12]:
lookup_table = torch.div(lookup_table, torch.sum(lookup_table, dim=2, keepdims=True))
torch.sum(lookup_table)

tensor(729.0001)

In [27]:
gen = torch.Generator().manual_seed(2147483648)
for _ in range(10):
    idx1, idx2 = 0, 0
    out = []
    while True:
        idx3 = torch.multinomial(lookup_table[idx1, idx2], num_samples=1, replacement=True, generator=gen).item()
        if idx3==0:
            break
        out.append(itos[idx3])
        idx1 = idx2
        idx2 = idx3
    print("Name: ", "".join(out))
        

Name:  cam
Name:  ainor
Name:  slea
Name:  em
Name:  mon
Name:  eiagianaven
Name:  kair
Name:  uzana
Name:  kentham
Name:  jara


In [15]:
log_likelihood = 0.0 
num_samples = 0
for name in names[:2]:
    name = ['.', '.'] + list(name) + ['.']
    for char1, char2, char3 in zip(name, name[1:], name[2:]):
        prob = lookup_table[stoi[char1], stoi[char2], stoi[char3]]
        log_likelihood += torch.log(prob)
        num_samples += 1
nll = -log_likelihood
print(f"nll_loss: {nll/num_samples}")

nll_loss: 2.206512451171875


In [17]:
xs, ys = [], []
for name in names:
    name = ['.', '.'] + list(name) + ['.']
    for char1, char2, char3 in zip(name, name[1:], name[2:]):
        xs.append([stoi[char1], stoi[char2]])
        ys.append(stoi[char3])
        
xs = torch.tensor(xs)
ys = torch.tensor(ys)
print(f"Number of samples: {xs.nelement()}")

Number of samples: 456292


In [18]:
xs

tensor([[ 0,  0],
        [ 0,  5],
        [ 5, 13],
        ...,
        [26, 25],
        [25, 26],
        [26, 24]])

In [19]:
gen = torch.Generator().manual_seed(214748364)
x_oh = torch.nn.functional.one_hot(xs, num_classes=27).float()
weights = torch.randn((54, 27), requires_grad=True)
print(f"Shape of encoded inputs: {x_oh.shape}")
print(f"Shape of weights matrix: {weights.shape}")

Shape of encoded inputs: torch.Size([228146, 2, 27])
Shape of weights matrix: torch.Size([54, 27])


In [20]:
ys.shape

torch.Size([228146])

In [23]:
x_oh_reshaped = x_oh.view(x_oh.shape[0], x_oh.shape[1] * x_oh.shape[2])
x_oh_reshaped.shape

torch.Size([228146, 54])

In [25]:
for _ in range(700):
    logits = torch.matmul(x_oh_reshaped, weights)
    counts = logits.exp()
    probs = torch.div(counts, torch.sum(counts, dim=1, keepdims=True))
    loss = -probs[torch.arange(xs.shape[0]), ys].log().mean() + 0.01 * (weights ** 2).mean()
    print(f"Loss: {loss.item()}")
    weights.grad = None
    loss.backward()
    weights.data += -30 * weights.grad

Loss: 4.380014896392822
Loss: 3.8032491207122803
Loss: 3.4484012126922607
Loss: 3.2576043605804443
Loss: 3.121537446975708
Loss: 3.0197627544403076
Loss: 2.9405405521392822
Loss: 2.876847982406616
Loss: 2.8247926235198975
Loss: 2.7817368507385254
Loss: 2.7457733154296875
Loss: 2.7153985500335693
Loss: 2.6894333362579346
Loss: 2.666961669921875
Loss: 2.64729380607605
Loss: 2.6299099922180176
Loss: 2.6144187450408936
Loss: 2.6005194187164307
Loss: 2.5879757404327393
Loss: 2.576596736907959
Loss: 2.566227912902832
Loss: 2.5567405223846436
Loss: 2.5480265617370605
Loss: 2.5399954319000244
Loss: 2.532569408416748
Loss: 2.5256826877593994
Loss: 2.519277811050415
Loss: 2.513305902481079
Loss: 2.5077242851257324
Loss: 2.502495765686035
Loss: 2.4975879192352295
Loss: 2.4929721355438232
Loss: 2.4886229038238525
Loss: 2.484518051147461
Loss: 2.480637311935425
Loss: 2.476963520050049
Loss: 2.473479747772217
Loss: 2.470172882080078
Loss: 2.467028856277466
Loss: 2.464036703109741
Loss: 2.46118545532

Loss: 2.358928918838501
Loss: 2.3588919639587402
Loss: 2.3588552474975586
Loss: 2.358818292617798
Loss: 2.3587820529937744
Loss: 2.358746290206909
Loss: 2.358710289001465
Loss: 2.3586747646331787
Loss: 2.3586394786834717
Loss: 2.3586044311523438
Loss: 2.358569622039795
Loss: 2.358534812927246
Loss: 2.3585007190704346
Loss: 2.358466386795044
Loss: 2.3584325313568115
Loss: 2.3583991527557373
Loss: 2.358365535736084
Loss: 2.358332395553589
Loss: 2.3582990169525146
Loss: 2.3582661151885986
Loss: 2.3582334518432617
Loss: 2.358201265335083
Loss: 2.3581693172454834
Loss: 2.3581371307373047
Loss: 2.358105182647705
Loss: 2.3580737113952637
Loss: 2.3580427169799805
Loss: 2.358011484146118
Loss: 2.357980728149414
Loss: 2.35794997215271
Loss: 2.357919216156006
Loss: 2.357889175415039
Loss: 2.3578591346740723
Loss: 2.3578290939331055
Loss: 2.3577992916107178
Loss: 2.3577699661254883
Loss: 2.3577404022216797
Loss: 2.3577115535736084
Loss: 2.357682228088379
Loss: 2.3576533794403076
Loss: 2.3576252460

Loss: 2.3534698486328125
Loss: 2.3534634113311768
Loss: 2.353456735610962
Loss: 2.3534505367279053
Loss: 2.3534438610076904
Loss: 2.353437662124634
Loss: 2.353430986404419
Loss: 2.3534250259399414
Loss: 2.3534185886383057
Loss: 2.353412389755249
Loss: 2.3534059524536133
Loss: 2.3533997535705566
Loss: 2.353393793106079
Loss: 2.3533873558044434
Loss: 2.353381395339966
Loss: 2.35337495803833
Loss: 2.3533689975738525
Loss: 2.353363275527954
Loss: 2.3533568382263184
Loss: 2.3533506393432617
Loss: 2.353344678878784
Loss: 2.3533387184143066
Loss: 2.353332996368408
Loss: 2.3533270359039307
Loss: 2.3533213138580322
Loss: 2.3533151149749756
Loss: 2.353309392929077
Loss: 2.3533034324645996
Loss: 2.353297710418701


In [28]:
gen = torch.Generator().manual_seed(2147483648)
for _ in range(10):
    out = []
    idx1 = 0
    idx2 = 0
    
    while True:
        x_enc_1 = torch.nn.functional.one_hot(torch.tensor([idx1]), num_classes=27).float()
        x_enc_2 = torch.nn.functional.one_hot(torch.tensor([idx2]), num_classes=27).float()
        
        logits = torch.matmul(torch.hstack((x_enc_1, x_enc_2)), weights)
        counts = logits.exp()
        probs = torch.div(counts, torch.sum(counts, dim=1, keepdims=True))
        
        idx3 = torch.multinomial(probs, num_samples=1, replacement=True, generator=gen).item()
        if idx3 == 0:
            break
        idx1 = idx2
        idx2 = idx3
        out.append(itos[idx3])
        
    print("".join(out))

can
ahior
slea
em
molariagialaven
kali
ustia
kentham
jara
cyl
