# nano word2vec

## Setup

In [159]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset

In [160]:
# hyperparameters
block_size = 8
n_embd = 64
batch_size = 32
learning_rate = 1e-3

max_iters = 5000
eval_interval = 500
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# eval_iters = 200

In [161]:
# https://huggingface.co/datasets/generics_kb

datasets = load_dataset("generics_kb", "generics_kb_simplewiki")
dataset = datasets["train"]
print(f'{len(dataset)=} {dataset[0].keys()=}')


charset_whitelist = 'abcdefghijklmnopqrstuvwxyz- '
def sanitize(s):
    return ''.join([c for c in s.lower() if c in charset_whitelist])

sentences = [sanitize(d['sentence']) for d in dataset]
print(f'{sentences[:3]=}')
print(f'{max([len(s.split()) for s in sentences])=}')

vocab = set([w for s in sentences for w in s.split()])
print(f'{len(vocab)=} {list(vocab)[:3]=}')

# The sample size for each word seems really small so this dataset probably won't work at all.
# can I get a dataset specialized on fruits maybe, to do queries of the type `lemon - yellow + green = lime`
queen = [s for s in sentences if 'queen' in s]
print(f'{len(queen)=} {queen[:3]=}')

len(dataset)=12765 dataset[0].keys()=dict_keys(['source_name', 'sentence', 'sentences_before', 'sentences_after', 'concept_name', 'quantifiers', 'id', 'bert_score', 'headings', 'categories'])
sentences[:3]=['sepsis happens when the bacterium enters the blood and make it form tiny clots', 'incubation period is only one to two days', 'scuba diving is a common tourist activity']
max([len(s.split()) for s in sentences])=22
len(vocab)=13477 list(vocab)[:3]=['occasionally', 'technological', 'welding']
len(queen)=4 queen[:3]=['monarch is a word that means king or queen', 'pregnant queens deliver their litters by themselves guided by instinct', 'most ant species have a system in which only the queen and breeding females can mate']


In [162]:
vocab_list = ['<end>', '<???>'] + list(vocab)
vocab_size = len(vocab_list)
stoi = {w: i for i, w in enumerate(vocab_list)}
itos = {i: w for w, i in stoi.items()}

def encode(s):
    return torch.tensor([stoi.get(w, 1) for w in sanitize(s).split() + ['<end>']], dtype=torch.long)

def decode(t):
    return ' '.join([itos[i.item()] for i in t])

# careful here if we use words outside of vocab it'll explode
for xs in ['I for one welcome our new robot overlords', 'The chicken cross the road']:
    print(f'{encode(xs)=}')
    print(f'{decode(encode(xs))=}')

encode(xs)=tensor([    1, 10912,  3840, 12269,  9667,  8109,     1,     1,     0])
decode(encode(xs))='<???> for one welcome our new <???> <???> <end>'
encode(xs)=tensor([8951, 4067,  614, 8951, 9491,    0])
decode(encode(xs))='the chicken cross the road <end>'


In [163]:
# shape the data for training
def chunk(s):
    s = torch.cat((torch.zeros(block_size, dtype=torch.long), s))
    for i in range(0, len(s) - block_size):
        yield s[i: i + block_size], s[i + 1: i + block_size + 1]

chunked = [c for s in sentences for c in chunk(encode(s))]
Xtrain = [c[0] for c in chunked]
Ytrain = [c[1] for c in chunked]

for i in range(3):
    print(Xtrain[i], Ytrain[i])
    print(f'{decode(Xtrain[i])=} {decode(Ytrain[i])=}')

tensor([0, 0, 0, 0, 0, 0, 0, 0]) tensor([   0,    0,    0,    0,    0,    0,    0, 6255])
decode(Xtrain[i])='<end> <end> <end> <end> <end> <end> <end> <end>' decode(Ytrain[i])='<end> <end> <end> <end> <end> <end> <end> sepsis'
tensor([   0,    0,    0,    0,    0,    0,    0, 6255]) tensor([   0,    0,    0,    0,    0,    0, 6255, 8277])
decode(Xtrain[i])='<end> <end> <end> <end> <end> <end> <end> sepsis' decode(Ytrain[i])='<end> <end> <end> <end> <end> <end> sepsis happens'
tensor([   0,    0,    0,    0,    0,    0, 6255, 8277]) tensor([    0,     0,     0,     0,     0,  6255,  8277, 10733])
decode(Xtrain[i])='<end> <end> <end> <end> <end> <end> sepsis happens' decode(Ytrain[i])='<end> <end> <end> <end> <end> sepsis happens when'


