# NAOMI-II: Semantic-Driven Transparent Dimension Training

**Revolutionary approach: Dimensions emerge through semantic clustering!**

Uses your existing pre-parsed Wikipedia + WordNet data:
- **197K words** (WordNet + Wikipedia vocabulary)
- **1.79M triples** (1.56M parse + 230K WordNet)
- **971 antonym pairs** for semantic axis discovery
- **100K Wikipedia sentences** (already parsed)

## Key Innovation

**Old approach (WRONG):**
- Force 15 "polarity dimensions"
- Train for opposition as goal
- Sparsity as penalty

**New approach (RIGHT):**
- ALL dimensions = semantic axes (morality, temperature, size, etc.)
- Words positioned by meaning
- Opposition emerges from opposite meanings
- Sparsity emerges from semantic irrelevance

## Architecture

**Semantic-driven embeddings:**
- `embeddings`: Position on each semantic axis
- `relevance`: Which axes matter for this word's meaning
- Final embedding = embeddings × relevance (automatic semantic sparsity!)

**Example:**
```
"good" on morality axis: value=+0.8, relevance=1.0 → +0.8 ✓
"good" on temperature: value=0.1, relevance=0.0 → 0.0 ✓ (semantically correct!)
```

## Expected Results

- 10-30 semantic axes emerge naturally
- Each axis = one global concept (morality, size, temperature, politics, quality, etc.)
- Words activate 5-20 axes on average (semantic sparsity)
- Antonyms naturally oppose on shared relevant axes
- Clear, interpretable dimensional structure

**Runtime:** ~10-12 hours on A100 GPU

---

## Prerequisites

Upload to Google Drive at `/NAOMI-II-data/wikipedia_100k_graph/`:
- `vocabulary.json` (197K words)
- `triples.pkl` (1.79M triples)
- `training_examples.pkl` (1.66M examples)
- `graph_stats.json`

## Step 1: Setup and Load Pre-Parsed Data

In [None]:
# Install dependencies
!pip install -q torch numpy nltk tqdm scikit-learn matplotlib

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
import torch.optim as optim

import numpy as np
import json
import pickle
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
import time
import os
import random

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Paths
DATA_PATH = "/content/drive/MyDrive/NAOMI-II-data/wikipedia_100k_graph"
RESULTS_DIR = "/content/drive/MyDrive/NAOMI-II-results/semantic_dims"
CHECKPOINT_DIR = "/content/checkpoints"

os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Verify GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"✓ GPU: {gpu_name} ({vram_gb:.1f} GB)")
    if 'A100' not in gpu_name:
        print(f"⚠️  WARNING: Not A100! Training will be slower.")
else:
    print("⚠️  WARNING: No GPU detected!")

print(f"✓ Device: {device}")
print(f"✓ Data path: {DATA_PATH}")
print(f"✓ Results path: {RESULTS_DIR}")

Mounted at /content/drive
✓ GPU: NVIDIA A100-SXM4-80GB (85.2 GB)
✓ Device: cuda
✓ Data path: /content/drive/MyDrive/NAOMI-II-data/wikipedia_100k_graph
✓ Results path: /content/drive/MyDrive/NAOMI-II-results/semantic_dims


In [None]:
# Load pre-parsed data from Google Drive
print("="*70)
print("LOADING PRE-PARSED WIKIPEDIA + WORDNET DATA")
print("="*70)

# Load vocabulary
print("\n[1/4] Loading vocabulary...")
with open(f"{DATA_PATH}/vocabulary.json", 'r') as f:
    vocab_data = json.load(f)
    word_to_id = vocab_data['word_to_id']
    vocabulary = list(word_to_id.keys())
    id_to_word = {idx: word for word, idx in word_to_id.items()}

print(f"  ✓ Vocabulary: {len(vocabulary):,} words")

# Load triples
print("\n[2/4] Loading triples...")
with open(f"{DATA_PATH}/triples.pkl", 'rb') as f:
    triples = pickle.load(f)

print(f"  ✓ Triples: {len(triples):,}")

# Load training examples
print("\n[3/4] Loading training examples...")
with open(f"{DATA_PATH}/training_examples.pkl", 'rb') as f:
    training_examples = pickle.load(f)

print(f"  ✓ Training examples: {len(training_examples):,}")

# Extract antonym pairs
print("\n[4/4] Extracting antonym pairs...")
antonym_pairs = []
for triple in triples:
    source_word, relation, target_word = triple
    if 'antonym' in relation.lower():
        antonym_pairs.append((source_word, target_word))

print(f"  ✓ Antonym pairs: {len(antonym_pairs):,}")

# Load stats
with open(f"{DATA_PATH}/graph_stats.json", 'r') as f:
    stats = json.load(f)

print("\n" + "="*70)
print("DATA LOADED SUCCESSFULLY")
print("="*70)
print(f"  Vocabulary: {stats['vocabulary_size']:,} sense-tagged words")
print(f"  Triples: {stats['num_triples']:,}")
print(f"  Training examples: {stats['num_training_examples']:,}")
print(f"  Source sentences: {stats['num_sentences']:,}")
print(f"  Antonym pairs: {len(antonym_pairs):,} ← KEY FOR SEMANTIC AXES!")
print("="*70)

# Sample antonyms
print("\nSample antonym pairs:")
for word1, word2 in antonym_pairs[:10]:
    w1_display = word1.split('_wn.')[0] if '_wn.' in word1 else word1
    w2_display = word2.split('_wn.')[0] if '_wn.' in word2 else word2
    print(f"  {w1_display} ↔ {w2_display}")

LOADING PRE-PARSED WIKIPEDIA + WORDNET DATA

[1/4] Loading vocabulary...
  ✓ Vocabulary: 197,095 words

[2/4] Loading triples...
  ✓ Triples: 1,785,340

[3/4] Loading training examples...
  ✓ Training examples: 1,662,260

