In [1]:
import torch
import torch.nn.functional as F
%matplotlib inline

In [2]:
words = open("names.txt", "r").read().split()

In [3]:
min(len(word) for word in words)

2

In [4]:
special_token = "."
alphabet = [special_token] + sorted(list(set("".join(words))))
itos = dict((i, s) for i, s in enumerate(alphabet))  # index to string
stoi = dict((s, i) for i, s in itos.items())  # string to index

In [5]:
num_grams = 3

shape = tuple(len(alphabet) for _ in range(num_grams))
N = torch.zeros(shape, dtype=torch.int32)

for word in words:
    processed_word = [special_token, special_token] + list(word) + [special_token]
    zipped_word = zip(*[processed_word[i:] for i in range(num_grams)])
    for chars in zipped_word:
        N[tuple(map(lambda x: stoi[x], chars))] += 1

In [6]:
P = (N + 1).float()
P /= P.sum(dim=2, keepdim=True)

In [7]:
P[0, 0]

tensor([3.1192e-05, 1.3759e-01, 4.0767e-02, 4.8129e-02, 5.2745e-02, 4.7785e-02,
        1.3038e-02, 2.0898e-02, 2.7293e-02, 1.8465e-02, 7.5577e-02, 9.2452e-02,
        4.9064e-02, 7.9195e-02, 3.5777e-02, 1.2321e-02, 1.6095e-02, 2.9008e-03,
        5.1154e-02, 6.4130e-02, 4.0830e-02, 2.4641e-03, 1.1759e-02, 9.6070e-03,
        4.2109e-03, 1.6719e-02, 2.9008e-02])

In [8]:
g = torch.Generator().manual_seed(0)

for i in range(10):
    ix1 = 0
    ix2 = 0
    name_chars = []
    while True:
        ix3 = torch.multinomial(P[ix1, ix2], 1, replacement=True, generator=g).item()
        if ix3 == 0:
            break
        name_chars.append(itos[ix3])
        ix1 = ix2
        ix2 = ix3
    print("".join(name_chars))

brona
kercur
jer
li
lum
advtqjyouna
suha
mer
trendoustaytel
zade


In [9]:
loss = .0
n = 0
for word in words:
    processed_word = [special_token] + list(word) + [special_token]
    zipped_word = zip(*[processed_word[i:] for i in range(num_grams)])
    for chars in zipped_word:
        loss += torch.log(P[tuple(map(lambda x: stoi[x], chars))]).item()
        n += 1
print(f"Loss: {-loss / n:.4f}")

Loss: 2.0931


# Gradient descent approach

In [10]:
xs = []
ys = []

for word in words:
    processed_word = [special_token] + list(word) + [special_token]
    zipped_word = zip(*[processed_word[i:] for i in range(num_grams)])
    for chars in zipped_word:
        xs.append(list(map(lambda x: stoi[x], chars[:-1])))
        ys.append(stoi[chars[-1]])
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num_examples = xs.nelement()
xenc = F.one_hot(xs, num_classes=len(stoi)).float().flatten(1, 2)

g = torch.Generator().manual_seed(0)
W = torch.randn((len(stoi) * (num_grams - 1), len(stoi)), generator=g, requires_grad=True)

In [11]:
regularization_param = 0.01
alpha = 50
epochs = 1000

for i in range(epochs):
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(dim=1, keepdims=True)
    loss = -probs[torch.arange(len(probs)), ys].log().mean() + regularization_param * (W ** 2).mean()
    print(f"Loss: {loss:.4f}")
    W.grad = None
    loss.backward()
    W.data += -alpha * W.grad

