# nano word2vec

## Setup

In [1]:
import torch
from datasets import load_dataset

In [34]:
# https://huggingface.co/datasets/generics_kb

datasets = load_dataset("generics_kb", "generics_kb_simplewiki")
dataset = datasets["train"]
print(f'{len(dataset)=} {dataset[0].keys()=}')


charset_whitelist = 'abcdefghijklmnopqrstuvwxyz- '
def sanitize(s):
    return ''.join([c for c in s.lower() if c in charset_whitelist])

sentences = [sanitize(d['sentence']) for d in dataset]
print(f'{sentences[:3]=}')

vocab = set([w for s in sentences for w in s.split()])
print(f'{len(vocab)=} {list(vocab)[:3]=}')

# The sample size for each word seems really small so this dataset probably won't work at all.
# can I get a dataset specialized on fruits maybe, to do queries of the type `lemon - yellow + green = lime`
queen = [s for s in sentences if 'queen' in s]
print(f'{len(queen)=} {queen[:3]=}')

len(dataset)=12765 dataset[0].keys()=dict_keys(['source_name', 'sentence', 'sentences_before', 'sentences_after', 'concept_name', 'quantifiers', 'id', 'bert_score', 'headings', 'categories'])
sentences[:3]=['sepsis happens when the bacterium enters the blood and make it form tiny clots', 'incubation period is only one to two days', 'scuba diving is a common tourist activity']
len(vocab)=13477 list(vocab)[:3]=['occasionally', 'technological', 'welding']
len(charset)=27 sorted(list(charset))=['-', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
len(queen)=4 queen[:3]=['monarch is a word that means king or queen', 'pregnant queens deliver their litters by themselves guided by instinct', 'most ant species have a system in which only the queen and breeding females can mate']


In [41]:
stoi = {w: i for i, w in enumerate(vocab)}
itos = {i: w for w, i in stoi.items()}

def encode(s):
    return torch.tensor([stoi[w] for w in sanitize(s).split()])

def decode(t):
    return ' '.join([itos[i.item()] for i in t])

# careful here if we use words outside of vocab it'll explode
xs = 'The chicken cross the road'
print(f'{encode(xs)=}')
print(f'{decode(encode(xs))=}')

encode(xs)=tensor([8949, 4065,  612, 8949, 9489])
decode(encode(xs))='the chicken cross the road'
