# Problem 8: Word Vectors - PyTorch NN (Text8)

**Dataset:** Text8 (200K words)  
**Method:** Neural network trained on PPMI matrix (PyTorch + SGD)

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import Counter
import re
from tqdm import tqdm

## 1. Preprocessing

In [2]:
def preprocess_text(text):
    text = text.lower()
    text = re.sub(r"[^a-z ]+", ' ', text)
    text = re.sub(r'\s+', ' ', text)
    return text.split()

def build_vocabulary(words, vocab_size=5000, min_count=3):
    stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'from',
                  'is', 'was', 'were', 'be', 'been', 'have', 'has', 'had', 'this', 'that', 'it', 'as'}
    
    word_counts = Counter(words)
    filtered = [(w, c) for w, c in word_counts.items() 
                if c >= min_count and len(w) >= 3 and w not in stop_words]
    
    most_common = sorted(filtered, key=lambda x: x[1], reverse=True)[:vocab_size]
    word_to_id = {word: idx for idx, (word, _) in enumerate(most_common)}
    id_to_word = {idx: word for word, idx in word_to_id.items()}
    corpus = [word_to_id[word] for word in words if word in word_to_id]
    
    return word_to_id, id_to_word, corpus

## 2. Co-occurrence Matrix

In [3]:
def build_cooccurrence_matrix(corpus, vocab_size, window_size=5):
    cooccur = torch.zeros((vocab_size, vocab_size), dtype=torch.float32)
    
    for i in tqdm(range(len(corpus)), desc="Building co-occurrence"):
        center = corpus[i]
        start = max(0, i - window_size)
        end = min(len(corpus), i + window_size + 1)
        
        for j in range(start, end):
            if i != j:
                context = corpus[j]
                distance = abs(i - j)
                weight = 1.0 / distance
                cooccur[center, context] += weight
    
    return cooccur

## 3. PPMI Computation

In [4]:
def compute_ppmi(cooccur_matrix):
    total = cooccur_matrix.sum()
    word_counts = cooccur_matrix.sum(dim=1)
    context_counts = cooccur_matrix.sum(dim=0)
    
    ppmi = torch.zeros_like(cooccur_matrix)
    
    for i in tqdm(range(cooccur_matrix.shape[0]), desc="Computing PPMI"):
        for j in range(cooccur_matrix.shape[1]):
            if cooccur_matrix[i, j] > 0:
                p_ij = cooccur_matrix[i, j] / total
                p_i = word_counts[i] / total
                p_j = context_counts[j] / total
                
                pmi = torch.log(p_ij / (p_i * p_j + 1e-10))
                ppmi[i, j] = torch.clamp(pmi, min=0)
    
    ppmi.fill_diagonal_(0)
    row_sums = ppmi.sum(dim=1, keepdim=True)
    ppmi_normalized = ppmi / (row_sums + 1e-10)
    
    return ppmi_normalized

## 4. Neural Network

In [5]:
class WordVectorNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.W1 = nn.Linear(vocab_size, embedding_dim, bias=False)
        self.W2 = nn.Linear(embedding_dim, vocab_size, bias=False)
        
        nn.init.normal_(self.W1.weight, mean=0, std=0.1)
        nn.init.normal_(self.W2.weight, mean=0, std=0.1)
    
    def forward(self, one_hot):
        hidden = self.W1(one_hot)
        output = self.W2(hidden)
        probs = torch.softmax(output, dim=-1)
        return probs
    
    def get_embeddings(self):
        return self.W1.weight.t().detach().cpu().numpy()

## 5. Training

In [6]:
def manual_ce_loss(probs, target):
    return -torch.sum(target * torch.log(probs + 1e-10))

def train_model(model, ppmi_matrix, optimizer, epochs=10, device='cpu'):
    model = model.to(device)
    ppmi_matrix = ppmi_matrix.to(device)
    vocab_size = ppmi_matrix.shape[0]
    
    for epoch in range(epochs):
        total_loss = 0
        indices = torch.randperm(vocab_size)
        
        for idx in tqdm(indices, desc=f"Epoch {epoch+1}/{epochs}"):
            one_hot = torch.zeros(vocab_size, device=device)
            one_hot[idx] = 1
            target = ppmi_matrix[idx]
            
            optimizer.zero_grad()
            probs = model(one_hot)
            loss = manual_ce_loss(probs, target)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / vocab_size
        print(f"Epoch {epoch+1}/{epochs}, Avg Loss: {avg_loss:.4f}")

## 6. Evaluation

In [7]:
def cosine_similarity(v1, v2):
    norm1 = np.linalg.norm(v1)
    norm2 = np.linalg.norm(v2)
    if norm1 == 0 or norm2 == 0:
        return 0
    return np.dot(v1, v2) / (norm1 * norm2)

