<p style="text-align: center; font-size:50px;">Word2Vec</p>

#### In this notebook, I will be attempting to implement a simple word2vec on the wikitext dataset. 
#### By vectorizing and embedding words in a certain dimensional vector, I hope to be able to capture word's meanings and context. 
#### This specific technique is relatively outdated but the whole field of vectorizing words and embedding them to a certain dimensional vector is paramount to any Language Processing tasks. 

In [10]:
from datasets import load_dataset
import gc
dataset = load_dataset("wikitext", "wikitext-2-v1")

Found cached dataset wikitext (/Users/kimhyunbin/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

In [11]:
import torch  
import torch.nn as nn 
embedding_dim = 300 
norm = 1 # max_norm is currently not supported on the mps backend

class Word2Vec(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.linear = nn.Linear(in_features=embedding_dim, out_features = vocab_size)

    def forward(self, x):
        x = self.embeddings(x)
        x = x.mean(axis=1)
        x = self.linear(x)
        return x

### Building up a vocab

In [12]:
from torchtext.vocab import build_vocab_from_iterator 
MIN_WORD_FREQUENCY = 50

def build_vocab(data_iter, tokenizer):
    vocab = build_vocab_from_iterator(
        map(tokenizer, data_iter),
        specials=["<unk>"],
        min_freq=MIN_WORD_FREQUENCY,
    )
    vocab.set_default_index(vocab["<unk>"])
    return vocab

In [13]:
train_set = dataset['train']['text']
val_set = dataset['validation']['text']
del dataset
gc.collect()

1177

In [14]:
from torchtext.data import get_tokenizer
tokenizer = get_tokenizer("basic_english")

vocab = build_vocab(train_set, tokenizer)

In [15]:
print(f"We have {len(vocab)} words in our vocab")

We have 4099 words in our vocab


#### Making our custom DataLoader

In [16]:
import torch 
CBOW_N_WORDS = 4 
MAX_SEQUENCE_LENGTH = 256  
def collate_cbow(batch, text_pipeline):
     batch_input, batch_output = [], []
     for text in batch:
         text_tokens_ids = text_pipeline(text)
         if len(text_tokens_ids) < CBOW_N_WORDS * 2 + 1:
             continue
         if MAX_SEQUENCE_LENGTH:
             text_tokens_ids = text_tokens_ids[:MAX_SEQUENCE_LENGTH]
         for idx in range(len(text_tokens_ids) - CBOW_N_WORDS * 2):
             token_id_sequence = text_tokens_ids[idx : (idx + CBOW_N_WORDS * 2 + 1)]
             output = token_id_sequence.pop(CBOW_N_WORDS)
             input_ = token_id_sequence
             batch_input.append(input_)
             batch_output.append(output)
     
     batch_input = torch.tensor(batch_input, dtype=torch.long)
     batch_output = torch.tensor(batch_output, dtype=torch.long)
     return batch_input, batch_output

text_pipeline = lambda x: vocab(tokenizer(x))

from torch.utils.data import DataLoader 
from functools import partial  

train_dataloader = DataLoader(
         train_set,
         batch_size=32,
         shuffle=True,         
         collate_fn=partial(collate_cbow, text_pipeline=text_pipeline),
)
val_dataloader = DataLoader(
         val_set,
         batch_size=32,
         shuffle=False,         
         collate_fn=partial(collate_cbow, text_pipeline=text_pipeline),
)

### Training Loop

In [17]:
import torchinfo
from torchinfo import summary

device = torch.device('mps')
print(f"Device set to {device}")

model = Word2Vec(len(vocab)).to(device)
summary(model)

Device set to mps


Layer (type:depth-idx)                   Param #
Word2Vec                                 --
├─Embedding: 1-1                         1,229,700
├─Linear: 1-2                            1,233,799
Total params: 2,463,499
Trainable params: 2,463,499
Non-trainable params: 0

In [18]:
from tqdm import tqdm

lr = 0.025
epochs = 10
log_interval = 100
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params = model.parameters(), lr = lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda = lambda epoch: 0.95 ** epoch)

for epoch in tqdm(range(epochs)):
    print(f"====Epoch {epoch}====")
    model.train()
    train_loss, train_count = 0, 0
    for batch, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        outputs = model(X)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        train_count += 1

        if (batch+1) % log_interval == 0:
            print(f"Batch: {batch+1} | Training Loss: {train_loss/train_count}")

    # Validation 
    model.eval()
    val_loss, val_count = 0, 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(val_dataloader):
            X, y = X.to(device), y.to(device)
            outputs = model(X)
            loss = criterion(outputs, y)
            val_loss += loss.item()
            val_count += 1 
    print(f"Training Loss: {train_loss/train_count} | Validation Loss: {val_loss/val_count}")

    scheduler.step()

  0%|          | 0/10 [00:00<?, ?it/s]

====Epoch 0====
Batch: 100 | Training Loss: 5.700555171966553
Batch: 200 | Training Loss: 5.460875282287597
Batch: 300 | Training Loss: 5.327524420420329
Batch: 400 | Training Loss: 5.246071569919586
Batch: 500 | Training Loss: 5.195186476707459
Batch: 600 | Training Loss: 5.154842410087586
Batch: 700 | Training Loss: 5.12525526864188
Batch: 800 | Training Loss: 5.096050428152084
Batch: 900 | Training Loss: 5.0706856139500935
Batch: 1000 | Training Loss: 5.050419266700745
Batch: 1100 | Training Loss: 5.033205549933694


 10%|█         | 1/10 [02:09<19:23, 129.29s/it]

Training Loss: 5.025709538509622 | Validation Loss: 4.848667940850985
====Epoch 1====
Batch: 100 | Training Loss: 4.43030499458313
Batch: 200 | Training Loss: 4.46770458817482
Batch: 300 | Training Loss: 4.501820142269135
Batch: 400 | Training Loss: 4.521402886509895
Batch: 500 | Training Loss: 4.535906949520111
Batch: 600 | Training Loss: 4.55255509018898
Batch: 700 | Training Loss: 4.566694319929395
Batch: 800 | Training Loss: 4.576878043115139
Batch: 900 | Training Loss: 4.590406875875261
Batch: 1000 | Training Loss: 4.598806721925736
Batch: 1100 | Training Loss: 4.606646579178896


 20%|██        | 2/10 [04:17<17:10, 128.83s/it]

Training Loss: 4.610378376490563 | Validation Loss: 4.8840721825421864
====Epoch 2====
Batch: 100 | Training Loss: 4.245000936985016
Batch: 200 | Training Loss: 4.267647408246994
Batch: 300 | Training Loss: 4.3123674734433495
Batch: 400 | Training Loss: 4.3421886259317395
Batch: 500 | Training Loss: 4.37359833574295
Batch: 600 | Training Loss: 4.3952649104595185
Batch: 700 | Training Loss: 4.415440137386322
Batch: 800 | Training Loss: 4.435311916768551
Batch: 900 | Training Loss: 4.450406462086572
Batch: 1000 | Training Loss: 4.461797515153885
Batch: 1100 | Training Loss: 4.475394317670302


 30%|███       | 3/10 [06:31<15:16, 130.99s/it]

Training Loss: 4.481946669181465 | Validation Loss: 4.907598487401413
====Epoch 3====
Batch: 100 | Training Loss: 4.106807301044464
Batch: 200 | Training Loss: 4.170511000156402
Batch: 300 | Training Loss: 4.220629930496216
Batch: 400 | Training Loss: 4.250997334718704
Batch: 500 | Training Loss: 4.285760184288025
Batch: 600 | Training Loss: 4.308191334406535
Batch: 700 | Training Loss: 4.329201340675354
Batch: 800 | Training Loss: 4.3495030874013905
Batch: 900 | Training Loss: 4.363507532013787
Batch: 1000 | Training Loss: 4.376463479042053
Batch: 1100 | Training Loss: 4.3924466566606


 40%|████      | 4/10 [08:37<12:54, 129.13s/it]

Training Loss: 4.398389753979673 | Validation Loss: 4.922367108070244
====Epoch 4====
Batch: 100 | Training Loss: 4.0406603074073795
Batch: 200 | Training Loss: 4.104192316532135
Batch: 300 | Training Loss: 4.14416809240977
Batch: 400 | Training Loss: 4.1739281618595125
Batch: 500 | Training Loss: 4.209615782737732
Batch: 600 | Training Loss: 4.235032593011856
Batch: 700 | Training Loss: 4.256171023845672
Batch: 800 | Training Loss: 4.27484176069498
Batch: 900 | Training Loss: 4.289953530364566
Batch: 1000 | Training Loss: 4.304309509038925
Batch: 1100 | Training Loss: 4.317463500499725


 50%|█████     | 5/10 [10:27<10:10, 122.16s/it]

Training Loss: 4.32297665669943 | Validation Loss: 4.947966721098302
====Epoch 5====
Batch: 100 | Training Loss: 3.993021776676178
Batch: 200 | Training Loss: 4.038391724824906
Batch: 300 | Training Loss: 4.073778836727143
Batch: 400 | Training Loss: 4.110183302760124
Batch: 500 | Training Loss: 4.139734154224396
Batch: 600 | Training Loss: 4.168943432569503
Batch: 700 | Training Loss: 4.190500344889505
Batch: 800 | Training Loss: 4.212404606044292
Batch: 900 | Training Loss: 4.227655019495223
Batch: 1000 | Training Loss: 4.242129618406296
Batch: 1100 | Training Loss: 4.256721100156957


 60%|██████    | 6/10 [12:18<07:53, 118.32s/it]

Training Loss: 4.261824580227456 | Validation Loss: 4.968322523569657
====Epoch 6====
Batch: 100 | Training Loss: 3.93473486661911
Batch: 200 | Training Loss: 3.982893364429474
Batch: 300 | Training Loss: 4.022195676962535
Batch: 400 | Training Loss: 4.051426775455475
Batch: 500 | Training Loss: 4.082562015533448
Batch: 600 | Training Loss: 4.111786295572917
Batch: 700 | Training Loss: 4.135510178293501
Batch: 800 | Training Loss: 4.153130821287632
Batch: 900 | Training Loss: 4.174414698547787
Batch: 1000 | Training Loss: 4.190364904642105
Batch: 1100 | Training Loss: 4.204964989965612


 70%|███████   | 7/10 [14:00<05:38, 112.98s/it]

Training Loss: 4.209700869143217 | Validation Loss: 4.976158081474951
====Epoch 7====
Batch: 100 | Training Loss: 3.8992581605911254
Batch: 200 | Training Loss: 3.938395303487778
Batch: 300 | Training Loss: 3.981775953769684
Batch: 400 | Training Loss: 4.015457997918129
Batch: 500 | Training Loss: 4.045438268661499
Batch: 600 | Training Loss: 4.069670852820079
Batch: 700 | Training Loss: 4.088676358972277
Batch: 800 | Training Loss: 4.106462588906288
Batch: 900 | Training Loss: 4.12276763147778
Batch: 1000 | Training Loss: 4.13587255358696
Batch: 1100 | Training Loss: 4.150331523418426


 80%|████████  | 8/10 [15:42<03:39, 109.59s/it]

Training Loss: 4.154829735747613 | Validation Loss: 4.994739083920495
====Epoch 8====
Batch: 100 | Training Loss: 3.8436225986480714
Batch: 200 | Training Loss: 3.893544706106186
Batch: 300 | Training Loss: 3.93956129471461
Batch: 400 | Training Loss: 3.9707255399227144
Batch: 500 | Training Loss: 3.992917022705078
Batch: 600 | Training Loss: 4.016055106719335
Batch: 700 | Training Loss: 4.035325904573713
Batch: 800 | Training Loss: 4.053538103401661
Batch: 900 | Training Loss: 4.071452567577362
Batch: 1000 | Training Loss: 4.08826977467537
Batch: 1100 | Training Loss: 4.102739448764107


 90%|█████████ | 9/10 [17:17<01:45, 105.14s/it]

Training Loss: 4.1097867250027145 | Validation Loss: 5.028427144228401
====Epoch 9====
Batch: 100 | Training Loss: 3.8041112923622133
Batch: 200 | Training Loss: 3.8516740250587462
Batch: 300 | Training Loss: 3.8871687014897662
Batch: 400 | Training Loss: 3.9212903106212615
Batch: 500 | Training Loss: 3.9474253516197204
Batch: 600 | Training Loss: 3.9703133726119995
Batch: 700 | Training Loss: 3.9924007661002023
Batch: 800 | Training Loss: 4.014044698178768
Batch: 900 | Training Loss: 4.033193284935422
Batch: 1000 | Training Loss: 4.049390785694122
Batch: 1100 | Training Loss: 4.0628356998617


100%|██████████| 10/10 [18:48<00:00, 112.90s/it]

Training Loss: 4.069798013682149 | Validation Loss: 5.048944719767166





### Checking out the embeddings and relationships between different words

In [27]:
import numpy as np

# embedding from first model layer
embeddings = list(model.parameters())[0]
embeddings = embeddings.cpu().detach().numpy()

# normalization
norms = (embeddings ** 2).sum(axis=1) ** (1 / 2)
norms = np.reshape(norms, (len(norms), 1))
embeddings_norm = embeddings / norms
embeddings_norm.shape

(4099, 300)

In [50]:
def get_top_words(word, n = 5):
    idx = vocab[word]
    if idx == 0:
        print("Out of vocab word")
        return 

    word_embed = embeddings_norm[idx]
    distance = np.matmul(embeddings_norm, word_embed.reshape(-1,1)).flatten()
    descending_order = np.argsort(-1 * distance)[1: n+1]

    top_dict = {}
    for sim_word_id in descending_order:
        sim_word = vocab.lookup_token(sim_word_id)
        top_dict[sim_word] = distance[sim_word_id]
    return top_dict

In [62]:
for word, sim in get_top_words("queen").items():
    if word == None:
        print("Out of vocab")
    else:
        print("{}: {:.3f}".format(word, sim))

desire: 0.239
jane: 0.236
richard: 0.235
house: 0.227
battle: 0.227


### Let us try the famous King - Man + Woman = Queen problem.

In [69]:
unknown_embeddings = embeddings_norm[vocab['king']] - embeddings_norm[vocab['man']] + embeddings_norm[vocab['woman']]

In [73]:
def get_top_words_from_embedding(embedding, n = 5):

    distance = np.matmul(embeddings_norm, embedding.reshape(-1,1)).flatten()
    descending_order = np.argsort(-1 * distance)[1: n+1]

    top_dict = {}
    for sim_word_id in descending_order:
        sim_word = vocab.lookup_token(sim_word_id)
        top_dict[sim_word] = distance[sim_word_id]
    return top_dict

In [74]:
for word, sim in get_top_words_from_embedding(unknown_embeddings).items():
    if word == None:
        print("Out of vocab")
    else:
        print("{}: {:.3f}".format(word, sim))

woman: 0.755
prerogative: 0.386
goddess: 0.385
patrick: 0.359
church: 0.326


### Sadly, we didn't get the Queen that we were looking for. 
### Given that we only trained on a small dataset, it seems pretty reasonable. 
### With a bigger dataset and a more complex model architecture, I believe we can accomplish that task. 