# nano word2vec

## Setup

In [235]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset

In [258]:
WEIGHT_PATH = 'weights.bak'

In [248]:
# hyperparameters
block_size = 8
n_embd = 96
n_hidden = 96
batch_size = 64
learning_rate = 1e-3
max_iters = 50000
eval_interval = 500
eval_iters = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [237]:
# https://huggingface.co/datasets/generics_kb

datasets = load_dataset("generics_kb", "generics_kb_simplewiki")
dataset = datasets["train"]
print(f'{len(dataset)=} {dataset[0].keys()=}')


charset_whitelist = 'abcdefghijklmnopqrstuvwxyz- '
def sanitize(s):
    return ''.join([c for c in s.lower() if c in charset_whitelist])

sentences = [sanitize(d['sentence']) for d in dataset]
print(f'{sentences[:3]=}')
print(f'{max([len(s.split()) for s in sentences])=}')

vocab = set([w for s in sentences for w in s.split()])
print(f'{len(vocab)=} {list(vocab)[:3]=}')

# The sample size for each word seems really small so this dataset probably won't work at all.
# can I get a dataset specialized on fruits maybe, to do queries of the type `lemon - yellow + green = lime`
queen = [s for s in sentences if 'queen' in s]
print(f'{len(queen)=} {queen[:3]=}')

len(dataset)=12765 dataset[0].keys()=dict_keys(['source_name', 'sentence', 'sentences_before', 'sentences_after', 'concept_name', 'quantifiers', 'id', 'bert_score', 'headings', 'categories'])
sentences[:3]=['sepsis happens when the bacterium enters the blood and make it form tiny clots', 'incubation period is only one to two days', 'scuba diving is a common tourist activity']
max([len(s.split()) for s in sentences])=22
len(vocab)=13477 list(vocab)[:3]=['occasionally', 'technological', 'welding']
len(queen)=4 queen[:3]=['monarch is a word that means king or queen', 'pregnant queens deliver their litters by themselves guided by instinct', 'most ant species have a system in which only the queen and breeding females can mate']


In [238]:
vocab_list = ['<end>', '<???>'] + list(vocab)
vocab_size = len(vocab_list)
stoi = {w: i for i, w in enumerate(vocab_list)}
itos = {i: w for w, i in stoi.items()}

def encode(s):
    return torch.tensor([stoi.get(w, 1) for w in sanitize(s).split() + ['<end>']], dtype=torch.long)

def decode(t):
    t = t.tolist() if isinstance(t, torch.Tensor) else t
    return ' '.join([itos[i] for i in t])

# careful here if we use words outside of vocab it'll explode
for xs in ['I for one welcome our new robot overlords', 'The chicken cross the road']:
    print(f'{encode(xs)=}')
    print(f'{decode(encode(xs))=}')

encode(xs)=tensor([    1, 10912,  3840, 12269,  9667,  8109,     1,     1,     0])
decode(encode(xs))='<???> for one welcome our new <???> <???> <end>'
encode(xs)=tensor([8951, 4067,  614, 8951, 9491,    0])
decode(encode(xs))='the chicken cross the road <end>'


In [239]:
# shape the data for training
def chunk(s):
    s = torch.cat((torch.zeros(block_size, dtype=torch.long), s))
    for i in range(0, len(s) - block_size):
        yield s[i: i + block_size], s[i + 1: i + block_size + 1]

chunked = [c for s in sentences for c in chunk(encode(s))]
Xtrain = [c[0] for c in chunked]
Ytrain = [c[1] for c in chunked]

for i in range(3):
    print(Xtrain[i], Ytrain[i])
    print(f'{decode(Xtrain[i])=} {decode(Ytrain[i])=}')

tensor([0, 0, 0, 0, 0, 0, 0, 0]) tensor([   0,    0,    0,    0,    0,    0,    0, 6255])
decode(Xtrain[i])='<end> <end> <end> <end> <end> <end> <end> <end>' decode(Ytrain[i])='<end> <end> <end> <end> <end> <end> <end> sepsis'
tensor([   0,    0,    0,    0,    0,    0,    0, 6255]) tensor([   0,    0,    0,    0,    0,    0, 6255, 8277])
decode(Xtrain[i])='<end> <end> <end> <end> <end> <end> <end> sepsis' decode(Ytrain[i])='<end> <end> <end> <end> <end> <end> sepsis happens'
tensor([   0,    0,    0,    0,    0,    0, 6255, 8277]) tensor([    0,     0,     0,     0,     0,  6255,  8277, 10733])
decode(Xtrain[i])='<end> <end> <end> <end> <end> <end> sepsis happens' decode(Ytrain[i])='<end> <end> <end> <end> <end> sepsis happens when'