[4/4] Extracting antonym pairs...
  ✓ Antonym pairs: 971

DATA LOADED SUCCESSFULLY
  Vocabulary: 197,095 sense-tagged words
  Triples: 1,785,340
  Training examples: 1,662,260
  Source sentences: 100,000
  Antonym pairs: 971 ← KEY FOR SEMANTIC AXES!

Sample antonym pairs:
  comparably ↔ incomparably
  leeward ↔ windward
  noblewoman ↔ nobleman
  cash ↔ credit
  diapsid ↔ anapsid
  hardware ↔ software
  ascent ↔ descent
  complexity ↔ simplicity
  intelligence ↔ stupidity
  king ↔ queen


## Step 2: Initialize Semantic-Driven Model

**Key innovation:** Each word learns which semantic axes are relevant to its meaning!

In [None]:
# Configuration
CONFIG = {
    'vocab_size': len(vocabulary),
    'embedding_dim': 256,
    'num_anchors': 51,
    'epochs': 150,
    'batch_size': 262144,
    'lr': 0.01,

    # Semantic-driven weights
    'semantic_clustering_weight': 2.0,
    'relevance_coherence_weight': 1.0,
    'target_dims_per_word': 10.0,      # Target: 10 dimensions per word on average
    'relevance_sparsity_weight': 0.05,  # Increased from 0.01 - stronger push toward target
    'relevance_commitment_weight': 0.5,  # Force binary commitment
    'reg_weight': 0.05,
    'parse_weight': 0.3,
    'wordnet_weight': 0.7,
    'patience': 30,
    'mixed_precision': True,
}

print("Configuration:")
print("="*70)
for key, value in CONFIG.items():
    print(f"  {key}: {value}")
print("="*70)

# Model
class SemanticTransparentEmbedding(nn.Module):
    """
    Embeddings with semantic-driven sparsity.

    Key components:
    - embeddings: Position of word on each semantic axis
    - relevance_logits: Which axes are relevant to this word's meaning
    - Final embedding = embeddings * sigmoid(relevance)

    Example:
      'good' on morality axis: embedding=+0.8, relevance=1.0 → +0.8
      'good' on temperature axis: embedding=0.1, relevance=0.0 → 0.0
    """
    def __init__(self, vocab_size, embedding_dim, num_anchors):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.num_anchors = num_anchors

        # Core embeddings (position on each semantic axis)
        self.embeddings = nn.Parameter(torch.randn(vocab_size, embedding_dim) * 0.01)

        # Dimension relevance (which axes matter for this word?)
        # Initialize with negative bias for sparsity
        # sigmoid(-3.0) ≈ 0.047 → starts with ~10 dims/word instead of ~100
        self.relevance_logits = nn.Parameter(torch.randn(vocab_size, embedding_dim) * 0.5 - 3.0)

        # Initialize anchors
        with torch.no_grad():
            self.embeddings[:, :num_anchors] = 0.0
            self.relevance_logits[:, :num_anchors] = -10.0  # Anchors always off

    def get_masked_embeddings(self, word_ids):
        """Get embeddings gated by semantic relevance."""
        emb = self.embeddings[word_ids]
        relevance = torch.sigmoid(self.relevance_logits[word_ids])
        return emb * relevance

    def get_learned_dims_mask(self):
        mask = torch.zeros(self.embedding_dim, dtype=torch.bool)
        mask[self.num_anchors:] = True
        return mask

model = SemanticTransparentEmbedding(
    CONFIG['vocab_size'],
    CONFIG['embedding_dim'],
    CONFIG['num_anchors']
).to(device)

print(f"\n✓ Semantic-driven model initialized")
print(f"  Vocabulary: {CONFIG['vocab_size']:,} words")
print(f"  Embedding dim: {CONFIG['embedding_dim']} ({CONFIG['num_anchors']} anchor + {CONFIG['embedding_dim']-CONFIG['num_anchors']} learned)")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"\n  KEY FEATURES:")
print(f"  - Each word learns which semantic axes are relevant")
print(f"  - Target: {CONFIG['target_dims_per_word']:.0f} dimensions per word (semantic sweet spot)")
print(f"  - Initialized sparse (bias toward -3.0) to break symmetry")
print(f"  - Stronger sparsity weight (0.05) to push toward target")

Configuration:
  vocab_size: 197095
  embedding_dim: 256
  num_anchors: 51
  epochs: 150
  batch_size: 262144
  lr: 0.01
  semantic_clustering_weight: 2.0
  relevance_coherence_weight: 1.0
  target_dims_per_word: 10.0
  relevance_sparsity_weight: 0.05
  reg_weight: 0.05
  parse_weight: 0.3
  wordnet_weight: 0.7
  patience: 30
  mixed_precision: True

✓ Semantic-driven model initialized
  Vocabulary: 197,095 words
  Embedding dim: 256 (51 anchor + 205 learned)
  Parameters: 100,912,640

  KEY FEATURES:
  - Each word learns which semantic axes are relevant
  - Target: 10 dimensions per word (semantic sweet spot)
  - Initialized sparse (bias toward -3.0) to break symmetry
  - Stronger sparsity weight (0.05) to push toward target


## Step 3: Semantic-Driven Loss Functions

No forced polarity structure - dimensions emerge through semantic clustering!

In [None]:
# Match antonym pairs to IDs (with fuzzy matching)
print("Matching antonym pairs to vocabulary...")
antonym_indices = []
for word1, word2 in antonym_pairs:
    idx1 = word_to_id.get(word1)
    idx2 = word_to_id.get(word2)

    # Fuzzy match if needed
    if idx1 is None:
        for w in word_to_id:
            if w.startswith(word1.split('_wn')[0] + "_wn."):
                idx1 = word_to_id[w]
                break
    if idx2 is None:
        for w in word_to_id:
            if w.startswith(word2.split('_wn')[0] + "_wn."):
                idx2 = word_to_id[w]
                break

    if idx1 is not None and idx2 is not None:
        antonym_indices.append((idx1, idx2))

