In [1]:
import torch
import torch.nn.functional as F
import matplotlib

%matplotlib inline

In [2]:
words = open('names.txt', 'r').read().splitlines()

chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}


In [3]:
context_size = 3

def build_training_data(words):
    """
    Return X, Y tuple of training data and labels given words.

    X will contain one row for each example. Each example will contain `context_size`
    elements representing character indices.

    Y will contain a character index label for each example.
    """

    xs = []
    ys = []
    for word in words:
        context = [0] * context_size
        for ch in word:
            ich = stoi[ch]
            xs.append(context)
            ys.append(ich)
            context = context[1:] + [ich]
        xs.append(context)
        ys.append(0)
    assert len(xs) == len(ys)
    X = torch.tensor(xs)
    Y = torch.tensor(ys)
    return X, Y

X, Y = build_training_data(words)
print(X)
print(Y)

tensor([[ 0,  0,  0],
        [ 0,  0,  5],
        [ 0,  5, 13],
        ...,
        [26, 26, 25],
        [26, 25, 26],
        [25, 26, 24]])
tensor([ 5, 13, 13,  ..., 26, 24,  0])


In [4]:

# Number of characters in our alphabet (the very first one is the terminator character).
vocab_size = 27

# Number of dimensions in vector space that we map each character to.
embedding_dims = 2

# The length of a context as a "flattened" array of each of its character's embeddings.
embedded_context_dims = context_size * embedding_dims

g = torch.Generator().manual_seed(2147483647)

# Matrix containing a "lookup table" from character indices to their embeddings in the vector space.
C = torch.randn((vocab_size, embedding_dims), dtype=torch.float, generator=g)

# Number of neurons in the hidden layer
w1_neurons = 100

# Hidden tanh layer
W1 = torch.randn((embedded_context_dims, w1_neurons), dtype=torch.float, generator=g)

# Final softmax layer
W2 = torch.randn((w1_neurons, vocab_size), dtype=torch.float, generator=g)

params = [C, W1, W2]

for param in params:
    param.requires_grad = True



In [9]:
def forward(X):
    num_examples = X.shape[0]

    # Each row is an example consisting of a "flattened" tensor of each character in the context.
    CX = C[X].view(num_examples, embedded_context_dims)

    # Make sure the very first example's first context item is the terminator character.
    # Commenting this out b/c we want this code to be used for more than just training!
    #terminator = C[0]
    #assert CX[0][:embedding_dims].tolist() == terminator.tolist()

    CXW1 = torch.tanh(CX @ W1)

    assert list(CXW1.shape) == [num_examples, w1_neurons]

    logits = CXW1 @ W2

    assert list(logits.shape) == [num_examples, vocab_size]

    # TODO: Use torch's softmax here to improve efficiency.
    fake_counts = logits.exp()

    probs = fake_counts / torch.sum(fake_counts, dim=1, keepdim=True)

    # Ensure the probabilities of all characters in the first example sum to approximately 1.0.
    assert probs[0].sum() - 1.0 < 0.000001

    return probs

def calc_loss(probs, Y):
    num_examples = probs.shape[0]
    assert num_examples == Y.shape[0]

    # TODO: Use torch's cross-entropy loss here to improve efficiency.
    loss = -probs[range(num_examples), Y].log().mean()

    return loss

def train(rounds=10000, learning_rate=0.1, minibatch_size=100):
    num_examples = X.shape[0]

    for i in range(rounds):
        minibatch_indexes = torch.randint(0, num_examples, (minibatch_size,))
        minibatch = X[minibatch_indexes]

        probs = forward(minibatch)

        loss = calc_loss(probs, Y[minibatch_indexes])
        print("minibatch loss: ", loss.item())

        for param in params:
            param.grad = None
        
        loss.backward()

        for param in params:
            param.data += -learning_rate * param.grad

train(rounds=10_000, learning_rate=0.1)
train(rounds=1_000, learning_rate=0.01)

print("Final loss over entire training set:", calc_loss(forward(X), Y).item())


minibatch loss:  2.323869228363037
minibatch loss:  2.230686664581299
minibatch loss:  2.353245735168457
minibatch loss:  2.4613289833068848
minibatch loss:  2.369802951812744
minibatch loss:  2.1387622356414795
minibatch loss:  2.3100905418395996
minibatch loss:  2.519531488418579
minibatch loss:  2.588667392730713
minibatch loss:  2.2401325702667236
minibatch loss:  2.2943525314331055
minibatch loss:  2.5229642391204834
minibatch loss:  2.509636640548706
minibatch loss:  2.3954758644104004
minibatch loss:  2.5106048583984375
minibatch loss:  2.3783750534057617
minibatch loss:  2.4074313640594482
minibatch loss:  2.31145977973938
minibatch loss:  2.3644659519195557
minibatch loss:  2.3104021549224854
minibatch loss:  2.1965348720550537
minibatch loss:  2.600968599319458
minibatch loss:  2.171712875366211
minibatch loss:  2.3615829944610596
minibatch loss:  2.609135389328003
minibatch loss:  2.1594557762145996
minibatch loss:  2.434515953063965
minibatch loss:  2.2978148460388184
minib

In [21]:
def predict(context_str='', num_chars=1000, stop_on_terminator=True, greedy=False):
    """
    Given an optional starting context, predicts next character(s) in the sequence.
    """

    while num_chars > 0:
        context = ([0] * context_size + [stoi[ch] for ch in context_str])[-context_size:]
        X = torch.tensor([context])
        probs = forward(X)
        if greedy:
            next_idx = probs[0].argmax().item()
        else:
            next_idx = torch.multinomial(probs[0], 1, replacement=True).item()
        if next_idx == 0 and stop_on_terminator:
            break
        context_str = context_str + itos[next_idx]
        num_chars -= 1
    return context_str

for _ in range(10):
    print(predict(''))


jaidannonca
ela
amulishere
mavena
likah
khekt
piaxin
sri
brocan
jamanna
