# Word2Vec


## Preliminaries


In [1]:
%%capture
%pip install -r requirements.txt

In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import numpy as np
from typing import List, Dict, Tuple, Optional
import logging
from pathlib import Path
from datasets import load_dataset, DatasetDict

Create necessary directories:


In [16]:
Path("./data/").mkdir(exist_ok=True)
Path("./models/").mkdir(exist_ok=True)

## Models


In [3]:
class Word2VecModel(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, is_skip_gram: bool = True):
        """
        Word2Vec model implementation in PyTorch

        Args:
            vocab_size: Size of vocabulary
            embedding_dim: Dimension of word embeddings
            is_skip_gram: If True, use Skip-gram model. If False, use CBOW
        """
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.is_skip_gram = is_skip_gram

        # Input embedding layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # Output layer
        self.output = nn.Linear(embedding_dim, vocab_size)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize embeddings and linear layer weights"""
        initrange = 0.5 / self.embedding_dim
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.output.weight.data.uniform_(-initrange, initrange)
        self.output.bias.data.zero_()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the model

        Args:
            x: Input tensor of word indices
               For Skip-gram: shape (batch_size, 1)
               For CBOW: shape (batch_size, context_size)

        Returns:
            Output logits of shape (batch_size, vocab_size)
        """
        if self.is_skip_gram:
            embeds = self.embedding(x).squeeze(1)
        else:
            embeds = self.embedding(x).mean(dim=1)

        return self.output(embeds)

## Dataset


In [4]:
class Word2VecDataset(Dataset):
    def __init__(
        self, texts: List[str], window_size: int, min_count: int, is_skip_gram: bool
    ):
        """
        Dataset for training Word2Vec model

        Args:
            texts: List of tokenized texts
            window_size: Size of context window
            min_count: Minimum frequency for words to be included
            is_skip_gram: If True, generate Skip-gram pairs. If False, generate CBOW pairs
        """
        self.window_size = window_size
        self.is_skip_gram = is_skip_gram

        # Build vocabulary
        word_counts = Counter([word for text in texts for word in text])
        self.vocab = {
            word: idx + 1  # Reserve 0 for padding
            for idx, (word, count) in enumerate(word_counts.items())
            if count >= min_count
        }
        self.vocab["<pad>"] = 0
        self.inverse_vocab = {idx: word for word, idx in self.vocab.items()}

        # Generate training pairs
        self.pairs = self._generate_pairs(texts)

    def _generate_pairs(
        self, texts: List[str]
    ) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """Generate input-target pairs for training"""
        pairs = []

        for text in texts:
            word_indices = [self.vocab.get(word, 0) for word in text]

            for i in range(len(text)):
                # Generate context window
                context_start = max(0, i - self.window_size)
                context_end = min(len(text), i + self.window_size + 1)
                context = (
                    word_indices[context_start:i] + word_indices[i + 1 : context_end]
                )

                if len(context) == 0:
                    continue

                if self.is_skip_gram:
                    # Skip-gram: predict context words from center word
                    center = word_indices[i]
                    for ctx in context:
                        if ctx != 0:  # Skip padding
                            pairs.append((torch.tensor([center]), torch.tensor(ctx)))
                else:
                    # CBOW: predict center word from context words
                    if word_indices[i] == 0:  # Skip padding
                        continue

                    # Pad context to fixed size
                    ctx_size = 2 * self.window_size
                    ctx_padded = context + [0] * (ctx_size - len(context))
                    ctx_padded = ctx_padded[:ctx_size]

                    pairs.append(
                        (torch.tensor(ctx_padded), torch.tensor(word_indices[i]))
                    )

        return pairs

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.pairs[idx]

## Train Loop


