In [30]:
import torch
import torch.nn.functional as F

words = open('names.txt', 'r').read().splitlines()

In [47]:
bichars = set()
chars = set()
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        pre = ch1 + ch2
        pos = ch3
        bichars.add(pre)
        chars.add(pos)
bichars = sorted(list(bichars))
chars = sorted(list(chars))
        

In [49]:
bicharstoi = {s:i for i,s in enumerate(bichars)}
itobichars = {i:s for s,i in bicharstoi.items()}
stoi = {s:i for i,s in enumerate(chars)}
itos = {i:s for s,i in stoi.items()}


In [50]:
#Create dataset

xs, ys =[], []
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = bicharstoi[ch1+ch2]
        ix2 = stoi[ch3]
        xs.append(ix1)
        ys.append(ix2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)

num = xs.nelement() #number of elements in xs
print(f'n of examples: {num}') 

n of examples: 196113


In [51]:
W = torch.rand((len(bichars),len(chars)), requires_grad=True)

In [53]:
xenc = F.one_hot(xs, num_classes = len(bichars)).float()

for k in range(200):
    #forward pass
    logits = xenc @ W

    #Softmax
    counts = (xenc @ W).exp()
    probs = counts / counts.sum(1, keepdims = True)

    loss = -probs[torch.arange(num), ys].log().mean()

    print(loss.item())

    #backward pass
    W.grad = None
    loss.backward()

    W.data += -100 * W.grad
    #loss += 0.01 * (W**2).mean()

2.263913631439209
2.2619376182556152
2.2600033283233643
2.2581095695495605
2.2562549114227295
2.2544379234313965
2.252657175064087
2.2509121894836426
2.249201536178589
2.2475244998931885
2.2458791732788086
2.244265556335449
2.2426822185516357
2.24112868309021
2.2396039962768555
2.2381069660186768
2.2366368770599365
2.2351937294006348
2.2337758541107178
2.2323832511901855
2.2310144901275635
2.2296695709228516
2.2283477783203125
2.227048397064209
2.2257704734802246
2.2245144844055176
2.223278522491455
2.2220633029937744
2.22086763381958
2.219691276550293
2.218533515930176
2.2173941135406494
2.2162725925445557
2.2151684761047363
2.2140815258026123
2.213010787963867
2.211956739425659
2.2109181880950928
2.209895372390747
2.208887815475464
2.207895040512085
2.206916332244873
2.2059521675109863
2.2050018310546875
2.2040648460388184
2.203141450881958
2.202230930328369
2.2013328075408936
2.2004473209381104
2.1995739936828613
2.1987125873565674
2.1978628635406494
2.1970245838165283
2.19619727134

In [71]:
def name_generator():
    name = ""

    idbi = torch.randint(0,26,(1,)).item() #first letter randomly generated
    bich = itobichars[idbi]
    name += bich[-1]
    
    while True:
        xenc = F.one_hot(torch.tensor([idbi]), num_classes=len(bichars)).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdims=True)

        idch = torch.multinomial(p, num_samples = 1, replacement = True).item() #samples from probability distribution p (returns index)
        ch = itos[idch]
        
        if len(name) < 3:
            while ch == '.':
                idch = torch.multinomial(p, num_samples = 1, replacement = True).item() #samples from probability distribution p (returns index)
                ch = itos[idch]
                bich = bich[-1] + ch
                idbi = bicharstoi[bich]

        if ch == ".":
            return name
        
        name += ch

        bich = bich[-1] + ch
        idbi = bicharstoi[bich]


In [77]:
[name_generator()]

['jayce']

In [42]:
itobichars[25]

'.z'