In [1]:
import numpy as np
import torch

In [2]:
# Read names from file
with open("names.txt", "r") as f:
    names = f.readlines()

In [3]:
names = [name.strip() for name in names]
names[:3]

['emma', 'olivia', 'ava']

In [4]:
# Clean the names
import re
names = [re.sub('[-,.]', '', name) for name in names]
names = [re.sub(r'\(.*\)', '', name) for name in names]
names = [name.lower() for name in names]

In [5]:
len(names)

32033

In [6]:
letter_set = {'.'}

for name in names:
    for l in list(name):
        letter_set.add(l)
        
letter_set = sorted(list(letter_set))
len(letter_set)

27

In [7]:
stoi = {letter: pos for pos, letter in enumerate(letter_set)}
itos = {pos: letter for letter, pos in stoi.items()}

In [8]:
lookup_table = torch.ones((27, 27, 27), dtype=torch.int32)
lookup_table

tensor([[[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]],

        ...,

        [[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         ...,
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]],

        [[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1],
         [1,

In [9]:
for name in names:
    name = ['.', '.'] + list(name) + ['.']
    for char1, char2, char3 in zip(name, name[1:], name[2:]):
        p1, p2, p3 = stoi[char1], stoi[char2], stoi[char3]
        lookup_table[p1, p2, p3] += 1

In [10]:
def get_occurences(char1, char2, char3):
    return lookup_table[stoi[char1], stoi[char2], stoi[char3]]

In [11]:
get_occurences('.', '.', 'a')

tensor(4411, dtype=torch.int32)

In [12]:
lookup_table = torch.div(lookup_table, torch.sum(lookup_table, dim=2, keepdims=True))
torch.sum(lookup_table)

tensor(729.0001)

In [13]:
gen = torch.Generator().manual_seed(2147483647)
for _ in range(10):
    idx1, idx2 = 0, 0
    out = []
    while True:
        idx3 = torch.multinomial(lookup_table[idx1, idx2], num_samples=1, replacement=True, generator=gen).item()
        if idx3==0:
            break
        out.append(itos[idx3])
        idx1 = idx2
        idx2 = idx3
    print("Name: ", "".join(out))
        

Name:  junide
Name:  jakasid
Name:  prelay
Name:  adin
Name:  kairritoper
Name:  sathen
Name:  sameia
Name:  yanileniassibduinrwin
Name:  lessiyanayla
Name:  te


In [14]:
log_likelihood = 0.0 
num_samples = 0
for name in names[:2]:
    name = ['.', '.'] + list(name) + ['.']
    for char1, char2, char3 in zip(name, name[1:], name[2:]):
        prob = lookup_table[stoi[char1], stoi[char2], stoi[char3]]
        log_likelihood += torch.log(prob)
        num_samples += 1
nll = -log_likelihood
print(f"nll_loss: {nll/num_samples}")

nll_loss: 2.206512451171875


In [15]:
xs, ys = [], []
for name in names:
    name = ['.', '.'] + list(name) + ['.']
    for char1, char2, char3 in zip(name, name[1:], name[2:]):
        xs.append([stoi[char1], stoi[char2]])
        ys.append(stoi[char3])
        
xs = torch.tensor(xs)
ys = torch.tensor(ys)
print(f"Number of samples: {xs.nelement()}")

Number of samples: 456292


In [16]:
xs

tensor([[ 0,  0],
        [ 0,  5],
        [ 5, 13],
        ...,
        [26, 25],
        [25, 26],
        [26, 24]])

In [17]:
gen = torch.Generator().manual_seed(2147483647)
x_oh = torch.nn.functional.one_hot(xs, num_classes=27).float()
weights = torch.randn((54, 27), requires_grad=True, generator=gen)
print(f"Shape of encoded inputs: {x_oh.shape}")
print(f"Shape of weights matrix: {weights.shape}")

Shape of encoded inputs: torch.Size([228146, 2, 27])
Shape of weights matrix: torch.Size([54, 27])


In [18]:
x_oh.shape, weights.shape

(torch.Size([228146, 2, 27]), torch.Size([54, 27]))

In [19]:
ys.shape

torch.Size([228146])

In [20]:
x_oh_reshaped = x_oh.view(x_oh.shape[0], x_oh.shape[1] * x_oh.shape[2])

In [22]:
x_oh_reshaped.shape

torch.Size([228146, 54])

In [21]:
for _ in range(1000):
    logits = torch.matmul(x_oh_reshaped, weights)
    counts = logits.exp()
    probs = torch.div(counts, torch.sum(counts, dim=1, keepdims=True))
#     print(probs.shape)
    loss = 0.0
#     for i in range(2):
    loss += -probs[torch.arange(xs.shape[0]), ys].log().mean()
#     loss = -probs[torch.arange(xs.shape[0]), torch.arange(xs.shape[1]), ys].log().mean()
    print(f"Loss: {loss.item()}")
    weights.grad = None
    loss.backward()
    weights.data += -40 * weights.grad

Loss: 4.232541561126709
Loss: 3.553056001663208
Loss: 3.27421236038208
Loss: 3.0445430278778076
Loss: 2.9447267055511475
Loss: 2.859506845474243
Loss: 2.814803123474121
Loss: 2.7571210861206055
Loss: 2.7277557849884033
Loss: 2.6887495517730713
Loss: 2.6680307388305664
Loss: 2.6388931274414062
Loss: 2.624164342880249
Loss: 2.601041555404663
Loss: 2.590730905532837
Loss: 2.571537971496582
Loss: 2.5645530223846436
Loss: 2.548041343688965
Loss: 2.543558120727539
Loss: 2.528939723968506
Loss: 2.526331663131714
Loss: 2.5130996704101562
Loss: 2.511906862258911
Loss: 2.4997267723083496
Loss: 2.499626874923706
Loss: 2.4882683753967285
Loss: 2.4890353679656982
Loss: 2.47832989692688
Loss: 2.4798033237457275
Loss: 2.4696247577667236
Loss: 2.4716858863830566
Loss: 2.46193528175354
Loss: 2.464493989944458
Loss: 2.4550940990448
Loss: 2.458078145980835
Loss: 2.4489693641662598
Loss: 2.4523208141326904
Loss: 2.4434547424316406
Loss: 2.4471256732940674
Loss: 2.438464641571045
Loss: 2.442415237426758
Lo

Loss: 2.3543779850006104
Loss: 2.3610565662384033
Loss: 2.3543145656585693
Loss: 2.3609933853149414
Loss: 2.3542513847351074
Loss: 2.3609306812286377
Loss: 2.354189872741699
Loss: 2.3608686923980713
Loss: 2.35412859916687
Loss: 2.3608078956604004
Loss: 2.3540680408477783
Loss: 2.3607475757598877
Loss: 2.354008197784424
Loss: 2.3606879711151123
Loss: 2.3539493083953857
Loss: 2.3606293201446533
Loss: 2.3538906574249268
Loss: 2.3605709075927734
Loss: 2.3538331985473633
Loss: 2.36051344871521
Loss: 2.353776216506958
Loss: 2.360456943511963
Loss: 2.35371994972229
Loss: 2.360400438308716
Loss: 2.3536643981933594
Loss: 2.3603451251983643
Loss: 2.353609561920166
Loss: 2.360290765762329
Loss: 2.35355544090271
Loss: 2.360236406326294
Loss: 2.353501796722412
Loss: 2.360182762145996
Loss: 2.3534488677978516
Loss: 2.3601300716400146
Loss: 2.353396415710449
Loss: 2.3600780963897705
Loss: 2.353344678878784
Loss: 2.3600263595581055
Loss: 2.3532934188842773
Loss: 2.3599753379821777
Loss: 2.353243112564

KeyboardInterrupt: 

In [29]:
gen = torch.Generator().manual_seed(2147483647)
for _ in range(10):
    out = []
    idx1 = 0
    idx2 = 0
    
    while True:
        x_enc_1 = torch.nn.functional.one_hot(torch.tensor([idx1]), num_classes=27).float()
        x_enc_2 = torch.nn.functional.one_hot(torch.tensor([idx2]), num_classes=27).float()
        
        logits = torch.matmul(torch.hstack((x_enc_1, x_enc_2)), weights)
        counts = logits.exp()
        probs = torch.div(counts, torch.sum(counts, dim=1, keepdims=True))
        
        idx3 = torch.multinomial(probs, num_samples=1, replacement=True, generator=gen).item()
        if idx3 == 0:
#             print("break")
            break
        idx1 = idx2
        idx2 = idx3
        out.append(itos[idx3])
        
#     print(probs.shape)
    print("".join(out))

junide
janasad
pres
yon
na
koi
ritoleras
tee
kilania
yanileniassibdainrwi


In [None]:
ab = torch.tensor([[1,2], [3, 4]])
ab.view(1,4)

In [23]:
idx1 = 0
idx2 = 0

x_enc_1 = torch.nn.functional.one_hot(torch.tensor([idx1]), num_classes=27).float()
x_enc_2 = torch.nn.functional.one_hot(torch.tensor([idx2]), num_classes=27).float()
torch.hstack((x_enc_1, x_enc_2)).shape

torch.Size([1, 54])

In [None]:
torch.hstack()