antonym_tensor = torch.tensor(antonym_indices, dtype=torch.long, device=device)
print(f"✓ Matched {len(antonym_indices):,} / {len(antonym_pairs):,} antonym pairs\n")

# Loss functions
def distance_loss(model, edge_samples):
    """Semantic distance loss with parse/WordNet weighting."""
    if len(edge_samples) == 0:
        return torch.tensor(0.0, device=device)

    parse_samples = []
    wordnet_samples = []

    for source_id, target_id, target_dist, relation in edge_samples:
        if relation.startswith('RelationType.'):
            parse_samples.append((source_id, target_id, target_dist))
        else:
            wordnet_samples.append((source_id, target_id, target_dist))

    total_loss = torch.tensor(0.0, device=device)

    if parse_samples:
        source_ids = torch.tensor([s for s, t, d in parse_samples], device=device)
        target_ids = torch.tensor([t for s, t, d in parse_samples], device=device)
        target_dists = torch.tensor([d for s, t, d in parse_samples], device=device, dtype=torch.float32)
        source_emb = model.get_masked_embeddings(source_ids)
        target_emb = model.get_masked_embeddings(target_ids)
        actual_dists = torch.norm(source_emb - target_emb, dim=1)
        parse_loss = F.mse_loss(actual_dists, target_dists)
        total_loss += CONFIG['parse_weight'] * parse_loss

    if wordnet_samples:
        source_ids = torch.tensor([s for s, t, d in wordnet_samples], device=device)
        target_ids = torch.tensor([t for s, t, d in wordnet_samples], device=device)
        target_dists = torch.tensor([d for s, t, d in wordnet_samples], device=device, dtype=torch.float32)
        source_emb = model.get_masked_embeddings(source_ids)
        target_emb = model.get_masked_embeddings(target_ids)
        actual_dists = torch.norm(source_emb - target_emb, dim=1)
        wordnet_loss = F.mse_loss(actual_dists, target_dists)
        total_loss += CONFIG['wordnet_weight'] * wordnet_loss

    return total_loss

def semantic_clustering_loss(model, antonym_tensor):
    """
    Encourage semantically-related antonym pairs to use same dimensions.

    Mechanism:
    1. Compute pair centroids (semantic similarity)
    2. Compute dimensional signatures (which dims they use)
    3. Force signature similarity to match semantic similarity

    Result: (good/bad) and (right/wrong) both use "morality" dimension!
    """
    if len(antonym_tensor) == 0:
        return torch.tensor(0.0, device=device)

    # Get masked embeddings (respects relevance)
    emb1 = model.get_masked_embeddings(antonym_tensor[:, 0])
    emb2 = model.get_masked_embeddings(antonym_tensor[:, 1])

    # Semantic signatures (what concepts are these pairs about?)
    pair_centroids = (emb1 + emb2) / 2
    semantic_sim = torch.mm(F.normalize(pair_centroids, dim=1),
                           F.normalize(pair_centroids, dim=1).T)

    # Dimensional signatures (which dimensions do they use?)
    diff_vectors = emb1 - emb2
    dimensional_sim = torch.mm(F.normalize(diff_vectors, dim=1),
                              F.normalize(diff_vectors, dim=1).T)

    # Force dimensional usage to match semantic similarity
    # Similar pairs → use same dimensions
    # Different pairs → use different dimensions
    return F.mse_loss(dimensional_sim, semantic_sim)

def relevance_coherence_loss(model, antonym_tensor):
    """
    Antonym pairs should activate the same dimensions.

    If 'good' activates morality dimension, 'bad' should too
    (even though they have opposite values).
    """
    if len(antonym_tensor) == 0:
        return torch.tensor(0.0, device=device)

    rel1 = torch.sigmoid(model.relevance_logits[antonym_tensor[:, 0]])
    rel2 = torch.sigmoid(model.relevance_logits[antonym_tensor[:, 1]])

    # Cosine similarity of relevance patterns
    relevance_similarity = F.cosine_similarity(rel1, rel2, dim=1)

    # Maximize similarity (both activate same axes)
    return -torch.mean(relevance_similarity)

def relevance_sparsity_loss(model, target_dims):
    """
    Target-based sparsity: aim for specific number of dimensions per word.

    Don't just minimize activation - target the semantic sweet spot!
    Target: 10 dimensions per word on average.

    Uses SOFT counting (sum of sigmoid values) for differentiability.
    """
    relevance = torch.sigmoid(model.relevance_logits[:, model.num_anchors:])

    # Average number of active dimensions per word (soft count)
    # Sum sigmoid values across dimensions for each word, then average
    avg_active_dims = torch.mean(torch.sum(relevance, dim=1))

    # Penalize deviation from target
    # If avg = 10 and target = 10, loss = 0 (perfect!)
    # If avg = 102 and target = 10, loss = high (too dense!)
    # If avg = 0 and target = 10, loss = high (too sparse!)
    target_tensor = torch.tensor(target_dims, device=avg_active_dims.device, dtype=avg_active_dims.dtype)
    return F.mse_loss(avg_active_dims, target_tensor)


def relevance_commitment_loss(model):
    """
    Force dimensions to commit: active (1.0) or inactive (0.0).
    
    Uses entropy penalty - sigmoid at 0.5 has HIGH entropy (uncertain).
    We want LOW entropy (committed to 0 or 1).
    """
    relevance = torch.sigmoid(model.relevance_logits[:, model.num_anchors:])
    
    # Entropy of Bernoulli distribution
    eps = 1e-8
    entropy = -(relevance * torch.log(relevance + eps) + 
                (1 - relevance) * torch.log(1 - relevance + eps))
    
    return torch.mean(entropy)

def regularization_loss(embeddings, num_anchors):
    """L2 regularization on learned dimensions."""
    return torch.mean(embeddings[:, num_anchors:] ** 2)