In [240]:
def get_batch():
    # TODO: swap between train and val
    ix = torch.randint(len(Xtrain), (batch_size,))
    x = torch.stack([Xtrain[i] for i in ix])
    y = torch.stack([Ytrain[i] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch()
print(xb[:2])
print(yb[:2])
print(f'{decode(xb[0])} -> {decode(yb[0])}')
print(f'{decode(xb[1])} -> {decode(yb[1])}')


tensor([[12292, 11441,  1585,   993,  5769,  1835, 12068,  1221],
        [    0,     0,     0,     0,     0,     0,     0,  2542]],
       device='cuda:0')
tensor([[11441,  1585,   993,  5769,  1835, 12068,  1221,  8131],
        [    0,     0,     0,     0,     0,     0,  2542,  3784]],
       device='cuda:0')
other drums because they are tuned to certain -> drums because they are tuned to certain musical
<end> <end> <end> <end> <end> <end> <end> felsic -> <end> <end> <end> <end> <end> <end> felsic magma


In [249]:
@torch.no_grad()
def estimate_loss():
    model.eval()
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        X, Y = get_batch()
        logits, loss = model(X, Y)
        losses[k] = loss.item()
    out = losses.mean()
    model.train()
    return out

## Implem the model

In [254]:
torch.manual_seed(0xdeadbeef) # for reproducibility

class Bnorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.bn = nn.BatchNorm1d(dim)

    def forward(self, x):
        # /!\
        # /!\ it looks insanely expensive, this 10x the training time
        # /!\
        return self.bn(x.transpose(1, 2)).transpose(1, 2)

class LM(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.layers = nn.Sequential(
            # nn.Linear(n_embd, n_hidden), Bnorm(n_hidden), nn.ReLU(),
            nn.Linear(n_embd, n_hidden), nn.ReLU(),
        )
        self.lm_head = nn.Linear(n_hidden, vocab_size)
        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        # print(f'{idx.shape=} {targets.shape=}')
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embd)
        x = self.layers(tok_emb)
        logits = self.lm_head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            # juggle with tensor shapes to match pytorch's cross_entropy
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop the context to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
model = LM()
m = model.to(device)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss.item())
print(logits[0])

torch.Size([256, 13479])
9.440207481384277
tensor([ 0.3658,  0.2438, -0.1715,  ..., -0.0156,  0.1206,  0.1259],
       device='cuda:0', grad_fn=<SelectBackward0>)


In [255]:
# create a pytorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [256]:
# train
for iter in range(max_iters):
    if iter % eval_interval == 0:
        loss = estimate_loss()
        print(f'step {iter}: train loss {loss:.4f}')

    xb, yb = get_batch()
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


step 0: train loss 9.4127
step 500: train loss 4.4257
step 1000: train loss 4.2177
step 1500: train loss 4.0268
step 2000: train loss 3.9412
step 2500: train loss 3.8143
step 3000: train loss 3.7047
step 3500: train loss 3.6555
step 4000: train loss 3.6052
step 4500: train loss 3.4570


In [257]:
# sample from the model
context = torch.zeros((1, block_size), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=300)[0].tolist()))

<end> <end> <end> <end> <end> <end> <end> <end> <end> movie <end> <end> alternative only save that opposes authorities areas back to uranus versions of all atoms go as background indicates levels are used in soil slightly or copper is very unstable and north pole and calculate is sacred and restore from conifers about their volunteer contain people consider their toes in the dogs are a solo bodies at the dragonfly that the toes on their own usually populations steel but they feed in often prefer have eyes are made up to the opposite items created by a like a different forms by measuring bird known from the markup oil is with gases is the important by pride up the sides of foil they live are named within cancer is special devices can lead compounds are served on our china has been form of millions they are made by breaking the total to new infection calcium take five minutes eyes have shot are served with males with light can be cheap theres found in a type of dark chocolate of the star

In [259]:
# backup to disk
torch.save(model.state_dict(), WEIGHT_PATH)

In [261]:
# load from disk
m2 = LM()
m2.load_state_dict(torch.load(WEIGHT_PATH))
m2 = m2.to(device)
m2.eval()

# context = torch.zeros((1, block_size), dtype=torch.long, device=device)
# print(decode(m2.generate(context, max_new_tokens=300)[0].tolist()))

<end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> <end> cyclic a seven-pointed services are a conservative and access to the distances weather more with a long and sizes made up of the best and barren island is a higher boost at higher frequencies are omnivores electrical big only reaches exist based on hunting ions are people sometimes star paper tail and demand and materials is used at the production and female folding is to build organisms is one of making problem off deciduous <end> <end> <end> <end> trilobites use the community of tiny teeth-like <end> <end> <end> <end> <end> some beers ingredients of hydrogen bonds keep services are performed of interaction in forests and order that are trained to act <end> <end> chemical nest known into a very hard every police uses a spell species for domestic yaks of an outward the source of their mainly by a problem in gardens also a useful than sequence of freely online than males use a ritual reminder needs to th