In [29]:
import torch
from torch.nn.functional import one_hot
import numpy as np 
import matplotlib.pyplot as plt

In [30]:
pre_words = open("./words.txt", "r").readlines()
words = []
for w in pre_words: 
    words.append(w.replace("\n", ""))
words[:3]

['a', 'abaca', 'abache']

In [31]:
alphabet = sorted(list(set("".join(words))))
end_start = "." # taking same values to reduce number of empty lines
sti = {f"{s}":i+1 for i, s in enumerate(alphabet)}
sti["."] = 0
its = {v: k for k, v in sti.items()}

In [32]:
# creating dataset
xs, ys = [], []
for w in words:
    chs = [end_start] + list(w) + [end_start]
    for ch1, ch2 in zip(chs, chs[1:]):
        xs.append(sti[ch1])
        ys.append(sti[ch2])
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.numel()

In [33]:
# initialize model
g = torch.Generator().manual_seed(42)
W = torch.rand((27, 27), generator=g, requires_grad=True)

In [34]:
# training
for i in range(100):
    # forward pass
    xenc = one_hot(xs, num_classes=27).float()
    res = xenc @ W
    # softmax
    # --------
    couts = res.exp()
    probs = couts/couts.sum(keepdim=True, dim=1)
    # --------
    
    loss_log_likelihood = -probs[torch.arange(num), ys].log().mean() # + 0.001 * mean(W**2) => regularization
    print(loss_log_likelihood.item())

    # backward
    W.grad = None
    loss_log_likelihood.backward()
    
    # update
    W.data += -100 * W.grad
    

3.3548223972320557
2.8158023357391357
2.634100914001465
2.5466718673706055
2.4926984310150146
2.4561009407043457
2.4294817447662354
2.4094648361206055
2.3937582969665527
2.381303071975708
2.3710341453552246
2.362630605697632
2.3554279804229736
2.3494296073913574
2.3441085815429688
2.3396549224853516
2.335564374923706
2.332170009613037
2.328925132751465
2.3262979984283447
2.3236560821533203
2.321613311767578
2.3194122314453125
2.3178348541259766
2.3159594535827637
2.3147668838500977
2.3131279945373535
2.3122618198394775
2.3107857704162598
2.3101983070373535
2.3088219165802
2.30846905708313
2.3071389198303223
2.306980848312378
2.3056585788726807
2.3056585788726807
2.3043227195739746
2.3044521808624268
2.3030970096588135
2.303332567214966
2.301962375640869
2.302288055419922
2.300910234451294
2.301314353942871
2.299936532974243
2.300410032272339
2.299037456512451
2.2995738983154297
2.2982089519500732
2.2988028526306152
2.2974464893341064
2.2980940341949463
2.2967453002929688
2.297442674636

In [41]:
# generating text
g = torch.Generator().manual_seed(32)
for i in range(5): # generate 5 words
    out = []
    ix = 0 
    while True:
        xenc = one_hot(torch.tensor([ix]), num_classes=27).float()
        res = xenc @ W
        counts = res.exp()
        p = counts / counts.mean(keepdim=True, dim=1)
        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(its[ix])
        if ix == 0:
            break
    print("".join(out))

li.
sto.
rerasto.
anantiaghe.
cccatoscarova.