print("✓ Semantic-driven loss functions initialized")
print("  - Semantic clustering: Pairs with similar meanings use same dimensions")
print("  - Relevance coherence: Antonyms activate same dimensions")
print("  - Relevance sparsity: TARGET-BASED (aim for 10 dims/word, not minimize!)")
print("  - Uses SOFT counting for differentiable gradients")
  print("  - Relevance commitment: Binary decisions (forces 0 or 1)")
print("  - Natural emergence of semantic axes!")

Matching antonym pairs to vocabulary...
✓ Matched 648 / 971 antonym pairs

✓ Semantic-driven loss functions initialized
  - Semantic clustering: Pairs with similar meanings use same dimensions
  - Relevance coherence: Antonyms activate same dimensions
  - Relevance sparsity: TARGET-BASED (aim for 10 dims/word, not minimize!)
  - Uses SOFT counting for differentiable gradients
  - Natural emergence of semantic axes!


## Step 4: Create DataLoaders

In [None]:
class KnowledgeGraphDataset(Dataset):
    def __init__(self, examples):
        self.examples = examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        source_id, relation, target_id = self.examples[idx]
        target_distance = 0.5
        if 'synonym' in relation.lower():
            target_distance = 0.1
        elif 'antonym' in relation.lower():
            target_distance = 0.9
        return source_id, target_id, target_distance, relation

# Split dataset
random.seed(42)
random.shuffle(training_examples)

split_idx = int(0.9 * len(training_examples))
train_examples = training_examples[:split_idx]
val_examples = training_examples[split_idx:]

train_dataset = KnowledgeGraphDataset(train_examples)
val_dataset = KnowledgeGraphDataset(val_examples)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=12, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=12, pin_memory=True)

print(f"✓ Train examples: {len(train_examples):,}")
print(f"✓ Val examples: {len(val_examples):,}")
print(f"✓ Train batches: {len(train_loader):,}")

✓ Train examples: 1,496,034
✓ Val examples: 166,226
✓ Train batches: 6


## Step 5: Training Loop

Dimensions emerge through semantic clustering - no forced structure!

In [None]:
# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
scaler = GradScaler('cuda') if CONFIG['mixed_precision'] else None

# History
history = {'train_loss': [], 'val_loss': [], 'avg_relevance_soft': [], 'avg_relevance_hard': []}
best_val_loss = float('inf')
patience_counter = 0

def compute_avg_relevance(model):
    """
    Compute average dimensions per word.
    Returns both soft count (what the loss sees) and hard count (>0.5 threshold).
    """
    rel = torch.sigmoid(model.relevance_logits[:, model.num_anchors:])

    # Soft count: sum of sigmoid values (matches loss function)
    soft_count = torch.mean(torch.sum(rel, dim=1)).item()

    # Hard count: number of dimensions > 0.5 (for interpretability)
    hard_count = torch.mean(torch.sum(rel > 0.5, dim=1).float()).item()

    return soft_count, hard_count

print("="*70)
print("STARTING SEMANTIC-DRIVEN TRAINING")
print("="*70)
print("Key difference: Dimensions emerge through semantic clustering")
print("No forced polarity structure - natural semantic axes!")
print(f"Target: {CONFIG['target_dims_per_word']:.0f} dimensions per word (soft count)")
print("="*70)
print()

start_time = time.time()

for epoch in range(1, CONFIG['epochs'] + 1):
    # Train
    model.train()
    total_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for source_ids, target_ids, target_dists, relations in pbar:
        edge_samples = list(zip(source_ids.tolist(), target_ids.tolist(), target_dists.tolist(), relations))
        optimizer.zero_grad()

        if CONFIG['mixed_precision']:
            with autocast('cuda'):
                d_loss = distance_loss(model, edge_samples)
                sc_loss = semantic_clustering_loss(model, antonym_tensor)
                rc_loss = relevance_coherence_loss(model, antonym_tensor)
                rs_loss = relevance_sparsity_loss(model, CONFIG['target_dims_per_word'])
                commit_loss = relevance_commitment_loss(model)
                r_loss = regularization_loss(model.embeddings, CONFIG['num_anchors'])

                total = (
                    d_loss +
                    CONFIG['semantic_clustering_weight'] * sc_loss +
                    CONFIG['relevance_coherence_weight'] * rc_loss +
                    CONFIG['relevance_sparsity_weight'] * rs_loss +
                    CONFIG['relevance_commitment_weight'] * commit_loss +
                    CONFIG['reg_weight'] * r_loss
                )
            scaler.scale(total).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            d_loss = distance_loss(model, edge_samples)
            sc_loss = semantic_clustering_loss(model, antonym_tensor)
            rc_loss = relevance_coherence_loss(model, antonym_tensor)
            rs_loss = relevance_sparsity_loss(model, CONFIG['target_dims_per_word'])
            commit_loss = relevance_commitment_loss(model)
            r_loss = regularization_loss(model.embeddings, CONFIG['num_anchors'])

            total = (
                d_loss +
                CONFIG['semantic_clustering_weight'] * sc_loss +
                CONFIG['relevance_coherence_weight'] * rc_loss +
                CONFIG['relevance_sparsity_weight'] * rs_loss +
                CONFIG['relevance_commitment_weight'] * commit_loss +
                CONFIG['reg_weight'] * r_loss
            )
            total.backward()
            optimizer.step()

        # Preserve anchors
        with torch.no_grad():
            if model.embeddings.grad is not None:
                model.embeddings.grad[:, :CONFIG['num_anchors']] = 0.0
            if model.relevance_logits.grad is not None:
                model.relevance_logits.grad[:, :CONFIG['num_anchors']] = 0.0

        total_loss += total.item()
        pbar.set_postfix({'loss': f"{total.item():.4f}", 'sc': f"{sc_loss.item():.4f}"})

    # Validate
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for source_ids, target_ids, target_dists, relations in val_loader:
            edge_samples = list(zip(source_ids.tolist(), target_ids.tolist(), target_dists.tolist(), relations))
            d_loss = distance_loss(model, edge_samples)
            sc_loss = semantic_clustering_loss(model, antonym_tensor)
            rc_loss = relevance_coherence_loss(model, antonym_tensor)
            rs_loss = relevance_sparsity_loss(model, CONFIG['target_dims_per_word'])
            commit_loss = relevance_commitment_loss(model)
            r_loss = regularization_loss(model.embeddings, CONFIG['num_anchors'])

            total = (
                d_loss +
                CONFIG['semantic_clustering_weight'] * sc_loss +
                CONFIG['relevance_coherence_weight'] * rc_loss +
                CONFIG['relevance_sparsity_weight'] * rs_loss +
                CONFIG['relevance_commitment_weight'] * commit_loss +
                CONFIG['reg_weight'] * r_loss
            )
            val_loss += total.item()

    train_loss = total_loss / len(train_loader)
    val_loss = val_loss / len(val_loader)
    avg_soft, avg_hard = compute_avg_relevance(model)

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['avg_relevance_soft'].append(avg_soft)
    history['avg_relevance_hard'].append(avg_hard)

    print(f"\nEpoch {epoch}/{CONFIG['epochs']}")
    print(f"  Train Loss: {train_loss:.6f}")
    print(f"  Val Loss: {val_loss:.6f}")
    print(f"  Avg dims/word (soft): {avg_soft:.1f} (target: {CONFIG['target_dims_per_word']:.0f})")
    print(f"  Avg dims/word (hard): {avg_hard:.1f} (>0.5 threshold)")

    # Save best
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        print(f"  ✓ New best model!")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'config': CONFIG,
        }, f"{CHECKPOINT_DIR}/best_model.pt")
    else:
        patience_counter += 1
        if patience_counter >= CONFIG['patience']:
            print(f"\nEarly stopping at epoch {epoch}")
            break
    print()

