In [8]:
import random
import time
import math
from collections import Counter
from itertools import combinations_with_replacement

import nltk
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from nltk.corpus import reuters
from scipy.stats import spearmanr

# Device Configuration
device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
print(f"Using device: {device}")

# Download necessary NLTK datasets
nltk.download("reuters")
nltk.download("punkt")

# Data Preparation
sample_size = 1000
window_size = 2
min_word_freq = 5  # Frequency threshold for vocabulary trimming


def build_corpus():
    corpus = []
    for file_id in reuters.fileids()[:sample_size]:
        sentences = reuters.words(file_id)
        sentences = [word.lower() for word in sentences if word.isalpha()]
        corpus.append(sentences)
    return corpus


def build_vocab(corpus):
    words = [word for sentence in corpus for word in sentence]
    word_counts = Counter(words)
    vocab = [word for word, count in word_counts.items() if count >= min_word_freq]
    vocab.append("<UNKNOWN>")
    word2index = {word: idx for idx, word in enumerate(vocab)}
    word2index["<UNKNOWN>"] = 0
    return vocab, len(vocab), word2index, word_counts


def build_skipgrams(corpus, word2index, window_size):
    skip_grams = []
    for sentence in corpus:
        for pos, center_word in enumerate(sentence):
            center_idx = word2index.get(center_word, word2index["<UNKNOWN>"])
            context_indices = [
                word2index.get(sentence[i], word2index["<UNKNOWN>"])
                for i in range(max(pos - window_size, 0), min(pos + window_size + 1, len(sentence)))
                if i != pos
            ]
            for context_idx in context_indices:
                skip_grams.append((center_idx, context_idx))
    return skip_grams


def weighting_function(x_ij, x_max=100, alpha=0.75):
    return (x_ij / x_max) ** alpha if x_ij < x_max else 1


# GloVe Model
class GloVe(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(GloVe, self).__init__()
        self.embedding_v = nn.Embedding(vocab_size, embed_size)
        self.embedding_u = nn.Embedding(vocab_size, embed_size)
        self.v_bias = nn.Embedding(vocab_size, 1)
        self.u_bias = nn.Embedding(vocab_size, 1)

    def forward(self, center_words, target_words, co_occurrences, weightings):
        center_embed = self.embedding_v(center_words)
        target_embed = self.embedding_u(target_words)
        center_bias = self.v_bias(center_words).squeeze(1)
        target_bias = self.u_bias(target_words).squeeze(1)
        inner_product = (center_embed * target_embed).sum(dim=1)
        loss = weightings * torch.pow(inner_product + center_bias + target_bias - co_occurrences, 2)
        return loss.mean()


def prepare_training_data(skip_grams, co_occurrence_matrix, word2index):
    training_data = []
    for center, context in skip_grams:
        co_occurrence = co_occurrence_matrix.get((center, context), 1)
        weight = weighting_function(co_occurrence)
        training_data.append((center, context, math.log(co_occurrence + 1), weight))
    return training_data


# Training Function
def train_glove_model(model, training_data, epochs, batch_size, learning_rate):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    for epoch in range(epochs):
        start_time = time.time()
        total_loss = 0
        random.shuffle(training_data)
        for i in range(0, len(training_data), batch_size):
            batch = training_data[i : i + batch_size]
            centers, contexts, coocs, weights = zip(*batch)
            centers = torch.LongTensor(centers).to(device)
            contexts = torch.LongTensor(contexts).to(device)
            coocs = torch.FloatTensor(coocs).to(device)
            weights = torch.FloatTensor(weights).to(device)
            optimizer.zero_grad()
            loss = model(centers, contexts, coocs, weights)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}, Time: {time.time() - start_time:.2f}s")


# Main Execution
corpus = build_corpus()
vocab, vocab_size, word2index, word_counts = build_vocab(corpus)
skip_grams = build_skipgrams(corpus, word2index, window_size)
co_occurrence_matrix = Counter(skip_grams)
training_data = prepare_training_data(skip_grams, co_occurrence_matrix, word2index)

embedding_dim = 100
epochs = 10
batch_size = 128
learning_rate = 0.001

model = GloVe(vocab_size, embedding_dim).to(device)
train_glove_model(model, training_data, epochs, batch_size, learning_rate)

# Save the model
torch.save({"model_state_dict": model.state_dict(), "word2index": word2index, "vocab": vocab}, "glove_model.pth")
print("Model and vocabulary saved to glove_model.pth")


Using device: mps


[nltk_data] Downloading package reuters to
[nltk_data]     /Users/silanm/Developer/nltk_data...
[nltk_data]   Package reuters is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/silanm/Developer/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Epoch 1, Loss: 32895.9103, Time: 7.85s
Epoch 2, Loss: 5698.0985, Time: 7.67s
Epoch 3, Loss: 2316.7251, Time: 7.61s
Epoch 4, Loss: 1125.0549, Time: 7.65s
Epoch 5, Loss: 601.5070, Time: 7.60s
Epoch 6, Loss: 344.5856, Time: 7.63s
Epoch 7, Loss: 211.7330, Time: 7.64s
Epoch 8, Loss: 140.3280, Time: 7.62s
Epoch 9, Loss: 101.4641, Time: 7.61s
Epoch 10, Loss: 79.2047, Time: 7.56s
Model and vocabulary saved to glove_model.pth
