In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline


In [2]:
# Create Dataset
words = open('names.txt', 'r').read().splitlines()
print(words[:10])

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn']


In [3]:
# Tokenizer - convert words into numbers for computer

chars = sorted(list(set("".join(words))))
char_to_int = dict()

for i in range (len(chars)):
    char_to_int[chars[i]] = i + 1

char_to_int["."] = 0
int_to_chars = ["."]+chars
char_to_int

{'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26,
 '.': 0}

In [20]:
# Prepare Dataset
xs, ys = [], []

for w in words[:1]:

    chs = ["."] + ["."] + ["."] + list(w) + ["."]
    for ch1, ch2, ch3, ch4 in zip(chs,chs[1:],chs[2:],chs[3:]):

        ind1 = char_to_int[ch1]
        ind2 = char_to_int[ch2]
        ind3 = char_to_int[ch3]
        ind4 = char_to_int[ch4]

        xs.append([ind1,ind2,ind3])
        ys.append(ind4)

        print("{}{}{} ----> {}".format(ch1,ch2,ch3,ch4))

xs = torch.tensor(xs)
ys = torch.tensor(ys)
num_examples = len(xs)
print("Num of examples: ", num_examples)

... ----> e
..e ----> m
.em ----> m
emm ----> a
mma ----> .
Num of examples:  5


In [23]:
# Hyperparameters
vocab_size = 27
context_length = 3
embedding_size = 2
W1_size = 100

g = torch.Generator().manual_seed(2147483647)

# Create NN Parameters
C = torch.randn((vocab_size,embedding_size), generator=g) #init with small random numbers near zero
W1 = torch.randn((context_length*embedding_size,W1_size), generator=g)
b1 = torch.randn((W1_size), generator=g)
W2 = torch.randn((W1_size,vocab_size), generator=g)
b2 = torch.randn((vocab_size), generator=g)

params = [C, W1, b1, W2, b2]

for p in params:
    p.requires_grad = True
    
n_params = sum(p.numel() for p in params)
print("Num of params: ", n_params)

Num of params:  3481


In [24]:
# Forward Pass

for i in range(100):
    xs_embed = C[xs]
    xs_embed_stack = xs_embed.view(-1,context_length*embedding_size)
    # xs_embed_stack = torch.cat((xs_embed[:,0,:],xs_embed[:,1,:],xs_embed[:,2,:]), dim=1)

    l1 = torch.tanh(xs_embed_stack @ W1 + b1)
    logits = l1 @ W2 + b2
    loss = F.cross_entropy(logits, ys) 
    print(loss.item())
    for p in params:
        p.grad = None

    loss.backward()

    for p in params:
        p.data += -0.1 * p.grad

24.941762924194336
15.248147964477539
8.964353561401367
4.224044322967529
1.2105400562286377
0.29714006185531616
0.09521935135126114
0.06758875399827957
0.053451500833034515
0.04457003250718117
0.03839101642370224
0.03381089121103287
0.03026415780186653
0.027428239583969116
0.02510356903076172
0.023160241544246674
0.021509408950805664
0.020088452845811844
0.01885126158595085
0.01776372827589512
0.016799679026007652
0.01593897119164467
0.015165427699685097
0.014466166496276855
0.013830659911036491
0.013250870630145073
0.012719196267426014
0.012230100110173225
0.011778381653130054
0.01135995052754879
0.010971111245453358
0.01060895249247551
0.01027065608650446
0.009953794069588184
0.009656643494963646
0.009377219714224339
0.009113977663218975
0.00886558834463358
0.008630714379251003
0.008408376947045326
0.008197424001991749
0.007997202686965466
0.007806727197021246
0.0076254382729530334
0.0074525573290884495
0.007287570275366306
0.007129936013370752
0.006979162339121103
0.006834803614765

Key Ideas is that we are going to learn the embeddings (is this a encoder v1?)

Naively if you feed a=1, b=2, c=3 into the network, then there is a bias as a and b are similar. 

Instead we can represent each symbol as a point in n-dimensional space. and have the network learn what that point is. We would then expect conceptually similar symbols to be close in space!

In [26]:
print(logits.max(1).indices)

for pred in logits.max(1).indices:
    print(int_to_chars[pred.item()])


tensor([ 5, 13, 13,  1,  0])
e
m
m
a
.
