In [1]:
from collections import Counter
from pathlib import Path
import urllib.request
import zipfile

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl

In [2]:
# Download data and extract to ../data/text8/text8
url = "http://mattmahoney.net/dc/text8.zip"
data_dir = Path("../data")
zip_path = data_dir / "text8.zip"

if not zip_path.exists():
    print("Downloading text8.zip ...")
    urllib.request.urlretrieve(url, zip_path)

# Extract if needed
extracted_file = data_dir / "text8"     # the corpus file (no extension)
if not extracted_file.exists():
    print(f"Unzipping to {extract_dir} ...")
    with zipfile.ZipFile(zip_path) as zf:
        zf.extractall(extract_dir)
    print("Done.")
else:
    print("text8 already extracted.")

print("Corpus path:", extracted_file.resolve())

text8 already extracted.
Corpus path: /mnt/custom-file-systems/efs/fs-038956e0ab0e15389_fsap-093b86bb45d55576e/data/text8


In [3]:
# Convert corpus to tokens
with open(extracted_file, "r") as handle:
    corpus = handle.readline()

tokens = corpus.lower().split()
print(f"Raw tokens: {len(tokens):,}")

Raw tokens: 17,005,207


In [4]:
# Build vocabulary with minimum count cut-off
min_count = 5
freq = Counter(tokens)
itos = [w for w, c in freq.items() if c >= min_count]
stoi = {w: i for i, w in enumerate(itos)}
counts = np.array([freq[w] for w in itos], dtype=np.int64)
print(f"Vocab size (min_count={min_count}): {len(itos):,}")

Vocab size (min_count=5): 71,290


In [5]:
# Subsampling of frequent words (Mikolov et al.)
# https://arxiv.org/pdf/1310.4546
t = 1e-5
total = counts.sum()
freqs = counts / total
p_keep = np.minimum(1.0, np.sqrt(t / freqs))

rng = np.random.default_rng(0)
ids = []
for w in tokens:
    if w in stoi:
        wid = stoi[w]
        if rng.random() < p_keep[wid]:
            ids.append(wid)

ids = np.array(ids, dtype=np.int64)
print(f"After subsampling: {len(ids):,} tokens")

After subsampling: 4,670,984 tokens


In [6]:
# Negative sampling distribution (unigram^0.75)
unigram = counts ** 0.75
unigram = unigram / unigram.sum()
cdf = np.cumsum(unigram)  # for fast sampling via inverse CDF


def draw_negatives(k, forbidden_ids, rng):
    """Sample k negatives not in forbidden_ids."""
    out = []
    while len(out) < k:
        r = rng.random()
        wid = int(np.searchsorted(cdf, r))
        if wid not in forbidden_ids:
            out.append(wid)
    return np.array(out, dtype=np.int64)

In [7]:
# PyTorch Dataset: Skip-Gram with dynamic window + on-the-fly negatives
class SkipGramNSDataset(Dataset):
    def __init__(self, word_ids, window_max=5, num_negatives=5, seed=1234):
        self.word_ids = np.asarray(word_ids, dtype=np.int64)
        self.window_max = window_max
        self.num_negatives = num_negatives
        self.rng = np.random.default_rng(seed)

    def __len__(self):
        # Each index returns one (center, pos, negs) triple (random positive)
        return len(self.word_ids)

    def __getitem__(self, i):
        center = self.word_ids[i]
        R = int(self.rng.integers(1, self.window_max + 1))  # dynamic window
        left = max(0, i - R)
        right = min(len(self.word_ids), i + R + 1)

        # Choose a positive context word at random from the window (excluding center)
        window = self.word_ids[left:right]
        if len(window) <= 1:
            # Edge case (rare): no context; resample a nearby index
            j = int(self.rng.integers(0, len(self.word_ids)))
            return self.__getitem__(j)

        # Exclude the center position i
        # Compute the relative index within window
        center_rel = i - left
        candidates = np.delete(window, center_rel)
        pos = int(self.rng.choice(candidates))

        # Draw negative samples (avoid center & pos)
        negs = draw_negatives(self.num_negatives, {center, pos}, self.rng)

        return int(center), int(pos), negs

In [8]:
def collate_batch(batch):
    centers, positives, negatives = zip(*batch)
    centers = torch.tensor(centers, dtype=torch.long)
    positives = torch.tensor(positives, dtype=torch.long)
    negatives = torch.tensor(np.stack(negatives), dtype=torch.long)  # [B, K]
    return centers, positives, negatives

In [9]:
dataset = SkipGramNSDataset(ids, window_max=5, num_negatives=5, seed=2024)
loader = DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=0, collate_fn=collate_batch)

