In [26]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import random
%matplotlib inline

In [27]:
DATASET_PATH = "./datasets/birds/birds.csv"

birds = open(DATASET_PATH, "r").read().splitlines()

print("First 10 birds in the dataset:\n")
print(", ".join(birds[:10]))
print(f"There are {len(birds):,d} birds in the dataset.")

min_length = map(len, birds)
max_length = map(len, birds)
print(f"\nThe shortest character name has {min(min_length)} characters.")
print(f"The longest character name has {max(max_length)} characters.")

First 10 birds in the dataset:

Abbott's babbler, Abbott's booby, Abbott's starling, Abbott's sunbird, Abd al-Kuri sparrow, Abdim's stork, Aberdare cisticola, Aberrant bush warbler, Abert's towhee, Abyssinian catbird
There are 10,976 birds in the dataset.

The shortest character name has 3 characters.
The longest character name has 35 characters.


In [28]:
from unidecode import unidecode

def clean_name(name):
    """
    Clean the bird name by:
    - Removing leading and trailing whitespaces
    - Converting to lowercase
    - Removing accents
    - Removing special characters
    - Replacing spaces with underscores
    """

    name = name.strip().lower()
    # replace special characters with a space
    name = ''.join(char if char.isalnum() or char.isspace() else ' ' for char in name)
    name = name.replace("`", "_")  # Remove apostrophes
    name = name.replace(" ", "_")
    name = unidecode(name)
    return name

In [29]:
# clean all names in the dataset
birds = list(map(clean_name, birds))

# create a mapping from tokens to indices
unique_tokens = set([c for w in birds for c in w])
SPECIAL_TOKEN = "."
index_to_token = {i: t for i, t in enumerate(unique_tokens, start=1)}
token_to_index = {v: k for k, v in index_to_token.items()}
index_to_token[0] = SPECIAL_TOKEN
token_to_index[SPECIAL_TOKEN] = 0

# log information about the tokenization
print(f"Number of unique tokens: {len(unique_tokens)}")
print(", ".join(sorted(unique_tokens)))
print(f"\nToken mapping: {index_to_token}")

Number of unique tokens: 28
_, `, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u, v, w, x, y, z

Token mapping: {1: 'x', 2: 't', 3: 'y', 4: 's', 5: 'u', 6: 'q', 7: 'l', 8: 'e', 9: 'w', 10: 'z', 11: 'a', 12: 'o', 13: '`', 14: 'j', 15: 'n', 16: 'v', 17: 'b', 18: 'h', 19: 'd', 20: '_', 21: 'i', 22: 'c', 23: 'g', 24: 'r', 25: 'f', 26: 'p', 27: 'k', 28: 'm', 0: '.'}


In [30]:
# Model parameters
CONTEXT_SIZE = 3
N_EMBEDDINGS = 10
N_HIDDEN = 64
N_TOKEN = len(token_to_index)

# Training parameters
TRAINING_SET_PORTION = 0.8
DEVELOPMENT_SET_PORTION = 0.1
TEST_SET_PORTION = 1 - (TRAINING_SET_PORTION + DEVELOPMENT_SET_PORTION)

