# Neural Sparse Model Pre-training

This notebook trains a Neural Sparse encoder with:
- Query-document pairs (27K pairs)
- Korean-English synonym alignment (74 pairs)
- Cross-lingual retrieval optimization

## Training Objectives
1. Ranking: Query-document matching
2. Cross-lingual: Korean-English synonym alignment
3. Sparsity: FLOPS regularization

In [18]:
import sys
sys.path.append('../..')

import torch
from torch.utils.data import DataLoader
import yaml
from pathlib import Path

from src.models.neural_sparse_encoder import NeuralSparseEncoder
from src.training.losses import CombinedLoss
from src.training.data_collator import NeuralSparseDataCollator
from src.data.training_data_builder import TrainingDataBuilder
from src.training.trainer import NeuralSparseTrainer

## 1. Load Configuration

In [19]:
# Load config
with open('../../config/training_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded:")
print(f"  Model: {config['model']['name']}")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Epochs: {config['training']['num_epochs']}")

Configuration loaded:
  Model: klue/bert-base
  Batch size: 16
  Learning rate: 2e-05
  Epochs: 10


## 2. Build Training Dataset

First, we need to create QD pairs from Wikipedia articles if they don't exist.

In [20]:
import pickle
import json
from pathlib import Path

# Check if QD pairs exist
qd_pairs_path = Path(config['data']['qd_pairs_path'])
documents_path = Path(config['data']['documents_path'])
synonyms_path = Path(config['data']['synonyms_path'])

# Fix paths to be relative to notebook
if not qd_pairs_path.is_absolute():
    qd_pairs_path = Path('../../') / qd_pairs_path
if not documents_path.is_absolute():
    documents_path = Path('../../') / documents_path
if not synonyms_path.is_absolute():
    synonyms_path = Path('../../') / synonyms_path

print(f"Paths:")
print(f"  QD pairs: {qd_pairs_path}")
print(f"  Documents: {documents_path}")
print(f"  Synonyms: {synonyms_path}")
print(f"  Synonyms exists: {synonyms_path.exists()}")

if not qd_pairs_path.exists():
    print("\nQD pairs not found. Creating from Wikipedia articles...")
    
    # Load Wikipedia articles
    ko_articles_path = "../../dataset/wikipedia/ko_articles.jsonl"
    en_articles_path = "../../dataset/wikipedia/en_articles.jsonl"
    
    # Create dataset directory
    qd_pairs_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Load articles
    ko_articles = []
    if Path(ko_articles_path).exists():
        with open(ko_articles_path, 'r', encoding='utf-8') as f:
            for line in f:
                ko_articles.append(json.loads(line))
    
    en_articles = []
    if Path(en_articles_path).exists():
        with open(en_articles_path, 'r', encoding='utf-8') as f:
            for line in f:
                en_articles.append(json.loads(line))
    
    print(f"  Loaded {len(ko_articles)} Korean articles")
    print(f"  Loaded {len(en_articles)} English articles")
    
    # Create QD pairs from articles
    # Each article title becomes a query, and the text becomes the document
    qd_pairs = []
    documents = []
    
    # Korean articles
    for article in ko_articles:
        doc_id = f"ko_{article['id']}"
        query = article['title']
        text = article['text']
        
        # Skip very short articles
        if len(text) < 200:
            continue
        
        qd_pairs.append({
            'query': query,
            'doc_id': doc_id,
            'label': 1.0
        })
        
        documents.append({
            'id': doc_id,
            'text': text,
            'title': query,
            'language': 'ko'
        })
    
    # English articles
    for article in en_articles:
        doc_id = f"en_{article['id']}"
        query = article['title']
        text = article['text']
        
        # Skip very short articles
        if len(text) < 200:
            continue
        
        qd_pairs.append({
            'query': query,
            'doc_id': doc_id,
            'label': 1.0
        })
        
        documents.append({
            'id': doc_id,
            'text': text,
            'title': query,
            'language': 'en'
        })
    
    print(f"\n  Created {len(qd_pairs)} QD pairs")
    print(f"  Created {len(documents)} documents")
    
    # Save QD pairs
    with open(qd_pairs_path, 'wb') as f:
        pickle.dump(qd_pairs, f)
    print(f"  Saved QD pairs to {qd_pairs_path}")
    
    # Save documents
    with open(documents_path, 'w', encoding='utf-8') as f:
        json.dump(documents, f, ensure_ascii=False, indent=2)
    print(f"  Saved documents to {documents_path}")
else:
    print("\nQD pairs already exist.")

# Now build training dataset
builder = TrainingDataBuilder()

train_dataset, val_dataset = builder.build_training_dataset(
    qd_pairs_path=str(qd_pairs_path),
    documents_path=str(documents_path),
    synonyms_path=str(synonyms_path) if synonyms_path.exists() else None,
    num_negatives=config['data']['num_negatives'],
    train_split=config['data']['train_split'],
)

print(f"\nDataset summary:")
print(f"  Train: {len(train_dataset)} pairs")
print(f"  Val: {len(val_dataset)} pairs")

Paths:
  QD pairs: ../../dataset/base_model/qd_pairs_base.pkl
  Documents: ../../dataset/base_model/documents.json
  Synonyms: ../../dataset/synonyms/combined_synonyms.json
  Synonyms exists: True

QD pairs already exist.
Loaded 27939 QD pairs from ../../dataset/base_model/qd_pairs_base.pkl
Converted 27939 QD pairs to standard format
Loaded 9996 documents from ../../dataset/base_model/documents.json
Loaded 31116 synonym pairs from ../../dataset/synonyms/combined_synonyms.json

Dataset split:
  Train: 25145 pairs
  Val: 2794 pairs
Initialized NeuralSparseDataset:
  QD pairs: 25145
  Documents: 9996
  Synonyms: 31116
  Num negatives: 10
Initialized NeuralSparseDataset:
  QD pairs: 2794
  Documents: 9996
  Synonyms: 31116
  Num negatives: 10

Dataset summary:
  Train: 25145 pairs
  Val: 2794 pairs


## 3. Initialize Model

In [12]:
# Initialize model
model = NeuralSparseEncoder(
    model_name=config['model']['name'],
    max_length=config['model']['max_length'],
    use_relu=config['model']['use_relu'],
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel parameters:")
print(f"  Total: {total_params:,}")
print(f"  Trainable: {trainable_params:,}")

Initialized NeuralSparseEncoder:
  Base model: klue/bert-base
  Hidden size: 768
  Vocab size: 32000
  Max length: 256
  Activation: ReLU

Model parameters:
  Total: 135,225,344
  Trainable: 135,225,344


## 4. Create Data Loaders

In [13]:
# Initialize collator
collator = NeuralSparseDataCollator(
    tokenizer=model.tokenizer,
    query_max_length=config['data']['query_max_length'],
    doc_max_length=config['data']['doc_max_length'],
    num_negatives=config['data']['num_negatives'],
)

# Create dataloaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=True,
    collate_fn=collator,
    num_workers=config['hardware']['num_workers'],
    pin_memory=config['hardware']['pin_memory'],
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=config['evaluation']['eval_batch_size'],
    shuffle=False,
    collate_fn=collator,
    num_workers=config['hardware']['num_workers'],
    pin_memory=config['hardware']['pin_memory'],
)

print(f"Data loaders created:")
print(f"  Train batches: {len(train_dataloader)}")
print(f"  Val batches: {len(val_dataloader)}")

Data loaders created:
  Train batches: 1572
  Val batches: 88


## 5. Initialize Loss and Optimizer

In [14]:
# Loss function
loss_fn = CombinedLoss(
    alpha_ranking=config['loss']['alpha_ranking'],
    beta_cross_lingual=config['loss']['beta_cross_lingual'],
    gamma_sparsity=config['loss']['gamma_sparsity'],
    ranking_margin=config['loss']['ranking_margin'],
    use_contrastive=config['loss']['use_contrastive'],
)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config['training']['learning_rate'],
    weight_decay=config['training']['weight_decay'],
)

