In [1]:
import random
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
with open('names.txt', 'r') as f:
    names = f.read()

In [3]:
names = names.split()

In [4]:
names_original = names.copy()

In [5]:
bigram_to_idx = {}
idx_to_bigram = {}
char_to_idx = {}
idx_to_char = {}

In [6]:
import string

letters = ['.'] + list(string.ascii_lowercase)
bigrams = [a + b for a in letters for b in letters]
bigram_to_idx = {bigram: index for index, bigram in enumerate(bigrams)}
char_to_idx = {char: index for index, char in enumerate(letters)}
idx_to_bigram = {index: bigram for index, bigram in enumerate(bigrams)}
idx_to_char = {index: char for index, char in enumerate(letters)}

BIGRAM_SIZE = len(bigram_to_idx)
CHAR_SIZE = len(char_to_idx)

In [7]:
names = ['.' + name + '.' for name in names]

In [8]:
def get_sample(print_name=False):
    sample = random.choice(names)
    if print_name:
        print(sample)
    X, y = [], []
    for ch, ch_nxt, ch_ans in zip(sample, sample[1:], sample[2:]):
        bigram = ch + ch_nxt
        bigram = bigram_to_idx[bigram]
        bigram = torch.tensor(bigram)
        sample_x = F.one_hot(bigram, BIGRAM_SIZE).float()
        X.append(sample_x)

        ans = char_to_idx[ch_ans]
        ans = torch.tensor(ans)
        sample_y = F.one_hot(ans, CHAR_SIZE).float()
        y.append(sample_y)
    return X, y

In [9]:
def make_dataset():
    dataset_x, dataset_y = [], []
    for sample in names:
        for ch, ch_nxt, ch_ans in zip(sample, sample[1:], sample[2:]):
            bigram = ch + ch_nxt
            bigram = bigram_to_idx[bigram]
            bigram = torch.tensor(bigram)
            sample_x = bigram
            # sample_x = F.one_hot(bigram, BIGRAM_SIZE).float()
            dataset_x.append(sample_x)
    
            ans = char_to_idx[ch_ans]
            ans = torch.tensor(ans)
            sample_y = ans
            # sample_y = F.one_hot(ans, CHAR_SIZE).float()
            dataset_y.append(sample_y)
    return dataset_x, dataset_y

In [10]:
dataset_x, dataset_y = make_dataset()

In [11]:
for i, j in zip(dataset_x[:10], dataset_y):
    print(idx_to_bigram[i.item()], idx_to_char[j.item()])

.e m
em m
mm a
ma .
.o l
ol i
li v
iv i
vi a
ia .


In [14]:
class TrigramModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(BIGRAM_SIZE, CHAR_SIZE)

    def forward(self, x):
        return self.fc(x)

In [15]:
def train_loop(epochs=1):
    for epoch in range(epochs):
        avg_loss = 0
        for sample_x, sample_true in zip(dataset_x, dataset_y):
            logits = model(sample_x)
            loss = criterion(logits, sample_true)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            avg_loss += loss

        avg_loss /= len(dataset_x)

        print(f'{epoch}: {avg_loss:4f}')
    

In [36]:
model = TrigramModel()
optimizer = optim.Adam(model.parameters(), lr=1, weight_decay=1e-2)  # weight_decay is the L2 regularization term
criterion = nn.CrossEntropyLoss()

In [37]:
train_loop()

0: 4.081882


In [16]:
def generate_name():
    ch1 = '.'
    ch2 = random.choice(list(string.ascii_lowercase))
    generated_name = [ch1, ch2]
    
    while True:
        bigram = ''.join(generated_name[-2:])
        bigram = bigram_to_idx[bigram]
        bigram = torch.tensor(bigram)
        sample_x = F.one_hot(bigram, BIGRAM_SIZE).float()
        
        logits = model(sample_x)
        probs = F.softmax(logits, dim=0)
        pred = torch.multinomial(probs, num_samples=1)
        pred_chr = idx_to_char[pred[0].item()]
    
        if pred_chr == '.':
            return ''.join(generated_name)
        else:
            generated_name.append(pred_chr)

In [55]:
generate_name()

'.qnzx'

In [12]:
dataset_x = torch.stack(dataset_x)
dataset_y = torch.stack(dataset_y)
num = len(dataset_x)

In [18]:
dataset_x = F.one_hot(dataset_x, BIGRAM_SIZE).float()

In [19]:
W = torch.randn((BIGRAM_SIZE, CHAR_SIZE), requires_grad=True)

In [74]:
for i in range(1000):
    logits = dataset_x @ W
    counts = logits.exp() # counts, equivalent to N
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(num), dataset_y].log().mean() + 0.01*(W**2).mean()
    print(loss.item())

    W.grad = None
    loss.backward()
    W.data += -50 * W.grad

2.17934513092041
2.1791465282440186
2.1789488792419434
2.1787519454956055
2.178555965423584
2.1783604621887207
2.178166389465332
2.1779725551605225
2.1777799129486084
2.1775877475738525
2.177396535873413
2.17720627784729
2.177016258239746
2.1768271923065186
2.1766393184661865
2.1764516830444336
2.176265239715576
2.176079511642456
2.175894260406494
2.1757097244262695
2.1755259037017822
2.1753430366516113
2.175161123275757
2.1749794483184814
2.1747987270355225
2.174618721008301
2.1744391918182373
2.1742606163024902
2.1740827560424805
2.173905849456787
2.173729181289673
2.173553228378296
2.1733782291412354
2.173203706741333
2.173029899597168
2.1728568077087402
2.17268443107605
2.1725127696990967
2.1723415851593018
2.1721713542938232
2.172001361846924
2.171832323074341
2.171663761138916
2.1714961528778076
2.1713290214538574
2.1711623668670654
2.1709964275360107
2.1708312034606934
2.1706666946411133
2.1705024242401123
2.170339345932007
2.1701765060424805
2.1700143814086914
2.169852972030639

In [92]:
def generate_name():
    ch1 = '.'
    ch2 = random.choice(list(string.ascii_lowercase))
    generated_name = [ch1, ch2]
    
    while True:
        bigram = ''.join(generated_name[-2:])
        bigram = bigram_to_idx[bigram]
        bigram = torch.tensor(bigram)
        sample_x = F.one_hot(bigram, BIGRAM_SIZE).float()
        
        logits = sample_x @ W
        # probs = F.softmax(logits, dim=0)
        
        counts = logits.exp() # counts, equivalent to N
        p = counts / counts.sum(0, keepdims=True) # probabilities for next character
        ix = torch.multinomial(p, num_samples=1, replacement=True).item()
        pred_chr = idx_to_char[ix]
    
        if pred_chr == '.':
            return ''.join(generated_name[1:])
        else:
            generated_name.append(pred_chr)

In [105]:
generate_name()

'letahailairalynn'