In [None]:
import numpy as np
import re
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset

from tqdm import tqdm

from collections import Counter

from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter

from gensim.utils import simple_preprocess
import spacy

torch.backends.cudnn.benchmark = True
nlp = spacy.load("en_core_web_sm", disable=["parser", "ner", "tagger"])
nlp.max_length = 100_000_0000

In [None]:
class KREmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_dim, sigma=1.0):
        super().__init__()
        self.embedding_weights = nn.Parameter(torch.randn(vocab_size, embedding_dim, dtype = torch.float32, device = device) * (1.0 / embedding_dim ** 0.5))
        self.sigma = sigma

    def forward(self, context, center):
        context_vecs = self.embedding_weights[context] # batch_size * (winlen - 1) * embedding
        center_vec = self.embedding_weights[center] # batch_size * embedding
        diff = context_vecs - center_vec.unsqueeze(1)  # batch_size * (winlen - 1) * embedding
        dist_sq = torch.sum(diff ** 2, dim=2)  # batch_size * (winlen - 1)
        weights = torch.exp(-dist_sq / (2 * self.sigma ** 2))  # batch_size * (winlen - 1)
        weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-8)  # batch_size * (winlen - 1)
        weighted_context = (weights.unsqueeze(2) * context_vecs).sum(dim=1)  # batch_size * embedding

        similarity_matrix = torch.mm(weighted_context, self.embedding_weights.t())
        return similarity_matrix # compare to embeddings and output logits
    
    def getEmbedding(self, id):
        return self.embedding_weights[id]

In [8]:
class CorpusDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [None]:
def preprocessing(text, min_count=5, threshold=1e-5, chunk_size=500000):
    # Stage 1: Fast Gensim cleaning (memory-safe)
    words = simple_preprocess(text, deacc=True, min_len=2)
    
    # Stage 2: Chunked SpaCy lemmatization
    def process_chunk(chunk):
        doc = nlp(" ".join(chunk))
        return [
            token.lemma_ for token in doc
            if token.is_alpha and not token.is_punct and len(token) > 1
        ]
    
    # Split words into chunks to avoid SpaCy memory issues
    lemmatized = []
    for i in range(0, len(words), chunk_size):
        chunk = words[i : i + chunk_size]
        lemmatized.extend(process_chunk(chunk))
    
    # Stage 3: Subsampling (Word2Vec style)
    word_counts = Counter(lemmatized)
    total_words = len(lemmatized)
    subsampled = [
        word for word in lemmatized
        if word_counts[word] >= min_count and 
            np.random.rand() > (1 - np.sqrt(threshold / (word_counts[word]/total_words)))
    ]
    
    # Build vocab
    vocab = list(set(subsampled))
    word2id = {w: i for i, w in enumerate(vocab)}
    id2word = {i: w for i, w in enumerate(vocab)}
    
    return subsampled, word2id, id2word

def generateData(words, word2id, winlen): # winlen must be odd
    vocab_size = len(word2id)
    word_size = len(words)
    batch_size = word_size - winlen + 1
    context_train = np.zeros((batch_size, winlen - 1))
    center_train = np.zeros((batch_size))
    for _ in range(winlen // 2, word_size - winlen // 2):
        fr = _ - winlen // 2
        center_train[fr] = word2id[words[_]]
        for __ in range(_ - winlen // 2, _):
            context_train[fr][__ - (_ - winlen // 2)] = word2id[words[__]]
        for __ in range(_ + 1, _ + winlen // 2 + 1):
            context_train[fr][__ - (_ - winlen // 2) - 1] = word2id[words[__]]
    return torch.tensor(context_train).int(), torch.tensor(center_train).int(), vocab_size, word_size
        
with open("wiki-103.train.tokens", 'r') as f:
    text = f.read()

words, word2id, id2word = preprocessing(text)
context_train, center_train, vocab_size, word_size = generateData(words, word2id, 7)
print(context_train.shape, center_train.shape, vocab_size)

torch.Size([82586969, 6]) torch.Size([82586969]) 89807


In [None]:
def train(model, optimizer, dataloader):
    criterion = nn.CrossEntropyLoss()
    num_epoches = 100

    best_loss = float('inf')
    for epoch in range(num_epoches):
        model.train()
        epoch_loss = 0.0
        
        for batch_context, batch_center in tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epoches}'):
            # Move data to device
            batch_context = batch_context.to(device)
            batch_center = batch_center.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            output = model(batch_context, batch_center)
            target = batch_center.long()
            
            # Compute loss
            loss = criterion(output, target)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item() * len(batch_context)
        
        epoch_loss /= len(context_train)
        if epoch % 5 == 0:
            print(f"Epoch {epoch}, Loss: {epoch_loss}")

        if epoch_loss < best_loss:
            best_loss = epoch_loss
        elif epoch > 20 and loss > best_loss * 1.05: 
            print(f"Early stopping at epoch {epoch}")
            break
    
    return model, optimizer

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
model = KREmbedding(vocab_size, 256, 7).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)

cpu


In [None]:
dataloader = DataLoader(
    dataset = CorpusDataset(context_train, center_train), 
    batch_size = 4096, 
    shuffle = True, 
    num_workers = 4,
    pin_memory = True if torch.cuda.is_available() else False
)
print(len(dataloader))

for name, param in model.named_parameters():
    print(f"{name}: {param.shape} | Device: {param.device}")
model, optimizer = train(model, optimizer, dataloader)

In [None]:
model.eval()
word2embed = {}
for (word, id) in word2id.items():
    embedding = model.getEmbedding(id).detach().to("cpu")
    embedding = embedding / torch.norm(embedding)
    word2embed[word] = embedding

with open("output-cbow.txt", 'w') as f:
    for (word, embed) in word2embed.items():
        f.write(word)
        f.write(str(list(embed.numpy())))
        f.write('\n')

In [None]:
def getClose(target, k = 5):
    sims = []

    cos = nn.CosineSimilarity(dim=0)
    for (word, embed) in word2embed.items():
        sim = cos(embed, target).item()
        sims.append((word, sim))

    res = sorted(sims, key=lambda x : x[1], reverse = True)
    return [_[0] for _ in res[:k]]


import random

correctCount = 0
totalCount = 0
with open("questions-words.txt", 'r') as f:
    qs = f.read()

qs_s = qs.split('\n')
random.shuffle(qs_s)

for q in qs_s[:100]:
    words = q.split()
    try:
        target = word2embed[words[2].lower()] + word2embed[words[1].lower()] - word2embed[words[0].lower()]
        target = target / torch.norm(target)
        ans = getClose(target, 10)
        if words[3].lower() in ans:
            correctCount += 1
        print(words[0].lower(), words[1].lower(), words[2].lower(), ans)
        totalCount += 1
    except KeyError:
        pass

print(correctCount, totalCount)

state = {
    "state_dict" : model.state_dict(),  # model parameters
    "optimizer" : optimizer.state_dict(),  # optimizer state
    "word2id" : word2id, 
    "id2word" : id2word
}
torch.save(state, 'model-kop103.pth')

In [None]:
completer = WordCompleter(list(word2embed.keys()))

while True:
    word = prompt("Word: ", completer = completer)
    cc = int(input("Input closest count: "))
    try:
        embed = word2embed[word]
        print(embed, *getClose(embed, cc))
    except KeyError:
        print("Nonexistent word.")