In [5]:
class Word2VecTrainer:
    def __init__(
        self,
        embedding_dim: int = 100,
        window_size: int = 5,
        min_count: int = 5,
        batch_size: int = 32,
        epochs: int = 5,
        learning_rate: float = 0.001,
        is_skip_gram: bool = True,
    ):
        """
        Initialize Word2Vec trainer

        Args:
            embedding_dim: Dimension of word embeddings
            window_size: Size of context window
            min_count: Minimum frequency of words to consider
            batch_size: Training batch size
            epochs: Number of training epochs
            learning_rate: Learning rate for optimizer
            is_skip_gram: If True, use Skip-gram model. If False, use CBOW
        """
        self.embedding_dim = embedding_dim
        self.window_size = window_size
        self.min_count = min_count
        self.batch_size = batch_size
        self.epochs = epochs
        self.learning_rate = learning_rate
        self.is_skip_gram = is_skip_gram

        # Setup logging
        logging.basicConfig(
            format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO
        )

    def train(
        self,
        texts: List[List[str]],
        output_path: Optional[str] = None,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ) -> Tuple[Word2VecModel, Dict[str, torch.Tensor]]:
        """
        Train Word2Vec model

        Args:
            texts: List of tokenized texts
            output_path: Optional path to save trained model
            device: Device to train on ('cuda' or 'cpu')

        Returns:
            Trained model and word embeddings dictionary
        """
        # Create dataset
        dataset = Word2VecDataset(
            texts, self.window_size, self.min_count, self.is_skip_gram
        )

        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

        # Initialize model
        model = Word2VecModel(
            len(dataset.vocab), self.embedding_dim, self.is_skip_gram
        ).to(device)

        # Initialize optimizer and loss
        optimizer = optim.Adam(model.parameters(), lr=self.learning_rate)
        criterion = nn.CrossEntropyLoss()

        # Training loop
        print(f"Training Word2Vec model on {len(texts)} texts...")
        model.train()
        for epoch in range(self.epochs):
            total_loss = 0
            for batch_idx, (x, y) in enumerate(dataloader):
                x, y = x.to(device), y.to(device)

                # Forward pass
                optimizer.zero_grad()
                output = model(x)
                loss = criterion(output, y)

                # Backward pass
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

                if (batch_idx + 1) % 100 == 0:
                    print(
                        f"Epoch {epoch + 1}/{self.epochs}, "
                        f"Batch {batch_idx + 1}/{len(dataloader)}, "
                        f"Loss: {total_loss / (batch_idx + 1):.4f}"
                    )

            print(
                f"Epoch {epoch + 1} completed, "
                f"Average Loss: {total_loss / len(dataloader):.4f}"
            )

        # Save model if output path provided
        if output_path:
            output_path = Path(output_path)
            output_path.parent.mkdir(parents=True, exist_ok=True)
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                    "vocab": dataset.vocab,
                    "embedding_dim": self.embedding_dim,
                    "is_skip_gram": self.is_skip_gram,
                },
                output_path,
            )
            print(f"Model saved to {output_path}")

        # Create word embeddings dictionary
        embeddings = {
            word: model.embedding.weight.data[idx].cpu()
            for word, idx in dataset.vocab.items()
        }

        return model, embeddings

    def load_model(
        self, path: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ) -> Tuple[Word2VecModel, Dict[str, int]]:
        """
        Load saved model

        Args:
            path: Path to saved model
            device: Device to load model on

        Returns:
            Loaded model and vocabulary
        """
        checkpoint = torch.load(path, map_location=device)

        model = Word2VecModel(
            len(checkpoint["vocab"]),
            checkpoint["embedding_dim"],
            checkpoint["is_skip_gram"],
        ).to(device)

        model.load_state_dict(checkpoint["model_state_dict"])
        return model, checkpoint["vocab"]

## Utilities


In [6]:
def find_similar(word: str, embeddings, n: int = 5) -> List[Tuple[str, float]]:
    if word not in embeddings:
        return []

    word_embedding = embeddings[word]
    similarities = []

    for w, embed in embeddings.items():
        if w != word:
            cos_sim = nn.functional.cosine_similarity(
                word_embedding.unsqueeze(0), embed.unsqueeze(0)
            )
            similarities.append((w, cos_sim.item()))

    return sorted(similarities, key=lambda x: x[1], reverse=True)[:n]

In [47]:
ds = load_dataset("sdwalker62/TinyStoriesWithValidationSet")
ds = ds.with_format("torch")

`ds` is a `DatasetDict` object that contains three splits; one for training, one for tuning, and the other for our test:


In [49]:
ds

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2098521
    })
    validate: Dataset({
        features: ['text'],
        num_rows: 21198
    })
    test: Dataset({
        features: ['text'],
        num_rows: 21990
    })
})

In [None]:
ds["train"][0]

{'text': 'Timmy was talking to his mom and dad. He said, "Can I play?" His mom said, "No Timmy, it\'s time to do your work." Timmy said, "But I don\'t want to do it!" His dad said, "It\'s ok, Timmy. You can do it. I recommend that you work a little bit." Timmy said, "Ok, dad." \n\nSo, Timmy got his work ready and he worked hard. After a few minutes, he was done and it looked perfect. His dad said, "Very good work, Timmy. You did a great job!" Timmy said, "Thanks, dad!" He was very proud of his work and smiled. \n\nThe end.'}

In [9]:
texts = [
    ["the", "quick", "brown", "fox", "jumps", "over", "the", "lazy", "dog"],
    ["machine", "learning", "is", "a", "subset", "of", "artificial", "intelligence"],
]

# Initialize trainer
trainer = Word2VecTrainer(
    embedding_dim=100,
    window_size=2,
    min_count=1,  # Set to 1 for this small example
    epochs=5,
    is_skip_gram=True,
)

# Train model
model, embeddings = trainer.train(texts, output_path="models/word2vec.pt")

Training Word2Vec model on 2 texts...
Epoch 1 completed, Average Loss: 2.8333
Epoch 2 completed, Average Loss: 2.8324
Epoch 3 completed, Average Loss: 2.8314
Epoch 4 completed, Average Loss: 2.8303
Epoch 5 completed, Average Loss: 2.8290
Model saved to models/word2vec.pt


In [10]:
try:
    similar_words = find_similar("fox", embeddings, n=5)
    print("\nWords most similar to 'fox':")
    for word, score in similar_words:
        print(f"{word}: {score:.4f}")
except KeyError:
    print("Word 'fox' not in vocabulary or insufficient training data")


Words most similar to 'fox':
the: 0.6072
brown: 0.4532
jumps: 0.3941
lazy: 0.2456
quick: 0.1852
