In [1]:
import numpy as np
import torch

In [2]:
# Read names from file
with open("names.txt", "r") as f:
    names = f.readlines()

In [3]:
names = [name.strip() for name in names]
names[:3]

['emma', 'olivia', 'ava']

In [4]:
# Clean the names
import re
names = [re.sub('[-,.]', '', name) for name in names]
names = [re.sub(r'\(.*\)', '', name) for name in names]
names = [name.lower() for name in names]

In [5]:
len(names)

32033

In [6]:
letter_set = {'.'}

for name in names:
    for l in list(name):
        letter_set.add(l)
        
letter_set = sorted(list(letter_set))
len(letter_set)

27

In [7]:
stoi = {letter: pos for pos, letter in enumerate(letter_set)}
itos = {pos: letter for letter, pos in stoi.items()}

In [8]:
lookup_table = torch.ones((27, 27, 27), dtype=torch.int32)

In [9]:
for name in names:
    name = ['.', '.'] + list(name) + ['.']
    for char1, char2, char3 in zip(name, name[1:], name[2:]):
        p1, p2, p3 = stoi[char1], stoi[char2], stoi[char3]
        lookup_table[p1, p2, p3] += 1

In [10]:
def get_occurences(char1, char2, char3):
    return lookup_table[stoi[char1], stoi[char2], stoi[char3]]

In [11]:
get_occurences('.', '.', 'a')

tensor(4411, dtype=torch.int32)

In [12]:
lookup_table = torch.div(lookup_table, torch.sum(lookup_table, dim=2, keepdims=True))
torch.sum(lookup_table)

tensor(729.0001)

In [13]:
gen = torch.Generator().manual_seed(2147483648)
for _ in range(10):
    idx1, idx2 = 0, 0
    out = []
    while True:
        idx3 = torch.multinomial(lookup_table[idx1, idx2], num_samples=1, replacement=True, generator=gen).item()
        if idx3==0:
            break
        out.append(itos[idx3])
        idx1 = idx2
        idx2 = idx3
    print("Name: ", "".join(out))
        

Name:  cam
Name:  ainor
Name:  slea
Name:  em
Name:  mon
Name:  eiagianaven
Name:  kair
Name:  uzana
Name:  kentham
Name:  jara


In [14]:
log_likelihood = 0.0 
num_samples = 0
for name in names:
    name = ['.', '.'] + list(name) + ['.']
    for char1, char2, char3 in zip(name, name[1:], name[2:]):
        prob = lookup_table[stoi[char1], stoi[char2], stoi[char3]]
        log_likelihood += torch.log(prob)
        num_samples += 1
nll = -log_likelihood
print(f"nll_loss: {nll/num_samples}")

nll_loss: 2.2119739055633545


In [15]:
xs, ys = [], []
for name in names:
    name = ['.', '.'] + list(name) + ['.']
    for char1, char2, char3 in zip(name, name[1:], name[2:]):
        xs.append([stoi[char1], stoi[char2]])
        ys.append(stoi[char3])
        
xs = torch.tensor(xs)
ys = torch.tensor(ys)
print(f"Number of samples: {xs.nelement()}")

Number of samples: 456292


In [16]:
xs

tensor([[ 0,  0],
        [ 0,  5],
        [ 5, 13],
        ...,
        [26, 25],
        [25, 26],
        [26, 24]])

In [17]:
gen = torch.Generator().manual_seed(214748364)
x_oh = torch.nn.functional.one_hot(xs, num_classes=27).float()
weights = torch.randn((54, 27), requires_grad=True)
print(f"Shape of encoded inputs: {x_oh.shape}")
print(f"Shape of weights matrix: {weights.shape}")

Shape of encoded inputs: torch.Size([228146, 2, 27])
Shape of weights matrix: torch.Size([54, 27])


In [18]:
ys.shape

torch.Size([228146])

In [19]:
x_oh_reshaped = x_oh.view(x_oh.shape[0], x_oh.shape[1] * x_oh.shape[2])
x_oh_reshaped.shape

torch.Size([228146, 54])

In [20]:
for _ in range(700):
    logits = torch.matmul(x_oh_reshaped, weights)
    counts = logits.exp()
    probs = torch.div(counts, torch.sum(counts, dim=1, keepdims=True))
    loss = -probs[torch.arange(xs.shape[0]), ys].log().mean() + 0.01 * (weights ** 2).mean()
    print(f"Loss: {loss.item()}")
    weights.grad = None
    loss.backward()
    weights.data += -30 * weights.grad