elapsed = time.time() - start_time
print("="*70)
print("TRAINING COMPLETE")
print("="*70)
print(f"Time: {elapsed/3600:.1f} hours")
print(f"Best val loss: {best_val_loss:.6f}")
print(f"Final avg dims/word (soft): {history['avg_relevance_soft'][-1]:.1f}")
print(f"Final avg dims/word (hard): {history['avg_relevance_hard'][-1]:.1f}")
print("="*70)

STARTING SEMANTIC-DRIVEN TRAINING
Key difference: Dimensions emerge through semantic clustering
No forced polarity structure - natural semantic axes!
Target: 10 dimensions per word (soft count)



Epoch 1: 100%|██████████| 6/6 [00:05<00:00,  1.07it/s, loss=-0.6554, sc=0.0096]



Epoch 1/150
  Train Loss: -0.608882
  Val Loss: -0.670916
  Avg dims/word (soft): 10.4 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 2: 100%|██████████| 6/6 [00:04<00:00,  1.32it/s, loss=-0.7454, sc=0.0087]



Epoch 2/150
  Train Loss: -0.710872
  Val Loss: -0.752981
  Avg dims/word (soft): 10.1 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 3: 100%|██████████| 6/6 [00:04<00:00,  1.31it/s, loss=-0.8116, sc=0.0082]



Epoch 3/150
  Train Loss: -0.785961
  Val Loss: -0.815974
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 4: 100%|██████████| 6/6 [00:04<00:00,  1.33it/s, loss=-0.8638, sc=0.0078]



Epoch 4/150
  Train Loss: -0.843739
  Val Loss: -0.864827
  Avg dims/word (soft): 9.9 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 5: 100%|██████████| 6/6 [00:04<00:00,  1.33it/s, loss=-0.9024, sc=0.0076]



Epoch 5/150
  Train Loss: -0.887453
  Val Loss: -0.900860
  Avg dims/word (soft): 9.9 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 6: 100%|██████████| 6/6 [00:04<00:00,  1.29it/s, loss=-0.9274, sc=0.0075]



Epoch 6/150
  Train Loss: -0.917956
  Val Loss: -0.924524
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 7: 100%|██████████| 6/6 [00:04<00:00,  1.32it/s, loss=-0.9416, sc=0.0073]



Epoch 7/150
  Train Loss: -0.936592
  Val Loss: -0.938152
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 8: 100%|██████████| 6/6 [00:04<00:00,  1.34it/s, loss=-0.9499, sc=0.0073]



Epoch 8/150
  Train Loss: -0.947045
  Val Loss: -0.946048
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 9: 100%|██████████| 6/6 [00:04<00:00,  1.34it/s, loss=-0.9559, sc=0.0072]



Epoch 9/150
  Train Loss: -0.953781
  Val Loss: -0.951661
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 10: 100%|██████████| 6/6 [00:04<00:00,  1.32it/s, loss=-0.9611, sc=0.0072]



Epoch 10/150
  Train Loss: -0.959207
  Val Loss: -0.956011
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 11: 100%|██████████| 6/6 [00:04<00:00,  1.35it/s, loss=-0.9649, sc=0.0071]



Epoch 11/150
  Train Loss: -0.963546
  Val Loss: -0.959153
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 12: 100%|██████████| 6/6 [00:04<00:00,  1.34it/s, loss=-0.9675, sc=0.0071]



Epoch 12/150
  Train Loss: -0.966737
  Val Loss: -0.961354
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 13: 100%|██████████| 6/6 [00:04<00:00,  1.29it/s, loss=-0.9698, sc=0.0071]



Epoch 13/150
  Train Loss: -0.969095
  Val Loss: -0.963020
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 14: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9715, sc=0.0071]



Epoch 14/150
  Train Loss: -0.970918
  Val Loss: -0.964389
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 15: 100%|██████████| 6/6 [00:04<00:00,  1.30it/s, loss=-0.9722, sc=0.0071]



