# OpenSearch Neural Sparse Model v2 Training

이 노트북은 최신 연구 논문 "Towards Competitive Search Relevance For Inference-Free Learned Sparse Retrievers" (arXiv:2411.04403v2)를 기반으로 OpenSearch Neural Sparse Model v2를 학습합니다.

## 주요 개선사항

### 1. IDF-Aware Penalty
- 기존 FLOPS regularization의 문제점 해결
- 중요한 저빈도 토큰 보존
- IDF 가중치를 활용한 차별화된 패널티 적용

### 2. Heterogeneous Ensemble Knowledge Distillation
- Dense와 Sparse teacher 모델의 앙상블
- 대규모 pre-training 데이터에 적용 가능
- Cross-encoder보다 효율적인 teacher 모델

### 3. Hardware Optimization
- Nvidia DGX Spark GPU 최적화
- Mixed Precision Training (FP16/BF16)
- Gradient Accumulation
- Distributed Training Support

## 환경 정보
- GPU: Nvidia GB10 (Compute Capability 12.1)
- CUDA: 13.0
- Python: 3.12.3
- PyTorch: 2.5.1


## 1. Setup and Imports

In [None]:
import sys
import os
from pathlib import Path

# Add project root to path
project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")
print(f"Working directory: {os.getcwd()}")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import pandas as pd
from typing import Dict, List, Optional, Tuple
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import json
import yaml

from transformers import (
    AutoTokenizer,
    AutoModel,
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)

# Import custom modules
from src.models.neural_sparse_encoder import NeuralSparseEncoder
from src.training.trainer import NeuralSparseTrainer
from src.training.losses import CombinedLoss
from src.training.data_collator import NeuralSparseDataCollator

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration

### Training Configuration Based on Paper Findings

In [None]:
# Base configuration
CONFIG = {
    # Model configuration
    "model": {
        "base_model": "opensearch-project/opensearch-neural-sparse-encoding-multilingual-v1",
        "max_query_length": 64,
        "max_doc_length": 256,
        "use_relu": True,
    },
    
    # Pre-training configuration (based on paper Section 5.1.2)
    "pretraining": {
        "enabled": True,
        "num_steps": 150000,
        "batch_size": 48,  # Per device
        "gradient_accumulation_steps": 1,
        "num_hard_negatives": 7,  # Paper: 7 hard negatives per query
        "learning_rate": 5e-5,
        "lambda_flops": 1e-7,  # Small coefficient for pre-training
        "warmup_steps": 5000,
        "max_grad_norm": 1.0,
        "scale_constant_S": 10,  # For scaling ensemble scores
    },
    
    # Fine-tuning configuration (based on paper Section 5.1.2)
    "finetuning": {
        "enabled": True,
        "num_steps": 50000,
        "batch_size": 40,  # Per device
        "gradient_accumulation_steps": 1,
        "num_hard_negatives": 10,  # Paper: 10 hard negatives per query
        "learning_rate": 2e-5,
        "lambda_flops": 0.02,  # Balance relevance and efficiency
        "warmup_steps": 2000,
        "max_grad_norm": 1.0,
        "scale_constant_S": 30,  # For fine-tuning
    },
    
    # IDF-aware penalty configuration (Section 4.1)
    "idf_penalty": {
        "enabled": True,
        "idf_source": "msmarco",  # Dataset to compute IDF from
        "default_idf": 1.0,  # For unseen tokens
    },
    
    # Knowledge distillation configuration (Section 4.2)
    # IMPORTANT: OpenSearch sparse models on HuggingFace Hub may not have
    # pretrained sparse projection head weights. Recommended approaches:
    # Option 1 (Recommended): Use only dense teacher initially
    #   - Set teacher_weights.sparse = 0
    # Option 2: Train sparse teacher first, then use for distillation
    #   - Train a model, save it, then set sparse_teacher to saved path
    "knowledge_distillation": {
        "enabled": True,
        "dense_teacher": "Alibaba-NLP/gte-large-en-v1.5",
        # Use None or set sparse weight to 0 to skip sparse teacher
        "sparse_teacher": None,  # Changed from opensearch model ID
        "cross_encoder": "cross-encoder/ms-marco-MiniLM-L-12-v2",  # For fine-tuning
        "teacher_weights": {
            "dense": 1.0,  # Use only dense teacher
            "sparse": 0.0,  # Skip sparse teacher
        },
        "temperature": 1.0,
        "use_normalization": True,  # Min-max normalization before ensemble
    },
    
    # Hardware optimization for Nvidia DGX Spark
    "hardware": {
        "mixed_precision": True,
        "precision": "bf16",  # BF16 for better stability
        "compile_model": True,  # PyTorch 2.0 compile
        "num_workers": 8,
        "pin_memory": True,
        "persistent_workers": True,
    },
    
    # Logging and checkpointing
    "logging": {
        "output_dir": "outputs/opensearch-neural-v2",
        "logging_steps": 100,
        "eval_steps": 500,
        "save_steps": 1000,
        "save_total_limit": 5,
    },
    
    # Data filtering (Section 4.2)
    "data_filtering": {
        "enabled": True,
        "top_k_filter": 10,  # Filter samples where positive not in top-10
    },
}

# Print configuration
print("Training Configuration:")
print(json.dumps(CONFIG, indent=2))

