# Neural Sparse Model Pre-training

This notebook trains a Neural Sparse encoder with:
- Query-document pairs (27K pairs)
- Korean-English synonym alignment (74 pairs)
- Cross-lingual retrieval optimization

## Training Objectives
1. Ranking: Query-document matching
2. Cross-lingual: Korean-English synonym alignment
3. Sparsity: FLOPS regularization

In [None]:
import sys
sys.path.append('../..')

import torch
from torch.utils.data import DataLoader
import yaml
from pathlib import Path

from src.models.neural_sparse_encoder import NeuralSparseEncoder
from src.training.losses import CombinedLoss
from src.training.data_collator import NeuralSparseDataCollator
from src.data.training_data_builder import TrainingDataBuilder
from src.training.trainer import NeuralSparseTrainer

## 1. Load Configuration

In [None]:
# Load config
with open('../../config/training_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded:")
print(f"  Model: {config['model']['name']}")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Epochs: {config['training']['num_epochs']}")

## 2. Build Training Dataset

In [None]:
# Initialize data builder
builder = TrainingDataBuilder()

# Build datasets
train_dataset, val_dataset = builder.build_training_dataset(
    qd_pairs_path=config['data']['qd_pairs_path'],
    documents_path=config['data']['documents_path'],
    synonyms_path=config['data']['synonyms_path'],
    num_negatives=config['data']['num_negatives'],
    train_split=config['data']['train_split'],
)

print(f"\nDataset summary:")
print(f"  Train: {len(train_dataset)} pairs")
print(f"  Val: {len(val_dataset)} pairs")

## 3. Initialize Model

In [None]:
# Initialize model
model = NeuralSparseEncoder(
    model_name=config['model']['name'],
    max_length=config['model']['max_length'],
    use_relu=config['model']['use_relu'],
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")

## 4. Create Data Loaders

In [None]:
# Initialize collator
collator = NeuralSparseDataCollator(
    tokenizer=model.tokenizer,
    query_max_length=config['data']['query_max_length'],
    doc_max_length=config['data']['doc_max_length'],
    num_negatives=config['data']['num_negatives'],
)

# Create dataloaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=True,
    collate_fn=collator,
    num_workers=config['hardware']['num_workers'],
    pin_memory=config['hardware']['pin_memory'],
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=config['evaluation']['eval_batch_size'],
    shuffle=False,
    collate_fn=collator,
    num_workers=config['hardware']['num_workers'],
    pin_memory=config['hardware']['pin_memory'],
)

print(f"Data loaders created:")
print(f"  Train batches: {len(train_dataloader)}")
print(f"  Val batches: {len(val_dataloader)}")

## 5. Initialize Loss and Optimizer

In [None]:
# Loss function
loss_fn = CombinedLoss(
    alpha_ranking=config['loss']['alpha_ranking'],
    beta_cross_lingual=config['loss']['beta_cross_lingual'],
    gamma_sparsity=config['loss']['gamma_sparsity'],
    ranking_margin=config['loss']['ranking_margin'],
    use_contrastive=config['loss']['use_contrastive'],
)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['training']['learning_rate'],
    weight_decay=config['training']['weight_decay'],
)

# Learning rate scheduler
from torch.optim.lr_scheduler import LinearLR, SequentialLR

total_steps = len(train_dataloader) * config['training']['num_epochs']
warmup_steps = config['training']['warmup_steps']

warmup_scheduler = LinearLR(
    optimizer,
    start_factor=0.1,
    end_factor=1.0,
    total_iters=warmup_steps,
)

print(f"\nTraining setup:")
print(f"  Total steps: {total_steps}")
print(f"  Warmup steps: {warmup_steps}")
print(f"  Loss weights: α={config['loss']['alpha_ranking']}, β={config['loss']['beta_cross_lingual']}, γ={config['loss']['gamma_sparsity']}")

## 6. Initialize Trainer

In [None]:
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Initialize trainer
trainer = NeuralSparseTrainer(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=warmup_scheduler,
    device=device,
    use_amp=config['training']['use_amp'],
    gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
    max_grad_norm=config['training']['max_grad_norm'],
    output_dir=config['logging']['output_dir'],
    save_steps=config['logging']['save_steps'],
    eval_steps=config['logging']['eval_steps'],
    logging_steps=config['logging']['logging_steps'],
)

## 7. Test Forward Pass

In [None]:
# Get a sample batch
sample_batch = next(iter(train_dataloader))

# Move to device
sample_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                for k, v in sample_batch.items()}

# Test forward pass
model.to(device)
model.eval()

