In [42]:
import torch
from torch import nn,optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

In [43]:
with open("data/names.txt","r") as f:
    words = f.read().lower().split('\n')
words[:5]

['emma', 'olivia', 'ava', 'isabella', 'sophia']

# n-gram


In [44]:
chars = sorted(list(set(''.join(words))))
str_to_idx = {c:i+1 for i,c in enumerate(chars)}
str_to_idx['.'] = 0
idx_to_str = {i:c for c,i in str_to_idx.items()}

In [45]:
block_size = 2
def contextized_iter(word,block_size=1):
    mark = ['.']*(block_size-1)
    word = mark + list(word) + mark
    return zip(*(word[i:] for i in range(block_size)))

print(*contextized_iter('leong',block_size),sep='\n')

n_letters = len(str_to_idx)
freqs = np.zeros(tuple(n_letters for _ in range(block_size)),dtype=np.int32)

for w in words:
    for chs in contextized_iter(w,block_size):
        idxs = tuple(str_to_idx[ch] for ch in chs)
        freqs[idxs] += 1
freqs = torch.tensor(freqs)
print(freqs.shape)

('.', 'l')
('l', 'e')
('e', 'o')
('o', 'n')
('n', 'g')
('g', '.')
torch.Size([27, 27])


$$
P(w_i | w_{i-1},\ldots, w_{i-n+1}) = 
\frac{P(\bigcap (w_i,w_{i-1},\ldots, w_{i-n+1}))}{P(\bigcap (w_{i-1},\ldots, w_{i-n+1}))}
$$


In [46]:
probs = (freqs+10).float()
# check broadcasting semantics
probs /= probs.sum((-1,),keepdim=True)
probs.shape

torch.Size([27, 27])

In [47]:
# model evaluation; model's quality
# log_probs = probs.log()
log_likelihood = 0.0
n = 0

for w in words:
    for chs in contextized_iter(w,block_size):
        idxs = tuple(str_to_idx[ch] for ch in chs)
        log_prob = probs[idxs].log()
        log_likelihood += log_prob
        n += 1
        # break
    # break

print(f'log_likelihood = {log_likelihood:4f}')
print(f'{-log_likelihood:4f}')
print(f'{-log_likelihood/n:4f}')

log_likelihood = -561615.312500
561615.312500
2.461649


In [51]:
# using mutlinomial to sample
g = torch.Generator().manual_seed(42)
def ngram_generate():
    out = []
    context = [0] * (block_size-1)
    while True:
        p = probs[tuple(context)].float()
        ix= torch.multinomial(p,num_samples=1,replacement=True,generator=g).item()
        context = context[1:]+[ix]
        out.append(idx_to_str[ix])
        if ix == 0:
            break
    return ''.join(out)
for i in range(5):
    print(ngram_generate())

ya.
syahle.
ahe.
dleekahmangonya.
tryahe.