Loss: 4.330939292907715
Loss: 3.7552990913391113
Loss: 3.449273109436035
Loss: 3.249814987182617
Loss: 3.1080565452575684
Loss: 3.0071299076080322
Loss: 2.9291675090789795
Loss: 2.867516040802002
Loss: 2.817561388015747
Loss: 2.7765862941741943
Loss: 2.7424776554107666
Loss: 2.713681936264038
Loss: 2.689020872116089
Loss: 2.66762113571167
Loss: 2.648831367492676
Loss: 2.6321637630462646
Loss: 2.6172494888305664
Loss: 2.6038060188293457
Loss: 2.5916123390197754
Loss: 2.580493211746216
Loss: 2.5703063011169434
Loss: 2.560936212539673
Loss: 2.552286386489868
Loss: 2.544275999069214
Loss: 2.5368359088897705
Loss: 2.5299081802368164
Loss: 2.5234410762786865
Loss: 2.5173912048339844
Loss: 2.5117194652557373
Loss: 2.50639271736145
Loss: 2.501380443572998
Loss: 2.496656656265259
Loss: 2.4921979904174805
Loss: 2.4879822731018066
Loss: 2.4839913845062256
Loss: 2.480208396911621
Loss: 2.4766173362731934
Loss: 2.4732048511505127
Loss: 2.4699580669403076
Loss: 2.4668657779693604
Loss: 2.46391749382

Loss: 2.3590869903564453
Loss: 2.35904860496521
Loss: 2.359011173248291
Loss: 2.358973503112793
Loss: 2.358936071395874
Loss: 2.3588995933532715
Loss: 2.3588626384735107
Loss: 2.358825922012329
Loss: 2.3587899208068848
Loss: 2.3587536811828613
Loss: 2.358717679977417
Loss: 2.358682155609131
Loss: 2.358646869659424
Loss: 2.358612060546875
Loss: 2.358577251434326
Loss: 2.3585424423217773
Loss: 2.358508348464966
Loss: 2.358474016189575
Loss: 2.3584399223327637
Loss: 2.3584063053131104
Loss: 2.358372926712036
Loss: 2.358339548110962
Loss: 2.358306646347046
Loss: 2.35827374458313
Loss: 2.358241081237793
Loss: 2.3582088947296143
Loss: 2.3581762313842773
Loss: 2.358144521713257
Loss: 2.3581130504608154
Loss: 2.358081340789795
Loss: 2.3580501079559326
Loss: 2.358018636703491
Loss: 2.357987880706787
Loss: 2.357957124710083
Loss: 2.3579261302948
Loss: 2.357896089553833
Loss: 2.357866048812866
Loss: 2.3578360080718994
Loss: 2.357806444168091
Loss: 2.357776641845703
Loss: 2.3577473163604736
Loss: 

Loss: 2.3534996509552
Loss: 2.3534932136535645
Loss: 2.353487253189087
Loss: 2.353480577468872
Loss: 2.3534741401672363
Loss: 2.3534679412841797
Loss: 2.353461742401123
Loss: 2.3534555435180664
Loss: 2.3534493446350098
Loss: 2.3534433841705322
Loss: 2.3534369468688965
Loss: 2.3534305095672607
Loss: 2.3534247875213623
Loss: 2.3534185886383057
Loss: 2.3534128665924072
Loss: 2.3534061908721924
Loss: 2.353400468826294
Loss: 2.3533945083618164
Loss: 2.353388786315918
Loss: 2.3533825874328613
Loss: 2.353376865386963
Loss: 2.3533709049224854
Loss: 2.353364944458008
Loss: 2.3533589839935303
Loss: 2.3533530235290527
Loss: 2.3533477783203125
Loss: 2.353341579437256
Loss: 2.3533360958099365


In [21]:
gen = torch.Generator().manual_seed(2147483648)
for _ in range(10):
    out = []
    idx1 = 0
    idx2 = 0
    
    while True:
        x_enc_1 = torch.nn.functional.one_hot(torch.tensor([idx1]), num_classes=27).float()
        x_enc_2 = torch.nn.functional.one_hot(torch.tensor([idx2]), num_classes=27).float()
        
        logits = torch.matmul(torch.hstack((x_enc_1, x_enc_2)), weights)
        counts = logits.exp()
        probs = torch.div(counts, torch.sum(counts, dim=1, keepdims=True))
        
        idx3 = torch.multinomial(probs, num_samples=1, replacement=True, generator=gen).item()
        if idx3 == 0:
            break
        idx1 = idx2
        idx2 = idx3
        out.append(itos[idx3])
        
    print("".join(out))

can
ahior
slea
em
molariagialaven
kali
ustia
kentham
jara
cyl