In [10]:
# Quick sanity check
centers, positives, negatives = next(iter(loader))
print("Batch shapes:", centers.shape, positives.shape, negatives.shape)
print("Example:", centers[0].item(), positives[0].item(), negatives[0][:5].tolist())

Batch shapes: torch.Size([1024]) torch.Size([1024]) torch.Size([1024, 5])
Example: 2002 2912 [439, 1294, 22165, 69327, 231]


In [11]:
# LightningModule for Skip-Gram + Negative Sampling
class SkipGramNegativeSampling(pl.LightningModule):
    """
    Skip-Gram with Negative Sampling (SGNS).
    - Two embedding tables: input (for centers) and output (for contexts).
    - Loss: -log σ(c·p) - Σ_k log σ(-c·n_k)
    """

    def __init__(self, vocab_size: int, dim: int = 300, lr: float = 2.5e-3, 
                 num_negatives: int = 5):
        super().__init__()
        self.save_hyperparameters()
        self.vocab_size = vocab_size
        self.dim = dim
        self.lr = lr
        self.num_negatives = num_negatives

        # Input (center) embeddings and Output (context) embeddings
        self.in_embed = nn.Embedding(vocab_size, dim)
        self.out_embed = nn.Embedding(vocab_size, dim)

        self.reset_parameters()

    def reset_parameters(self):
        # word2vec-style init: small uniform
        bound = 0.5 / self.dim
        nn.init.uniform_(self.in_embed.weight,  -bound, bound)
        nn.init.zeros_(self.out_embed.weight)  # often initialized at 0 for out vectors

    @torch.no_grad()
    def get_embeddings(self, normalize: bool = True) -> torch.Tensor:
        """Return the input embeddings (what you typically export)."""
        E = self.in_embed.weight
        if normalize:
            E = F.normalize(E, dim=1)
        return E

    def forward(self, centers: torch.LongTensor, contexts: torch.LongTensor, 
                negatives: torch.LongTensor):
        """
        centers:   [B]
        contexts:  [B]
        negatives: [B, K]
        Returns scores for positive and negative pairs.
        """
        c = self.in_embed(centers)          # [B, D]
        p = self.out_embed(contexts)        # [B, D]
        n = self.out_embed(negatives)       # [B, K, D]

        # Positive scores: dot(c, p)
        pos_score = (c * p).sum(dim=-1)     # [B]

        # Negative scores: dot(c, n_k) for each k
        # Expand c to [B, 1, D] to broadcast across K negatives
        neg_score = (n * c.unsqueeze(1)).sum(dim=-1)  # [B, K]

        return pos_score, neg_score

    def sgns_loss(self, pos_score, neg_score):
        # -log σ(pos) - sum log σ(-neg)
        loss_pos = F.logsigmoid(pos_score)              # [B]
        loss_neg = F.logsigmoid(-neg_score).sum(dim=1)  # [B]
        loss = -(loss_pos + loss_neg).mean()
        return loss

    def training_step(self, batch, batch_idx):
        centers, positives, negatives = batch   # shapes: [B], [B], [B,K]
        pos_score, neg_score = self(centers, positives, negatives)
        loss = self.sgns_loss(pos_score, neg_score)

        # (Optional) A couple of easy diagnostics to ensure things are sane
        with torch.no_grad():
            # Probability-like metrics just for logging intuition
            pos_prob = torch.sigmoid(pos_score).mean()
            neg_prob = torch.sigmoid(neg_score).mean()

        self.log("train/loss", loss, prog_bar=True, on_step=True,
                 on_epoch=True)
        self.log("train/pos_prob", pos_prob, prog_bar=False, on_step=True,
                 on_epoch=True)
        self.log("train/neg_prob", neg_prob, prog_bar=False, on_step=True,
                 on_epoch=True)
        return loss

    def configure_optimizers(self):
        # Adam is perfectly fine here (sparse updates are optional)
        opt = torch.optim.Adam(self.parameters(), lr=self.lr)
        return opt

In [12]:
vocab_size = len(itos)
model = SkipGramNegativeSampling(vocab_size=vocab_size, dim=300, lr=2.5e-3,
                                 num_negatives=5)

trainer = pl.Trainer(
    max_epochs=3,
    accelerator="auto",
    devices="auto",
    precision="32-true",   # "16-mixed" works too if you like
    log_every_n_steps=25,
)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [13]:
trainer.fit(model, train_dataloaders=loader)

2025-10-24 06:08:34.739581: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-24 06:08:34.753756: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761286114.771977    3889 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761286114.777655    3889 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-10-24 06:08:34.795727: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
