In [1]:
import torch
import torch.nn.functional as F
import matplotlib

%matplotlib inline

In [2]:
words = open('names.txt', 'r').read().splitlines()

chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}


In [3]:
context_size = 3

def build_xy(words):
    """
    Return X, Y tuple of training data and labels given words.

    X will contain one row for each example. Each example will contain `context_size`
    elements representing character indices.

    Y will contain a character index label for each example.
    """

    xs = []
    ys = []
    for word in words:
        context = [0] * context_size
        for ch in word:
            ich = stoi[ch]
            xs.append(context)
            ys.append(ich)
            context = context[1:] + [ich]
        xs.append(context)
        ys.append(0)
    assert len(xs) == len(ys)
    X = torch.tensor(xs)
    Y = torch.tensor(ys)
    return X, Y

X, Y = build_xy(words[:3])
print(X)
print(Y)

tensor([[ 0,  0,  0],
        [ 0,  0,  5],
        [ 0,  5, 13],
        [ 5, 13, 13],
        [13, 13,  1],
        [ 0,  0,  0],
        [ 0,  0, 15],
        [ 0, 15, 12],
        [15, 12,  9],
        [12,  9, 22],
        [ 9, 22,  9],
        [22,  9,  1],
        [ 0,  0,  0],
        [ 0,  0,  1],
        [ 0,  1, 22],
        [ 1, 22,  1]])
tensor([ 5, 13, 13,  1,  0, 15, 12,  9, 22,  9,  1,  0,  1, 22,  1,  0])


In [4]:

num_examples = Y.shape[0]

# Number of characters in our alphabet (the very first one is the terminator character).
vocab_size = 27

# Number of dimensions in vector space that we map each character to.
embedding_dims = 2

# The length of a context as a "flattened" array of each of its character's embeddings.
embedded_context_dims = context_size * embedding_dims

g = torch.Generator().manual_seed(2147483647)

# Matrix containing a "lookup table" from character indices to their embeddings in the vector space.
C = torch.randn((vocab_size, embedding_dims), dtype=torch.float, generator=g)

# Number of neurons in the hidden layer
w1_neurons = 10

# Hidden tanh layer
W1 = torch.randn((embedded_context_dims, w1_neurons), dtype=torch.float, generator=g)

# Final softmax layer
W2 = torch.randn((w1_neurons, vocab_size), dtype=torch.float, generator=g)

params = [C, W1, W2]

for param in params:
    param.requires_grad = True



In [5]:
num_training_rounds = 100

learning_rate = 0.1

def forward(X):
    num_examples = X.shape[0]

    # Each row is an example consisting of a "flattened" tensor of each character in the context.
    CX = C[X].view(num_examples, embedded_context_dims)

    # Make sure the very first example's first context item is the terminator character.
    # Commenting this out b/c we want this code to be used for more than just training!
    #terminator = C[0]
    #assert CX[0][:embedding_dims].tolist() == terminator.tolist()

    CXW1 = torch.tanh(CX @ W1)

    assert list(CXW1.shape) == [num_examples, w1_neurons]

    logits = CXW1 @ W2

    assert list(logits.shape) == [num_examples, vocab_size]

    # TODO: Use torch's softmax here to improve efficiency.
    fake_counts = logits.exp()

    probs = fake_counts / torch.sum(fake_counts, dim=1, keepdim=True)

    # Ensure the probabilities of all characters in the first example sum to approximately 1.0.
    assert probs[0].sum() - 1.0 < 0.000001

    return probs

for i in range(num_training_rounds):
    probs = forward(X)

    # TODO: Use torch's cross-entropy loss here to improve efficiency.
    loss = -probs[range(num_examples), Y].log().mean()

    print("LOSS: ", loss.item())

    for param in params:
        param.grad = None
    
    loss.backward()

    for param in params:
        param.data += -learning_rate * param.grad



LOSS:  6.117327690124512
LOSS:  5.758670330047607
LOSS:  5.432277202606201
LOSS:  5.146899700164795
LOSS:  4.900434970855713
LOSS:  4.682767391204834
LOSS:  4.48641300201416
LOSS:  4.306982517242432
LOSS:  4.140802383422852
LOSS:  3.9846444129943848
LOSS:  3.8361637592315674
LOSS:  3.6939735412597656
LOSS:  3.557457447052002
LOSS:  3.4264895915985107
LOSS:  3.301115036010742
LOSS:  3.1812713146209717
LOSS:  3.066671371459961
LOSS:  2.9568614959716797
LOSS:  2.8513476848602295
LOSS:  2.7496840953826904
LOSS:  2.65151309967041
LOSS:  2.5565662384033203
LOSS:  2.4646553993225098
LOSS:  2.3756566047668457
LOSS:  2.289499044418335
LOSS:  2.206149101257324
LOSS:  2.125601291656494
LOSS:  2.0478672981262207
LOSS:  1.972965955734253
LOSS:  1.9009182453155518
LOSS:  1.831740140914917
LOSS:  1.7654387950897217
LOSS:  1.7020087242126465
LOSS:  1.641431212425232
LOSS:  1.5836721658706665
LOSS:  1.5286822319030762
LOSS:  1.4763970375061035
LOSS:  1.4267381429672241
LOSS:  1.3796145915985107
LOSS:  

In [6]:
def predict_greedy(context_str='', num_chars=1):
    """
    Given a starting context, returns the context with the most likely next characters.
    """

    context = ([0] * context_size + [stoi[ch] for ch in context_str])[-context_size:]
    X = torch.tensor([context])
    probs = forward(X)
    new_context_str = context_str + itos[probs[0].argmax().item()]
    num_chars -= 1
    if num_chars == 0:
        return new_context_str
    return predict_greedy(new_context_str, num_chars)

# This is funny
predict_greedy('e', 100)


'emma...ava...ava...ava...ava...ava...ava...ava...ava...ava...ava...ava...ava...ava...ava...ava...ava.'