# v21.4 Training with Curriculum Learning

## Key Improvements over v21.3

1. **Dynamic Lambda Self**: 8.0 for single terms → 4.0 for sentences
2. **Minimum Activation Loss**: Prevents garbage outputs
3. **Curriculum Learning**: 3 phases (single-term focus → balanced → full)
4. **Adjusted FLOPS**: 3e-3 → 5e-3 (phased increase)

## Target Metrics

| Metric | v21.3 | v21.4 Target |
|--------|-------|--------------|
| Single-term Recall | 63.1% | 80%+ |
| Garbage Ratio | ~15% | < 5% |
| Sparsity | 95.55% | 95%+ |

In [1]:
import sys
from pathlib import Path

def find_project_root():
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForMaskedLM, get_linear_schedule_with_warmup
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from tqdm.auto import tqdm
from collections import defaultdict

# Set seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print(f"Project root: {PROJECT_ROOT}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Project root: /home/west/Documents/cursor-workspace/opensearch-neural-pre-train
PyTorch version: 2.10.0.dev20251109+cu130
CUDA available: True
GPU: NVIDIA GB10


    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    
  queued_call()


## 1. Configuration

In [2]:
@dataclass
class CurriculumPhase:
    """Configuration for a curriculum learning phase."""
    name: str
    start_epoch: int
    end_epoch: int
    lambda_flops: float
    lambda_min_activation: float
    lr_multiplier: float
    data_file: str


@dataclass
class TrainingConfig:
    """Training configuration."""
    # Model
    model_name: str = "skt/A.X-Encoder-base"
    max_length: int = 64
    
    # Training
    total_epochs: int = 30
    batch_size: int = 64
    learning_rate: float = 3e-6
    warmup_ratio: float = 0.1
    max_grad_norm: float = 1.0
    
    # Loss weights (static)
    lambda_synonym: float = 10.0
    lambda_margin: float = 2.5
    target_margin: float = 1.5
    
    # Dynamic lambda_self
    lambda_self_min: float = 4.0
    lambda_self_max: float = 8.0
    lambda_self_decay_tokens: int = 10
    
    # Minimum activation
    min_activation_k: int = 5
    
    # Curriculum phases
    phases: List[CurriculumPhase] = field(default_factory=list)
    
    # Paths
    data_dir: Path = None
    output_dir: Path = None
    
    def __post_init__(self):
        if not self.phases:
            self.phases = [
                CurriculumPhase(
                    name="phase1_single_term",
                    start_epoch=1,
                    end_epoch=10,
                    lambda_flops=3e-3,
                    lambda_min_activation=2.0,
                    lr_multiplier=1.0,
                    data_file="phase1_single_term_focus_triplets.jsonl",
                ),
                CurriculumPhase(
                    name="phase2_balanced",
                    start_epoch=11,
                    end_epoch=20,
                    lambda_flops=4e-3,
                    lambda_min_activation=1.5,
                    lr_multiplier=0.5,
                    data_file="phase2_balanced_triplets.jsonl",
                ),
                CurriculumPhase(
                    name="phase3_full",
                    start_epoch=21,
                    end_epoch=30,
                    lambda_flops=5e-3,
                    lambda_min_activation=1.0,
                    lr_multiplier=0.25,
                    data_file="phase3_full_triplets.jsonl",
                ),
            ]


# Create config
config = TrainingConfig(
    data_dir=PROJECT_ROOT / "data" / "v21.4",
    output_dir=PROJECT_ROOT / "outputs" / "v21.4_korean_enhanced",
)
config.output_dir.mkdir(parents=True, exist_ok=True)

print("Configuration:")
print(f"  Model: {config.model_name}")
print(f"  Epochs: {config.total_epochs}")
print(f"  Batch size: {config.batch_size}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Lambda self range: {config.lambda_self_min} - {config.lambda_self_max}")
print(f"\nCurriculum Phases:")
for phase in config.phases:
    print(f"  {phase.name}: epochs {phase.start_epoch}-{phase.end_epoch}, "
          f"lambda_flops={phase.lambda_flops}, lr_mult={phase.lr_multiplier}")