print("\n" + "=" * 80)
print("CONFIGURATION NOTES")
print("=" * 80)
print("\n1. Sparse Teacher Configuration:")
print("   - Currently set to use ONLY dense teacher (sparse_weight=0)")
print("   - This avoids issues with models lacking pretrained sparse heads")
print("   - For two-stage approach:")
print("     a) First train a sparse model (current config)")
print("     b) Then use trained model as sparse teacher for next iteration")
print("\n2. Model Loading:")
print("   - Base model will be loaded from HuggingFace Hub")
print("   - Projection layer will be randomly initialized")
print("   - Training is required before using for retrieval")
print("\n3. See docs/MODEL_LOADING_GUIDE.md for detailed information")
print("=" * 80)

## 3. IDF Computation

### Compute IDF weights from corpus (Section 4.1)

In [None]:
def compute_idf_weights(
    documents: List[str],
    tokenizer: AutoTokenizer,
    save_path: Optional[str] = None,
) -> Dict[int, float]:
    """
    Compute IDF weights for vocabulary.
    
    IDF(t) = log(N / df(t))
    where N = total documents, df(t) = document frequency of term t
    
    Args:
        documents: List of document texts
        tokenizer: HuggingFace tokenizer
        save_path: Optional path to save IDF weights
        
    Returns:
        Dictionary mapping token IDs to IDF weights
    """
    from collections import Counter
    
    print(f"Computing IDF weights from {len(documents)} documents...")
    
    # Count document frequency for each token
    df_counter = Counter()
    
    for i, doc in enumerate(tqdm(documents, desc="Processing documents")):
        # Tokenize
        tokens = tokenizer.encode(doc, add_special_tokens=False)
        # Count unique tokens in this document
        unique_tokens = set(tokens)
        df_counter.update(unique_tokens)
        
        if i % 10000 == 0:
            print(f"Processed {i} documents, unique tokens: {len(df_counter)}")
    
    # Compute IDF
    N = len(documents)
    idf_weights = {}
    
    for token_id, df in df_counter.items():
        idf = np.log(N / df)
        idf_weights[token_id] = float(idf)
    
    print(f"Computed IDF for {len(idf_weights)} tokens")
    print(f"IDF statistics:")
    idf_values = list(idf_weights.values())
    print(f"  Min: {min(idf_values):.4f}")
    print(f"  Max: {max(idf_values):.4f}")
    print(f"  Mean: {np.mean(idf_values):.4f}")
    print(f"  Median: {np.median(idf_values):.4f}")
    
    # Save if path provided
    if save_path:
        with open(save_path, 'w') as f:
            json.dump(idf_weights, f, indent=2)
        print(f"IDF weights saved to {save_path}")
    
    return idf_weights


def load_idf_weights(path: str) -> Dict[int, float]:
    """
    Load pre-computed IDF weights.
    
    Args:
        path: Path to IDF weights JSON file
        
    Returns:
        Dictionary mapping token IDs to IDF weights
    """
    with open(path, 'r') as f:
        idf_weights = json.load(f)
    
    # Convert string keys to int
    idf_weights = {int(k): float(v) for k, v in idf_weights.items()}
    
    print(f"Loaded IDF weights for {len(idf_weights)} tokens from {path}")
    return idf_weights

## 4. IDF-Aware Loss Functions

### Implementation of IDF-aware penalty (Section 4.1)

