In [33]:
words = open('names.txt', 'r').read().splitlines()
chars = list(set(''.join(words)))
chars.sort()
ctoi = {c: i + 1 for i, c in enumerate(chars)}
ctoi["."] = 0
itoc = {i: c for c, i in ctoi.items()}

In [34]:
import torch
context_size = 3
xs = []
ys = []

for word in words:
    context = [0] * context_size
    for char in word:
        xs.append(context)
        ys.append(ctoi[char])
        context = context[1:] + [ctoi[char]]
    
    xs.append(context)
    ys.append(0)


xs = torch.tensor(xs)
ys = torch.tensor(ys)

In [35]:
train_size = int(xs.shape[0] * 0.8)
dev_size = int(xs.shape[0] * 0.1)

train_xs = xs[:train_size]
train_ys = ys[:train_size]

dev_xs = xs[train_size:train_size + dev_size]
dev_ys = ys[train_size:train_size + dev_size]

test_xs = xs[train_size + dev_size:]
test_ys = ys[train_size + dev_size:]

In [82]:
embedding_size = 10
embedding_context = embedding_size * context_size
hidden_layer_size = 200
g = torch.Generator().manual_seed(2147483647)

C = torch.randn((len(ctoi.keys()), embedding_size), generator=g)
W1 = torch.randn((embedding_context, hidden_layer_size), generator=g) *5/3 / (embedding_context) ** 0.5
B1 = torch.randn((hidden_layer_size), generator=g)
GAMMA1 = torch.ones((1, hidden_layer_size))
BETA1 = torch.zeros((1, hidden_layer_size))

W2 = torch.randn((hidden_layer_size, len(ctoi.keys())), generator=g) / (hidden_layer_size) ** 0.5
B2 = torch.randn(len(ctoi.keys()), generator=g)
params = [C, W1, B1, GAMMA1, BETA1, W2, B2]

running_mean = torch.zeros((1, hidden_layer_size))
running_std = torch.ones((1, hidden_layer_size))

for p in params:
    p.requires_grad = True

In [96]:
import torch.nn.functional as F

batch_size = 32

for i in range(200000):
    #minibatch
    mini_batch = torch.randint(0, train_xs.shape[0], (batch_size,), generator=g)

    embed = C[train_xs[mini_batch]]
    first_layer = embed.view(-1,embedding_context) @ W1 + B1
    
    first_layer_mean = first_layer.mean(0, keepdim = True)
    first_layer_std = first_layer.std(0, keepdim = True)

    with torch.no_grad():
        running_mean = 0.9 * running_mean + 0.1 * first_layer_mean
        running_std = 0.9 * running_std + 0.1 * first_layer_std

    normalized_first_layer =  GAMMA1 * ((first_layer - first_layer_mean)/ first_layer_std) + BETA1
    firt_layer_logits = torch.tanh(normalized_first_layer)

    second_layer_logits = firt_layer_logits @ W2 + B2
    loss = F.cross_entropy(second_layer_logits, train_ys[mini_batch])

    # backward
    for p in params:
        p.grad = None

    loss.backward()

    lr = 0.1 if i < 100000 else 0.01 
    for p in params:
        p.data += -lr * p.grad


print(loss.item())

2.031841993331909


In [97]:
with torch.no_grad():
    embed = C[dev_xs]
    first_layer = torch.tanh(embed.view(-1,embedding_context) @ W1 + B1)
    normalized_first_layer =  GAMMA1 * ((first_layer - running_mean)/ running_std) + BETA1
    firt_layer_logits = torch.tanh(normalized_first_layer)

    second_layer_logits = firt_layer_logits @ W2 + B2
    loss = F.cross_entropy(second_layer_logits, dev_ys)
    print(loss.item())

3.6338751316070557


In [86]:
with torch.no_grad():
    for i in range(15):
        word = []
        idxs = [0] * context_size
        while True:
            embed = C[torch.tensor([idxs])]
            firt_layer_logits = torch.tanh(embed.view(-1,embedding_context) @ W1 + B1)
            second_layer_logits = firt_layer_logits @ W2 + B2
            probs = F.softmax(second_layer_logits, dim=1)
            ix = torch.multinomial(probs, num_samples=1).item()
            word.append(itoc[ix])
            if ix == 0:
                break
            
            idxs = idxs[1:] + [ix]
        
        print("".join(word))





zayzussffwsnnstsenassannadangssinahlangrnwynzangsczozongunsongissznassanaqukeawasynsstenanqansusirangingannstannssangsonssangolcnssnstonnyszrassfssssannssspssffannstanghjanelessanssevannassanousgzizmangstfnsnnsassonsannvssnnstanghnysanghjvunmangrnyszbastsopshfwangjnay.
zbnggonnssnuz.
zcflinazzashbwanthlvnyzmminjasnnssinonsazspswfsssqmnssonssttangaskensonnghsathlannzqunnghnnlonankstsslina.
nyfsanahjanmosefuzsangpannoksavannclnnxthsnassasajinndgtonnnjritzoustfnandclanghanzinosasszofsannonsz.
sangongrnfacsevsinossnyntwzoninnsxsanjssnnasissabbazsasufffstfannsonnigknjannanchunnghnassfsssinnansfnostinzsanszann.
zangssnanthlusslps.
zahmasssiffnnysssanjonynnsznogwaspflongwnzsosgwennstasssannglnnssanphannxsonninstrssjynnssynazviszavigwnnsssnozlissnnsznasson.
sannssannsusfzosstosfnansossvni.
kogtnysssfnasssonshsannsssannannasalfanngssfnanwanzosgwonthnnghnysanrfswnsin.
sswandanzmmourassenassslynnsssongalvnnossaspsssssponnn.
assffannstzf.
bzanggxonnnssspjovannosonjynnsssonyssanotsonsm.
krwnassajnn