Loss: 4.3783
Loss: 3.5119
Loss: 3.1645
Loss: 2.9630
Loss: 2.8306
Loss: 2.7412
Loss: 2.6765
Loss: 2.6273
Loss: 2.5881
Loss: 2.5562
Loss: 2.5296
Loss: 2.5070
Loss: 2.4876
Loss: 2.4708
Loss: 2.4560
Loss: 2.4430
Loss: 2.4314
Loss: 2.4210
Loss: 2.4117
Loss: 2.4032
Loss: 2.3955
Loss: 2.3885
Loss: 2.3821
Loss: 2.3762
Loss: 2.3707
Loss: 2.3656
Loss: 2.3609
Loss: 2.3565
Loss: 2.3525
Loss: 2.3487
Loss: 2.3451
Loss: 2.3417
Loss: 2.3386
Loss: 2.3356
Loss: 2.3328
Loss: 2.3302
Loss: 2.3277
Loss: 2.3253
Loss: 2.3231
Loss: 2.3209
Loss: 2.3189
Loss: 2.3170
Loss: 2.3152
Loss: 2.3134
Loss: 2.3117
Loss: 2.3101
Loss: 2.3086
Loss: 2.3072
Loss: 2.3058
Loss: 2.3044
Loss: 2.3032
Loss: 2.3019
Loss: 2.3007
Loss: 2.2996
Loss: 2.2985
Loss: 2.2975
Loss: 2.2964
Loss: 2.2955
Loss: 2.2945
Loss: 2.2936
Loss: 2.2927
Loss: 2.2919
Loss: 2.2911
Loss: 2.2903
Loss: 2.2895
Loss: 2.2888
Loss: 2.2881
Loss: 2.2874
Loss: 2.2867
Loss: 2.2861
Loss: 2.2854
Loss: 2.2848
Loss: 2.2842
Loss: 2.2837
Loss: 2.2831
Loss: 2.2825
Loss: 2.2820

In [12]:
# sample
g = torch.Generator().manual_seed(0)

for i in range(10):
    out = []
    ix1 = 0
    ix2 = 0
    while True:
        xenc = F.one_hot(torch.tensor([[ix1, ix2]]), num_classes=len(stoi)).float().flatten(1, 2)
        logits = xenc @ W
        # softmax
        counts = logits.exp()
        probs = counts / counts.sum(dim=1, keepdims=True)
        ix3 = torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item()
        if ix3 == 0:
            break
        out.append(itos[ix3])
        ix1 = ix2
        ix2 = ix3
    print(''.join(out))

tensor([[-4.0416,  2.5324, -0.2432, -0.2364,  0.0305,  2.4533, -1.2079, -1.3379,
          1.0737,  1.3821, -0.7340, -1.0008,  0.4592,  1.5101, -0.8770,  0.7643,
         -1.2371, -2.4235,  1.2894,  0.2276, -0.0208,  2.1168,  0.4154, -1.4197,
         -0.7245,  1.3680,  1.0385]], grad_fn=<MmBackward0>)
bhria
tensor([[-4.0416,  2.5324, -0.2432, -0.2364,  0.0305,  2.4533, -1.2079, -1.3379,
          1.0737,  1.3821, -0.7340, -1.0008,  0.4592,  1.5101, -0.8770,  0.7643,
         -1.2371, -2.4235,  1.2894,  0.2276, -0.0208,  2.1168,  0.4154, -1.4197,
         -0.7245,  1.3680,  1.0385]], grad_fn=<MmBackward0>)
evackey
tensor([[-4.0416,  2.5324, -0.2432, -0.2364,  0.0305,  2.4533, -1.2079, -1.3379,
          1.0737,  1.3821, -0.7340, -1.0008,  0.4592,  1.5101, -0.8770,  0.7643,
         -1.2371, -2.4235,  1.2894,  0.2276, -0.0208,  2.1168,  0.4154, -1.4197,
         -0.7245,  1.3680,  1.0385]], grad_fn=<MmBackward0>)
ren
tensor([[-4.0416,  2.5324, -0.2432, -0.2364,  0.0305,  2.4533, -1.2079

In [13]:
loss.item()

2.252534866333008