# v22.1 Training with SPLADELossV23

## Key Improvements over v22.0

1. **IDF-Aware FLOPS Penalty**: Preserves informative high-IDF tokens while penalizing stopwords
2. **Knowledge Distillation**: Dense teacher model guides sparse student learning
3. **Curriculum Learning**: 3 phases with decreasing KD weight for student independence
4. **Fixed Hyperparameters**: Balanced loss weights based on expert recommendations

## Curriculum Phases

| Phase | Epochs | Temperature | lambda_kd | Focus |
|-------|--------|-------------|-----------|-------|
| 1 | 1-7 | 0.08 | 1.5 | Foundation with teacher guidance |
| 2 | 8-14 | 0.05 | 1.0 | Balanced multi-type training |
| 3 | 15-20 | 0.04 | 0.5 | Hard negative refinement |

## Target Metrics

| Metric | v22.0 | v22.1 Target |
|--------|-------|-------------|
| Recall@1 | ~70% | 80%+ |
| MRR | ~0.75 | 0.85+ |

In [None]:
import sys
from pathlib import Path


def find_project_root() -> Path:
    """Find the project root directory."""
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent


PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    get_cosine_schedule_with_warmup,
)
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any
from tqdm.auto import tqdm
from collections import defaultdict
from datetime import datetime

# Import v23 loss functions
from src.model import (
    SPLADEDoc,
    SPLADEDocExpansion,
    SPLADELossV23,
    IDFAwareFLOPSLoss,
    DenseTeacherScorer,
)

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

