# nano word2vec

## Setup

In [1]:
import torch
from datasets import load_dataset

In [66]:
# hyperparameters
block_size = 10
batch_size = 32


In [67]:
# https://huggingface.co/datasets/generics_kb

datasets = load_dataset("generics_kb", "generics_kb_simplewiki")
dataset = datasets["train"]
print(f'{len(dataset)=} {dataset[0].keys()=}')


charset_whitelist = 'abcdefghijklmnopqrstuvwxyz- '
def sanitize(s):
    return ''.join([c for c in s.lower() if c in charset_whitelist])

sentences = [sanitize(d['sentence']) for d in dataset]
print(f'{sentences[:3]=}')
print(f'{max([len(s.split()) for s in sentences])=}')

vocab = set([w for s in sentences for w in s.split()])
print(f'{len(vocab)=} {list(vocab)[:3]=}')

# The sample size for each word seems really small so this dataset probably won't work at all.
# can I get a dataset specialized on fruits maybe, to do queries of the type `lemon - yellow + green = lime`
queen = [s for s in sentences if 'queen' in s]
print(f'{len(queen)=} {queen[:3]=}')

len(dataset)=12765 dataset[0].keys()=dict_keys(['source_name', 'sentence', 'sentences_before', 'sentences_after', 'concept_name', 'quantifiers', 'id', 'bert_score', 'headings', 'categories'])
sentences[:3]=['sepsis happens when the bacterium enters the blood and make it form tiny clots', 'incubation period is only one to two days', 'scuba diving is a common tourist activity']
max([len(s.split()) for s in sentences])=22
len(vocab)=13477 list(vocab)[:3]=['occasionally', 'technological', 'welding']
len(queen)=4 queen[:3]=['monarch is a word that means king or queen', 'pregnant queens deliver their litters by themselves guided by instinct', 'most ant species have a system in which only the queen and breeding females can mate']


In [98]:
vocab_list = ['<end>', '<???>'] + list(vocab)
stoi = {w: i for i, w in enumerate(vocab_list)}
itos = {i: w for w, i in stoi.items()}

def encode(s):
    return torch.tensor([stoi.get(w, 1) for w in sanitize(s).split() + ['<end>']], dtype=torch.long)

def decode(t):
    return ' '.join([itos[i.item()] for i in t])

# careful here if we use words outside of vocab it'll explode
for xs in ['I for one welcome our new robot overlords', 'The chicken cross the road']:
    print(f'{encode(xs)=}')
    print(f'{decode(encode(xs))=}')

encode(xs)=tensor([    1, 10912,  3840, 12269,  9667,  8109,     1,     1,     0])
decode(encode(xs))='<???> for one welcome our new <???> <???> <end>'
encode(xs)=tensor([8951, 4067,  614, 8951, 9491,    0])
decode(encode(xs))='the chicken cross the road <end>'


In [99]:
# shape the data for training
def chunk(s):
    s = torch.cat((torch.zeros(block_size, dtype=torch.long), s))
    for i in range(0, len(s) - block_size):
        yield s[i: i + block_size], s[i + block_size: i + block_size + 1]

chunked = [c for s in sentences for c in chunk(encode(s))]
Xtrain = [c[0] for c in chunked]
Ytrain = [c[1] for c in chunked]

for i in range(3):
    print(Xtrain[i], Ytrain[i])
    print(f'{decode(Xtrain[i])=} {decode(Ytrain[i])=}')

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) tensor([6255])
decode(Xtrain[i])='<end> <end> <end> <end> <end> <end> <end> <end> <end> <end>' decode(Ytrain[i])='sepsis'
tensor([   0,    0,    0,    0,    0,    0,    0,    0,    0, 6255]) tensor([8277])
decode(Xtrain[i])='<end> <end> <end> <end> <end> <end> <end> <end> <end> sepsis' decode(Ytrain[i])='happens'
tensor([   0,    0,    0,    0,    0,    0,    0,    0, 6255, 8277]) tensor([10733])
decode(Xtrain[i])='<end> <end> <end> <end> <end> <end> <end> <end> sepsis happens' decode(Ytrain[i])='when'


In [116]:
def get_batch():
    # TODO: swap between train and val
    ix = torch.randint(len(Xtrain), (batch_size,))
    x = torch.stack([Xtrain[i] for i in ix])
    y = torch.stack([Ytrain[i] for i in ix])
    # x, y = x.to(device), y.to(device)
    return x, y

xb, yb = get_batch()
print(xb[:2])
print(yb[:2])
print(f'{decode(xb[0])} -> {decode(yb[0])}')
print(f'{decode(xb[1])} -> {decode(yb[1])}')


tensor([[    0,     0,     0,     0,     0,     0,     0,     0, 10070,  8209],
        [    0,     0,     0,     0,     0, 10493,  9904,  6968,  1430,  5350]])
tensor([[3971],
        [7893]])
<end> <end> <end> <end> <end> <end> <end> <end> most humans -> have
<end> <end> <end> <end> <end> pyroelectricity is also a necessary -> consequence
