In [1]:
import numpy
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
import nltk
from nltk.corpus import stopwords
import string

import random
# Set seed for reproducibility
seed = 2542
random.seed(seed)

from tqdm import tqdm

In [2]:
torch.cuda.is_available() #check for CUDA

True

In [3]:
# Set the custom NLTK data directory
nltk_data_dir = '../data/external/'
if not os.path.exists(nltk_data_dir):
    os.makedirs(nltk_data_dir)
nltk.data.path.append(nltk_data_dir)

# Download the stopwords into the specified directory
nltk.download('stopwords', download_dir=nltk_data_dir)

[nltk_data] Downloading package stopwords to ../data/external/...
[nltk_data]   Package stopwords is already up-to-date!


True

#### pre-processing

In [4]:
# Define the stopwords
stop_words = set(stopwords.words('english'))

In [5]:
with open('../data/raw/tiny-shakespeare.txt', 'r') as file:
    # Read the entire file content as a single string
    words = file.read()

In [6]:
# Split the text into words
# Remove punctuation
translator = str.maketrans('', '', string.punctuation)
words = words.translate(translator)

words = words.lower().split()

# Remove stopwords and punctuation
cleaned_words = [word for word in words if word.lower() not in stop_words and word not in string.punctuation]

In [7]:
unique_words = list(set(cleaned_words))

# Initialize the word_to_ix dictionary
word_to_ix = {word: idx for idx, word in enumerate(unique_words)}

# Initialize the ix_to_word dictionary by inverting word_to_ix
ix_to_word = {idx: word for word, idx in word_to_ix.items()}

cleaned_words_to_ix = [word_to_ix[word] for word in cleaned_words]

#### Dataset Hyperparameters

In [8]:
vocab_length = len(unique_words) #12178
print(vocab_length)

context_window = 2 #each side
n_negatives = 4

12718


#### Prepare dataset

In [9]:
dataset = []
vocab_set = set(word_to_ix.values())

for text_index in tqdm(range(vocab_length, len(cleaned_words) - vocab_length)):
    center_word_indexes = [cleaned_words_to_ix[text_index]]# * (2 * context_window)
    
    context_word_indexes = [
        cleaned_words_to_ix[text_index - context_index] 
        for context_index in range(1, context_window + 1)
    ] + [
        cleaned_words_to_ix[text_index + context_index] 
        for context_index in range(1, context_window + 1)
    ]
    
    negative_sample_set = vocab_set - set(context_word_indexes)
    negative_sample_indexes = random.sample(negative_sample_set, n_negatives)

    dataset.append([center_word_indexes, context_word_indexes, negative_sample_indexes])

since Python 3.9 and will be removed in a subsequent version.
  negative_sample_indexes = random.sample(negative_sample_set, n_negatives)
100%|███████████████████████████████████| 81939/81939 [00:11<00:00, 7351.62it/s]


In [10]:
dataset[0]

[[6202], [8689, 1146, 4299, 991], [1034, 1112, 10359, 2202]]

In [11]:
#Define dataset class for data loader
class Word2VecDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        center_word_indexes, context_word_indexes, negative_sample_indexes = self.data[idx]
        return (
            torch.tensor(center_word_indexes, dtype=torch.long),
            torch.tensor(context_word_indexes, dtype=torch.long),
            torch.tensor(negative_sample_indexes, dtype=torch.long)
        )

#### Word2Vec Class

In [12]:
class Word2Vec(nn.Module):
    def __init__(self, vocab_length, vector_length):
        super().__init__()

        # Vector Embeddings
        self.center_embeddings = nn.Embedding(vocab_length, vector_length)
        self.context_embeddings = nn.Embedding(vocab_length, vector_length)
        
        # Move all model parameters to GPU if available
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

    def forward(self, center_indices, context_indices, negative_indices):
        # Get embeddings
        center_embeddings = self.center_embeddings(center_indices)  
        context_embeddings = self.context_embeddings(context_indices)  
        negative_embeddings = self.context_embeddings(negative_indices)  

        # Compute dot product for positive pairs
        positive_dot_products = torch.bmm(center_embeddings, context_embeddings.transpose(1, 2)).squeeze(1)  

        # Compute dot product for negative pairs
        negative_dot_products = torch.bmm(center_embeddings, negative_embeddings.transpose(1, 2)).squeeze(1)  

        return positive_dot_products, negative_dot_products

embedding_dim = 100  # Length of word vector
word2vec_model = Word2Vec(vocab_length, embedding_dim)

In [16]:
def train(model, data_loader, optimizer, criterion, device, epoch):
    model.train()
    total_loss = 0

    # Wrap the data_loader with tqdm to show the progress bar
    progress_bar = tqdm(data_loader, desc=f'Epoch {epoch}', leave=False)
    
    for center_indices, context_indices, negative_indices in progress_bar:
        print(center_indices.shape)
        center_indices, context_indices, negative_indices = center_indices.to(device), context_indices.to(device), negative_indices.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        pos_scores, neg_scores = model(center_indices, context_indices, negative_indices)
        
        # True labels - 1s for positive samples, 0s for negative samples
        positive_labels = torch.ones_like(pos_scores)
        negative_labels = torch.zeros_like(neg_scores)
        
        # Calculate loss
        loss_pos = criterion(pos_scores, positive_labels)
        loss_neg = criterion(neg_scores, negative_labels)
        loss = loss_pos + loss_neg
        
        # Backward and optimize
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss

In [17]:
# Create the word2vec dataset and DataLoader
word2vec_dataset = Word2VecDataset(dataset)
batch_size = 256 # You can adjust the batch size according to your needs
data_loader = DataLoader(word2vec_dataset, batch_size=batch_size, shuffle=True)

#Model Hyperparameters
lr = 1e-4

optimizer = torch.optim.Adam(word2vec_model.parameters(), lr=lr)

# Loss function
criterion = nn.BCEWithLogitsLoss()

#Epochs
epochs = 1000

#### Notes to self

#### Look at PyTorch Broadcasting Documentation -> \<operation> a row vector to all rows in a matrix

#### Training

In [18]:
lowest_loss = float('inf')

#for epoch in range(epochs):
for epoch in range(1):
    it_loss = train(model=word2vec_model, data_loader=data_loader, optimizer=optimizer, criterion=criterion,
                    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), epoch=epoch)
    print(f"Epoch {epoch}: Loss {it_loss}")

    # Check if the current epoch's loss is the lowest and save the model
    if it_loss < lowest_loss:
        lowest_loss = it_loss
        torch.save(word2vec_model.state_dict(), '../models/word2vec/best_model.pth')
        print(f"Saved new best model with loss: {lowest_loss}")

Epoch 0:  11%|███▍                            | 35/321 [00:00<00:01, 175.11it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


Epoch 0:  24%|███████▋                        | 77/321 [00:00<00:01, 196.06it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


Epoch 0:  37%|███████████▍                   | 119/321 [00:00<00:00, 203.77it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


Epoch 0:  50%|███████████████▌               | 161/321 [00:00<00:00, 201.69it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


Epoch 0:  57%|█████████████████▌             | 182/321 [00:00<00:00, 194.17it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


Epoch 0:  69%|█████████████████████▌         | 223/321 [00:01<00:00, 194.05it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


Epoch 0:  83%|█████████████████████████▌     | 265/321 [00:01<00:00, 200.28it/s]

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])


                                                                                

torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([256, 1])
torch.Size([19, 1])
Epoch 0: Loss 2561.9285049438477




Saved new best model with loss: 2561.9285049438477