In [164]:
def get_batch():
    # TODO: swap between train and val
    ix = torch.randint(len(Xtrain), (batch_size,))
    x = torch.stack([Xtrain[i] for i in ix])
    y = torch.stack([Ytrain[i] for i in ix])
    # x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch()
print(xb[:2])
print(yb[:2])
print(f'{decode(xb[0])} -> {decode(yb[0])}')
print(f'{decode(xb[1])} -> {decode(yb[1])}')


tensor([[    0,     0,     0,     0,     0,     0,  3835,  8826],
        [ 5054,  6302,  3088, 10753,  6977,  2188,    11,  4845]])
tensor([[    0,     0,     0,     0,     0,  3835,  8826,   928],
        [ 6302,  3088, 10753,  6977,  2188,    11,  4845,  5054]])
<end> <end> <end> <end> <end> <end> variations involve -> <end> <end> <end> <end> <end> variations involve replacing
and can see long distances shapes shadows color -> can see long distances shapes shadows color and


## Implem the model

In [165]:
torch.manual_seed(0xdeadbeef) # for reproducibility

class LM(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.layers = nn.Sequential(
            nn.Linear(n_embd, n_embd), nn.ReLU(),
            # nn.Linear(n_embd, n_embd), nn.BatchNorm1d(n_embd), nn.ReLU(),
            # nn.Linear(n_embd, n_embd), nn.BatchNorm1d(num_features=n_embd), nn.ReLU(),
        )
        # self.position_embedding_table = nn.Embedding(block_size, n_embd)
        # self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        # self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        # print(f'{idx.shape=} {targets.shape=}')
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        x = self.layers(tok_emb)
        logits = self.lm_head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            # juggle with tensor shapes to match pytorch's cross_entropy
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop the context to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
model = LM()
# m = model.to(device)
m = model
logits, loss = m(xb, yb)
print(logits.shape)
print(loss.item())
print(logits[0])

torch.Size([256, 13479])
9.364344596862793
tensor([ 0.5058,  0.2564,  0.0894,  ..., -0.0993,  0.1320, -0.3849],
       grad_fn=<SelectBackward0>)


In [166]:
# create a pytorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [167]:
# train
for iter in range(max_iters):
    # if iter % eval_interval == 0:
    #     losses = estimate_loss()
    #     print(f'step {iter}: train loss {losses["train"]:.4f}, val loss {losses["val"]:.4f}')

    xb, yb = get_batch()
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if iter % eval_interval == 0:
        print(f'step {iter}: train loss {loss:.4f}')


step 0: train loss 9.3797


step 500: train loss 3.8205
step 1000: train loss 4.1586
step 1500: train loss 4.1597
step 2000: train loss 4.0523
step 2500: train loss 3.4045
step 3000: train loss 3.4417
step 3500: train loss 3.5563
step 4000: train loss 3.7179
step 4500: train loss 3.3655


In [170]:
# sample from the model
def decode(t):
    return ' '.join([itos[i] for i in t])

device='cpu'
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=300)[0].tolist()))

<end> <end> <end> <end> <end> <end> <end> heat or nine emperors <end> <end> steel and off from the out thinly flowers and sizes of their spending a milky are succulent are former bays that are omnivorous themselves like more difficult when they have important role at institutions are constant muscles of families with a form of stroke days have a parameter or near their molecules join in cold jobs the making enter protective trees use the harness react actually different substances <end> <end> <end> <end> <end> environmentalists dead holes to british different music music <end> <end> <end> <end> children have fewer services shows and an organism microbes that they belong to make tools on earth if different countries about what caused characters for an lead compounds are sloped out sizes and make a solvent extraction relating facts for the begin to use page are noise to hear when a style doing are made of papier in the ropes are larger than open to find based on the universe are a stairc