In [None]:
class IDFAwareLoss(nn.Module):
    """
    IDF-aware ranking loss with FLOPS regularization.
    
    Based on paper Section 4.1:
    - Ranking loss weighted by IDF values
    - FLOPS regularization for sparsity
    - Gradient composition encourages preserving high-IDF tokens
    """
    
    def __init__(
        self,
        idf_weights: Dict[int, float],
        vocab_size: int,
        lambda_flops: float = 0.02,
        default_idf: float = 1.0,
        device: Optional[torch.device] = None,
    ):
        """
        Initialize IDF-aware loss.
        
        Args:
            idf_weights: Dictionary mapping token IDs to IDF values
            vocab_size: Vocabulary size
            lambda_flops: Weight for FLOPS regularization
            default_idf: Default IDF for unseen tokens
            device: Device to put IDF tensor on
        """
        super().__init__()
        self.lambda_flops = lambda_flops
        
        # Create IDF tensor
        idf_tensor = torch.ones(vocab_size) * default_idf
        for token_id, idf_val in idf_weights.items():
            if token_id < vocab_size:
                idf_tensor[token_id] = idf_val
        
        # Register as buffer (not trainable parameter)
        self.register_buffer('idf_weights', idf_tensor)
        
        if device is not None:
            self.idf_weights = self.idf_weights.to(device)
        
        print(f"Initialized IDF-aware loss with {len(idf_weights)} IDF weights")
    
    def compute_idf_weighted_score(
        self,
        query_rep: torch.Tensor,
        doc_rep: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute IDF-weighted similarity score (Equation 5 in paper).
        
        s(q, d) = Σ_t idf(t) * q_t * d_t
        
        Args:
            query_rep: Query sparse representation [batch_size, vocab_size]
            doc_rep: Document sparse representation [batch_size, vocab_size]
            
        Returns:
            IDF-weighted scores [batch_size]
        """
        # Element-wise multiplication with IDF weights
        weighted_query = query_rep * self.idf_weights.unsqueeze(0)
        
        # Dot product
        scores = torch.sum(weighted_query * doc_rep, dim=-1)
        return scores
    
    def forward(
        self,
        query_rep: torch.Tensor,
        pos_doc_rep: torch.Tensor,
        neg_doc_reps: torch.Tensor,
        teacher_scores: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute IDF-aware loss.
        
        Args:
            query_rep: Query representations [batch_size, vocab_size]
            pos_doc_rep: Positive doc representations [batch_size, vocab_size]
            neg_doc_reps: Negative doc representations [batch_size, num_neg, vocab_size]
            teacher_scores: Optional teacher scores [batch_size, 1+num_neg]
            
        Returns:
            Dictionary with loss components
        """
        batch_size = query_rep.shape[0]
        num_neg = neg_doc_reps.shape[1]
        
        # Compute IDF-weighted scores
        pos_scores = self.compute_idf_weighted_score(query_rep, pos_doc_rep)
        
        # Compute negative scores
        neg_scores = torch.stack([
            self.compute_idf_weighted_score(query_rep, neg_doc_reps[:, i, :])
            for i in range(num_neg)
        ], dim=1)
        
        # Concatenate all scores
        all_scores = torch.cat([pos_scores.unsqueeze(1), neg_scores], dim=1)
        
        # Ranking loss
        if teacher_scores is not None:
            # Knowledge distillation with KL divergence
            student_log_probs = F.log_softmax(all_scores, dim=-1)
            teacher_probs = F.softmax(teacher_scores, dim=-1)
            ranking_loss = F.kl_div(
                student_log_probs,
                teacher_probs,
                reduction='batchmean',
            )
        else:
            # Standard cross-entropy (positive is first)
            labels = torch.zeros(batch_size, dtype=torch.long, device=query_rep.device)
            ranking_loss = F.cross_entropy(all_scores, labels)
        
        # FLOPS regularization (Equation 3 in paper)
        # Average squared activation across batch
        doc_flops = torch.sum(
            (pos_doc_rep.mean(dim=0) ** 2)
        )
        
        flops_loss = self.lambda_flops * doc_flops
        
        # Total loss
        total_loss = ranking_loss + flops_loss
        
        return {
            'total_loss': total_loss,
            'ranking_loss': ranking_loss,
            'flops_loss': flops_loss,
        }

print("IDF-aware loss functions defined")

## 5. Teacher Models for Knowledge Distillation

### Heterogeneous ensemble of dense and sparse teachers (Section 4.2)

In [None]:
class EnsembleTeacher:
    """
    Ensemble of heterogeneous teacher models.
    
    Based on paper Section 4.2:
    - Combines dense and sparse siamese retrievers
    - Min-max normalization before ensemble
    - Weighted sum of normalized scores
    
    Note: The sparse teacher may not have pretrained sparse head weights.
    In that case, it will be initialized with random projection weights.
    For best results, either:
    1. Train a sparse teacher first, or
    2. Use only the dense teacher (set sparse_weight=0)
    """
    
    def __init__(
        self,
        dense_teacher_name: str,
        sparse_teacher_name: Optional[str] = None,
        dense_weight: float = 0.5,
        sparse_weight: float = 0.5,
        scale_constant: float = 10.0,
        device: Optional[torch.device] = None,
    ):
        """
        Initialize ensemble teacher.
        
        Args:
            dense_teacher_name: HuggingFace model name for dense teacher
            sparse_teacher_name: HuggingFace model name for sparse teacher (optional)
            dense_weight: Weight for dense teacher
            sparse_weight: Weight for sparse teacher
            scale_constant: Scaling constant S (paper Equation 9)
            device: Device to load models on
        """
        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.dense_weight = dense_weight
        self.sparse_weight = sparse_weight
        self.scale_constant = scale_constant
        self.use_sparse_teacher = sparse_teacher_name is not None and sparse_weight > 0
        
        # Load dense teacher
        print(f"Loading dense teacher: {dense_teacher_name}")
        from sentence_transformers import SentenceTransformer
        self.dense_teacher = SentenceTransformer(
            dense_teacher_name,
            device=str(self.device),
            trust_remote_code=True,
        )
        self.dense_teacher.eval()
        print("Dense teacher loaded successfully")
        
        # Load sparse teacher if specified
        if self.use_sparse_teacher:
            print(f"\nLoading sparse teacher: {sparse_teacher_name}")
            try:
                # Load sparse teacher (OpenSearch neural sparse model)
                # Note: This will automatically handle models without pretrained head
                self.sparse_teacher = NeuralSparseEncoder.from_pretrained(
                    sparse_teacher_name
                )
                self.sparse_teacher.to(self.device)
                self.sparse_teacher.eval()
                print("Sparse teacher loaded successfully")
                print("WARNING: If the projection layer was randomly initialized,")
                print("         the sparse teacher may not provide meaningful scores.")
                print("         Consider training a sparse teacher first or using only")
                print("         the dense teacher (set sparse_weight=0 in CONFIG).")
            except Exception as e:
                print(f"ERROR loading sparse teacher: {e}")
                print("Falling back to dense teacher only")
                self.use_sparse_teacher = False
                self.dense_weight = 1.0
                self.sparse_weight = 0.0
        else:
            print("\nSparse teacher disabled (using dense teacher only)")
            self.sparse_teacher = None
        
        print("\nEnsemble teacher initialized:")
        print(f"  Dense weight: {self.dense_weight}")
        print(f"  Sparse weight: {self.sparse_weight}")
        print(f"  Scale constant: {scale_constant}")
    
    def min_max_normalize(
        self,
        scores: torch.Tensor,
        dim: int = -1,
    ) -> torch.Tensor:
        """
        Min-max normalization (paper Equation 8).
        
        ŝ_i = (s_i - min(s)) / (max(s) - min(s))
        
        Args:
            scores: Scores tensor
            dim: Dimension to normalize over
            
        Returns:
            Normalized scores in [0, 1]
        """
        min_scores = scores.min(dim=dim, keepdim=True)[0]
        max_scores = scores.max(dim=dim, keepdim=True)[0]
        
        # Avoid division by zero
        range_scores = max_scores - min_scores
        range_scores = torch.clamp(range_scores, min=1e-8)
        
        normalized = (scores - min_scores) / range_scores
        return normalized
    
    @torch.no_grad()
    def get_scores(
        self,
        queries: List[str],
        documents: List[List[str]],  # [batch_size, num_docs]
    ) -> torch.Tensor:
        """
        Get ensemble teacher scores.
        
        Args:
            queries: List of query strings
            documents: List of document lists (pos + negs for each query)
            
        Returns:
            Ensemble scores [batch_size, num_docs]
        """
        batch_size = len(queries)
        num_docs = len(documents[0])
        
        # Dense teacher scores
        query_embeddings = self.dense_teacher.encode(
            queries,
            convert_to_tensor=True,
            show_progress_bar=False,
        )
        
        dense_scores = []
        for i, docs in enumerate(documents):
            doc_embeddings = self.dense_teacher.encode(
                docs,
                convert_to_tensor=True,
                show_progress_bar=False,
            )
            # Cosine similarity
            scores = torch.cosine_similarity(
                query_embeddings[i].unsqueeze(0),
                doc_embeddings,
                dim=-1,
            )
            dense_scores.append(scores)
        
        dense_scores = torch.stack(dense_scores)  # [batch_size, num_docs]
        
        # Sparse teacher scores (if enabled)
        if self.use_sparse_teacher:
            query_sparse_reps = self.sparse_teacher.encode(queries, device=self.device)
            
            sparse_scores = []
            for i, docs in enumerate(documents):
                doc_sparse_reps = self.sparse_teacher.encode(docs, device=self.device)
                # Dot product similarity
                scores = torch.sum(
                    query_sparse_reps[i].unsqueeze(0) * doc_sparse_reps,
                    dim=-1,
                )
                sparse_scores.append(scores)
            
            sparse_scores = torch.stack(sparse_scores)  # [batch_size, num_docs]
            
            # Normalize scores (Equation 8)
            dense_norm = self.min_max_normalize(dense_scores, dim=1)
            sparse_norm = self.min_max_normalize(sparse_scores, dim=1)
            
            # Weighted ensemble (Equation 9)
            ensemble_scores = (
                self.dense_weight * dense_norm +
                self.sparse_weight * sparse_norm
            )
        else:
            # Use only dense teacher
            ensemble_scores = self.min_max_normalize(dense_scores, dim=1)
        
        # Scale back with constant S
        ensemble_scores = self.scale_constant * ensemble_scores
        
        return ensemble_scores

print("Ensemble teacher class defined")

## 6. Data Loading

### Dataset with hard negative mining and filtering

In [None]:
class SparseRetrievalDataset(Dataset):
    """
    Dataset for neural sparse retrieval training.
    
    Based on paper Section 4.2:
    - Hard negative mining
    - Consistency-based filtering (top-k)
    """
    
    def __init__(
        self,
        queries: List[str],
        positive_docs: List[str],
        negative_docs: List[List[str]],
        filter_top_k: Optional[int] = None,
        miner_model: Optional[nn.Module] = None,
    ):
        """
        Initialize dataset.
        
        Args:
            queries: List of query strings
            positive_docs: List of positive documents
            negative_docs: List of negative document lists
            filter_top_k: Filter samples where positive not in top-k
            miner_model: Model for mining hard negatives
        """
        self.queries = queries
        self.positive_docs = positive_docs
        self.negative_docs = negative_docs
        
        # Apply filtering if specified
        if filter_top_k is not None and miner_model is not None:
            self._apply_consistency_filter(filter_top_k, miner_model)
        
        print(f"Dataset initialized with {len(self)} samples")
    
    def _apply_consistency_filter(
        self,
        top_k: int,
        miner_model: nn.Module,
    ) -> None:
        """
        Apply consistency-based filtering.
        
        Filters out samples where positive document is not in top-k results.
        """
        print(f"Applying consistency filter (top-{top_k})...")
        
        filtered_indices = []
        
        miner_model.eval()
        with torch.no_grad():
            for i in tqdm(range(len(self)), desc="Filtering"):
                query = self.queries[i]
                pos_doc = self.positive_docs[i]
                neg_docs = self.negative_docs[i]
                
                # Get all documents
                all_docs = [pos_doc] + neg_docs[:top_k-1]
                
                # Encode and score
                query_rep = miner_model.encode([query])[0]
                doc_reps = miner_model.encode(all_docs)
                
                scores = torch.sum(query_rep * doc_reps, dim=-1)
                top_k_indices = torch.topk(scores, k=min(top_k, len(all_docs)))[1]
                
                # Check if positive (index 0) is in top-k
                if 0 in top_k_indices:
                    filtered_indices.append(i)
        
        # Update dataset
        self.queries = [self.queries[i] for i in filtered_indices]
        self.positive_docs = [self.positive_docs[i] for i in filtered_indices]
        self.negative_docs = [self.negative_docs[i] for i in filtered_indices]
        
        print(f"Filtered to {len(self)} samples ({len(filtered_indices)/len(self.queries)*100:.1f}% retained)")
    
    def __len__(self) -> int:
        return len(self.queries)
    
    def __getitem__(self, idx: int) -> Dict[str, any]:
        # Return both tokenized inputs AND raw text for teacher model
        return {
            'query': self.queries[idx],
            'positive_doc': self.positive_docs[idx],
            'negative_docs': self.negative_docs[idx],
        }

print("Dataset class defined")

## 7. Training Setup

### Initialize model, tokenizer, and training components

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    CONFIG['model']['base_model']
)
print(f"Loaded tokenizer: {CONFIG['model']['base_model']}")