Epoch 15/150
  Train Loss: -0.972373
  Val Loss: -0.965517
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 16: 100%|██████████| 6/6 [00:04<00:00,  1.32it/s, loss=-0.9742, sc=0.0071]



Epoch 16/150
  Train Loss: -0.973633
  Val Loss: -0.966435
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 17: 100%|██████████| 6/6 [00:04<00:00,  1.30it/s, loss=-0.9752, sc=0.0070]



Epoch 17/150
  Train Loss: -0.974636
  Val Loss: -0.967180
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 18: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9757, sc=0.0070]



Epoch 18/150
  Train Loss: -0.975452
  Val Loss: -0.967798
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 19: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9766, sc=0.0070]



Epoch 19/150
  Train Loss: -0.976168
  Val Loss: -0.968317
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 20: 100%|██████████| 6/6 [00:04<00:00,  1.30it/s, loss=-0.9766, sc=0.0070]



Epoch 20/150
  Train Loss: -0.976747
  Val Loss: -0.968748
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 21: 100%|██████████| 6/6 [00:04<00:00,  1.31it/s, loss=-0.9776, sc=0.0070]



Epoch 21/150
  Train Loss: -0.977285
  Val Loss: -0.969114
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 22: 100%|██████████| 6/6 [00:04<00:00,  1.31it/s, loss=-0.9775, sc=0.0070]



Epoch 22/150
  Train Loss: -0.977705
  Val Loss: -0.969428
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 23: 100%|██████████| 6/6 [00:05<00:00,  1.08it/s, loss=-0.9784, sc=0.0070]



Epoch 23/150
  Train Loss: -0.978120
  Val Loss: -0.969700
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 24: 100%|██████████| 6/6 [00:04<00:00,  1.30it/s, loss=-0.9783, sc=0.0070]



Epoch 24/150
  Train Loss: -0.978442
  Val Loss: -0.969940
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 25: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9789, sc=0.0070]



Epoch 25/150
  Train Loss: -0.978756
  Val Loss: -0.970147
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 26: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9794, sc=0.0070]



Epoch 26/150
  Train Loss: -0.979036
  Val Loss: -0.970331
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 27: 100%|██████████| 6/6 [00:04<00:00,  1.29it/s, loss=-0.9792, sc=0.0070]



Epoch 27/150
  Train Loss: -0.979254
  Val Loss: -0.970495
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 28: 100%|██████████| 6/6 [00:04<00:00,  1.29it/s, loss=-0.9798, sc=0.0070]



Epoch 28/150
  Train Loss: -0.979486
  Val Loss: -0.970641
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 29: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9797, sc=0.0070]



Epoch 29/150
  Train Loss: -0.979669
  Val Loss: -0.970774
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 30: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9800, sc=0.0070]



Epoch 30/150
  Train Loss: -0.979854
  Val Loss: -0.970894
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 31: 100%|██████████| 6/6 [00:04<00:00,  1.26it/s, loss=-0.9799, sc=0.0070]



Epoch 31/150
  Train Loss: -0.980008
  Val Loss: -0.971004
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 32: 100%|██████████| 6/6 [00:05<00:00,  1.16it/s, loss=-0.9798, sc=0.0070]



Epoch 32/150
  Train Loss: -0.980146
  Val Loss: -0.971102
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 33: 100%|██████████| 6/6 [00:04<00:00,  1.29it/s, loss=-0.9806, sc=0.0070]



Epoch 33/150
  Train Loss: -0.980319
  Val Loss: -0.971194
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 34: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9806, sc=0.0070]



Epoch 34/150
  Train Loss: -0.980450
  Val Loss: -0.971280
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 35: 100%|██████████| 6/6 [00:04<00:00,  1.29it/s, loss=-0.9804, sc=0.0070]



Epoch 35/150
  Train Loss: -0.980563
  Val Loss: -0.971362
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 36: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9806, sc=0.0070]



Epoch 36/150
  Train Loss: -0.980689
  Val Loss: -0.971445
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 37: 100%|██████████| 6/6 [00:04<00:00,  1.26it/s, loss=-0.9806, sc=0.0070]



Epoch 37/150
  Train Loss: -0.980797
  Val Loss: -0.971519
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 38: 100%|██████████| 6/6 [00:04<00:00,  1.29it/s, loss=-0.9807, sc=0.0070]



Epoch 38/150
  Train Loss: -0.980912
  Val Loss: -0.971594
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 39: 100%|██████████| 6/6 [00:04<00:00,  1.31it/s, loss=-0.9809, sc=0.0070]



Epoch 39/150
  Train Loss: -0.981021
  Val Loss: -0.971665
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 40: 100%|██████████| 6/6 [00:05<00:00,  1.11it/s, loss=-0.9813, sc=0.0070]



Epoch 40/150
  Train Loss: -0.981140
  Val Loss: -0.971734
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 41: 100%|██████████| 6/6 [00:04<00:00,  1.31it/s, loss=-0.9815, sc=0.0070]



Epoch 41/150
  Train Loss: -0.981247
  Val Loss: -0.971797
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 42: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9813, sc=0.0070]



Epoch 42/150
  Train Loss: -0.981327
  Val Loss: -0.971857
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 43: 100%|██████████| 6/6 [00:04<00:00,  1.31it/s, loss=-0.9812, sc=0.0070]



Epoch 43/150
  Train Loss: -0.981409
  Val Loss: -0.971910
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 44: 100%|██████████| 6/6 [00:04<00:00,  1.24it/s, loss=-0.9819, sc=0.0070]



Epoch 44/150
  Train Loss: -0.981520
  Val Loss: -0.971961
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 45: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9817, sc=0.0070]



Epoch 45/150
  Train Loss: -0.981586
  Val Loss: -0.972012
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 46: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9817, sc=0.0070]



Epoch 46/150
  Train Loss: -0.981662
  Val Loss: -0.972059
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 47: 100%|██████████| 6/6 [00:04<00:00,  1.29it/s, loss=-0.9817, sc=0.0070]



