In [1]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [2]:
# Load names & create bigrams
names = []
with open('names.txt') as f:
    names = f.readlines()
names = ['.' + x.strip() + '.' for x in names]
bigrams = []
for name in names:
    name_bigrams = [''.join(x) for x in zip(name[:-1], name[1:])]
    bigrams.extend(name_bigrams)
bigrams

['.e',
 'em',
 'mm',
 'ma',
 'a.',
 '.o',
 'ol',
 'li',
 'iv',
 'vi',
 'ia',
 'a.',
 '.a',
 'av',
 'va',
 'a.',
 '.i',
 'is',
 'sa',
 'ab',
 'be',
 'el',
 'll',
 'la',
 'a.',
 '.s',
 'so',
 'op',
 'ph',
 'hi',
 'ia',
 'a.',
 '.c',
 'ch',
 'ha',
 'ar',
 'rl',
 'lo',
 'ot',
 'tt',
 'te',
 'e.',
 '.m',
 'mi',
 'ia',
 'a.',
 '.a',
 'am',
 'me',
 'el',
 'li',
 'ia',
 'a.',
 '.h',
 'ha',
 'ar',
 'rp',
 'pe',
 'er',
 'r.',
 '.e',
 'ev',
 've',
 'el',
 'ly',
 'yn',
 'n.',
 '.a',
 'ab',
 'bi',
 'ig',
 'ga',
 'ai',
 'il',
 'l.',
 '.e',
 'em',
 'mi',
 'il',
 'ly',
 'y.',
 '.e',
 'el',
 'li',
 'iz',
 'za',
 'ab',
 'be',
 'et',
 'th',
 'h.',
 '.m',
 'mi',
 'il',
 'la',
 'a.',
 '.e',
 'el',
 'll',
 'la',
 'a.',
 '.a',
 'av',
 've',
 'er',
 'ry',
 'y.',
 '.s',
 'so',
 'of',
 'fi',
 'ia',
 'a.',
 '.c',
 'ca',
 'am',
 'mi',
 'il',
 'la',
 'a.',
 '.a',
 'ar',
 'ri',
 'ia',
 'a.',
 '.s',
 'sc',
 'ca',
 'ar',
 'rl',
 'le',
 'et',
 'tt',
 't.',
 '.v',
 'vi',
 'ic',
 'ct',
 'to',
 'or',
 'ri',
 'ia',
 'a.',

In [3]:
# create mapping from letters to indices
letters = []
for name in names:
    letters.extend(list(name))
letters = list(set(letters))
letters.sort()
mapping = dict(zip(letters, list(range(27))))
mapping

{'.': 0,
 'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26}

In [4]:
# create training data
x_y_idxs = [(mapping[bigram[0]], mapping[bigram[1]]) for bigram in bigrams]
x_idxs = [item[0] for item in x_y_idxs]
y_idxs = [ [ item[1] ] for item in x_y_idxs]
X = F.one_hot(torch.tensor(x_idxs), num_classes=27).to(torch.float32)
X

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.],
        [0., 0., 0.,  ..., 1., 0., 0.]])

In [5]:
W = torch.randn(27, 27, requires_grad=True)

In [6]:
for i in range(100):
    # forward pass
    P = X @ W
    Q = P.exp()
    R = Q / Q.sum(dim=1, keepdim=True)

    # negative log loss
    y_probs = R.gather(dim=1, index=torch.tensor(y_idxs))
    loss = - y_probs.log().mean()
    print(f"Iteration: {i}, Loss: {loss}")

    # backprop
    loss.backward()
    W.data += -50.0 * W.grad
    W.grad.zero_()


Iteration: 0, Loss: 3.930293560028076
Iteration: 1, Loss: 3.4809799194335938
Iteration: 2, Loss: 3.2049553394317627
Iteration: 3, Loss: 3.038757085800171
Iteration: 4, Loss: 2.934551477432251
Iteration: 5, Loss: 2.860172986984253
Iteration: 6, Loss: 2.8057162761688232
Iteration: 7, Loss: 2.764683485031128
Iteration: 8, Loss: 2.7326619625091553
Iteration: 9, Loss: 2.706887722015381
Iteration: 10, Loss: 2.6856637001037598
Iteration: 11, Loss: 2.6678831577301025
Iteration: 12, Loss: 2.6527602672576904
Iteration: 13, Loss: 2.6397037506103516
Iteration: 14, Loss: 2.6282691955566406
Iteration: 15, Loss: 2.6181252002716064
Iteration: 16, Loss: 2.60902738571167
Iteration: 17, Loss: 2.6007962226867676
Iteration: 18, Loss: 2.5932981967926025
Iteration: 19, Loss: 2.586432933807373
Iteration: 20, Loss: 2.580120801925659
Iteration: 21, Loss: 2.574298143386841
Iteration: 22, Loss: 2.5689127445220947
Iteration: 23, Loss: 2.563920259475708
Iteration: 24, Loss: 2.559281826019287
Iteration: 25, Loss: 2.

In [7]:
# Sample from NN
for i in range(10):
    x = F.one_hot(torch.tensor([0]), num_classes=27).to(torch.float32)
    char_idxs = [ 0, ]
    while True:
        # forward pass
        P = x @ W
        Q = P.exp()
        R = Q / Q.sum(dim=1, keepdim=True)
        
        # sample from output prob dist
        next_char = torch.multinomial(Q[0, :], 1, replacement=True).item()
        char_idxs.append(next_char)
        if next_char == 0:
            break
            
        x = F.one_hot(torch.tensor([next_char]), num_classes=27).to(torch.float32)
        
    name_pred = ''.join([letters[idx] for idx in char_idxs])
    print(name_pred)


.lionig.
.ka.
.tonn.
.zaimimalllelieyn.
.ame.
.a.
.ysoplarey.
.tlimadaha.
.olancescadraynom.
.lyster.