# Initialize model
model = NeuralSparseEncoder(
    model_name=CONFIG['model']['base_model'],
    max_length=CONFIG['model']['max_doc_length'],
    use_relu=CONFIG['model']['use_relu'],
)
model = model.to(device)

# Compile model for faster training (PyTorch 2.0+)
if CONFIG['hardware']['compile_model'] and torch.__version__ >= '2.0':
    print("Compiling model with torch.compile...")
    model = torch.compile(model)

print(f"Model initialized and moved to {device}")

## 8. Load or Compute IDF Weights

In [None]:
# Path for IDF weights
idf_path = project_root / "data" / "idf_weights_msmarco.json"

if idf_path.exists():
    print(f"Loading pre-computed IDF weights from {idf_path}")
    idf_weights = load_idf_weights(str(idf_path))
else:
    print("IDF weights not found. You need to compute them from your corpus.")
    print("Example:")
    print("""\n
# Load your corpus
documents = [...]  # Your document corpus

# Compute IDF
idf_weights = compute_idf_weights(
    documents=documents,
    tokenizer=tokenizer,
    save_path=str(idf_path),
)
    """)
    
    # For demonstration, create dummy IDF weights
    print("\nCreating dummy IDF weights for demonstration...")
    idf_weights = {i: 1.0 for i in range(tokenizer.vocab_size)}