Epoch 47/150
  Train Loss: -0.981734
  Val Loss: -0.972102
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 48: 100%|██████████| 6/6 [00:04<00:00,  1.25it/s, loss=-0.9818, sc=0.0070]



Epoch 48/150
  Train Loss: -0.981805
  Val Loss: -0.972146
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 49: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9821, sc=0.0070]



Epoch 49/150
  Train Loss: -0.981887
  Val Loss: -0.972186
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 50: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9820, sc=0.0070]



Epoch 50/150
  Train Loss: -0.981946
  Val Loss: -0.972226
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 51: 100%|██████████| 6/6 [00:04<00:00,  1.26it/s, loss=-0.9820, sc=0.0070]



Epoch 51/150
  Train Loss: -0.982009
  Val Loss: -0.972265
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 52: 100%|██████████| 6/6 [00:04<00:00,  1.25it/s, loss=-0.9817, sc=0.0070]



Epoch 52/150
  Train Loss: -0.982053
  Val Loss: -0.972300
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 53: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9822, sc=0.0070]



Epoch 53/150
  Train Loss: -0.982138
  Val Loss: -0.972333
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 54: 100%|██████████| 6/6 [00:04<00:00,  1.26it/s, loss=-0.9825, sc=0.0070]



Epoch 54/150
  Train Loss: -0.982205
  Val Loss: -0.972366
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 55: 100%|██████████| 6/6 [00:04<00:00,  1.29it/s, loss=-0.9823, sc=0.0070]



Epoch 55/150
  Train Loss: -0.982250
  Val Loss: -0.972400
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 56: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9822, sc=0.0070]



Epoch 56/150
  Train Loss: -0.982303
  Val Loss: -0.972433
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 57: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9823, sc=0.0070]



Epoch 57/150
  Train Loss: -0.982357
  Val Loss: -0.972461
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 58: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9823, sc=0.0070]



Epoch 58/150
  Train Loss: -0.982411
  Val Loss: -0.972487
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 59: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9825, sc=0.0070]



Epoch 59/150
  Train Loss: -0.982467
  Val Loss: -0.972513
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 60: 100%|██████████| 6/6 [00:04<00:00,  1.24it/s, loss=-0.9826, sc=0.0070]



Epoch 60/150
  Train Loss: -0.982522
  Val Loss: -0.972545
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 61: 100%|██████████| 6/6 [00:04<00:00,  1.29it/s, loss=-0.9825, sc=0.0070]



Epoch 61/150
  Train Loss: -0.982566
  Val Loss: -0.972574
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 62: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9825, sc=0.0070]



Epoch 62/150
  Train Loss: -0.982612
  Val Loss: -0.972603
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 63: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9826, sc=0.0070]



Epoch 63/150
  Train Loss: -0.982662
  Val Loss: -0.972625
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 64: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9828, sc=0.0070]



Epoch 64/150
  Train Loss: -0.982714
  Val Loss: -0.972652
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 65: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9828, sc=0.0070]



Epoch 65/150
  Train Loss: -0.982757
  Val Loss: -0.972675
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 66: 100%|██████████| 6/6 [00:04<00:00,  1.24it/s, loss=-0.9829, sc=0.0070]



Epoch 66/150
  Train Loss: -0.982803
  Val Loss: -0.972699
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 67: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9829, sc=0.0070]



Epoch 67/150
  Train Loss: -0.982841
  Val Loss: -0.972724
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 68: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9830, sc=0.0070]



Epoch 68/150
  Train Loss: -0.982888
  Val Loss: -0.972744
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 69: 100%|██████████| 6/6 [00:04<00:00,  1.28it/s, loss=-0.9832, sc=0.0070]



Epoch 69/150
  Train Loss: -0.982933
  Val Loss: -0.972768
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 70: 100%|██████████| 6/6 [00:04<00:00,  1.29it/s, loss=-0.9830, sc=0.0070]



Epoch 70/150
  Train Loss: -0.982966
  Val Loss: -0.972789
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 71: 100%|██████████| 6/6 [00:04<00:00,  1.25it/s, loss=-0.9831, sc=0.0070]



Epoch 71/150
  Train Loss: -0.983006
  Val Loss: -0.972811
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 72: 100%|██████████| 6/6 [00:04<00:00,  1.30it/s, loss=-0.9830, sc=0.0070]



Epoch 72/150
  Train Loss: -0.983035
  Val Loss: -0.972830
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 73: 100%|██████████| 6/6 [00:04<00:00,  1.29it/s, loss=-0.9831, sc=0.0070]



Epoch 73/150
  Train Loss: -0.983075
  Val Loss: -0.972849
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 74: 100%|██████████| 6/6 [00:04<00:00,  1.26it/s, loss=-0.9832, sc=0.0070]



Epoch 74/150
  Train Loss: -0.983114
  Val Loss: -0.972866
  Avg dims/word (soft): 10.0 (target: 10)
  Avg dims/word (hard): 0.0 (>0.5 threshold)
  ✓ New best model!



Epoch 75: 100%|██████████| 6/6 [00:04<00:00,  1.27it/s, loss=-0.9832, sc=0.0070]
Exception ignored in: <function _releaseLock at 0x7f0658e95760>
Traceback (most recent call last):
  File "/usr/lib/python3.12/logging/__init__.py", line 243, in _releaseLock
    def _releaseLock():
    
KeyboardInterrupt: 


## Step 6: Save Results

In [None]:
# Save embeddings and relevance
final_embeddings = model.embeddings.detach().cpu().numpy()
final_relevance = torch.sigmoid(model.relevance_logits).detach().cpu().numpy()
masked_embeddings = final_embeddings * final_relevance

np.save(f"{CHECKPOINT_DIR}/embeddings.npy", final_embeddings)
np.save(f"{CHECKPOINT_DIR}/relevance.npy", final_relevance)
np.save(f"{CHECKPOINT_DIR}/masked_embeddings.npy", masked_embeddings)

