In [1]:
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import torch
import torch.nn.functional as F

In [3]:
words = open('names.txt').read().splitlines()
len(words), words[:8]

(32033,
 ['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia'])

In [4]:
# let's only have one special token, and let's have it at index 0, offset others by 1
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
num_classes = len(stoi)

In [5]:
# build the dataset (only for N words right now)

block_size = 3 # How many characters do we take to predict the next one : 3 chars to predict the 4th
X, Y, = [], [] # X, input | Y, labels

for w in words[:5]:
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        context = context[1:] + [ix] # crop and append moving window

X = torch.tensor(X)
Y = torch.tensor(Y)

X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [6]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((3*2, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = [C, W1, b1, W2, b2]


sum(p.nelement() for p in parameters) # total number of parameters

3481

In [7]:
emb = C[X] # [32, 3, 2]
h = torch.tanh(emb.view(-1, W1.shape[0]) @ W1 + b1) # [32, 100]
logits = h @ W2 + b2 # [32, 27]
counts = logits.exp()
probs = counts / counts.sum(dim=1, keepdims=True)
loss = -probs[torch.arange(len(Y)), Y].log().mean()
loss

tensor(17.7697)

In [8]:
# BUT WAIT - Time to bring out the Cross Entropy guns
F.cross_entropy(logits, Y)

tensor(17.7697)

In [9]:
# So, once you get the logits
# JUST use cross entropy loss directly
# Don't roll your own
# No intermediate tensors when using cross entropy directly
# Fused kernels are used
# Backward pass is simpler - because backward pass can have simpler expression derived analytically
# Another reason is that cross entropy loss is numerically well behaved

In [10]:
logits = torch.tensor([-100, -3, 0, 1])
counts = logits.exp()
probs = counts / counts.sum()
counts, probs

(tensor([3.7835e-44, 4.9787e-02, 1.0000e+00, 2.7183e+00]),
 tensor([9.8091e-45, 1.3213e-02, 2.6539e-01, 7.2140e-01]))

In [11]:
# when logits are large positive numbers, the exp. operation can go up to inf, ruining the probs

In [12]:
logits = torch.tensor([-100, -3, 0, 100])
counts = logits.exp()
probs = counts / counts.sum()
counts, probs

(tensor([3.7835e-44, 4.9787e-02, 1.0000e+00,        inf]),
 tensor([0., 0., 0., nan]))

In [13]:
# but what if we subtracted the max value from the logits, then we have it stable again
# which is what's done in the cross entropy implementation

In [14]:
_logits = torch.tensor([-100, -3, 0, 100])
logits = logits - _logits.max()
counts = logits.exp()
probs = counts / counts.sum()
counts, probs

(tensor([0.0000e+00, 1.4013e-45, 3.7835e-44, 1.0000e+00]),
 tensor([0.0000e+00, 1.4013e-45, 3.7835e-44, 1.0000e+00]))

In [15]:
torch.tensor([-100, -3, 0, 100]).max()

tensor(100)

In [16]:
# REASONS for using cross entropy
# Forward pass efficient : fused kernels
# Backward pass efficient : simpler expression derived analytically in the implementation
# Numerically well behaved (stable)