## 9. Initialize Loss Function

In [None]:
# Initialize IDF-aware loss
loss_fn = IDFAwareLoss(
    idf_weights=idf_weights,
    vocab_size=tokenizer.vocab_size,
    lambda_flops=CONFIG['finetuning']['lambda_flops'],
    default_idf=CONFIG['idf_penalty']['default_idf'],
    device=device,
)

print("Loss function initialized")

## 10. Initialize Teacher Models (Optional)

In [None]:
if CONFIG['knowledge_distillation']['enabled']:
    print("Initializing ensemble teacher models...")
    print("=" * 80)
    
    # Determine if sparse teacher should be used
    sparse_teacher_name = CONFIG['knowledge_distillation']['sparse_teacher']
    sparse_weight = CONFIG['knowledge_distillation']['teacher_weights']['sparse']
    
    if sparse_weight == 0:
        print("Note: Sparse teacher weight is 0 - using dense teacher only")
        sparse_teacher_name = None
    
    if sparse_teacher_name:
        print("\nIMPORTANT: Sparse Teacher Information")
        print("-" * 80)
        print("The OpenSearch sparse models on HuggingFace Hub may not have")
        print("pretrained sparse projection head weights (neural_sparse_head.pt).")
        print("\nIf the model doesn't have pretrained weights, you have two options:")
        print("  1. Use only the dense teacher (recommended for initial training)")
        print("     Set CONFIG['knowledge_distillation']['teacher_weights']['sparse'] = 0")
        print("  2. Train a sparse teacher first, then use it for distillation")
        print("\nProceeding with sparse teacher loading...")
        print("-" * 80)
    
    teacher = EnsembleTeacher(
        dense_teacher_name=CONFIG['knowledge_distillation']['dense_teacher'],
        sparse_teacher_name=sparse_teacher_name,
        dense_weight=CONFIG['knowledge_distillation']['teacher_weights']['dense'],
        sparse_weight=sparse_weight,
        scale_constant=CONFIG['finetuning']['scale_constant_S'],
        device=device,
    )
    
    print("\n" + "=" * 80)
    print("Teacher models initialization complete")
    print("=" * 80)
else:
    teacher = None
    print("Training without knowledge distillation")

## 11. Load Training Data

In [None]:
# TODO: Load your training data
# This is a placeholder - replace with your actual data loading

print("Loading training data...")
print("NOTE: This is a placeholder. Load your actual training data here.")

