In [1]:
import torch
import torch.nn.functional as F
import matplotlib
import random

%matplotlib inline

In [2]:
words = open('names.txt', 'r').read().splitlines()

random.shuffle(words)

chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}


In [3]:
context_size = 4

def build_training_data(words):
    """
    Return X, Y tuple of training data and labels given words.

    X will contain one row for each example. Each example will contain `context_size`
    elements representing character indices.

    Y will contain a character index label for each example.
    """

    xs = []
    ys = []
    for word in words:
        context = [0] * context_size
        for ch in word:
            ich = stoi[ch]
            xs.append(context)
            ys.append(ich)
            context = context[1:] + [ich]
        xs.append(context)
        ys.append(0)
    assert len(xs) == len(ys)
    X = torch.tensor(xs)
    Y = torch.tensor(ys)
    return X, Y

train_cutoff = int(0.8 * len(words))
dev_cutoff = int(0.9 * len(words))

X_train, Y_train = build_training_data(words[:train_cutoff])
X_dev, Y_dev = build_training_data(words[train_cutoff:dev_cutoff])
X_test, Y_test = build_training_data(words[dev_cutoff:])

print("training examples:", len(Y_train))
print("dev examples:", len(Y_dev))
print("test examples:", len(Y_test))



training examples: 182437
dev examples: 22819
test examples: 22890


In [4]:

# Number of characters in our alphabet (the very first one is the terminator character).
vocab_size = 27

# Number of dimensions in vector space that we map each character to.
embedding_dims = 3

# The length of a context as a "flattened" array of each of its character's embeddings.
embedded_context_dims = context_size * embedding_dims

g = torch.Generator().manual_seed(2147483647)

# Matrix containing a "lookup table" from character indices to their embeddings in the vector space.
C = torch.randn((vocab_size, embedding_dims), dtype=torch.float, generator=g)

# Number of neurons in the hidden layer
w1_neurons = 200

# Hidden tanh layer
W1 = torch.randn((embedded_context_dims, w1_neurons), dtype=torch.float, generator=g)

b1 = torch.randn(w1_neurons, dtype=torch.float, generator=g)

# Final softmax layer, scaled by 0.1 to make the initial weights be as similar to
# each other as possible to start, thus ultimately giving each character an equal
# probability, which results in a much better initial loss (described in beginning
# of lecture 4).
W2 = torch.randn((w1_neurons, vocab_size), dtype=torch.float, generator=g) * 0.1

# Initialize softmax biases to 0 so every character has equal probability (see above).
b2 = torch.randn(vocab_size, dtype=torch.float, generator=g) * 0

params = [C, W1, b1, W2, b2]

for param in params:
    param.requires_grad = True

print("Total parameters:", sum([param.numel() for param in params]))


Total parameters: 8108


In [5]:
def forward(X):
    num_examples = X.shape[0]

    # Each row is an example consisting of a "flattened" tensor of each character in the context.
    CX = C[X].view(num_examples, embedded_context_dims)

    # Make sure the very first example's first context item is the terminator character.
    # Commenting this out b/c we want this code to be used for more than just training!
    #terminator = C[0]
    #assert CX[0][:embedding_dims].tolist() == terminator.tolist()

    CXW1 = torch.tanh(CX @ W1 + b1)

    assert list(CXW1.shape) == [num_examples, w1_neurons]

    logits = CXW1 @ W2 + b2

    assert list(logits.shape) == [num_examples, vocab_size]

    # TODO: Use torch's softmax here to improve efficiency.
    fake_counts = logits.exp()

    probs = fake_counts / torch.sum(fake_counts, dim=1, keepdim=True)

    # Ensure the probabilities of all characters in the first example sum to approximately 1.0.
    assert probs[0].sum() - 1.0 < 0.000001

    return probs

def calc_loss(probs, Y):
    num_examples = probs.shape[0]
    assert num_examples == Y.shape[0]

    # TODO: Use torch's cross-entropy loss here to improve efficiency.
    loss = -probs[range(num_examples), Y].log().mean()

    return loss

def train(rounds, learning_rate, minibatch_size=500, X=X_train, Y=Y_train):
    num_examples = X.shape[0]

    for i in range(rounds):
        minibatch_indexes = torch.randint(0, num_examples, (minibatch_size,))
        minibatch = X[minibatch_indexes]

        probs = forward(minibatch)

        loss = calc_loss(probs, Y[minibatch_indexes])
        print("minibatch loss: ", loss.item())

        for param in params:
            param.grad = None
        
        loss.backward()

        for param in params:
            param.data += -learning_rate * param.grad

train(rounds=10_000, learning_rate=0.1)
train(rounds=10_000, learning_rate=0.01)

print("Final loss over training set:", calc_loss(forward(X_train), Y_train).item())
print("Final loss over dev set:", calc_loss(forward(X_dev), Y_dev).item())


minibatch loss:  3.9783127307891846
minibatch loss:  3.8291420936584473
minibatch loss:  3.799957036972046
minibatch loss:  3.738525867462158
minibatch loss:  3.523550271987915
minibatch loss:  3.4754014015197754
minibatch loss:  3.451584577560425
minibatch loss:  3.336010456085205
minibatch loss:  3.3330812454223633
minibatch loss:  3.2760818004608154
minibatch loss:  3.3099379539489746
minibatch loss:  3.1809022426605225
minibatch loss:  3.1116833686828613
minibatch loss:  3.08756685256958
minibatch loss:  3.128166437149048
minibatch loss:  3.126695394515991
minibatch loss:  3.1019556522369385
minibatch loss:  3.127678155899048
minibatch loss:  3.107304096221924
minibatch loss:  2.9658963680267334
minibatch loss:  2.9726381301879883
minibatch loss:  3.0871078968048096
minibatch loss:  2.936527967453003
minibatch loss:  2.98783802986145
minibatch loss:  3.078272819519043
minibatch loss:  2.904601573944092
minibatch loss:  2.942866802215576
minibatch loss:  2.9598991870880127
minibatch

In [6]:
def predict(context_str='', num_chars=1000, stop_on_terminator=True, greedy=False):
    """
    Given an optional starting context, predicts next character(s) in the sequence.
    """

    while num_chars > 0:
        context = ([0] * context_size + [stoi[ch] for ch in context_str])[-context_size:]
        X = torch.tensor([context])
        probs = forward(X)
        if greedy:
            next_idx = probs[0].argmax().item()
        else:
            next_idx = torch.multinomial(probs[0], 1, replacement=True).item()
        if next_idx == 0 and stop_on_terminator:
            break
        context_str = context_str + itos[next_idx]
        num_chars -= 1
    return context_str

for _ in range(10):
    print(predict(''))


adricleu
calie
lashrii
elian
vonlia
xael
kathulmy
lileigh
koie
braan


In [7]:
for _ in range(10): print(predict('atu'))

aturiavabe
atur
atuideeterianna
aturi
atunbnn
atua
atua
atush
atur
atullinn
