# nano word2vec

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset

In [None]:
WEIGHT_PATH = 'weights.bak'

In [None]:
# hyperparameters
block_size = 8
n_embd = 96
n_hidden = 96
batch_size = 64
learning_rate = 1e-4
max_iters = 500000
eval_interval = 500
eval_iters = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# https://huggingface.co/datasets/generics_kb

datasets = load_dataset("generics_kb", "generics_kb_simplewiki")
dataset = datasets["train"]
print(f'{len(dataset)=} {dataset[0].keys()=}')


charset_whitelist = 'abcdefghijklmnopqrstuvwxyz- '
def sanitize(s):
    return ''.join([c for c in s.lower() if c in charset_whitelist])

sentences = [sanitize(d['sentence']) for d in dataset]
print(f'{sentences[:3]=}')
print(f'{max([len(s.split()) for s in sentences])=}')

vocab = set([w for s in sentences for w in s.split()])
print(f'{len(vocab)=} {list(vocab)[:3]=}')

# The sample size for each word seems really small so this dataset probably won't work at all.
# can I get a dataset specialized on fruits maybe, to do queries of the type `lemon - yellow + green = lime`
queen = [s for s in sentences if 'queen' in s]
print(f'{len(queen)=} {queen[:3]=}')

In [None]:
vocab_list = ['<end>', '<???>'] + list(vocab)
vocab_size = len(vocab_list)
stoi = {w: i for i, w in enumerate(vocab_list)}
itos = {i: w for w, i in stoi.items()}

def encode(s):
    return torch.tensor([stoi.get(w, 1) for w in sanitize(s).split() + ['<end>']], dtype=torch.long)

def decode(t):
    t = t.tolist() if isinstance(t, torch.Tensor) else t
    return ' '.join([itos[i] for i in t])

# careful here if we use words outside of vocab it'll explode
for xs in ['I for one welcome our new robot overlords', 'The chicken cross the road']:
    print(f'{encode(xs)=}')
    print(f'{decode(encode(xs))=}')

In [None]:
# shape the data for training
def chunk(s):
    s = torch.cat((torch.zeros(block_size, dtype=torch.long), s))
    for i in range(0, len(s) - block_size):
        yield s[i: i + block_size], s[i + 1: i + block_size + 1]

chunked = [c for s in sentences for c in chunk(encode(s))]
Xtrain = [c[0] for c in chunked]
Ytrain = [c[1] for c in chunked]

for i in range(3):
    print(Xtrain[i], Ytrain[i])
    print(f'{decode(Xtrain[i])=} {decode(Ytrain[i])=}')

In [None]:
def get_batch():
    # TODO: swap between train and val
    ix = torch.randint(len(Xtrain), (batch_size,))
    x = torch.stack([Xtrain[i] for i in ix])
    y = torch.stack([Ytrain[i] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch()
print(xb[:2])
print(yb[:2])
print(f'{decode(xb[0])} -> {decode(yb[0])}')
print(f'{decode(xb[1])} -> {decode(yb[1])}')


In [None]:
@torch.no_grad()
def estimate_loss():
    model.eval()
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        X, Y = get_batch()
        logits, loss = model(X, Y)
        losses[k] = loss.item()
    out = losses.mean()
    model.train()
    return out

## Implem the model

In [None]:
torch.manual_seed(0xdeadbeef) # for reproducibility

class Bnorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.bn = nn.BatchNorm1d(dim)

    def forward(self, x):
        # /!\
        # /!\ it looks insanely expensive, this 10x the training time
        # /!\
        return self.bn(x.transpose(1, 2)).transpose(1, 2)

class LM(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.layers = nn.Sequential(
            # nn.Linear(n_embd, n_hidden), Bnorm(n_hidden), nn.ReLU(),
            nn.Linear(n_embd, n_hidden), nn.ReLU(),
        )
        self.lm_head = nn.Linear(n_hidden, vocab_size)
        
    def forward(self, idx, targets=None):
        B, T = idx.shape
        # print(f'{idx.shape=} {targets.shape=}')
        tok_emb = self.token_embedding_table(idx) # (B, T, n_embd)
        x = self.layers(tok_emb)
        logits = self.lm_head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            # juggle with tensor shapes to match pytorch's cross_entropy
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            # crop the context to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
    
model = LM()
m = model.to(device)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss.item())
print(logits[0])

In [None]:
# create a pytorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
# train
for iter in range(max_iters):
    if iter % eval_interval == 0:
        loss = estimate_loss()
        print(f'step {iter}: train loss {loss:.4f}')

    xb, yb = get_batch()
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


In [None]:
# sample from the model
context = torch.zeros((1, block_size), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=300)[0].tolist()))

In [None]:
# backup to disk
# torch.save(model.state_dict(), WEIGHT_PATH)

In [None]:
# load from disk
# m2 = LM()
# m2.load_state_dict(torch.load(WEIGHT_PATH))
# m2 = m2.to(device)
# m2.eval()