# Learning rate scheduler
from torch.optim.lr_scheduler import LinearLR, SequentialLR

total_steps = len(train_dataloader) * config['training']['num_epochs']
warmup_steps = config['training']['warmup_steps']

warmup_scheduler = LinearLR(
    optimizer,
    start_factor=0.1,
    end_factor=1.0,
    total_iters=warmup_steps,
)

print(f"\nTraining setup:")
print(f"  Total steps: {total_steps}")
print(f"  Warmup steps: {warmup_steps}")
print(f"  Loss weights: α={config['loss']['alpha_ranking']}, β={config['loss']['beta_cross_lingual']}, γ={config['loss']['gamma_sparsity']}")


Training setup:
  Total steps: 15720
  Warmup steps: 500
  Loss weights: α=1.0, β=0.3, γ=0.001


## 6. Initialize Trainer

In [15]:
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Initialize trainer
trainer = NeuralSparseTrainer(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    scheduler=warmup_scheduler,
    device=device,
    use_amp=config['training']['use_amp'],
    gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
    max_grad_norm=config['training']['max_grad_norm'],
    output_dir=config['logging']['output_dir'],
    save_steps=config['logging']['save_steps'],
    eval_steps=config['logging']['eval_steps'],
    logging_steps=config['logging']['logging_steps'],
)