# Save vocabulary
with open(f"{CHECKPOINT_DIR}/vocabulary.json", 'w') as f:
    json.dump({'word_to_id': word_to_id, 'id_to_word': {str(k): v for k, v in id_to_word.items()}}, f)

# Save history
with open(f"{CHECKPOINT_DIR}/history.json", 'w') as f:
    json.dump(history, f, indent=2)

# Copy to Drive
!cp -r {CHECKPOINT_DIR}/* {RESULTS_DIR}/

print(f"✓ Results saved to {RESULTS_DIR}")
print(f"  - embeddings.npy ({final_embeddings.shape})")
print(f"  - relevance.npy ({final_relevance.shape})")
print(f"  - masked_embeddings.npy ({masked_embeddings.shape})")
print(f"  - vocabulary.json ({len(vocabulary):,} words)")
print(f"  - history.json")
print(f"  - best_model.pt")

## Step 7: Discover Semantic Axes

Analyze which dimensions became semantic axes through natural clustering!

In [None]:
print("="*70)
print("DISCOVERING SEMANTIC AXES")
print("="*70)

# Get final embeddings and relevance
embeddings_np = model.embeddings.detach().cpu().numpy()
relevance_np = torch.sigmoid(model.relevance_logits).detach().cpu().numpy()
masked_embeddings_np = embeddings_np * relevance_np

# Analyze antonym behavior on each dimension
idx1 = antonym_tensor[:, 0].cpu().numpy()
idx2 = antonym_tensor[:, 1].cpu().numpy()

emb1 = masked_embeddings_np[idx1]
emb2 = masked_embeddings_np[idx2]
diffs = emb1 - emb2

# Find dimensions with semantic structure
semantic_axes = []
for dim in range(CONFIG['num_anchors'], embeddings_np.shape[1]):
    dim_diffs = diffs[:, dim]

    # Key metrics
    mean_abs_diff = np.mean(np.abs(dim_diffs))
    sign_consistency = np.abs(np.mean(np.sign(dim_diffs)))

    # How many words activate this dimension?
    num_active = np.sum(relevance_np[:, dim] > 0.5)

    # Semantic axis score
    if mean_abs_diff > 0.01 and num_active > 10:
        semantic_axes.append({
            'dim': dim,
            'consistency': sign_consistency,
            'mean_diff': mean_abs_diff,
            'num_active_words': num_active,
            'score': mean_abs_diff * sign_consistency
        })

semantic_axes.sort(key=lambda x: x['score'], reverse=True)

print(f"\nFound {len(semantic_axes)} candidate semantic axes:\n")
print(f"{'Dim':<6} {'Consistency':<12} {'Mean Diff':<12} {'Active Words':<15} {'Score':<10}")
print("-" * 70)

for axis in semantic_axes[:20]:
    print(f"{axis['dim']:<6} {axis['consistency']:<12.4f} {axis['mean_diff']:<12.4f} "
          f"{axis['num_active_words']:<15} {axis['score']:<10.4f}")

print(f"\n✓ Natural emergence of {len([a for a in semantic_axes if a['consistency'] > 0.2])} strong axes (>20% consistency)!")

## Step 8: Detailed Semantic Axis Analysis

Deep dive into the top semantic axes to see what concepts emerged!

In [None]:
# Analyze top 5 semantic axes in detail
print("\n" + "="*70)
print("TOP 5 SEMANTIC AXES - DETAILED ANALYSIS")
print("="*70)

for rank, axis in enumerate(semantic_axes[:5], 1):
    dim_idx = axis['dim']

    print(f"\n{'='*70}")
    print(f"Semantic Axis {dim_idx} (Rank #{rank})")
    print(f"  Consistency: {axis['consistency']:.1%}")
    print(f"  Active words: {axis['num_active_words']}")
    print('='*70)

    # Get dimension values (with relevance gating)
    dim_values = masked_embeddings_np[:, dim_idx]
    dim_relevance = relevance_np[:, dim_idx]

    # Show words that activate this dimension
    active_words = np.where(dim_relevance > 0.5)[0]
    active_values = dim_values[active_words]

    sorted_active = active_words[np.argsort(active_values)]

    print("\n  POSITIVE POLE (words with high relevance):")
    for idx in sorted_active[-10:][::-1]:
        word = id_to_word[idx]
        word_display = word.split('_wn.')[0] if '_wn.' in word else word
        rel = dim_relevance[idx]
        val = dim_values[idx]
        print(f"    {word_display:25s} value={val:+.3f}  relevance={rel:.3f}")

    print("\n  NEGATIVE POLE:")
    for idx in sorted_active[:10]:
        word = id_to_word[idx]
        word_display = word.split('_wn.')[0] if '_wn.' in word else word
        rel = dim_relevance[idx]
        val = dim_values[idx]
        print(f"    {word_display:25s} value={val:+.3f}  relevance={rel:.3f}")

    # Show antonym pairs on this axis
    print("\n  ANTONYM PAIRS (sample where both have high relevance):")
    shown = 0
    for i in range(len(idx1)):
        w1_rel = relevance_np[idx1[i], dim_idx]
        w2_rel = relevance_np[idx2[i], dim_idx]

        if w1_rel > 0.5 and w2_rel > 0.5:  # Both relevant
            word1 = id_to_word[idx1[i]]
            word2 = id_to_word[idx2[i]]
            val1 = masked_embeddings_np[idx1[i], dim_idx]
            val2 = masked_embeddings_np[idx2[i], dim_idx]

            w1_display = word1.split('_wn.')[0] if '_wn.' in word1 else word1
            w2_display = word2.split('_wn.')[0] if '_wn.' in word2 else word2

            print(f"    {w1_display:15s} ({val1:+.3f}) ↔ {w2_display:15s} ({val2:+.3f})")
            shown += 1
            if shown >= 10:
                break

print("\n" + "="*70)