print(f"Project root: {PROJECT_ROOT}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Configuration

In [None]:
@dataclass
class CurriculumPhaseV23:
    """Configuration for v22.1 curriculum learning phase."""
    name: str
    start_epoch: int
    end_epoch: int
    # Temperature for InfoNCE
    temperature: float
    # Knowledge distillation weight
    lambda_kd: float
    # Learning rate multiplier
    lr_multiplier: float
    # Data file for this phase
    data_file: str
    # Description
    description: str = ""


@dataclass
class TrainingConfigV23:
    """Training configuration for v22.1 with SPLADELossV23."""
    # Model
    model_name: str = "skt/kobert-base-v1"
    max_length: int = 128
    use_expansion: bool = True
    expansion_mode: str = "mlm"
    
    # Training
    num_epochs: int = 20
    batch_size: int = 64
    learning_rate: float = 2e-5
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    max_grad_norm: float = 1.0
    gradient_accumulation_steps: int = 1
    
    # Mixed precision
    mixed_precision: str = "bf16"  # "bf16", "fp16", or "no"
    
    # Loss weights (expert recommended)
    lambda_infonce: float = 2.5
    lambda_self: float = 1.0
    lambda_positive: float = 3.0
    lambda_margin: float = 0.0  # Disabled (redundant with InfoNCE)
    lambda_flops: float = 0.003
    lambda_min_act: float = 1.0
    lambda_kd: float = 1.0  # Base KD weight (adjusted per phase)
    
    # Loss hyperparameters
    temperature: float = 0.08
    margin: float = 0.3
    top_k: int = 5
    min_activation: float = 0.5
    kd_temperature: float = 2.0
    
    # IDF configuration
    use_idf_weighting: bool = True
    idf_alpha: float = 2.0
    
    # Knowledge distillation
    teacher_model_name: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
    kd_warmup_epochs: int = 0  # Start KD from epoch 1
    
    # Checkpointing
    save_every_n_epochs: int = 5
    keep_last_n_checkpoints: int = 3
    
    # Early stopping
    early_stopping_patience: int = 5
    early_stopping_min_delta: float = 0.001
    
    # Logging
    log_every_n_steps: int = 50
    
    # Curriculum phases
    phases: List[CurriculumPhaseV23] = field(default_factory=list)
    
    # Paths
    data_dir: Path = None
    output_dir: Path = None
    idf_weights_path: Path = None
    
    def __post_init__(self):
        if not self.phases:
            # v22.1 Curriculum: Temperature annealing + KD weight decay
            self.phases = [
                CurriculumPhaseV23(
                    name="phase1_foundation",
                    start_epoch=1,
                    end_epoch=7,
                    temperature=0.08,
                    lambda_kd=1.5,  # Strong teacher guidance
                    lr_multiplier=1.0,
                    data_file="phase1_single_term_focus_triplets.jsonl",
                    description="Foundation with strong teacher guidance",
                ),
                CurriculumPhaseV23(
                    name="phase2_balanced",
                    start_epoch=8,
                    end_epoch=14,
                    temperature=0.05,
                    lambda_kd=1.0,  # Balanced teacher guidance
                    lr_multiplier=0.5,
                    data_file="phase2_balanced_triplets.jsonl",
                    description="Balanced multi-type training",
                ),
                CurriculumPhaseV23(
                    name="phase3_refinement",
                    start_epoch=15,
                    end_epoch=20,
                    temperature=0.04,
                    lambda_kd=0.5,  # Reduced teacher guidance
                    lr_multiplier=0.25,
                    data_file="phase3_full_triplets.jsonl",
                    description="Hard negative refinement with student independence",
                ),
            ]


# Create configuration
config = TrainingConfigV23(
    data_dir=PROJECT_ROOT / "data" / "v22.1",
    output_dir=PROJECT_ROOT / "outputs" / "v22.1",
    idf_weights_path=PROJECT_ROOT / "data" / "v22.1" / "idf_weights.pt",
)
config.output_dir.mkdir(parents=True, exist_ok=True)

print("v22.1 Configuration (SPLADELossV23):")
print(f"  Model: {config.model_name}")
print(f"  Max length: {config.max_length}")
print(f"  Epochs: {config.num_epochs}")
print(f"  Batch size: {config.batch_size}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Mixed precision: {config.mixed_precision}")
print(f"\nLoss Weights:")
print(f"  lambda_infonce: {config.lambda_infonce}")
print(f"  lambda_self: {config.lambda_self}")
print(f"  lambda_positive: {config.lambda_positive}")
print(f"  lambda_flops: {config.lambda_flops}")
print(f"  lambda_min_act: {config.lambda_min_act}")
print(f"  lambda_kd (base): {config.lambda_kd}")
print(f"\nIDF Configuration:")
print(f"  use_idf_weighting: {config.use_idf_weighting}")
print(f"  idf_alpha: {config.idf_alpha}")
print(f"\nTeacher Model: {config.teacher_model_name}")
print(f"\nCurriculum Phases:")
for phase in config.phases:
    print(f"  {phase.name}: epochs {phase.start_epoch}-{phase.end_epoch}")
    print(f"    temp={phase.temperature}, lambda_kd={phase.lambda_kd}, lr_mult={phase.lr_multiplier}")
    print(f"    {phase.description}")

## 2. Model Definition

In [None]:
class SPLADEModelV23(nn.Module):
    """
    SPLADE model for v22.1 with vocabulary expansion support.
    
    Uses MLM head to enable activation of any vocabulary token,
    not just input tokens.
    """
    
    def __init__(
        self,
        model_name: str = "skt/kobert-base-v1",
        use_expansion: bool = True,
        expansion_mode: str = "mlm",
        dropout: float = 0.1,
    ):
        super().__init__()
        self.use_expansion = use_expansion
        self.expansion_mode = expansion_mode
        
        if use_expansion and expansion_mode == "mlm":
            self.model = AutoModelForMaskedLM.from_pretrained(model_name)
            self.config = self.model.config
        else:
            from transformers import AutoModel
            self.model = AutoModel.from_pretrained(model_name)
            self.config = self.model.config
            # Token importance predictor for non-expansion mode
            self.token_importance = nn.Sequential(
                nn.Linear(self.config.hidden_size, self.config.hidden_size),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(self.config.hidden_size, 1),
            )
        
        self.relu = nn.ReLU()
        
        # Enable gradient checkpointing for memory efficiency
        if hasattr(self.model, "gradient_checkpointing_enable"):
            self.model.gradient_checkpointing_enable()
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass.
        
        Args:
            input_ids: [batch_size, seq_len]
            attention_mask: [batch_size, seq_len]
        
        Returns:
            sparse_repr: [batch_size, vocab_size]
            token_weights: [batch_size, seq_len]
        """
        if self.use_expansion and self.expansion_mode == "mlm":
            # MLM-based expansion: logits over full vocabulary
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            logits = outputs.logits  # [batch, seq_len, vocab_size]
            
            # SPLADE transformation: log(1 + ReLU(x))
            token_scores = torch.log1p(self.relu(logits))
            
            # Apply attention mask
            mask = attention_mask.unsqueeze(-1).float()
            token_scores = token_scores * mask
            
            # Max pooling over sequence length
            sparse_repr, _ = token_scores.max(dim=1)
            
            # Token weights for analysis
            token_weights = token_scores.max(dim=-1).values
        else:
            # Standard SPLADE: only input tokens activated
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            hidden_states = outputs.last_hidden_state
            
            # Predict token importance
            importance_scores = self.token_importance(hidden_states).squeeze(-1)
            token_weights = torch.log1p(self.relu(importance_scores))
            token_weights = token_weights * attention_mask.float()
            
            # Create sparse representation
            batch_size, seq_len = input_ids.shape
            vocab_size = self.config.vocab_size
            
            one_hot = torch.zeros(
                batch_size, seq_len, vocab_size,
                device=input_ids.device,
                dtype=token_weights.dtype
            ).scatter(2, input_ids.unsqueeze(-1), 1)
            
            mask = attention_mask.unsqueeze(-1).float()
            masked_weights = token_weights.unsqueeze(-1) * one_hot * mask
            sparse_repr, _ = masked_weights.max(dim=1)
        
        return sparse_repr, token_weights
    
    def encode_documents(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Encode documents to sparse representations."""
        sparse_repr, _ = self.forward(input_ids, attention_mask)
        return sparse_repr


# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SPLADEModelV23(
    model_name=config.model_name,
    use_expansion=config.use_expansion,
    expansion_mode=config.expansion_mode,
)
model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)

print(f"Model loaded: {config.model_name}")
print(f"Device: {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Vocab size: {model.config.vocab_size:,}")
print(f"Hidden size: {model.config.hidden_size}")
print(f"Use expansion: {config.use_expansion}")