# Example data structure:
train_queries = ["query 1", "query 2", ...]  # Your queries
train_positive_docs = ["pos doc 1", "pos doc 2", ...]  # Positive documents
train_negative_docs = [  # Hard negatives for each query
    ["neg 1", "neg 2", ...],  # Negatives for query 1
    ["neg 1", "neg 2", ...],  # Negatives for query 2
    ...
]

# For demonstration, create dummy data
print("Creating dummy data for demonstration...")
num_samples = 100
train_queries = [f"This is query {i}" for i in range(num_samples)]
train_positive_docs = [f"This is positive document {i}" for i in range(num_samples)]
train_negative_docs = [
    [f"Negative doc {i}-{j}" for j in range(CONFIG['finetuning']['num_hard_negatives'])]
    for i in range(num_samples)
]

print(f"Loaded {len(train_queries)} training samples")

## 12. Create Dataset and DataLoader

In [None]:
# Create dataset
train_dataset = SparseRetrievalDataset(
    queries=train_queries,
    positive_docs=train_positive_docs,
    negative_docs=train_negative_docs,
    filter_top_k=CONFIG['data_filtering']['top_k_filter'] if CONFIG['data_filtering']['enabled'] else None,
    miner_model=model if CONFIG['data_filtering']['enabled'] else None,
)

# Create data collator
data_collator = NeuralSparseDataCollator(
    tokenizer=tokenizer,
    query_max_length=CONFIG['model']['max_query_length'],
    doc_max_length=CONFIG['model']['max_doc_length'],
    num_negatives=CONFIG['finetuning']['num_hard_negatives'],
)

# Create dataloader
train_dataloader = DataLoader(
    train_dataset,
    batch_size=CONFIG['finetuning']['batch_size'],
    shuffle=True,
    collate_fn=data_collator,
    num_workers=CONFIG['hardware']['num_workers'],
    pin_memory=CONFIG['hardware']['pin_memory'],
    persistent_workers=CONFIG['hardware']['persistent_workers'],
)

print(f"Created dataloader with {len(train_dataloader)} batches")

In [None]:
# Override DataCollator to include raw text for teacher model
# This fixes the KeyError: 'queries' issue

class FixedDataCollator(NeuralSparseDataCollator):
    """
    Extended data collator that includes raw text for teacher model.
    
    Inherits from NeuralSparseDataCollator and adds raw text pass-through
    for knowledge distillation.
    """
    
    def __call__(self, features: List[Dict]) -> Dict[str, torch.Tensor]:
        """
        Collate features and add raw text for teacher model.
        
        Args:
            features: List of dictionaries from Dataset.__getitem__()
                Each dict has keys: 'query', 'positive_doc', 'negative_docs'
        
        Returns:
            Batch dictionary with:
                - Tokenized inputs (for student model)
                - Raw text (for teacher model)
        """
        # Extract raw text BEFORE tokenization
        queries = [f['query'] for f in features]
        pos_docs = [f['positive_doc'] for f in features]
        neg_docs = [f['negative_docs'] for f in features]
        
        # Call parent class to get tokenized inputs
        batch = super().__call__(features)
        
        # Add raw text for teacher model
        batch['queries'] = queries
        batch['positive_docs'] = pos_docs
        batch['negative_docs'] = neg_docs
        
        return batch

# Use fixed collator instead of original
data_collator = FixedDataCollator(
    tokenizer=tokenizer,
    query_max_length=CONFIG['model']['max_query_length'],
    doc_max_length=CONFIG['model']['max_doc_length'],
    num_negatives=CONFIG['finetuning']['num_hard_negatives'],
)

print("✓ Using FixedDataCollator with raw text pass-through for teacher model")

In [None]:
# Validate batch structure before training
print("Validating batch structure...")
print("=" * 80)

# Get a sample batch
sample_batch_iter = iter(train_dataloader)
sample_batch = next(sample_batch_iter)

print("\nBatch keys:")
for key in sample_batch.keys():
    if isinstance(sample_batch[key], torch.Tensor):
        print(f"  {key:30s}: torch.Tensor {sample_batch[key].shape}")
    elif isinstance(sample_batch[key], list):
        print(f"  {key:30s}: List[...] (length={len(sample_batch[key])})")
        if len(sample_batch[key]) > 0:
            if isinstance(sample_batch[key][0], str):
                print(f"    └─ Sample: {sample_batch[key][0][:50]}...")
            elif isinstance(sample_batch[key][0], list):
                print(f"    └─ Nested list with {len(sample_batch[key][0])} items")
    else:
        print(f"  {key:30s}: {type(sample_batch[key])}")

print("\n" + "=" * 80)
print("VALIDATION RESULTS")
print("=" * 80)

# Check required keys for student model (tokenized inputs)
student_keys = [
    'query_input_ids', 'query_attention_mask',
    'pos_doc_input_ids', 'pos_doc_attention_mask',
    'neg_doc_input_ids', 'neg_doc_attention_mask'
]

print("\nStudent model keys (tokenized):")
for key in student_keys:
    status = "✓" if key in sample_batch else "✗"
    print(f"  {status} {key}")

# Check required keys for teacher model (raw text)
teacher_keys = ['queries', 'positive_docs', 'negative_docs']

print("\nTeacher model keys (raw text):")
for key in teacher_keys:
    status = "✓" if key in sample_batch else "✗"
    print(f"  {status} {key}")