# context = torch.zeros((1, block_size), dtype=torch.long, device=device)
# print(decode(m2.generate(context, max_new_tokens=300)[0].tolist()))

## Can we do anything with embeddings?

In [None]:
# Euclidean distance
def euclidean_dist(a, b):
    return torch.sqrt(torch.sum((a - b) ** 2))

# Cosine distance
def cosine_dist(a, b):
    return 1 - (a @ b.T) / (torch.sqrt(torch.sum(a**2)) * torch.sqrt(torch.sum(b**2)))
    # return 1 - torch.nn.functional.cosine_similarity(a, b)

a = torch.randn(5)
b = torch.randn(5)

assert euclidean_dist(a, a) == 0, 'identity'

assert cosine_dist(a, a) == 0, 'identity'
assert cosine_dist(a, b) == cosine_dist(b, a), 'commutativity'
assert cosine_dist(a, b) == 1 - torch.nn.functional.cosine_similarity(a, b, dim=0), 'check formula'

In [None]:
def get_embedding(word):
    return m.token_embedding_table(torch.tensor(stoi[word], dtype=torch.long, device=device))

king = get_embedding('king')
queen = get_embedding('queen')

print(f'{euclidean_dist(king, queen)=}')
print(f'{cosine_dist(king, queen)=}')

In [None]:
# seems pretty random to me, cabbage is closer to queen than king is closer to queen :/
cabbage = get_embedding('cabbage')
print(f'{euclidean_dist(cabbage, queen)=}')
print(f'{cosine_dist(cabbage, queen)=}')

shadows = get_embedding('shadows')
print(f'{euclidean_dist(shadows, queen)=}')
print(f'{cosine_dist(shadows, queen)=}')

In [None]:
# compute all embeddings
embds = torch.stack([get_embedding(w) for w in vocab_list])
embds.shape

In [None]:
# compute all pairwise distances
def pairwise_euclidean_distance(embds):
    xx = torch.sum(embds**2, dim=1)
    xy = embds @ embds.T
    x2 = xx.view(-1, 1)
    return x2 - 2 * xy + xx

pwed = pairwise_euclidean_distance(embds)  
print(pwed.shape)



In [None]:
# Compute the closest word for each other word in the vocab:
#
# This is totally non-sensical :(
# I don't see any pattern in the results 
# hyp 1: the dataset is too sparse, only a few mention of each words
# hyp 2: the model is too small, it's not able to learn anything
# hyp 3: Unknown-unknown, I messed up something
e = pwed
mask = (torch.ones_like(e) * float('inf')).tril()
vals, ind = torch.min((e + mask), dim=1)

for i, j in enumerate(ind[:-1]):
    print(f'{itos[i]} {itos[j.item()]}')

# Try 2

## Setup

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [7]:
# hyperparameters
threshold = 10
context_size = 2 # 2 words on each side
n_embd = 96
batch_size = 64
learning_rate = 1e-3
max_iters = 5000
eval_interval = 500
eval_iters = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
# read another dataset
# http://mattmahoney.net/dc/textdata.html

with open('text8', 'r') as f:
    text = f.read()

print(f'{len(text)=}')
print(f'{len(text.split())=}')
print(f'alphabet = "{"".join(sorted(set(text)))}"')

vocab = list(set(text.split()))
print(f'{vocab[:10]=}')
print(f'{len(vocab)=}')

len(text)=100000000
len(text.split())=17005207
alphabet = " abcdefghijklmnopqrstuvwxyz"
vocab[:10]=['noho', 'bunia', 'cupped', 'elemond', 'elles', 'biometricians', 'shadowings', 'teodoro', 'sidor', 'ludwich']
len(vocab)=253854


In [3]:
# how crappy is my dataset ? :(
from collections import Counter

# looking at the sorted vocab give me very low confidence in the dataset quality
print(f'{sorted(vocab)[:100]=}')
cs = Counter(text.split())
print(f'{cs.most_common(100)=}')
print(f'{cs["aaaaaacceglllnorst"]=}')

# lots of words are only mentioned once
ccs = Counter(cs.values())
print(ccs.most_common(10))

# on the plus side 'queen' and 'king' seem well represented
print(f'{cs["king"]=} {cs["queen"]=}')