with torch.no_grad():
    query_outputs = model(
        input_ids=sample_batch['query_input_ids'],
        attention_mask=sample_batch['query_attention_mask'],
    )
    query_rep = query_outputs['sparse_rep']
    
    # Get sparsity stats
    stats = model.get_sparsity_stats(query_rep)
    
    print("\nSample sparse representation:")
    print(f"  Shape: {query_rep.shape}")
    print(f"  Avg non-zero terms: {stats['avg_nonzero_terms']:.0f}")
    print(f"  Sparsity ratio: {stats['sparsity_ratio']:.3f}")
    print(f"  Avg L1 norm: {stats['avg_l1_norm']:.2f}")
    
    # Show top activated terms
    print("\n  Top 10 activated terms:")
    top_terms = model.get_top_k_terms(query_rep[0], k=10)
    for i, (term, weight) in enumerate(top_terms, 1):
        print(f"    {i:2d}. {term:20s}: {weight:.4f}")

## 8. Start Training (Small Scale Test)

First, let's do a quick test with 1 epoch to verify everything works.

In [None]:
# Quick test: 1 epoch
print("Starting quick test training (1 epoch)...\n")
trainer.train(num_epochs=1)

## 9. Analyze Training Results

In [None]:
# Evaluate on validation set
model.eval()
val_losses = trainer.evaluate()

print("\nValidation Results:")
print(f"  Total loss: {val_losses['total_loss']:.4f}")
print(f"  Ranking loss: {val_losses['ranking_loss']:.4f}")
print(f"  Cross-lingual loss: {val_losses['cross_lingual_loss']:.4f}")
print(f"  Sparsity loss: {val_losses['sparsity_loss']:.4f}")

## 10. Test Sparse Representations

In [None]:
# Test queries
test_queries = [
    "인공지능 모델 학습",
    "machine learning training",
    "검색 시스템 개발",
    "search system development",
]

print("\nTesting sparse representations:\n")
print("=" * 80)

model.eval()
with torch.no_grad():
    for query in test_queries:
        # Encode
        sparse_rep = model.encode([query], device=device)
        
        # Stats
        stats = model.get_sparsity_stats(sparse_rep)
        
        # Top terms
        top_terms = model.get_top_k_terms(sparse_rep[0], k=10)
        
        print(f"\nQuery: {query}")
        print(f"  Non-zero terms: {stats['avg_nonzero_terms']:.0f}")
        print(f"  Sparsity: {stats['sparsity_ratio']:.3f}")
        print(f"  Top 5 terms:")
        for term, weight in top_terms[:5]:
            print(f"    {term:20s}: {weight:.4f}")
        print("-" * 80)

## 11. Test Cross-lingual Alignment

In [None]:
import torch.nn.functional as F
import json

# Load synonyms
with open('../../dataset/synonyms/combined_synonyms.json', 'r', encoding='utf-8') as f:
    synonyms = json.load(f)

# Test top 10 synonyms
test_synonyms = synonyms[:10]

print("\nCross-lingual Alignment Test:\n")
print("=" * 80)

model.eval()
with torch.no_grad():
    for syn in test_synonyms:
        korean = syn['korean']
        english = syn['english']
        
        # Encode both
        korean_rep = model.encode([korean], device=device)
        english_rep = model.encode([english], device=device)
        
        # Compute similarity
        cosine_sim = F.cosine_similarity(korean_rep, english_rep, dim=-1)
        dot_sim = torch.sum(korean_rep * english_rep, dim=-1)
        
        print(f"{korean:15s} ↔ {english:20s}")
        print(f"  Cosine similarity: {cosine_sim.item():.4f}")
        print(f"  Dot product: {dot_sim.item():.2f}")

# Average similarity
all_cosine_sims = []
with torch.no_grad():
    for syn in test_synonyms:
        korean_rep = model.encode([syn['korean']], device=device)
        english_rep = model.encode([syn['english']], device=device)
        cosine_sim = F.cosine_similarity(korean_rep, english_rep, dim=-1)
        all_cosine_sims.append(cosine_sim.item())

print("\n" + "=" * 80)
print(f"Average cosine similarity: {sum(all_cosine_sims) / len(all_cosine_sims):.4f}")

## 12. Full Training (Optional)

Uncomment below to run full training with all epochs.

In [None]:
# # Full training
# print("Starting full training...\n")
# trainer.train(num_epochs=config['training']['num_epochs'])
# 
# print("\nTraining completed!")
# print(f"Best validation loss: {trainer.best_val_loss:.4f}")
# print(f"Best model saved to: {config['logging']['output_dir']}/best_model")

## Summary

We've successfully:
1. ✅ Loaded 27K query-document pairs
2. ✅ Loaded 74 Korean-English synonym pairs
3. ✅ Initialized Neural Sparse Encoder (klue/bert-base)
4. ✅ Set up training pipeline with mixed precision
5. ✅ Ran initial training experiment
6. ✅ Tested sparse representations
7. ✅ Verified cross-lingual alignment

**Next steps:**
- Run full training (10 epochs)
- Evaluate on test set
- Export model for OpenSearch
- Build evaluation framework