# Validate shapes match
batch_size = len(sample_batch['queries'])
num_negatives = CONFIG['finetuning']['num_hard_negatives']

print(f"\nBatch size: {batch_size}")
print(f"Number of negatives: {num_negatives}")

shape_checks = [
    ("query_input_ids", (batch_size, None)),
    ("pos_doc_input_ids", (batch_size, None)),
    ("neg_doc_input_ids", (batch_size, num_negatives, None)),
]

print("\nShape validation:")
for key, expected_prefix in shape_checks:
    actual_shape = sample_batch[key].shape
    match = (actual_shape[0] == expected_prefix[0] and 
             (expected_prefix[1] is None or actual_shape[1] == expected_prefix[1]))
    status = "✓" if match else "✗"
    print(f"  {status} {key:30s}: {actual_shape}")

print("\n" + "=" * 80)
if all(k in sample_batch for k in student_keys + teacher_keys):
    print("SUCCESS: Batch structure is correct!")
    print("Ready to start training with knowledge distillation.")
else:
    print("ERROR: Missing required keys in batch!")
    missing = [k for k in student_keys + teacher_keys if k not in sample_batch]
    print(f"Missing: {missing}")
print("=" * 80)

## 13. Initialize Optimizer and Scheduler

In [None]:
# Calculate total training steps
num_training_steps = CONFIG['finetuning']['num_steps']
num_warmup_steps = CONFIG['finetuning']['warmup_steps']

# Initialize optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['finetuning']['learning_rate'],
    weight_decay=0.01,
    betas=(0.9, 0.999),
    eps=1e-8,
)

# Initialize learning rate scheduler
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)

# Initialize gradient scaler for mixed precision
scaler = GradScaler() if CONFIG['hardware']['mixed_precision'] else None

print(f"Optimizer and scheduler initialized")
print(f"Total training steps: {num_training_steps}")
print(f"Warmup steps: {num_warmup_steps}")
print(f"Mixed precision: {CONFIG['hardware']['mixed_precision']}")

## 14. Training Loop

### Main training loop with mixed precision and gradient accumulation

In [None]:
def train_step(
    batch: Dict[str, torch.Tensor],
    model: nn.Module,
    loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,
    scaler: Optional[GradScaler],
    teacher: Optional[EnsembleTeacher] = None,
) -> Dict[str, float]:
    """
    Single training step.
    
    Args:
        batch: Batch dictionary
        model: Student model
        loss_fn: Loss function
        optimizer: Optimizer
        scaler: Gradient scaler for mixed precision
        teacher: Optional teacher model for knowledge distillation
        
    Returns:
        Dictionary with loss values
    """
    model.train()
    
    # Move batch to device
    batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
             for k, v in batch.items()}
    
    # Get teacher scores if available
    teacher_scores = None
    if teacher is not None:
        queries = batch['queries']  # Assuming raw text is passed
        positive_docs = batch['positive_docs']
        negative_docs = batch['negative_docs']
        
        # Combine into document lists
        all_docs = [[pos] + negs for pos, negs in zip(positive_docs, negative_docs)]
        
        teacher_scores = teacher.get_scores(queries, all_docs)
    
    # Forward pass with autocast
    with autocast(enabled=scaler is not None, dtype=torch.bfloat16):
        # Encode query
        query_outputs = model(
            input_ids=batch['query_input_ids'],
            attention_mask=batch['query_attention_mask'],
        )
        query_rep = query_outputs['sparse_rep']
        
        # Encode positive document
        pos_outputs = model(
            input_ids=batch['pos_doc_input_ids'],
            attention_mask=batch['pos_doc_attention_mask'],
        )
        pos_rep = pos_outputs['sparse_rep']
        
        # Encode negative documents
        batch_size, num_neg, seq_len = batch['neg_doc_input_ids'].shape
        neg_input_ids = batch['neg_doc_input_ids'].view(batch_size * num_neg, seq_len)
        neg_attention_mask = batch['neg_doc_attention_mask'].view(batch_size * num_neg, seq_len)
        
        neg_outputs = model(
            input_ids=neg_input_ids,
            attention_mask=neg_attention_mask,
        )
        neg_rep = neg_outputs['sparse_rep'].view(batch_size, num_neg, -1)
        
        # Compute loss
        losses = loss_fn(
            query_rep=query_rep,
            pos_doc_rep=pos_rep,
            neg_doc_reps=neg_rep,
            teacher_scores=teacher_scores,
        )
        
        total_loss = losses['total_loss']
    
    # Backward pass
    optimizer.zero_grad()
    
    if scaler is not None:
        scaler.scale(total_loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            CONFIG['finetuning']['max_grad_norm'],
        )
        scaler.step(optimizer)
        scaler.update()
    else:
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            CONFIG['finetuning']['max_grad_norm'],
        )
        optimizer.step()
    
    # Return losses as floats
    return {k: v.item() for k, v in losses.items()}

print("Training step function defined")

In [None]:
# Training loop
print("Starting training...")

output_dir = Path(CONFIG['logging']['output_dir'])
output_dir.mkdir(parents=True, exist_ok=True)

global_step = 0
training_losses = []

# Progress bar
pbar = tqdm(total=num_training_steps, desc="Training")

