In [1]:
import torch
import torch.nn.functional as F 
import matplotlib.pyplot as pyplot
%matplotlib inline

In [60]:
words = open("words_alpha.txt", "r").read().splitlines()

words[1:10]

['aa', 'aaa', 'aah', 'aahed', 'aahing', 'aahs', 'aal', 'aalii', 'aaliis']

In [61]:
len(words)

370104

In [62]:
# lookup tables

characters = sorted(list(set("".join(words))))
stoi = {s : i + 1 for i, s in enumerate(characters)}
stoi["."] = 0
itos = {i : s for s, i in stoi.items()}

In [63]:
# building the dataset

block_size = 3 # how many chars used to predict the next character
X, y = [], []

for w in words[:3]:
    context = [0] * block_size # start characters in the start
    for ch in w + '.':
        ix = stoi[ch] 
        X.append(context)
        y.append(ix)
        context = context[1:] + [ix] # crop and append 

X = torch.tensor(X)
y = torch.tensor(y)

tensor([[[ 1.5190, -0.7319],
         [ 1.5190, -0.7319],
         [ 1.5190, -0.7319]],

        [[ 1.5190, -0.7319],
         [ 1.5190, -0.7319],
         [ 1.6754,  0.8582]],

        [[ 1.5190, -0.7319],
         [ 1.5190, -0.7319],
         [ 1.5190, -0.7319]],

        [[ 1.5190, -0.7319],
         [ 1.5190, -0.7319],
         [ 1.6754,  0.8582]],

        [[ 1.5190, -0.7319],
         [ 1.6754,  0.8582],
         [ 1.6754,  0.8582]],

        [[ 1.5190, -0.7319],
         [ 1.5190, -0.7319],
         [ 1.5190, -0.7319]],

        [[ 1.5190, -0.7319],
         [ 1.5190, -0.7319],
         [ 1.6754,  0.8582]],

        [[ 1.5190, -0.7319],
         [ 1.6754,  0.8582],
         [ 1.6754,  0.8582]],

        [[ 1.6754,  0.8582],
         [ 1.6754,  0.8582],
         [ 1.6754,  0.8582]]])

In [77]:
g = torch.Generator().manual_seed(2)

# encoding alphabets into 2-dim space
C = torch.randn((27, 2), generator = g)

# hidden layer 1
W1 = torch.randn((6, 100), generator = g)
b1 = torch.randn(100, generator = g)

# output layer
W2 = torch.randn((100, 27), generator = g)
b2 = torch.randn(27, generator = g)

parameters = [C, W1, b1, W2, b2]

emb = C[X] # indexing using X
# emb = torch.cat([emb[:, 0, :], emb[:, 1, :], emb[:, 2, :]]) # concatenating along dimension 1
# emb = torch.cat(torch.unbind(emb, 1), 1) # concatenating along dimension 1 (dynamic)
# emb = emb.view(emb.shape[0], 6) # concatenating along dimension 1 (dynamic) (efficient)
h = torch.tanh((emb.view(emb.shape[0], 6) @ W1) + b1)
logits = ((h @ W2) + b2)
# counts = logits.exp()
# probs = counts / counts.sum(1, keepdim = True)
# loss = -probs[torch.arange(emb.shape[0]), y].log().mean()
# can be done using cross_entropy
loss = F.cross_entropy(logits, y)

loss

tensor(15.4757)