Using device: cuda
  GPU: NVIDIA GB10
  Memory: 128.5 GB


    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    
  queued_call()


Initialized NeuralSparseTrainer:
  Device: cuda
  Mixed precision: True
  Gradient accumulation: 2
  Output directory: outputs


  self.scaler = GradScaler() if self.use_amp else None


## 7. Test Forward Pass

In [16]:
# Get a sample batch
sample_batch = next(iter(train_dataloader))

# Move to device
sample_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                for k, v in sample_batch.items()}

# Test forward pass
model.to(device)
model.eval()

with torch.no_grad():
    query_outputs = model(
        input_ids=sample_batch['query_input_ids'],
        attention_mask=sample_batch['query_attention_mask'],
    )
    query_rep = query_outputs['sparse_rep']
    
    # Get sparsity stats
    stats = model.get_sparsity_stats(query_rep)
    
    print("\nSample sparse representation:")
    print(f"  Shape: {query_rep.shape}")
    print(f"  Avg non-zero terms: {stats['avg_nonzero_terms']:.0f}")
    print(f"  Sparsity ratio: {stats['sparsity_ratio']:.3f}")
    print(f"  Avg L1 norm: {stats['avg_l1_norm']:.2f}")
    
    # Show top activated terms
    print("\n  Top 10 activated terms:")
    top_terms = model.get_top_k_terms(query_rep[0], k=10)
    for i, (term, weight) in enumerate(top_terms, 1):
        print(f"    {i:2d}. {term:20s}: {weight:.4f}")


Sample sparse representation:
  Shape: torch.Size([16, 32000])
  Avg non-zero terms: 31823
  Sparsity ratio: 0.006
  Avg L1 norm: 29755.03

  Top 10 activated terms:
     1. é                   : 2.4682
     2. 적중                  : 2.4348
     3. ##점                 : 2.4095
     4. 그쳐                  : 2.3877
     5. 구약                  : 2.3775
     6. 가맹                  : 2.3772
     7. 월마트                 : 2.3714
     8. 말레이시아               : 2.3697
     9. 장엄                  : 2.3344
    10. ##무시                : 2.3083


## 8. Start Training (Small Scale Test)

First, let's do a quick test with 1 epoch to verify everything works.

In [17]:
# Quick test: 1 epoch
print("Starting quick test training (1 epoch)...\n")
trainer.train(num_epochs=1)

Starting quick test training (1 epoch)...


Starting training for 1 epochs...
Total training steps: 1572


  with autocast(enabled=self.use_amp):
  self.scheduler.step()
Validation: 100%|██████████| 88/88 [01:00<00:00,  1.45it/s] loss=7.3479] 



Validation at step 500:
  Val loss: 0.0075
Model saved to outputs/best_model


Epoch 0:  64%|██████▎   | 1000/1572 [07:36<3:16:14, 20.58s/it, loss=7.3479]

Checkpoint saved to outputs/best_model
  New best model saved!


Epoch 0: 100%|██████████| 1572/1572 [11:17<00:00,  2.32it/s, loss=4.8997]  



Epoch 0 completed:
  Train loss: 4.6754
  Ranking loss: 0.8827
  Cross-lingual loss: 0.0356
  Sparsity loss: 3781.9632
Model saved to outputs/epoch-0
Checkpoint saved to outputs/epoch-0

Training completed!
Best validation loss: 0.0075


## 9. Analyze Training Results

In [21]:
# Evaluate on validation set
model.eval()
val_losses = trainer.evaluate()

print("\nValidation Results:")
print(f"  Total loss: {val_losses['total_loss']:.4f}")
print(f"  Ranking loss: {val_losses['ranking_loss']:.4f}")
print(f"  Cross-lingual loss: {val_losses['cross_lingual_loss']:.4f}")
print(f"  Sparsity loss: {val_losses['sparsity_loss']:.4f}")

Validation: 100%|██████████| 88/88 [01:00<00:00,  1.46it/s]


Validation Results:
  Total loss: 0.0008
  Ranking loss: 0.0000
  Cross-lingual loss: 0.0000
  Sparsity loss: 0.7969





## 10. Test Sparse Representations

In [22]:
# Test queries
test_queries = [
    "인공지능 모델 학습",
    "machine learning training",
    "검색 시스템 개발",
    "search system development",
]

print("\nTesting sparse representations:\n")
print("=" * 80)

model.eval()
with torch.no_grad():
    for query in test_queries:
        # Encode
        sparse_rep = model.encode([query], device=device)
        
        # Stats
        stats = model.get_sparsity_stats(sparse_rep)
        
        # Top terms
        top_terms = model.get_top_k_terms(sparse_rep[0], k=10)
        
        print(f"\nQuery: {query}")
        print(f"  Non-zero terms: {stats['avg_nonzero_terms']:.0f}")
        print(f"  Sparsity: {stats['sparsity_ratio']:.3f}")
        print(f"  Top 5 terms:")
        for term, weight in top_terms[:5]:
            print(f"    {term:20s}: {weight:.4f}")
        print("-" * 80)