def find_similar_words(word, word_to_id, id_to_word, embeddings, top_k=15):
    if word not in word_to_id:
        return None
    
    word_idx = word_to_id[word]
    word_emb = embeddings[word_idx]
    
    similarities = []
    for idx in range(len(embeddings)):
        if idx != word_idx:
            other_emb = embeddings[idx]
            sim = cosine_similarity(word_emb, other_emb)
            similarities.append((id_to_word[idx], sim))
    
    similarities.sort(key=lambda x: x[1], reverse=True)
    return similarities[:top_k]

def evaluate_model(embeddings, test_words, word_to_id, id_to_word):
    for word in test_words:
        similar = find_similar_words(word, word_to_id, id_to_word, embeddings)
        if similar:
            print(f"\n{word.upper()}:")
            for w, sim in similar:
                print(f"  {w:20s} {sim:.4f}")

## 7. Run Pipeline

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

print("\nLoading Text8...")
with open('../text8_200K.txt', 'r') as f:
    text = f.read()

words = preprocess_text(text)
print(f"Total words: {len(words):,}")

print("\nBuilding vocabulary...")
word_to_id, id_to_word, corpus = build_vocabulary(words, vocab_size=5000)
print(f"Vocab: {len(word_to_id):,}, Corpus: {len(corpus):,}")

print("\nBuilding co-occurrence matrix...")
cooccur = build_cooccurrence_matrix(corpus, len(word_to_id), window_size=5)

print("\nComputing PPMI...")
ppmi = compute_ppmi(cooccur)

print("\nInitializing model...")
model = WordVectorNN(vocab_size=len(word_to_id), embedding_dim=200)
optimizer = optim.SGD(model.parameters(), lr=0.789)

print("\nTraining...")
train_model(model, ppmi, optimizer, epochs=10, device=device)

print("\nExtracting embeddings...")
embeddings = model.get_embeddings()
print(f"Embeddings shape: {embeddings.shape}")

print("\n" + "="*60)
print("EVALUATION")
print("="*60)
test_words = ["china", "computer", "phone", "napoleon", "god", "catholic"]
evaluate_model(embeddings, test_words, word_to_id, id_to_word)

Device: cpu

Loading Text8...
Total words: 199,999

Building vocabulary...
Vocab: 5,000, Corpus: 115,831

Building co-occurrence matrix...


Building co-occurrence: 100%|██████████| 115831/115831 [00:03<00:00, 35132.70it/s]



Computing PPMI...


Computing PPMI: 100%|██████████| 5000/5000 [00:59<00:00, 84.35it/s]



Initializing model...

Training...


Epoch 1/10: 100%|██████████| 5000/5000 [00:05<00:00, 965.84it/s]


Epoch 1/10, Avg Loss: 8.5270


Epoch 2/10: 100%|██████████| 5000/5000 [00:05<00:00, 968.11it/s]


Epoch 2/10, Avg Loss: 8.4681


Epoch 3/10: 100%|██████████| 5000/5000 [00:05<00:00, 942.01it/s]


Epoch 3/10, Avg Loss: 8.4030


Epoch 4/10: 100%|██████████| 5000/5000 [00:05<00:00, 956.70it/s]


Epoch 4/10, Avg Loss: 8.3180


Epoch 5/10: 100%|██████████| 5000/5000 [00:05<00:00, 942.23it/s]


Epoch 5/10, Avg Loss: 8.1910


Epoch 6/10: 100%|██████████| 5000/5000 [00:05<00:00, 942.83it/s]


Epoch 6/10, Avg Loss: 7.9882


Epoch 7/10: 100%|██████████| 5000/5000 [00:05<00:00, 954.88it/s]


Epoch 7/10, Avg Loss: 7.6973


Epoch 8/10: 100%|██████████| 5000/5000 [00:05<00:00, 945.24it/s]


Epoch 8/10, Avg Loss: 7.3849


Epoch 9/10: 100%|██████████| 5000/5000 [00:05<00:00, 941.10it/s]


Epoch 9/10, Avg Loss: 7.0644


Epoch 10/10: 100%|██████████| 5000/5000 [00:05<00:00, 959.35it/s]

Epoch 10/10, Avg Loss: 6.7417

Extracting embeddings...
Embeddings shape: (5000, 200)

EVALUATION

CHINA:
  japan                0.6511
  korea                0.6141
  vietnam              0.6137
  india                0.6046
  myanmar              0.6027
  singapore            0.6023
  mongolia             0.5929
  thailand             0.5914
  cambodia             0.5649
  malaysia             0.5616
  laos                 0.5265
  southeast            0.5148
  buddhism             0.5146
  siberia              0.5063
  indonesia            0.4978

COMPUTER:
  ask                  0.3731
  attitude             0.3393
  representation       0.3201
  why                  0.3155
  animated             0.3088
  true                 0.3078
  internet             0.3023
  comments             0.2973
  variants             0.2955
  windows              0.2938
  design               0.2912
  realized             0.2870
  plutarch             0.2867
  sequences            0.2867
  enter      