## 3. Load IDF Weights

In [None]:
def load_or_create_idf_weights(
    idf_path: Path,
    vocab_size: int,
    tokenizer: AutoTokenizer,
    corpus_path: Optional[Path] = None,
) -> torch.Tensor:
    """
    Load pre-computed IDF weights or create default ones.
    
    Args:
        idf_path: Path to saved IDF weights
        vocab_size: Vocabulary size
        tokenizer: Tokenizer for corpus processing
        corpus_path: Optional path to corpus for IDF computation
    
    Returns:
        IDF weights tensor [vocab_size]
    """
    if idf_path.exists():
        print(f"Loading IDF weights from {idf_path}")
        idf_weights = torch.load(idf_path, map_location="cpu")
        if idf_weights.shape[0] != vocab_size:
            print(f"Warning: IDF weights size mismatch. Expected {vocab_size}, got {idf_weights.shape[0]}")
            print("Creating default uniform weights...")
            idf_weights = torch.ones(vocab_size)
    elif corpus_path is not None and corpus_path.exists():
        print(f"Computing IDF weights from corpus: {corpus_path}")
        # Load corpus
        corpus = []
        with open(corpus_path, "r", encoding="utf-8") as f:
            for line in f:
                data = json.loads(line)
                if "text" in data:
                    corpus.append(data["text"])
                elif "anchor" in data:
                    corpus.append(data["anchor"])
                    if "positive" in data:
                        corpus.append(data["positive"])
        
        print(f"Corpus size: {len(corpus):,} documents")
        
        # Compute IDF
        idf_weights = IDFAwareFLOPSLoss.compute_idf_from_corpus(
            corpus=corpus,
            tokenizer=tokenizer,
            smoothing="bm25",
        )
        
        # Save for future use
        torch.save(idf_weights, idf_path)
        print(f"Saved IDF weights to {idf_path}")
    else:
        print("No IDF weights found. Creating default uniform weights...")
        print("Note: For better results, pre-compute IDF weights from your corpus.")
        idf_weights = torch.ones(vocab_size)
    
    return idf_weights


# Load IDF weights
if config.use_idf_weighting:
    idf_weights = load_or_create_idf_weights(
        idf_path=config.idf_weights_path,
        vocab_size=model.config.vocab_size,
        tokenizer=tokenizer,
        corpus_path=config.data_dir / "corpus.jsonl" if config.data_dir.exists() else None,
    )
    idf_weights = idf_weights.to(device)
    
    # Analyze IDF distribution
    print(f"\nIDF Statistics:")
    print(f"  Min: {idf_weights.min().item():.4f}")
    print(f"  Max: {idf_weights.max().item():.4f}")
    print(f"  Mean: {idf_weights.mean().item():.4f}")
    print(f"  Std: {idf_weights.std().item():.4f}")
else:
    idf_weights = None
    print("IDF weighting disabled.")

## 4. Initialize Dense Teacher for Knowledge Distillation

In [None]:
# Initialize dense teacher model for knowledge distillation
print(f"Initializing dense teacher model: {config.teacher_model_name}")
teacher_model = DenseTeacherScorer(
    model_name=config.teacher_model_name,
    device=str(device),
)

# Test teacher model
test_texts = ["machine learning", "deep learning", "natural language processing"]
with torch.no_grad():
    test_scores = teacher_model.compute_scores(test_texts[:2], test_texts)
    print(f"\nTeacher model test (similarity matrix):")
    print(test_scores)

print(f"\nDense teacher model ready.")

## 5. Initialize SPLADELossV23

In [None]:
# Initialize SPLADELossV23 with all components
criterion = SPLADELossV23(
    # Loss weights
    lambda_infonce=config.lambda_infonce,
    lambda_self=config.lambda_self,
    lambda_positive=config.lambda_positive,
    lambda_margin=config.lambda_margin,
    lambda_flops=config.lambda_flops,
    lambda_min_act=config.lambda_min_act,
    lambda_kd=config.lambda_kd,
    # Hyperparameters
    temperature=config.temperature,
    margin=config.margin,
    top_k=config.top_k,
    min_activation=config.min_activation,
    kd_temperature=config.kd_temperature,
    # IDF configuration
    vocab_size=model.config.vocab_size,
    idf_weights=idf_weights,
    idf_alpha=config.idf_alpha,
    # Teacher model
    teacher_model=teacher_model,
)

print("SPLADELossV23 initialized with:")
print(f"  InfoNCE (lambda={config.lambda_infonce}, temp={config.temperature})")
print(f"  Self-reconstruction (lambda={config.lambda_self})")
print(f"  Positive activation (lambda={config.lambda_positive})")
print(f"  Triplet margin (lambda={config.lambda_margin}, margin={config.margin})")
print(f"  IDF-aware FLOPS (lambda={config.lambda_flops}, alpha={config.idf_alpha})")
print(f"  Minimum activation (lambda={config.lambda_min_act})")
print(f"  Knowledge distillation (lambda={config.lambda_kd}, temp={config.kd_temperature})")

## 6. Dataset