Configuration:
  Model: skt/A.X-Encoder-base
  Epochs: 30
  Batch size: 64
  Learning rate: 3e-06
  Lambda self range: 4.0 - 8.0

Curriculum Phases:
  phase1_single_term: epochs 1-10, lambda_flops=0.003, lr_mult=1.0
  phase2_balanced: epochs 11-20, lambda_flops=0.004, lr_mult=0.5
  phase3_full: epochs 21-30, lambda_flops=0.005, lr_mult=0.25


## 2. Model Definition

In [3]:
class SPLADEModel(nn.Module):
    """SPLADE model for Korean sparse retrieval with v21.4 improvements."""
    
    def __init__(self, model_name: str = "skt/A.X-Encoder-base"):
        super().__init__()
        self.model = AutoModelForMaskedLM.from_pretrained(model_name)
        self.config = self.model.config
        self.relu = nn.ReLU()
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass.
        
        Returns:
            sparse_repr: [batch_size, vocab_size] - max-pooled sparse representation
            token_weights: [batch_size, seq_len] - per-token max weights
        """
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        
        # SPLADE transformation: log(1 + ReLU(x))
        token_scores = torch.log1p(self.relu(logits))
        
        # Apply attention mask
        mask = attention_mask.unsqueeze(-1).float()
        token_scores = token_scores * mask
        
        # Max pooling over sequence length
        sparse_repr, _ = token_scores.max(dim=1)
        
        # Token weights for analysis
        token_weights = token_scores.max(dim=-1).values
        
        return sparse_repr, token_weights


# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SPLADEModel(config.model_name)
model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)

print(f"Model loaded: {config.model_name}")
print(f"Device: {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

Model loaded: skt/A.X-Encoder-base
Device: cuda
Parameters: 149,372,240


## 3. Dataset

In [4]:
class TripletDataset(Dataset):
    """Dataset for triplet training."""
    
    def __init__(self, data_path: Path, tokenizer, max_length: int = 64):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []
        
        with open(data_path, "r", encoding="utf-8") as f:
            for line in f:
                self.data.append(json.loads(line))
        
        print(f"Loaded {len(self.data)} triplets from {data_path.name}")
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict:
        item = self.data[idx]
        
        # Tokenize
        anchor = self.tokenizer(
            item["anchor"],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        positive = self.tokenizer(
            item["positive"],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        negative = self.tokenizer(
            item["negative"],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        
        return {
            "anchor_input_ids": anchor["input_ids"].squeeze(0),
            "anchor_attention_mask": anchor["attention_mask"].squeeze(0),
            "positive_input_ids": positive["input_ids"].squeeze(0),
            "positive_attention_mask": positive["attention_mask"].squeeze(0),
            "negative_input_ids": negative["input_ids"].squeeze(0),
            "negative_attention_mask": negative["attention_mask"].squeeze(0),
        }


# Load validation data
val_dataset = TripletDataset(
    config.data_dir / "validation_triplets.jsonl",
    tokenizer,
    config.max_length,
)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

Loaded 47093 triplets from validation_triplets.jsonl


## 4. Loss Functions

In [5]:
def compute_dynamic_lambda_self(
    attention_mask: torch.Tensor,
    lambda_min: float = 4.0,
    lambda_max: float = 8.0,
    decay_tokens: int = 10,
) -> torch.Tensor:
    """
    Compute per-sample lambda_self based on input length.
    
    Short inputs (1-3 tokens) get higher lambda (8.0)
    Long inputs (10+ tokens) get lower lambda (4.0)
    """
    # Count actual tokens (excluding padding, CLS, SEP)
    token_counts = attention_mask.sum(dim=1) - 2  # subtract CLS and SEP
    token_counts = token_counts.clamp(min=1).float()
    
    # Linear interpolation
    # lambda(n) = max(lambda_min, lambda_max - slope * (n - 1))
    slope = (lambda_max - lambda_min) / (decay_tokens - 1)
    lambda_weights = lambda_max - slope * (token_counts - 1)
    lambda_weights = lambda_weights.clamp(min=lambda_min, max=lambda_max)
    
    return lambda_weights


def compute_self_reconstruction_loss(
    sparse_repr: torch.Tensor,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    lambda_weights: torch.Tensor,
    pad_token_id: int,
    special_token_ids: set,
) -> torch.Tensor:
    """
    Compute weighted self-reconstruction loss.
    
    Encourages the model to activate tokens that appear in the input.
    """
    batch_size = sparse_repr.size(0)
    losses = []
    
    for i in range(batch_size):
        # Get valid token IDs (exclude padding and special tokens)
        valid_mask = (attention_mask[i] == 1)
        token_ids = input_ids[i][valid_mask]
        
        # Filter special tokens
        valid_ids = [tid.item() for tid in token_ids 
                     if tid.item() not in special_token_ids and tid.item() != pad_token_id]
        
        if not valid_ids:
            losses.append(torch.tensor(0.0, device=sparse_repr.device))
            continue
        
        # Get activations for input tokens
        activations = sparse_repr[i, valid_ids]
        
        # Negative mean activation (we want to maximize)
        loss = -activations.mean()
        losses.append(loss * lambda_weights[i])
    
    return torch.stack(losses).mean()


def compute_positive_activation_loss(
    anchor_repr: torch.Tensor,
    positive_input_ids: torch.Tensor,
    positive_attention_mask: torch.Tensor,
    pad_token_id: int,
    special_token_ids: set,
) -> torch.Tensor:
    """
    Encourage anchor to activate positive synonym tokens.
    """
    batch_size = anchor_repr.size(0)
    losses = []
    
    for i in range(batch_size):
        valid_mask = (positive_attention_mask[i] == 1)
        token_ids = positive_input_ids[i][valid_mask]
        
        valid_ids = [tid.item() for tid in token_ids 
                     if tid.item() not in special_token_ids and tid.item() != pad_token_id]
        
        if not valid_ids:
            losses.append(torch.tensor(0.0, device=anchor_repr.device))
            continue
        
        activations = anchor_repr[i, valid_ids]
        loss = -activations.mean()
        losses.append(loss)
    
    return torch.stack(losses).mean()


def compute_triplet_margin_loss(
    anchor_repr: torch.Tensor,
    positive_repr: torch.Tensor,
    negative_repr: torch.Tensor,
    margin: float = 1.5,
) -> torch.Tensor:
    """
    Triplet margin loss with cosine similarity.
    """
    pos_sim = F.cosine_similarity(anchor_repr, positive_repr, dim=-1)
    neg_sim = F.cosine_similarity(anchor_repr, negative_repr, dim=-1)
    
    loss = F.relu(margin - pos_sim + neg_sim)
    return loss.mean()


def compute_flops_loss(sparse_repr: torch.Tensor) -> torch.Tensor:
    """
    FLOPS regularization loss for sparsity.
    """
    # Mean activation per vocabulary term across batch
    mean_activations = sparse_repr.mean(dim=0)  # [vocab_size]
    
    # Sum of squared means
    flops = (mean_activations ** 2).sum()
    
    return flops


def compute_minimum_activation_loss(
    sparse_repr: torch.Tensor,
    k: int = 5,
    epsilon: float = 1e-8,
) -> torch.Tensor:
    """
    Minimum activation loss to prevent garbage outputs.
    
    Ensures top-k activations maintain sufficient strength.
    """
    # Get top-k activations for each sample
    topk_values, _ = torch.topk(sparse_repr, k=k, dim=-1)
    
    # Mean of top-k per sample
    mean_topk = topk_values.mean(dim=-1)
    
    # Negative log (encourages higher activations)
    loss = -torch.log(mean_topk + epsilon).mean()
    
    return loss


# Get special token IDs
special_token_ids = set()
for attr in ['pad_token_id', 'cls_token_id', 'sep_token_id', 
             'unk_token_id', 'mask_token_id', 'bos_token_id', 'eos_token_id']:
    token_id = getattr(tokenizer, attr, None)
    if token_id is not None:
        special_token_ids.add(token_id)

print(f"Special token IDs: {special_token_ids}")

Special token IDs: {0, 1, 2, 3, 4, 5, 49999}


## 5. Training Loop

In [6]:
def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler,
    config: TrainingConfig,
    phase: CurriculumPhase,
    device: torch.device,
    epoch: int,
) -> Dict[str, float]:
    """Train for one epoch."""
    model.train()
    
    total_loss = 0.0
    loss_components = defaultdict(float)
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch} ({phase.name})")
    
    for batch in pbar:
        # Move to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        anchor_repr, _ = model(batch["anchor_input_ids"], batch["anchor_attention_mask"])
        positive_repr, _ = model(batch["positive_input_ids"], batch["positive_attention_mask"])
        negative_repr, _ = model(batch["negative_input_ids"], batch["negative_attention_mask"])
        
        # Compute dynamic lambda_self
        lambda_weights = compute_dynamic_lambda_self(
            batch["anchor_attention_mask"],
            config.lambda_self_min,
            config.lambda_self_max,
            config.lambda_self_decay_tokens,
        )
        
        # Compute losses
        loss_self = compute_self_reconstruction_loss(
            anchor_repr, batch["anchor_input_ids"], batch["anchor_attention_mask"],
            lambda_weights, tokenizer.pad_token_id, special_token_ids,
        )
        
        loss_positive = compute_positive_activation_loss(
            anchor_repr, batch["positive_input_ids"], batch["positive_attention_mask"],
            tokenizer.pad_token_id, special_token_ids,
        )
        
        loss_triplet = compute_triplet_margin_loss(
            anchor_repr, positive_repr, negative_repr, config.target_margin,
        )
        
        loss_flops = compute_flops_loss(anchor_repr)
        
        loss_min_act = compute_minimum_activation_loss(
            anchor_repr, k=config.min_activation_k,
        )
        
        # Total loss with phase-specific weights
        loss = (
            loss_self +
            config.lambda_synonym * loss_positive +
            config.lambda_margin * loss_triplet +
            phase.lambda_flops * loss_flops +
            phase.lambda_min_activation * loss_min_act
        )
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
        optimizer.step()
        scheduler.step()
        
        # Track losses
        total_loss += loss.item()
        loss_components["self"] += loss_self.item()
        loss_components["positive"] += loss_positive.item()
        loss_components["triplet"] += loss_triplet.item()
        loss_components["flops"] += loss_flops.item()
        loss_components["min_act"] += loss_min_act.item()
        
        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "lr": f"{scheduler.get_last_lr()[0]:.2e}",
        })
    
    n_batches = len(dataloader)
    return {
        "total": total_loss / n_batches,
        **{k: v / n_batches for k, v in loss_components.items()},
    }

In [7]:
@torch.no_grad()
def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
    tokenizer,
) -> Dict[str, float]:
    """Evaluate model."""
    model.eval()
    
    all_recalls = []
    all_mrrs = []
    
    for batch in tqdm(dataloader, desc="Evaluating"):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        anchor_repr, _ = model(batch["anchor_input_ids"], batch["anchor_attention_mask"])
        positive_repr, _ = model(batch["positive_input_ids"], batch["positive_attention_mask"])
        negative_repr, _ = model(batch["negative_input_ids"], batch["negative_attention_mask"])
        
        # Compute similarities
        pos_sim = F.cosine_similarity(anchor_repr, positive_repr, dim=-1)
        neg_sim = F.cosine_similarity(anchor_repr, negative_repr, dim=-1)
        
        # Recall: positive should rank higher than negative
        recalls = (pos_sim > neg_sim).float()
        all_recalls.extend(recalls.cpu().tolist())
        
        # MRR
        for i in range(len(pos_sim)):
            if pos_sim[i] > neg_sim[i]:
                all_mrrs.append(1.0)
            else:
                all_mrrs.append(0.5)
    
    return {
        "recall": np.mean(all_recalls) * 100,
        "mrr": np.mean(all_mrrs),
    }

## 6. Run Training

In [8]:
def run_curriculum_training(model, config, device):
    """Run full curriculum training."""
    training_history = []
    best_recall = 0.0
    
    for phase in config.phases:
        print(f"\n{'=' * 60}")
        print(f"Starting {phase.name}: Epochs {phase.start_epoch}-{phase.end_epoch}")
        print(f"lambda_flops={phase.lambda_flops}, lambda_min_act={phase.lambda_min_activation}")
        print(f"lr_multiplier={phase.lr_multiplier}")
        print(f"{'=' * 60}")
        
        # Load phase-specific data
        train_dataset = TripletDataset(
            config.data_dir / phase.data_file,
            tokenizer,
            config.max_length,
        )
        train_loader = DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            drop_last=True,
        )
        
        # Setup optimizer and scheduler for this phase
        phase_epochs = phase.end_epoch - phase.start_epoch + 1
        total_steps = len(train_loader) * phase_epochs
        warmup_steps = int(total_steps * config.warmup_ratio)
        
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate * phase.lr_multiplier,
            weight_decay=0.01,
        )
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps,
        )
        
        # Train for phase epochs
        for epoch in range(phase.start_epoch, phase.end_epoch + 1):
            train_metrics = train_epoch(
                model, train_loader, optimizer, scheduler,
                config, phase, device, epoch,
            )
            
            # Evaluate
            eval_metrics = evaluate(model, val_loader, device, tokenizer)
            
            print(f"\nEpoch {epoch}: Loss={train_metrics['total']:.4f}, "
                  f"Recall={eval_metrics['recall']:.1f}%, MRR={eval_metrics['mrr']:.4f}")
            
            # Save history
            history_entry = {
                "epoch": epoch,
                "phase": phase.name,
                **train_metrics,
                **eval_metrics,
            }
            training_history.append(history_entry)
            
            # Save best model
            if eval_metrics["recall"] > best_recall:
                best_recall = eval_metrics["recall"]
                torch.save({
                    "epoch": epoch,
                    "phase": phase.name,
                    "model_state_dict": model.state_dict(),
                    "eval_results": eval_metrics,
                    "config": {
                        "model_name": config.model_name,
                        "max_length": config.max_length,
                    },
                }, config.output_dir / "best_model.pt")
                print(f"  -> New best model saved! Recall: {best_recall:.1f}%")
        
        # Save phase checkpoint
        torch.save({
            "epoch": phase.end_epoch,
            "phase": phase.name,
            "model_state_dict": model.state_dict(),
        }, config.output_dir / f"{phase.name}_checkpoint.pt")
        print(f"\n{phase.name} checkpoint saved.")
    
    # Save training history
    with open(config.output_dir / "training_history.json", "w") as f:
        json.dump(training_history, f, indent=2)
    
    return training_history

In [None]:
# Run training
print("Starting v21.4 Curriculum Training...")
print(f"Output directory: {config.output_dir}")

history = run_curriculum_training(model, config, device)

print("\n" + "=" * 60)
print("Training Complete!")
print("=" * 60)

Starting v21.4 Curriculum Training...
Output directory: /home/west/Documents/cursor-workspace/opensearch-neural-pre-train/outputs/v21.4_korean_enhanced

Starting phase1_single_term: Epochs 1-10
lambda_flops=0.003, lambda_min_act=2.0
lr_multiplier=1.0
Loaded 136117 triplets from phase1_single_term_focus_triplets.jsonl


Epoch 1 (phase1_single_term):   0%|          | 0/2126 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/736 [00:00<?, ?it/s]


Epoch 1: Loss=-32.5524, Recall=97.8%, MRR=0.9888
  -> New best model saved! Recall: 97.8%


Epoch 2 (phase1_single_term):   0%|          | 0/2126 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/736 [00:00<?, ?it/s]


Epoch 2: Loss=-45.0634, Recall=98.3%, MRR=0.9917
  -> New best model saved! Recall: 98.3%


Epoch 3 (phase1_single_term):   0%|          | 0/2126 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/736 [00:00<?, ?it/s]


Epoch 3: Loss=-46.6957, Recall=98.6%, MRR=0.9929
  -> New best model saved! Recall: 98.6%


Epoch 4 (phase1_single_term):   0%|          | 0/2126 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/736 [00:00<?, ?it/s]


Epoch 4: Loss=-47.7580, Recall=98.7%, MRR=0.9935
  -> New best model saved! Recall: 98.7%


Epoch 5 (phase1_single_term):   0%|          | 0/2126 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/736 [00:00<?, ?it/s]


Epoch 5: Loss=-48.5300, Recall=98.7%, MRR=0.9936
  -> New best model saved! Recall: 98.7%


Epoch 6 (phase1_single_term):   0%|          | 0/2126 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/736 [00:00<?, ?it/s]

## 7. Training Summary

In [None]:
import matplotlib.pyplot as plt

# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

epochs = [h["epoch"] for h in history]

# Total Loss
axes[0, 0].plot(epochs, [h["total"] for h in history], 'b-')
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Total Loss")
axes[0, 0].set_title("Total Loss")
axes[0, 0].axvline(x=10.5, color='r', linestyle='--', alpha=0.5, label='Phase change')
axes[0, 0].axvline(x=20.5, color='r', linestyle='--', alpha=0.5)

# Recall
axes[0, 1].plot(epochs, [h["recall"] for h in history], 'g-')
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("Recall (%)")
axes[0, 1].set_title("Validation Recall")
axes[0, 1].axvline(x=10.5, color='r', linestyle='--', alpha=0.5)
axes[0, 1].axvline(x=20.5, color='r', linestyle='--', alpha=0.5)

# Loss Components
for component in ["self", "positive", "triplet"]:
    axes[1, 0].plot(epochs, [h[component] for h in history], label=component)
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].set_ylabel("Loss")
axes[1, 0].set_title("Loss Components")
axes[1, 0].legend()

# FLOPS and Min Activation
ax1 = axes[1, 1]
ax1.plot(epochs, [h["flops"] for h in history], 'b-', label='FLOPS')
ax1.set_xlabel("Epoch")
ax1.set_ylabel("FLOPS Loss", color='b')
ax2 = ax1.twinx()
ax2.plot(epochs, [h["min_act"] for h in history], 'r-', label='Min Act')
ax2.set_ylabel("Min Activation Loss", color='r')
axes[1, 1].set_title("Regularization Losses")

plt.tight_layout()
plt.savefig(config.output_dir / "training_curves.png", dpi=150)
plt.show()

# Print final summary
print("\nFinal Results:")
print(f"  Best Recall: {max(h['recall'] for h in history):.1f}%")
print(f"  Best MRR: {max(h['mrr'] for h in history):.4f}")
print(f"  Final Loss: {history[-1]['total']:.4f}")

## Next Steps

1. Run `04_evaluation.ipynb` for comprehensive evaluation
2. Compare with v21.2 and v21.3 baselines
3. Test on problem terms (추천, 데이터베이스, 증상, 질환, 인슐린)