Testing sparse representations:


Query: 인공지능 모델 학습
  Non-zero terms: 1
  Sparsity: 1.000
  Top 5 terms:
    시대                  : 0.8566
    [MASK]              : 0.0000
    $                   : 0.0000
    %                   : 0.0000
    #                   : 0.0000
--------------------------------------------------------------------------------

Query: machine learning training
  Non-zero terms: 1
  Sparsity: 1.000
  Top 5 terms:
    시대                  : 0.6683
    [MASK]              : 0.0000
    $                   : 0.0000
    %                   : 0.0000
    #                   : 0.0000
--------------------------------------------------------------------------------

Query: 검색 시스템 개발
  Non-zero terms: 1
  Sparsity: 1.000
  Top 5 terms:
    시대                  : 0.8425
    [MASK]              : 0.0000
    $                   : 0.0000
    %                   : 0.0000
    #                   : 0.0000
-------------------------------------------------------------------------------

## 11. Test Cross-lingual Alignment

In [23]:
import torch.nn.functional as F
import json

# Load synonyms
with open('../../dataset/synonyms/combined_synonyms.json', 'r', encoding='utf-8') as f:
    synonyms = json.load(f)

# Test top 10 synonyms
test_synonyms = synonyms[:10]

print("\nCross-lingual Alignment Test:\n")
print("=" * 80)

model.eval()
with torch.no_grad():
    for syn in test_synonyms:
        korean = syn['korean']
        english = syn['english']
        
        # Encode both
        korean_rep = model.encode([korean], device=device)
        english_rep = model.encode([english], device=device)
        
        # Compute similarity
        cosine_sim = F.cosine_similarity(korean_rep, english_rep, dim=-1)
        dot_sim = torch.sum(korean_rep * english_rep, dim=-1)
        
        print(f"{korean:15s} ↔ {english:20s}")
        print(f"  Cosine similarity: {cosine_sim.item():.4f}")
        print(f"  Dot product: {dot_sim.item():.2f}")

# Average similarity
all_cosine_sims = []
with torch.no_grad():
    for syn in test_synonyms:
        korean_rep = model.encode([syn['korean']], device=device)
        english_rep = model.encode([syn['english']], device=device)
        cosine_sim = F.cosine_similarity(korean_rep, english_rep, dim=-1)
        all_cosine_sims.append(cosine_sim.item())

print("\n" + "=" * 80)
print(f"Average cosine similarity: {sum(all_cosine_sims) / len(all_cosine_sims):.4f}")


Cross-lingual Alignment Test:

A               ↔ A                   
  Cosine similarity: 1.0000
  Dot product: 0.78
A               ↔ a                   
  Cosine similarity: 1.0000
  Dot product: 0.81
ASCII           ↔ ASCII               
  Cosine similarity: 1.0000
  Dot product: 0.70
ASCII           ↔ ascii               
  Cosine similarity: 1.0000
  Dot product: 0.73
DNA             ↔ DNA                 
  Cosine similarity: 1.0000
  Dot product: 0.80
DNA             ↔ dna                 
  Cosine similarity: 1.0000
  Dot product: 0.77
E               ↔ E                   
  Cosine similarity: 1.0000
  Dot product: 0.80
E               ↔ e                   
  Cosine similarity: 1.0000
  Dot product: 0.82
F               ↔ F                   
  Cosine similarity: 1.0000
  Dot product: 0.82
F               ↔ f                   
  Cosine similarity: 1.0000
  Dot product: 0.84

Average cosine similarity: 1.0000


## 12. Full Training (Optional)

Uncomment below to run full training with all epochs.

In [None]:
# # Full training
print("Starting full training...\n")
trainer.train(num_epochs=config['training']['num_epochs'])

print("\nTraining completed!")
print(f"Best validation loss: {trainer.best_val_loss:.4f}")
print(f"Best model saved to: {config['logging']['output_dir']}/best_model")

## Summary

We've successfully:
1. ✅ Loaded 27K query-document pairs
2. ✅ Loaded 74 Korean-English synonym pairs
3. ✅ Initialized Neural Sparse Encoder (klue/bert-base)
4. ✅ Set up training pipeline with mixed precision
5. ✅ Ran initial training experiment
6. ✅ Tested sparse representations
7. ✅ Verified cross-lingual alignment

**Next steps:**
- Run full training (10 epochs)
- Evaluate on test set
- Export model for OpenSearch
- Build evaluation framework