In [31]:
def build_datasets(words: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Build datasets from a list of words by creating input and target tensors.
    
    Args:
        words (list[str]): List of words to build the datasets from.
        
    Returns:
        tuple[torch.Tensor, torch.Tensor]: Tuple containing the input tensor X and target tensor Y.
    """
    # Create a mapping from tokens to indices
    X, Y = [], []
    
    # Create the context for each character in the words
    for w in words:
        context = [0] * CONTEXT_SIZE
        for ch in w + SPECIAL_TOKEN:  # Add special token at the end
            ix = token_to_index[ch]
            X.append(context)
            Y.append(ix)
            # Update the context by shifting it and adding the new index 
            context = context[1:] + [ix]

    # Convert lists to tensors
    X = torch.tensor(X, dtype=torch.int64)
    Y = torch.tensor(Y, dtype=torch.int64)

    return X, Y
    
# Shuffle the words
random.seed(1234)
random.shuffle(birds)

# Split the dataset into training, development, and test sets
train_size = int(TRAINING_SET_PORTION * len(birds))
dev_size = int(DEVELOPMENT_SET_PORTION * len(birds))

X_train, Y_train = build_datasets(birds[:train_size])
X_dev, Y_dev = build_datasets(birds[train_size:train_size + dev_size])
X_test, Y_test = build_datasets(birds[train_size + dev_size:])

# print tensor shapes
print("Training set shape:", X_train.shape, Y_train.shape)
print("Development set shape:", X_dev.shape, Y_dev.shape)
print("Test set shape:", X_test.shape, Y_test.shape)

Training set shape: torch.Size([172513, 3]) torch.Size([172513])
Development set shape: torch.Size([21531, 3]) torch.Size([21531])
Test set shape: torch.Size([21461, 3]) torch.Size([21461])


In [94]:
def cmp(s, dt, t):
    """
    Compare the true gradient dt and the approximate gradient t.grad.
    Print the results in a table format.
    """
    assert t.grad.shape == dt.shape, f"Shape mismatch: expected {t.grad.shape}, got {dt.shape}"

    exact = torch.all(dt==t.grad).item()
    approx = torch.allclose(dt, t.grad)
    max_diff = (dt - t.grad).abs().max().item()
    print(f'{s:15s} | exact: {str(exact):5s} | approximate: {str(approx):5s} | maxdiff: {max_diff}')

In [33]:
g = torch.Generator().manual_seed(123456789) # for reproducibility

# Embedding matrix
C = torch.randn((N_TOKEN, N_EMBEDDINGS), generator=g)

# Layer 1
W1 = torch.randn((N_EMBEDDINGS * CONTEXT_SIZE, N_HIDDEN), generator=g) * (5/3) / (N_EMBEDDINGS * CONTEXT_SIZE)**0.5
b1 = torch.randn(N_HIDDEN, generator=g) * 0.1

# Layer 2
W2 = torch.randn((N_HIDDEN, N_TOKEN), generator=g) * 0.1
b2 = torch.randn(N_TOKEN, generator=g) * 0.1

# Batch normalization
bngain = torch.randn((1, N_HIDDEN), generator=g) * 0.1 + 1.0
bnbias = torch.randn((1, N_HIDDEN), generator=g) * 0.1

# Parameters
parameters = [C, W1, b1, W2, b2, bngain, bnbias]

# Model size
print(f"Model size: {sum(p.numel() for p in parameters)}")

# Turn on gradient tracking
for p in parameters:
    p.requires_grad = True

Model size: 4287


In [34]:
batch_size = 32
ix = torch.randint(0, X_train.shape[0], (batch_size,), generator=g)
Xb, Yb = X_train[ix], Y_train[ix]

In [199]:
# Forward pass on a single batch
n = batch_size
emb = C[Xb]                                            # shape (batch_size, context_size, embedding_size)
embcat = emb.view(emb.shape[0], -1)                    # shape (batch_size, context_size * embedding_size)

# Layer 1
hprebn = embcat @ W1 + b1                              # shape (batch_size, hidden_size)

# Batch normalization
bnmeani = hprebn.sum(dim=0, keepdim=True) / n          # shape (1, hidden_size)
bndiff = hprebn - bnmeani                              # shape (batch_size, hidden_size)
bndiff2 = bndiff ** 2                                  # shape (batch_size, hidden_size)
bnvar = bndiff2.sum(dim=0, keepdim=True) / (n - 1)     # shape (1, hidden_size)
bnvar_inv = 1 / torch.sqrt(bnvar + 1e-5)               # shape (1, hidden_size)
bnraw = bndiff * bnvar_inv                             # shape (batch_size, hidden_size)
hpreact = bngain * bnraw + bnbias                      # shape (batch_size, hidden_size)

# Non-linearity
h = torch.tanh(hpreact)                                # shape (batch_size, hidden_size)
logits = h @ W2 + b2                                   # shape (batch_size, vocab_size)

# Cross-entropy loss
logit_maxes = logits.max(dim=1, keepdim=True).values   # shape (batch_size, 1)
norm_logits = logits - logit_maxes                     # shape (batch_size, vocab_size)
counts = norm_logits.exp()                             # shape (batch_size, vocab_size)
counts_sum = counts.sum(dim=1, keepdim=True)           # shape (batch_size, 1)
counts_sum_inv = counts_sum ** -1                      # shape (batch_size, 1)
probs = counts * counts_sum_inv                        # shape (batch_size, vocab_size)
logprobs = probs.log()                                 # shape (batch_size, vocab_size)
loss = - logprobs[range(logprobs.shape[0]), Yb].mean() # shape (1)  

# PyTorch backward pass
for p in parameters:
    p.grad = None
for t in [
    logprobs, probs, counts_sum_inv, counts_sum, counts,
    norm_logits, logit_maxes, logits, h, hpreact,
    bnraw, bnvar_inv, bnvar, bndiff, bndiff2,
    bnmeani, hprebn, embcat, emb
    ]:
    t.retain_grad()
loss.backward()
loss

tensor(3.7229, grad_fn=<NegBackward0>)

In [None]:
# Compute all the gradients manually and compare with PyTorch backward pass
dlogprobs = torch.zeros_like(logprobs).index_put((torch.Tensor(range(n)).int(), Yb), torch.tensor(-1/n))        # shape (batch_size, vocab_size)
dprobs = dlogprobs * (1 / probs)                                                                                # shape (batch_size, vocab_size)
dcounts_sum_inv = (dprobs * counts).sum(dim=1, keepdim=True)                                                    # shape (batch_size, 1)
dcounts_sum = dcounts_sum_inv * ( - 1 / counts_sum ** 2 )                                                       # shape (batch_size, 1)
dcounts = counts_sum_inv * dprobs + torch.ones_like(counts) * dcounts_sum                                       # shape (batch_size, vocab_size)
dnorm_logits = dcounts * counts                                                                                 # shape (batch_size, vocab_size)
dlogit_maxes = (-dnorm_logits).sum(dim=1, keepdim=True)                                                         # shape (batch_size, 1)
dlogits = dlogit_maxes * F.one_hot(logits.max (1).indices, num_classes=logits.shape [1]) + dnorm_logits.clone() # shape (batch_size, vocab_size)
dh = dlogits @ W2.T                                                                                             # shape (batch_size, hidden_size)
dW2 = h.T @ dlogits                                                                                             # shape (hidden_size, hidden_size)
db2 = dlogits.sum(dim=0, keepdim=False)                                                                         # shape (hidden_size)
dhpreact = dh * (1 - h ** 2)                                                                                    # shape (batch_size, hidden_size)
dbngain = (dhpreact * bnraw).sum(dim=0, keepdim=True)                                                           # shape (1, hidden_size)
dbnbias = dhpreact.sum(dim=0, keepdim=True)                                                                     # shape (1, hidden_size)
dbnraw = (dhpreact * bngain)                                                                                    # shape (batch_size, hidden_size)
dbnvar_inv = (dbnraw * bndiff).sum(dim=0, keepdim=True)                                                         # shape (1, hidden_size)
dbnvar = ( -0.5 * (bnvar + 1e-5) ** -1.5 ) * dbnvar_inv                                                         # shape (1, hidden_size)

In [221]:
cmp('logprobs', dlogprobs, logprobs)
cmp('probs', dprobs, probs)
cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)
cmp('counts_sum', dcounts_sum, counts_sum)
cmp('counts', dcounts, counts)
cmp('norm_logits', dnorm_logits, norm_logits)
cmp('logit_maxes', dlogit_maxes, logit_maxes)
cmp('logits', dlogits, logits)
cmp('h', dh, h)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)
cmp('hpreact', dhpreact, hpreact)
cmp('bngain', dbngain, bngain)
cmp('bnbias', dbnbias, bnbias)
cmp('bnraw', dbnraw, bnraw)
cmp('bnvar_inv', dbnvar_inv, bnvar_inv)
cmp('bnvar', dbnvar, bnvar)
# cmp('bndiff2', dbndiff2, bndiff2)
# cmp('bndiff', dbndiff, bndiff)
# cmp('bnmeani', dbnmeani, bnmeani)
# cmp('hprebn', dhprebn, hprebn)
# cmp('embcat', dembcat, embcat)
# cmp('W1', dW1, W1)
# cmp('b1', db1, b1)
# cmp('emb', demb, emb)
# cmp('C', dC, C)



logprobs        | exact: True  | approximate: True  | maxdiff: 0.0
probs           | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum_inv  | exact: True  | approximate: True  | maxdiff: 0.0
counts_sum      | exact: True  | approximate: True  | maxdiff: 0.0
counts          | exact: True  | approximate: True  | maxdiff: 0.0
norm_logits     | exact: True  | approximate: True  | maxdiff: 0.0
logit_maxes     | exact: True  | approximate: True  | maxdiff: 0.0
logits          | exact: True  | approximate: True  | maxdiff: 0.0
h               | exact: True  | approximate: True  | maxdiff: 0.0
W2              | exact: True  | approximate: True  | maxdiff: 0.0
b2              | exact: True  | approximate: True  | maxdiff: 0.0
hpreact         | exact: True  | approximate: True  | maxdiff: 0.0
bngain          | exact: True  | approximate: True  | maxdiff: 0.0
bnbias          | exact: True  | approximate: True  | maxdiff: 0.0
bnraw           | exact: True  | approximate: True  | maxdiff:

In [None]:
# 