while global_step < num_training_steps:
    for batch in train_dataloader:
        # Training step
        losses = train_step(
            batch=batch,
            model=model,
            loss_fn=loss_fn,
            optimizer=optimizer,
            scaler=scaler,
            teacher=teacher,
        )
        
        scheduler.step()
        global_step += 1
        
        # Log losses
        training_losses.append(losses)
        
        # Update progress bar
        pbar.update(1)
        pbar.set_postfix({
            'loss': f"{losses['total_loss']:.4f}",
            'rank': f"{losses['ranking_loss']:.4f}",
            'flops': f"{losses['flops_loss']:.6f}",
        })
        
        # Logging
        if global_step % CONFIG['logging']['logging_steps'] == 0:
            avg_loss = np.mean([l['total_loss'] for l in training_losses[-100:]])
            print(f"\nStep {global_step}: Avg loss = {avg_loss:.4f}")
        
        # Save checkpoint
        if global_step % CONFIG['logging']['save_steps'] == 0:
            checkpoint_dir = output_dir / f"checkpoint-{global_step}"
            model.save_pretrained(str(checkpoint_dir))
            print(f"Checkpoint saved to {checkpoint_dir}")
        
        # Stop if reached max steps
        if global_step >= num_training_steps:
            break

pbar.close()
print("Training completed!")

## 15. Save Final Model

In [None]:
# Save final model
final_model_dir = output_dir / "final_model"
model.save_pretrained(str(final_model_dir))
print(f"Final model saved to {final_model_dir}")

## 16. Training Analysis and Visualization

In [None]:
# Plot training losses
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Total loss
total_losses = [l['total_loss'] for l in training_losses]
axes[0].plot(total_losses, alpha=0.3, label='Raw')
axes[0].plot(pd.Series(total_losses).rolling(100).mean(), label='Smoothed (100)')
axes[0].set_title('Total Loss')
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True)

# Ranking loss
ranking_losses = [l['ranking_loss'] for l in training_losses]
axes[1].plot(ranking_losses, alpha=0.3, label='Raw')
axes[1].plot(pd.Series(ranking_losses).rolling(100).mean(), label='Smoothed (100)')
axes[1].set_title('Ranking Loss')
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True)

# FLOPS loss
flops_losses = [l['flops_loss'] for l in training_losses]
axes[2].plot(flops_losses, alpha=0.3, label='Raw')
axes[2].plot(pd.Series(flops_losses).rolling(100).mean(), label='Smoothed (100)')
axes[2].set_title('FLOPS Loss')
axes[2].set_xlabel('Step')
axes[2].set_ylabel('Loss')
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.savefig(output_dir / 'training_losses.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Training plots saved to {output_dir / 'training_losses.png'}")

## 17. Model Inference and Testing

In [None]:
# Test the trained model
model.eval()

# Test queries
test_queries = [
    "What is machine learning?",
    "How to train neural networks?",
    "Best practices for deep learning",
]

test_documents = [
    "Machine learning is a subset of artificial intelligence.",
    "Neural networks are trained using backpropagation.",
    "Deep learning requires large datasets and GPU computing.",
]

print("Testing model inference...\n")

with torch.no_grad():
    # Encode queries
    query_reps = model.encode(test_queries, device=device)
    
    # Encode documents
    doc_reps = model.encode(test_documents, device=device)
    
    # Compute similarities
    for i, query in enumerate(test_queries):
        print(f"Query: {query}")
        print("Similarities:")
        
        query_rep = query_reps[i]
        similarities = torch.sum(query_rep.unsqueeze(0) * doc_reps, dim=-1)
        
        for j, (doc, sim) in enumerate(zip(test_documents, similarities)):
            print(f"  [{sim:.4f}] {doc}")
        
        # Get top activated terms
        print("\nTop activated terms:")
        top_terms = model.get_top_k_terms(query_rep, k=10)
        for term, weight in top_terms:
            print(f"  {term:20s}: {weight:.4f}")
        print("\n" + "="*80 + "\n")
    
    # Sparsity statistics
    print("Sparsity Statistics:")
    stats = model.get_sparsity_stats(doc_reps)
    for key, value in stats.items():
        print(f"  {key:30s}: {value:.4f}")

## 18. Summary and Next Steps

### Training Summary

In [None]:
print("=" * 80)
print("TRAINING SUMMARY")
print("=" * 80)
print(f"\nModel: {CONFIG['model']['base_model']}")
print(f"Total training steps: {global_step}")
print(f"Final loss: {training_losses[-1]['total_loss']:.4f}")
print(f"Final ranking loss: {training_losses[-1]['ranking_loss']:.4f}")
print(f"Final FLOPS loss: {training_losses[-1]['flops_loss']:.6f}")
print(f"\nModel saved to: {final_model_dir}")

print("\n" + "=" * 80)
print("NEXT STEPS")
print("=" * 80)
print("""
1. Evaluate on BEIR benchmark datasets
2. Measure retrieval efficiency (FLOPS, latency)
3. Compare with baseline models (BM25, SPLADE-v3)
4. Deploy to OpenSearch for production testing
5. Fine-tune hyperparameters if needed
""")

# Save configuration
config_path = output_dir / "training_config.json"
with open(config_path, 'w') as f:
    json.dump(CONFIG, f, indent=2)
print(f"\nConfiguration saved to: {config_path}")