sorted(vocab)[:100]=['a', 'aa', 'aaa', 'aaaa', 'aaaaaacceglllnorst', 'aaaaaaccegllnorrst', 'aaaaaah', 'aaaaaalmrsstt', 'aaaaaannrstyy', 'aaaaabbcdrr', 'aaaaargh', 'aaaargh', 'aaaassembly', 'aaab', 'aaabbbccc', 'aaahh', 'aaai', 'aaake', 'aaan', 'aaargh', 'aaas', 'aaate', 'aab', 'aababb', 'aabach', 'aabba', 'aabbcc', 'aabbirem', 'aabebwuvev', 'aabehlpt', 'aabmup', 'aabre', 'aabybro', 'aac', 'aaca', 'aacca', 'aaccording', 'aachen', 'aachener', 'aachtopf', 'aaci', 'aacis', 'aacisuan', 'aacplus', 'aacr', 'aacs', 'aacvd', 'aad', 'aadgad', 'aadl', 'aadlik', 'aadnani', 'aadvantage', 'aadyam', 'aaemu', 'aaf', 'aafc', 'aafjes', 'aafk', 'aafp', 'aag', 'aagaard', 'aagama', 'aagard', 'aage', 'aagesen', 'aagsin', 'aah', 'aahaaram', 'aahc', 'aahe', 'aahl', 'aahz', 'aai', 'aaib', 'aaiieee', 'aaimmah', 'aairpass', 'aaiun', 'aaiyangar', 'aaj', 'aajker', 'aak', 'aakirkeby', 'aakjaer', 'aakkram', 'aal', 'aalberg', 'aalborg', 'aalborghus', 'aalborgt', 'aalcc', 'aale', 'aalen', 'aalens', 'aalesund', 'aalesu

In [4]:
# let's butcher the dataset ¯\_(ツ)_/¯
# remove all the words that are only mentioned bellow a threshold
butchered_vocab = [w for w, c in cs.items() if c >= threshold]
butchered_vocab_s = set(butchered_vocab)
butchered_text = [w for w in text.split() if w in butchered_vocab_s]

print(f'{len(butchered_vocab)=}')
print(f'{len(butchered_text)=}')

len(butchered_vocab)=47134
len(butchered_text)=16561031


In [18]:
# encode/decode helpers
vocab_size = len(butchered_vocab)
stoi = {w: i for i, w in enumerate(butchered_vocab)}
itos = {i: w for w, i in stoi.items()}

def encode(ws):
    return torch.tensor([stoi[w] for w in ws], dtype=torch.long)

def decode(t):
    t = t.tolist() if isinstance(t, torch.Tensor) else t
    t = [t] if isinstance(t, int) else t
    return ' '.join([itos[i] for i in t])

for xs in ['i for one welcome our new robot overlords', 'the chicken cross the road']:
    print(f'{encode(xs.split())=}')
    print(f'{decode(encode(xs.split()))=}')

encode(xs.split())=tensor([  412,   305,   192, 20460,   785,   439,  2217, 30480])
decode(encode(xs.split()))='i for one welcome our new robot overlords'
encode(xs.split())=tensor([   15, 15026,  3282,    15,  3098])
decode(encode(xs.split()))='the chicken cross the road'


In [32]:
# shape the data for training
# using the skip-gram method
def chunk(ws):
    x, y = [], []
    # miss a few words at the beginning and end of the text, w/e
    for i in range(context_size, len(ws) - context_size):
        x.append(ws[i])
        # TODO: here a possible optimization would be to probabilistically discard some of the most common words
        # the paper suggest proba to keep the word as:
        # $P(w_i) = ({\sqrt {z(w_i) \over 0.001} + 1}) . {0.001 \over z(w_i)}$
        # z(w_i) being the frequency of the word in the corpus
        y.append(torch.cat((ws[i - context_size: i], ws[i + 1: i + 1 + context_size])))
    return torch.tensor(x).view(-1, 1), torch.stack(y)

X, Y = chunk(encode(butchered_text[:10]))
print(X.shape, Y.shape)

for i in range(3):
    print(X[i], Y[i])
    print(f'{decode(X[i])=} {decode(Y[i])=}')

torch.Size([6, 1]) torch.Size([6, 4])
tensor([2]) tensor([0, 1, 3, 4])
decode(X[i])='as' decode(Y[i])='anarchism originated a term'
tensor([3]) tensor([1, 2, 4, 5])
decode(X[i])='a' decode(Y[i])='originated as term of'
tensor([4]) tensor([2, 3, 5, 6])
decode(X[i])='term' decode(Y[i])='as a of abuse'


In [29]:
def get_batch():
    ix = torch.randint(len(X), (batch_size,))
    x, y = X[ix], Y[ix]
    x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch()
print(xb[:2])
print(yb[:2])
print(f'{decode(xb[0])} -> {decode(yb[0])}')
print(f'{decode(xb[1])} -> {decode(yb[1])}')

tensor([[7],
        [2]], device='cuda:0')
tensor([[5, 6, 8, 9],
        [0, 1, 3, 4]], device='cuda:0')
first -> of abuse used against
as -> anarchism originated a term


In [30]:
@torch.no_grad()
def estimate_loss():
    model.eval()
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
        x, y = get_batch()
        logits, loss = model(x, y)
        losses[k] = loss.item()
    out = losses.mean()
    model.train()
    return out

## Skip-gram model
given a word guess the (#context_size) words surrounding it.
e.g. "I for one welcome our robot overlords"

welcome -> for, one, our, robot