In [None]:
class TripletDatasetV23(Dataset):
    """
    Dataset for triplet training with text preservation for KD.
    """
    
    def __init__(
        self,
        data_path: Path,
        tokenizer: AutoTokenizer,
        max_length: int = 128,
        return_raw_text: bool = True,
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.return_raw_text = return_raw_text
        self.data = []
        
        if not data_path.exists():
            print(f"Warning: Data file not found: {data_path}")
            return
        
        with open(data_path, "r", encoding="utf-8") as f:
            for line in f:
                self.data.append(json.loads(line))
        
        print(f"Loaded {len(self.data):,} triplets from {data_path.name}")
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict[str, Any]:
        item = self.data[idx]
        
        anchor_text = item["anchor"]
        positive_text = item["positive"]
        negative_text = item["negative"]
        
        # Tokenize
        anchor = self.tokenizer(
            anchor_text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        positive = self.tokenizer(
            positive_text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        negative = self.tokenizer(
            negative_text,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        
        result = {
            "anchor_input_ids": anchor["input_ids"].squeeze(0),
            "anchor_attention_mask": anchor["attention_mask"].squeeze(0),
            "positive_input_ids": positive["input_ids"].squeeze(0),
            "positive_attention_mask": positive["attention_mask"].squeeze(0),
            "negative_input_ids": negative["input_ids"].squeeze(0),
            "negative_attention_mask": negative["attention_mask"].squeeze(0),
        }
        
        if self.return_raw_text:
            result["anchor_text"] = anchor_text
            result["positive_text"] = positive_text
            result["negative_text"] = negative_text
        
        return result


def collate_fn_with_text(batch: List[Dict]) -> Dict[str, Any]:
    """Custom collate function that handles text fields."""
    result = {}
    
    # Stack tensor fields
    tensor_keys = [
        "anchor_input_ids", "anchor_attention_mask",
        "positive_input_ids", "positive_attention_mask",
        "negative_input_ids", "negative_attention_mask",
    ]
    for key in tensor_keys:
        result[key] = torch.stack([item[key] for item in batch])
    
    # Collect text fields as lists
    text_keys = ["anchor_text", "positive_text", "negative_text"]
    for key in text_keys:
        if key in batch[0]:
            result[key] = [item[key] for item in batch]
    
    return result


# Load validation data
val_data_path = config.data_dir / "validation_triplets.jsonl"
if val_data_path.exists():
    val_dataset = TripletDatasetV23(
        val_data_path,
        tokenizer,
        config.max_length,
        return_raw_text=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        collate_fn=collate_fn_with_text,
    )
    print(f"Validation loader: {len(val_loader)} batches")
else:
    val_loader = None
    print("Warning: Validation data not found. Skipping validation.")

## 7. Training Functions

In [None]:
def train_epoch_v23(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: Any,
    criterion: SPLADELossV23,
    config: TrainingConfigV23,
    phase: CurriculumPhaseV23,
    device: torch.device,
    epoch: int,
    scaler: Optional[GradScaler] = None,
) -> Dict[str, float]:
    """
    Train for one epoch with SPLADELossV23.
    
    Args:
        model: SPLADE model
        dataloader: Training data loader
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        criterion: SPLADELossV23 instance
        config: Training configuration
        phase: Current curriculum phase
        device: Device to use
        epoch: Current epoch number
        scaler: Gradient scaler for mixed precision
    
    Returns:
        Dictionary of average loss values
    """
    model.train()
    
    total_loss = 0.0
    loss_components = defaultdict(float)
    num_batches = 0
    
    # Update criterion with phase-specific settings
    criterion.update_temperature(phase.temperature)
    criterion.update_weights(lambda_kd=phase.lambda_kd)
    
    # Determine dtype for mixed precision
    use_amp = config.mixed_precision in ["bf16", "fp16"]
    amp_dtype = torch.bfloat16 if config.mixed_precision == "bf16" else torch.float16
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch} ({phase.name})")
    
    for step, batch in enumerate(pbar):
        # Move tensors to device
        batch_tensors = {
            k: v.to(device) for k, v in batch.items()
            if isinstance(v, torch.Tensor)
        }
        
        # Get raw texts for KD
        anchor_texts = batch.get("anchor_text")
        positive_texts = batch.get("positive_text")
        
        # Forward pass with mixed precision
        with autocast(enabled=use_amp, dtype=amp_dtype):
            # Compute sparse representations
            anchor_repr, _ = model(
                batch_tensors["anchor_input_ids"],
                batch_tensors["anchor_attention_mask"],
            )
            positive_repr, _ = model(
                batch_tensors["positive_input_ids"],
                batch_tensors["positive_attention_mask"],
            )
            negative_repr, _ = model(
                batch_tensors["negative_input_ids"],
                batch_tensors["negative_attention_mask"],
            )
            
            # Compute loss
            loss, loss_dict = criterion(
                anchor_repr=anchor_repr,
                positive_repr=positive_repr,
                negative_repr=negative_repr,
                anchor_input_ids=batch_tensors["anchor_input_ids"],
                anchor_attention_mask=batch_tensors["anchor_attention_mask"],
                positive_input_ids=batch_tensors["positive_input_ids"],
                positive_attention_mask=batch_tensors["positive_attention_mask"],
                anchor_texts=anchor_texts,
                positive_texts=positive_texts,
            )
        
        # Backward pass
        if scaler is not None:
            scaler.scale(loss).backward()
            if (step + 1) % config.gradient_accumulation_steps == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
        else:
            loss.backward()
            if (step + 1) % config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
        
        # Track losses
        total_loss += loss_dict["total"]
        for key, value in loss_dict.items():
            loss_components[key] += value
        num_batches += 1
        
        # Update progress bar
        if step % config.log_every_n_steps == 0:
            pbar.set_postfix({
                "loss": f"{loss_dict['total']:.4f}",
                "infonce": f"{loss_dict['infonce']:.4f}",
                "kd": f"{loss_dict['kd']:.4f}",
                "lr": f"{scheduler.get_last_lr()[0]:.2e}",
            })
    
    # Compute averages
    avg_losses = {k: v / num_batches for k, v in loss_components.items()}
    avg_losses["total"] = total_loss / num_batches
    
    return avg_losses


@torch.no_grad()
def evaluate_v23(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: SPLADELossV23,
    device: torch.device,
) -> Dict[str, float]:
    """
    Evaluate model on validation set.
    
    Args:
        model: SPLADE model
        dataloader: Validation data loader
        criterion: SPLADELossV23 instance
        device: Device to use
    
    Returns:
        Dictionary of evaluation metrics
    """
    model.eval()
    
    total_loss = 0.0
    loss_components = defaultdict(float)
    all_recalls = []
    all_mrrs = []
    num_batches = 0
    
    for batch in tqdm(dataloader, desc="Evaluating"):
        # Move tensors to device
        batch_tensors = {
            k: v.to(device) for k, v in batch.items()
            if isinstance(v, torch.Tensor)
        }
        
        anchor_texts = batch.get("anchor_text")
        positive_texts = batch.get("positive_text")
        
        # Forward pass
        anchor_repr, _ = model(
            batch_tensors["anchor_input_ids"],
            batch_tensors["anchor_attention_mask"],
        )
        positive_repr, _ = model(
            batch_tensors["positive_input_ids"],
            batch_tensors["positive_attention_mask"],
        )
        negative_repr, _ = model(
            batch_tensors["negative_input_ids"],
            batch_tensors["negative_attention_mask"],
        )
        
        # Compute loss
        loss, loss_dict = criterion(
            anchor_repr=anchor_repr,
            positive_repr=positive_repr,
            negative_repr=negative_repr,
            anchor_input_ids=batch_tensors["anchor_input_ids"],
            anchor_attention_mask=batch_tensors["anchor_attention_mask"],
            positive_input_ids=batch_tensors["positive_input_ids"],
            positive_attention_mask=batch_tensors["positive_attention_mask"],
            anchor_texts=anchor_texts,
            positive_texts=positive_texts,
        )
        
        total_loss += loss_dict["total"]
        for key, value in loss_dict.items():
            loss_components[key] += value
        num_batches += 1
        
        # Compute retrieval metrics
        pos_sim = F.cosine_similarity(anchor_repr, positive_repr, dim=-1)
        neg_sim = F.cosine_similarity(anchor_repr, negative_repr, dim=-1)
        
        # Recall: positive should rank higher than negative
        recalls = (pos_sim > neg_sim).float()
        all_recalls.extend(recalls.cpu().tolist())
        
        # MRR
        for i in range(len(pos_sim)):
            if pos_sim[i] > neg_sim[i]:
                all_mrrs.append(1.0)
            else:
                all_mrrs.append(0.5)
    
    # Compute averages
    avg_losses = {k: v / num_batches for k, v in loss_components.items()}
    avg_losses["total"] = total_loss / num_batches
    
    return {
        **avg_losses,
        "recall": np.mean(all_recalls) * 100,
        "mrr": np.mean(all_mrrs),
    }

## 8. Curriculum Training Loop

In [None]:
class EarlyStopping:
    """Early stopping handler."""
    
    def __init__(self, patience: int = 5, min_delta: float = 0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.should_stop = False
    
    def __call__(self, val_loss: float) -> bool:
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0
        return self.should_stop


def run_curriculum_training_v23(
    model: nn.Module,
    config: TrainingConfigV23,
    criterion: SPLADELossV23,
    device: torch.device,
) -> List[Dict[str, Any]]:
    """
    Run full curriculum training with SPLADELossV23.
    
    Args:
        model: SPLADE model
        config: Training configuration
        criterion: SPLADELossV23 instance
        device: Device to use
    
    Returns:
        Training history
    """
    training_history = []
    best_recall = 0.0
    best_val_loss = float("inf")
    
    # Initialize gradient scaler for mixed precision
    use_amp = config.mixed_precision in ["bf16", "fp16"]
    scaler = GradScaler() if use_amp and config.mixed_precision == "fp16" else None
    
    # Early stopping
    early_stopping = EarlyStopping(
        patience=config.early_stopping_patience,
        min_delta=config.early_stopping_min_delta,
    )
    
    for phase in config.phases:
        print(f"\n{'=' * 70}")
        print(f"Starting {phase.name}: Epochs {phase.start_epoch}-{phase.end_epoch}")
        print(f"Temperature: {phase.temperature}, lambda_kd: {phase.lambda_kd}")
        print(f"LR multiplier: {phase.lr_multiplier}")
        print(f"Description: {phase.description}")
        print(f"{'=' * 70}")
        
        # Load phase-specific data
        train_data_path = config.data_dir / phase.data_file
        if not train_data_path.exists():
            print(f"Warning: Training data not found: {train_data_path}")
            print("Skipping this phase...")
            continue
        
        train_dataset = TripletDatasetV23(
            train_data_path,
            tokenizer,
            config.max_length,
            return_raw_text=True,
        )
        
        if len(train_dataset) == 0:
            print(f"Warning: Empty training dataset for {phase.name}")
            continue
        
        train_loader = DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            drop_last=True,
            collate_fn=collate_fn_with_text,
            num_workers=4,
            pin_memory=True,
        )
        
        # Setup optimizer and scheduler for this phase
        phase_epochs = phase.end_epoch - phase.start_epoch + 1
        total_steps = len(train_loader) * phase_epochs
        warmup_steps = int(total_steps * config.warmup_ratio)
        
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate * phase.lr_multiplier,
            weight_decay=config.weight_decay,
        )
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps,
        )
        
        print(f"\nPhase training configuration:")
        print(f"  Training samples: {len(train_dataset):,}")
        print(f"  Batches per epoch: {len(train_loader):,}")
        print(f"  Total steps: {total_steps:,}")
        print(f"  Warmup steps: {warmup_steps:,}")
        print(f"  Learning rate: {config.learning_rate * phase.lr_multiplier:.2e}")
        
        # Train for phase epochs
        for epoch in range(phase.start_epoch, phase.end_epoch + 1):
            # Training
            train_metrics = train_epoch_v23(
                model=model,
                dataloader=train_loader,
                optimizer=optimizer,
                scheduler=scheduler,
                criterion=criterion,
                config=config,
                phase=phase,
                device=device,
                epoch=epoch,
                scaler=scaler,
            )
            
            # Evaluation
            if val_loader is not None:
                eval_metrics = evaluate_v23(model, val_loader, criterion, device)
            else:
                eval_metrics = {"recall": 0.0, "mrr": 0.0, "val_total": 0.0}
            
            # Log results
            print(f"\nEpoch {epoch}:")
            print(f"  Train - Loss: {train_metrics['total']:.4f}, "
                  f"InfoNCE: {train_metrics['infonce']:.4f}, "
                  f"KD: {train_metrics['kd']:.4f}")
            print(f"  Train - Self: {train_metrics['self']:.4f}, "
                  f"Positive: {train_metrics['positive']:.4f}, "
                  f"FLOPS: {train_metrics['flops']:.4f}")
            if val_loader is not None:
                print(f"  Val   - Loss: {eval_metrics['total']:.4f}, "
                      f"Recall: {eval_metrics['recall']:.1f}%, "
                      f"MRR: {eval_metrics['mrr']:.4f}")
            
            # Save history
            history_entry = {
                "epoch": epoch,
                "phase": phase.name,
                "temperature": phase.temperature,
                "lambda_kd": phase.lambda_kd,
                "lr": scheduler.get_last_lr()[0],
                "train": train_metrics,
                "eval": eval_metrics,
            }
            training_history.append(history_entry)
            
            # Save best model
            current_recall = eval_metrics.get("recall", 0.0)
            if current_recall > best_recall:
                best_recall = current_recall
                torch.save({
                    "epoch": epoch,
                    "phase": phase.name,
                    "model_state_dict": model.state_dict(),
                    "eval_results": eval_metrics,
                    "config": {
                        "model_name": config.model_name,
                        "max_length": config.max_length,
                        "vocab_size": model.config.vocab_size,
                        "use_expansion": config.use_expansion,
                        "version": "v22.1",
                    },
                }, config.output_dir / "best_model.pt")
                print(f"  -> New best model saved! Recall: {best_recall:.1f}%")
            
            # Save periodic checkpoint
            if epoch % config.save_every_n_epochs == 0:
                checkpoint_path = config.output_dir / f"checkpoint_epoch_{epoch}.pt"
                torch.save({
                    "epoch": epoch,
                    "phase": phase.name,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "scheduler_state_dict": scheduler.state_dict(),
                    "train_metrics": train_metrics,
                    "eval_metrics": eval_metrics,
                }, checkpoint_path)
                print(f"  -> Checkpoint saved: {checkpoint_path.name}")
            
            # Early stopping check
            val_loss = eval_metrics.get("total", train_metrics["total"])
            if early_stopping(val_loss):
                print(f"\nEarly stopping triggered at epoch {epoch}")
                break
        
        # Save phase checkpoint
        phase_checkpoint_path = config.output_dir / f"{phase.name}_checkpoint.pt"
        torch.save({
            "epoch": phase.end_epoch,
            "phase": phase.name,
            "model_state_dict": model.state_dict(),
        }, phase_checkpoint_path)
        print(f"\n{phase.name} checkpoint saved: {phase_checkpoint_path.name}")
        
        # Reset early stopping for next phase
        early_stopping = EarlyStopping(
            patience=config.early_stopping_patience,
            min_delta=config.early_stopping_min_delta,
        )
    
    return training_history

## 9. Run Training

In [None]:
# Print training summary before starting
print("v22.1 Training Summary")
print("=" * 70)
print(f"Model: {config.model_name}")
print(f"Output directory: {config.output_dir}")
print(f"Mixed precision: {config.mixed_precision}")
print(f"\nData distribution:")
for phase in config.phases:
    data_path = config.data_dir / phase.data_file
    if data_path.exists():
        with open(data_path) as f:
            count = sum(1 for _ in f)
        print(f"  {phase.name}: {count:,} triplets")
    else:
        print(f"  {phase.name}: FILE NOT FOUND")

print(f"\nStarting training...")
print("=" * 70)

# Run training
start_time = datetime.now()
history = run_curriculum_training_v23(model, config, criterion, device)
end_time = datetime.now()

print("\n" + "=" * 70)
print("Training Complete!")
print(f"Total time: {end_time - start_time}")
print("=" * 70)

## 10. Training Summary and Visualization

In [None]:
import matplotlib.pyplot as plt

if len(history) > 0:
    # Extract data for plotting
    epochs = [h["epoch"] for h in history]
    train_losses = [h["train"]["total"] for h in history]
    infonce_losses = [h["train"]["infonce"] for h in history]
    kd_losses = [h["train"]["kd"] for h in history]
    flops_losses = [h["train"]["flops"] for h in history]
    recalls = [h["eval"]["recall"] for h in history]
    mrrs = [h["eval"]["mrr"] for h in history]
    temperatures = [h["temperature"] for h in history]
    lambda_kds = [h["lambda_kd"] for h in history]
    
    # Create figure with subplots
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    
    # Phase boundaries
    phase_boundaries = [7.5, 14.5]  # Between phases
    
    # 1. Total Loss
    axes[0, 0].plot(epochs, train_losses, 'b-', linewidth=2)
    axes[0, 0].set_xlabel("Epoch")
    axes[0, 0].set_ylabel("Total Loss")
    axes[0, 0].set_title("Total Training Loss")
    for b in phase_boundaries:
        axes[0, 0].axvline(x=b, color='r', linestyle='--', alpha=0.5)
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. InfoNCE Loss
    axes[0, 1].plot(epochs, infonce_losses, 'purple', linewidth=2)
    axes[0, 1].set_xlabel("Epoch")
    axes[0, 1].set_ylabel("InfoNCE Loss")
    axes[0, 1].set_title("InfoNCE Contrastive Loss")
    for b in phase_boundaries:
        axes[0, 1].axvline(x=b, color='r', linestyle='--', alpha=0.5)
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. Knowledge Distillation Loss
    axes[0, 2].plot(epochs, kd_losses, 'orange', linewidth=2)
    axes[0, 2].set_xlabel("Epoch")
    axes[0, 2].set_ylabel("KD Loss")
    axes[0, 2].set_title("Knowledge Distillation Loss")
    for b in phase_boundaries:
        axes[0, 2].axvline(x=b, color='r', linestyle='--', alpha=0.5)
    axes[0, 2].grid(True, alpha=0.3)
    
    # 4. Recall
    axes[0, 3].plot(epochs, recalls, 'g-', linewidth=2, marker='o', markersize=4)
    axes[0, 3].set_xlabel("Epoch")
    axes[0, 3].set_ylabel("Recall (%)")
    axes[0, 3].set_title("Validation Recall")
    for b in phase_boundaries:
        axes[0, 3].axvline(x=b, color='r', linestyle='--', alpha=0.5)
    axes[0, 3].grid(True, alpha=0.3)
    
    # 5. MRR
    axes[1, 0].plot(epochs, mrrs, 'c-', linewidth=2, marker='o', markersize=4)
    axes[1, 0].set_xlabel("Epoch")
    axes[1, 0].set_ylabel("MRR")
    axes[1, 0].set_title("Validation MRR")
    for b in phase_boundaries:
        axes[1, 0].axvline(x=b, color='r', linestyle='--', alpha=0.5)
    axes[1, 0].grid(True, alpha=0.3)
    
    # 6. FLOPS Loss
    axes[1, 1].plot(epochs, flops_losses, 'brown', linewidth=2)
    axes[1, 1].set_xlabel("Epoch")
    axes[1, 1].set_ylabel("FLOPS Loss")
    axes[1, 1].set_title("IDF-Aware FLOPS Loss")
    for b in phase_boundaries:
        axes[1, 1].axvline(x=b, color='r', linestyle='--', alpha=0.5)
    axes[1, 1].grid(True, alpha=0.3)
    
    # 7. Temperature Schedule
    axes[1, 2].plot(epochs, temperatures, 'red', linewidth=2, marker='s', markersize=4)
    axes[1, 2].set_xlabel("Epoch")
    axes[1, 2].set_ylabel("Temperature")
    axes[1, 2].set_title("Temperature Annealing")
    for b in phase_boundaries:
        axes[1, 2].axvline(x=b, color='r', linestyle='--', alpha=0.5)
    axes[1, 2].grid(True, alpha=0.3)
    
    # 8. KD Weight Schedule
    axes[1, 3].plot(epochs, lambda_kds, 'magenta', linewidth=2, marker='s', markersize=4)
    axes[1, 3].set_xlabel("Epoch")
    axes[1, 3].set_ylabel("lambda_kd")
    axes[1, 3].set_title("Knowledge Distillation Weight")
    for b in phase_boundaries:
        axes[1, 3].axvline(x=b, color='r', linestyle='--', alpha=0.5)
    axes[1, 3].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(config.output_dir / "training_curves_v22.1.png", dpi=150)
    plt.show()
    
    # Print final summary
    print("\nv22.1 Final Results:")
    print(f"  Best Recall: {max(recalls):.1f}%")
    print(f"  Best MRR: {max(mrrs):.4f}")
    print(f"  Final Loss: {train_losses[-1]:.4f}")
    print(f"  Final InfoNCE: {infonce_losses[-1]:.4f}")
    print(f"  Final KD Loss: {kd_losses[-1]:.4f}")
else:
    print("No training history to visualize.")

## 11. Save Final Model and Artifacts

In [None]:
# Save final checkpoint
if len(history) > 0:
    final_checkpoint = {
        "model_state_dict": model.state_dict(),
        "config": {
            "model_name": config.model_name,
            "max_length": config.max_length,
            "vocab_size": model.config.vocab_size,
            "hidden_size": model.config.hidden_size,
            "use_expansion": config.use_expansion,
            "expansion_mode": config.expansion_mode,
            "version": "v22.1",
        },
        "loss_config": {
            "lambda_infonce": config.lambda_infonce,
            "lambda_self": config.lambda_self,
            "lambda_positive": config.lambda_positive,
            "lambda_margin": config.lambda_margin,
            "lambda_flops": config.lambda_flops,
            "lambda_min_act": config.lambda_min_act,
            "lambda_kd": config.lambda_kd,
            "use_idf_weighting": config.use_idf_weighting,
        },
        "training_info": {
            "total_epochs": len(history),
            "final_recall": history[-1]["eval"]["recall"],
            "final_mrr": history[-1]["eval"]["mrr"],
            "best_recall": max(h["eval"]["recall"] for h in history),
            "best_mrr": max(h["eval"]["mrr"] for h in history),
        },
    }
    
    checkpoint_path = config.output_dir / "checkpoint.pt"
    torch.save(final_checkpoint, checkpoint_path)
    print(f"Final checkpoint saved: {checkpoint_path}")
    
    # Save tokenizer
    tokenizer_path = config.output_dir / "tokenizer"
    tokenizer.save_pretrained(tokenizer_path)
    print(f"Tokenizer saved: {tokenizer_path}")
    
    # Save training history
    history_path = config.output_dir / "training_history.json"
    with open(history_path, "w", encoding="utf-8") as f:
        json.dump(history, f, indent=2, ensure_ascii=False)
    print(f"Training history saved: {history_path}")
    
    # Save IDF weights if used
    if config.use_idf_weighting and idf_weights is not None:
        idf_path = config.output_dir / "idf_weights.pt"
        torch.save(idf_weights.cpu(), idf_path)
        print(f"IDF weights saved: {idf_path}")
else:
    print("No training completed. Skipping model save.")

## 12. Upload to S3 (Optional)

In [None]:
# Optional: Upload to S3
UPLOAD_TO_S3 = False  # Set to True to enable S3 upload
S3_BUCKET = "your-bucket-name"
S3_PREFIX = "models/opensearch-neural-sparse/v22.1"

if UPLOAD_TO_S3:
    import boto3
    from botocore.exceptions import ClientError
    
    s3_client = boto3.client("s3", region_name="us-east-1")
    
    files_to_upload = [
        config.output_dir / "checkpoint.pt",
        config.output_dir / "best_model.pt",
        config.output_dir / "training_history.json",
        config.output_dir / "training_curves_v22.1.png",
    ]
    
    # Add tokenizer files
    tokenizer_dir = config.output_dir / "tokenizer"
    if tokenizer_dir.exists():
        for file in tokenizer_dir.iterdir():
            files_to_upload.append(file)
    
    print(f"Uploading to s3://{S3_BUCKET}/{S3_PREFIX}/")
    
    for file_path in files_to_upload:
        if file_path.exists():
            s3_key = f"{S3_PREFIX}/{file_path.name}"
            try:
                s3_client.upload_file(str(file_path), S3_BUCKET, s3_key)
                print(f"  Uploaded: {file_path.name}")
            except ClientError as e:
                print(f"  Failed to upload {file_path.name}: {e}")
    
    print("\nS3 upload complete.")
else:
    print("S3 upload disabled. Set UPLOAD_TO_S3 = True to enable.")

## Next Steps

1. **Run Benchmark**: Test the trained model on benchmark datasets
   ```bash
   python benchmark/run_benchmark.py --model-path outputs/v22.1/checkpoint.pt
   ```

2. **Evaluate on Problem Terms**: Test specific Korean terms that were problematic before

3. **Compare with v22.0**: Analyze improvements from IDF-aware FLOPS and knowledge distillation

4. **Index to OpenSearch**: Deploy the model to OpenSearch cluster for production testing