# v22.0 Training with InfoNCE Contrastive Loss

## Key Improvements over v21.4

1. **InfoNCE Contrastive Loss**: In-batch negatives for better representation learning
2. **Temperature Annealing**: 0.07 → 0.05 → 0.03 (sharper discrimination)
3. **Expanded Single-term Data**: 448 → 29,322 triplets (65x increase)
4. **Curriculum Learning**: 3 phases with dynamic InfoNCE weight

## Target Metrics

| Metric | v21.4 | v22.0 Target |
|--------|-------|--------------|
| Recall@1 | 38.1% | 70%+ |
| MRR | 0.4412 | 0.75+ |
| Single-term Recall | ~30% | 80%+ |

In [None]:
import sys
from pathlib import Path

def find_project_root():
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForMaskedLM, get_linear_schedule_with_warmup
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from tqdm.auto import tqdm
from collections import defaultdict

# Import v22.0 loss functions
from src.model.losses import (
    InfoNCELoss,
    SelfReconstructionLoss,
    PositiveActivationLoss,
    TripletMarginLoss,
    FLOPSLoss,
    MinimumActivationLoss,
    SPLADELossV22,
)

# Set seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print(f"Project root: {PROJECT_ROOT}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Configuration with InfoNCE

In [None]:
@dataclass
class CurriculumPhaseV22:
    """Configuration for v22.0 curriculum learning phase."""
    name: str
    start_epoch: int
    end_epoch: int
    # InfoNCE parameters
    lambda_infonce: float
    temperature: float
    # Other loss weights
    lambda_flops: float
    lambda_min_activation: float
    lr_multiplier: float
    data_file: str


@dataclass
class TrainingConfigV22:
    """Training configuration for v22.0."""
    # Model
    model_name: str = "skt/A.X-Encoder-base"
    max_length: int = 64
    
    # Training
    total_epochs: int = 30
    batch_size: int = 64
    learning_rate: float = 3e-6
    warmup_ratio: float = 0.1
    max_grad_norm: float = 1.0
    
    # Loss weights (static across phases)
    lambda_self: float = 4.0
    lambda_positive: float = 10.0
    lambda_margin: float = 2.5
    target_margin: float = 1.5
    
    # Minimum activation
    min_activation_k: int = 5
    min_activation_threshold: float = 0.5
    
    # Curriculum phases
    phases: List[CurriculumPhaseV22] = field(default_factory=list)
    
    # Paths
    data_dir: Path = None
    output_dir: Path = None
    
    def __post_init__(self):
        if not self.phases:
            # v22.0 Curriculum: Temperature annealing + InfoNCE weight increase
            self.phases = [
                CurriculumPhaseV22(
                    name="phase1_single_term",
                    start_epoch=1,
                    end_epoch=10,
                    lambda_infonce=1.0,
                    temperature=0.07,  # Warm start
                    lambda_flops=3e-3,
                    lambda_min_activation=2.0,
                    lr_multiplier=1.0,
                    data_file="phase1_single_term_focus_triplets.jsonl",
                ),
                CurriculumPhaseV22(
                    name="phase2_balanced",
                    start_epoch=11,
                    end_epoch=20,
                    lambda_infonce=1.5,
                    temperature=0.05,  # Sharper
                    lambda_flops=4e-3,
                    lambda_min_activation=1.5,
                    lr_multiplier=0.5,
                    data_file="phase2_balanced_triplets.jsonl",
                ),
                CurriculumPhaseV22(
                    name="phase3_full",
                    start_epoch=21,
                    end_epoch=30,
                    lambda_infonce=2.0,
                    temperature=0.03,  # Sharp discrimination
                    lambda_flops=5e-3,
                    lambda_min_activation=1.0,
                    lr_multiplier=0.25,
                    data_file="phase3_full_triplets.jsonl",
                ),
            ]


# Create config
config = TrainingConfigV22(
    data_dir=PROJECT_ROOT / "data" / "v22.0",
    output_dir=PROJECT_ROOT / "outputs" / "v22.0_infonce",
)
config.output_dir.mkdir(parents=True, exist_ok=True)

print("v22.0 Configuration:")
print(f"  Model: {config.model_name}")
print(f"  Epochs: {config.total_epochs}")
print(f"  Batch size: {config.batch_size}")
print(f"  Learning rate: {config.learning_rate}")
print(f"\nCurriculum Phases (with InfoNCE):")
for phase in config.phases:
    print(f"  {phase.name}: epochs {phase.start_epoch}-{phase.end_epoch}")
    print(f"    lambda_infonce={phase.lambda_infonce}, temp={phase.temperature}")
    print(f"    lambda_flops={phase.lambda_flops}, lr_mult={phase.lr_multiplier}")

## 2. Model Definition

In [None]:
class SPLADEModelV22(nn.Module):
    """SPLADE model for v22.0 with InfoNCE support."""
    
    def __init__(self, model_name: str = "skt/A.X-Encoder-base"):
        super().__init__()
        self.model = AutoModelForMaskedLM.from_pretrained(model_name)
        self.config = self.model.config
        self.relu = nn.ReLU()
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass.
        
        Returns:
            sparse_repr: [batch_size, vocab_size] - max-pooled sparse representation
            token_weights: [batch_size, seq_len] - per-token max weights
        """
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        
        # SPLADE transformation: log(1 + ReLU(x))
        token_scores = torch.log1p(self.relu(logits))
        
        # Apply attention mask
        mask = attention_mask.unsqueeze(-1).float()
        token_scores = token_scores * mask
        
        # Max pooling over sequence length
        sparse_repr, _ = token_scores.max(dim=1)
        
        # Token weights for analysis
        token_weights = token_scores.max(dim=-1).values
        
        return sparse_repr, token_weights


# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SPLADEModelV22(config.model_name)
model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)

print(f"Model loaded: {config.model_name}")
print(f"Device: {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 3. Dataset

In [None]:
class TripletDataset(Dataset):
    """Dataset for triplet training."""
    
    def __init__(self, data_path: Path, tokenizer, max_length: int = 64):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []
        
        with open(data_path, "r", encoding="utf-8") as f:
            for line in f:
                self.data.append(json.loads(line))
        
        print(f"Loaded {len(self.data):,} triplets from {data_path.name}")
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Dict:
        item = self.data[idx]
        
        # Tokenize
        anchor = self.tokenizer(
            item["anchor"],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        positive = self.tokenizer(
            item["positive"],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        negative = self.tokenizer(
            item["negative"],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        
        return {
            "anchor_input_ids": anchor["input_ids"].squeeze(0),
            "anchor_attention_mask": anchor["attention_mask"].squeeze(0),
            "positive_input_ids": positive["input_ids"].squeeze(0),
            "positive_attention_mask": positive["attention_mask"].squeeze(0),
            "negative_input_ids": negative["input_ids"].squeeze(0),
            "negative_attention_mask": negative["attention_mask"].squeeze(0),
        }


# Load validation data
val_dataset = TripletDataset(
    config.data_dir / "validation_triplets.jsonl",
    tokenizer,
    config.max_length,
)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

## 4. v22.0 Loss Functions (InfoNCE + Combined)

In [None]:
# Get special token IDs for self-reconstruction loss
special_token_ids = set()
for attr in ['pad_token_id', 'cls_token_id', 'sep_token_id', 
             'unk_token_id', 'mask_token_id', 'bos_token_id', 'eos_token_id']:
    token_id = getattr(tokenizer, attr, None)
    if token_id is not None:
        special_token_ids.add(token_id)

print(f"Special token IDs: {special_token_ids}")


def compute_self_reconstruction_loss(
    sparse_repr: torch.Tensor,
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    pad_token_id: int,
    special_token_ids: set,
) -> torch.Tensor:
    """Compute self-reconstruction loss (encourage input token activation)."""
    batch_size = sparse_repr.size(0)
    losses = []
    
    for i in range(batch_size):
        valid_mask = (attention_mask[i] == 1)
        token_ids = input_ids[i][valid_mask]
        
        valid_ids = [tid.item() for tid in token_ids 
                     if tid.item() not in special_token_ids and tid.item() != pad_token_id]
        
        if not valid_ids:
            losses.append(torch.tensor(0.0, device=sparse_repr.device))
            continue
        
        activations = sparse_repr[i, valid_ids]
        loss = -activations.mean()
        losses.append(loss)
    
    return torch.stack(losses).mean()


def compute_positive_activation_loss(
    anchor_repr: torch.Tensor,
    positive_input_ids: torch.Tensor,
    positive_attention_mask: torch.Tensor,
    pad_token_id: int,
    special_token_ids: set,
) -> torch.Tensor:
    """Encourage anchor to activate positive synonym tokens."""
    batch_size = anchor_repr.size(0)
    losses = []
    
    for i in range(batch_size):
        valid_mask = (positive_attention_mask[i] == 1)
        token_ids = positive_input_ids[i][valid_mask]
        
        valid_ids = [tid.item() for tid in token_ids 
                     if tid.item() not in special_token_ids and tid.item() != pad_token_id]
        
        if not valid_ids:
            losses.append(torch.tensor(0.0, device=anchor_repr.device))
            continue
        
        activations = anchor_repr[i, valid_ids]
        loss = -activations.mean()
        losses.append(loss)
    
    return torch.stack(losses).mean()

## 5. Training Loop with InfoNCE

In [None]:
def train_epoch_v22(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler,
    config: TrainingConfigV22,
    phase: CurriculumPhaseV22,
    device: torch.device,
    epoch: int,
    infonce_loss_fn: InfoNCELoss,
) -> Dict[str, float]:
    """Train for one epoch with InfoNCE."""
    model.train()
    
    total_loss = 0.0
    loss_components = defaultdict(float)
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch} ({phase.name})")
    
    for batch in pbar:
        # Move to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        anchor_repr, _ = model(batch["anchor_input_ids"], batch["anchor_attention_mask"])
        positive_repr, _ = model(batch["positive_input_ids"], batch["positive_attention_mask"])
        negative_repr, _ = model(batch["negative_input_ids"], batch["negative_attention_mask"])
        
        # 1. InfoNCE Loss (NEW in v22.0)
        loss_infonce = infonce_loss_fn(anchor_repr, positive_repr, negative_repr)
        
        # 2. Self-reconstruction loss
        loss_self = compute_self_reconstruction_loss(
            anchor_repr, batch["anchor_input_ids"], batch["anchor_attention_mask"],
            tokenizer.pad_token_id, special_token_ids,
        )
        
        # 3. Positive activation loss
        loss_positive = compute_positive_activation_loss(
            anchor_repr, batch["positive_input_ids"], batch["positive_attention_mask"],
            tokenizer.pad_token_id, special_token_ids,
        )
        
        # 4. Triplet margin loss
        pos_sim = F.cosine_similarity(anchor_repr, positive_repr, dim=-1)
        neg_sim = F.cosine_similarity(anchor_repr, negative_repr, dim=-1)
        loss_triplet = F.relu(config.target_margin - pos_sim + neg_sim).mean()
        
        # 5. FLOPS regularization
        mean_activations = anchor_repr.mean(dim=0)
        loss_flops = (mean_activations ** 2).sum()
        
        # 6. Minimum activation loss
        topk_values, _ = torch.topk(anchor_repr, k=config.min_activation_k, dim=-1)
        mean_topk = topk_values.mean(dim=-1)
        loss_min_act = F.relu(config.min_activation_threshold - mean_topk).mean()
        
        # Total loss with phase-specific weights
        loss = (
            phase.lambda_infonce * loss_infonce +
            config.lambda_self * loss_self +
            config.lambda_positive * loss_positive +
            config.lambda_margin * loss_triplet +
            phase.lambda_flops * loss_flops +
            phase.lambda_min_activation * loss_min_act
        )
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
        optimizer.step()
        scheduler.step()
        
        # Track losses
        total_loss += loss.item()
        loss_components["infonce"] += loss_infonce.item()
        loss_components["self"] += loss_self.item()
        loss_components["positive"] += loss_positive.item()
        loss_components["triplet"] += loss_triplet.item()
        loss_components["flops"] += loss_flops.item()
        loss_components["min_act"] += loss_min_act.item()
        
        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "infonce": f"{loss_infonce.item():.4f}",
            "lr": f"{scheduler.get_last_lr()[0]:.2e}",
        })
    
    n_batches = len(dataloader)
    return {
        "total": total_loss / n_batches,
        **{k: v / n_batches for k, v in loss_components.items()},
    }

In [None]:
@torch.no_grad()
def evaluate_v22(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
) -> Dict[str, float]:
    """Evaluate model."""
    model.eval()
    
    all_recalls = []
    all_mrrs = []
    
    for batch in tqdm(dataloader, desc="Evaluating"):
        batch = {k: v.to(device) for k, v in batch.items()}
        
        anchor_repr, _ = model(batch["anchor_input_ids"], batch["anchor_attention_mask"])
        positive_repr, _ = model(batch["positive_input_ids"], batch["positive_attention_mask"])
        negative_repr, _ = model(batch["negative_input_ids"], batch["negative_attention_mask"])
        
        # Compute similarities
        pos_sim = F.cosine_similarity(anchor_repr, positive_repr, dim=-1)
        neg_sim = F.cosine_similarity(anchor_repr, negative_repr, dim=-1)
        
        # Recall: positive should rank higher than negative
        recalls = (pos_sim > neg_sim).float()
        all_recalls.extend(recalls.cpu().tolist())
        
        # MRR
        for i in range(len(pos_sim)):
            if pos_sim[i] > neg_sim[i]:
                all_mrrs.append(1.0)
            else:
                all_mrrs.append(0.5)
    
    return {
        "recall": np.mean(all_recalls) * 100,
        "mrr": np.mean(all_mrrs),
    }

## 6. Run v22.0 Curriculum Training

In [None]:
def run_curriculum_training_v22(model, config, device):
    """Run full v22.0 curriculum training with InfoNCE."""
    training_history = []
    best_recall = 0.0
    
    for phase in config.phases:
        print(f"\n{'=' * 60}")
        print(f"Starting {phase.name}: Epochs {phase.start_epoch}-{phase.end_epoch}")
        print(f"InfoNCE: lambda={phase.lambda_infonce}, temp={phase.temperature}")
        print(f"lambda_flops={phase.lambda_flops}, lambda_min_act={phase.lambda_min_activation}")
        print(f"lr_multiplier={phase.lr_multiplier}")
        print(f"{'=' * 60}")
        
        # Create InfoNCE loss with phase-specific temperature
        infonce_loss_fn = InfoNCELoss(
            temperature=phase.temperature,
            similarity="cosine",
        )
        
        # Load phase-specific data
        train_dataset = TripletDataset(
            config.data_dir / phase.data_file,
            tokenizer,
            config.max_length,
        )
        train_loader = DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            drop_last=True,
        )
        
        # Setup optimizer and scheduler for this phase
        phase_epochs = phase.end_epoch - phase.start_epoch + 1
        total_steps = len(train_loader) * phase_epochs
        warmup_steps = int(total_steps * config.warmup_ratio)
        
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate * phase.lr_multiplier,
            weight_decay=0.01,
        )
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps,
        )
        
        # Train for phase epochs
        for epoch in range(phase.start_epoch, phase.end_epoch + 1):
            train_metrics = train_epoch_v22(
                model, train_loader, optimizer, scheduler,
                config, phase, device, epoch, infonce_loss_fn,
            )
            
            # Evaluate
            eval_metrics = evaluate_v22(model, val_loader, device)
            
            print(f"\nEpoch {epoch}: Loss={train_metrics['total']:.4f}, "
                  f"InfoNCE={train_metrics['infonce']:.4f}, "
                  f"Recall={eval_metrics['recall']:.1f}%, MRR={eval_metrics['mrr']:.4f}")
            
            # Save history
            history_entry = {
                "epoch": epoch,
                "phase": phase.name,
                "temperature": phase.temperature,
                **train_metrics,
                **eval_metrics,
            }
            training_history.append(history_entry)
            
            # Save best model
            if eval_metrics["recall"] > best_recall:
                best_recall = eval_metrics["recall"]
                torch.save({
                    "epoch": epoch,
                    "phase": phase.name,
                    "model_state_dict": model.state_dict(),
                    "eval_results": eval_metrics,
                    "config": {
                        "model_name": config.model_name,
                        "max_length": config.max_length,
                        "version": "v22.0",
                    },
                }, config.output_dir / "best_model.pt")
                print(f"  -> New best model saved! Recall: {best_recall:.1f}%")
        
        # Save phase checkpoint
        torch.save({
            "epoch": phase.end_epoch,
            "phase": phase.name,
            "model_state_dict": model.state_dict(),
        }, config.output_dir / f"{phase.name}_checkpoint.pt")
        print(f"\n{phase.name} checkpoint saved.")
    
    # Save training history
    with open(config.output_dir / "training_history.json", "w") as f:
        json.dump(training_history, f, indent=2)
    
    return training_history

In [None]:
# Run training
print("Starting v22.0 Curriculum Training with InfoNCE...")
print(f"Output directory: {config.output_dir}")
print(f"\nData distribution:")
for phase in config.phases:
    data_path = config.data_dir / phase.data_file
    if data_path.exists():
        with open(data_path) as f:
            count = sum(1 for _ in f)
        print(f"  {phase.name}: {count:,} triplets")

history = run_curriculum_training_v22(model, config, device)

print("\n" + "=" * 60)
print("Training Complete!")
print("=" * 60)

## 7. Training Summary

In [None]:
import matplotlib.pyplot as plt

# Plot training curves
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

epochs = [h["epoch"] for h in history]

# Total Loss
axes[0, 0].plot(epochs, [h["total"] for h in history], 'b-')
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Total Loss")
axes[0, 0].set_title("Total Loss")
axes[0, 0].axvline(x=10.5, color='r', linestyle='--', alpha=0.5, label='Phase change')
axes[0, 0].axvline(x=20.5, color='r', linestyle='--', alpha=0.5)

# InfoNCE Loss (NEW)
axes[0, 1].plot(epochs, [h["infonce"] for h in history], 'purple')
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("InfoNCE Loss")
axes[0, 1].set_title("InfoNCE Contrastive Loss")
axes[0, 1].axvline(x=10.5, color='r', linestyle='--', alpha=0.5)
axes[0, 1].axvline(x=20.5, color='r', linestyle='--', alpha=0.5)

# Recall
axes[0, 2].plot(epochs, [h["recall"] for h in history], 'g-')
axes[0, 2].set_xlabel("Epoch")
axes[0, 2].set_ylabel("Recall (%)")
axes[0, 2].set_title("Validation Recall")
axes[0, 2].axvline(x=10.5, color='r', linestyle='--', alpha=0.5)
axes[0, 2].axvline(x=20.5, color='r', linestyle='--', alpha=0.5)

# Loss Components
for component in ["self", "positive", "triplet"]:
    axes[1, 0].plot(epochs, [h[component] for h in history], label=component)
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].set_ylabel("Loss")
axes[1, 0].set_title("Loss Components")
axes[1, 0].legend()

# Temperature Schedule
temps = [h["temperature"] for h in history]
axes[1, 1].plot(epochs, temps, 'orange', marker='o', markersize=2)
axes[1, 1].set_xlabel("Epoch")
axes[1, 1].set_ylabel("Temperature")
axes[1, 1].set_title("Temperature Annealing")

# FLOPS and Min Activation
ax1 = axes[1, 2]
ax1.plot(epochs, [h["flops"] for h in history], 'b-', label='FLOPS')
ax1.set_xlabel("Epoch")
ax1.set_ylabel("FLOPS Loss", color='b')
ax2 = ax1.twinx()
ax2.plot(epochs, [h["min_act"] for h in history], 'r-', label='Min Act')
ax2.set_ylabel("Min Activation Loss", color='r')
axes[1, 2].set_title("Regularization Losses")

plt.tight_layout()
plt.savefig(config.output_dir / "training_curves_v22.png", dpi=150)
plt.show()

# Print final summary
print("\nv22.0 Final Results:")
print(f"  Best Recall: {max(h['recall'] for h in history):.1f}%")
print(f"  Best MRR: {max(h['mrr'] for h in history):.4f}")
print(f"  Final InfoNCE: {history[-1]['infonce']:.4f}")
print(f"  Final Loss: {history[-1]['total']:.4f}")

## 8. Save Final Model for Benchmark

In [None]:
# Save final model in a format suitable for benchmark
final_checkpoint = {
    "model_state_dict": model.state_dict(),
    "config": {
        "model_name": config.model_name,
        "max_length": config.max_length,
        "vocab_size": model.config.vocab_size,
        "hidden_size": model.config.hidden_size,
        "version": "v22.0",
    },
    "training_info": {
        "total_epochs": config.total_epochs,
        "final_recall": history[-1]["recall"],
        "final_mrr": history[-1]["mrr"],
        "best_recall": max(h["recall"] for h in history),
    },
}

torch.save(final_checkpoint, config.output_dir / "checkpoint.pt")
print(f"Final model saved to {config.output_dir / 'checkpoint.pt'}")

# Also save tokenizer config for later use
tokenizer.save_pretrained(config.output_dir / "tokenizer")
print(f"Tokenizer saved to {config.output_dir / 'tokenizer'}")

## Next Steps

1. Run benchmark with the new v22.0 model:
   ```bash
   python benchmark/run_benchmark.py --model-path outputs/v22.0_infonce/checkpoint.pt
   ```

2. Compare with v21.4 baseline and other methods

3. Test on problem terms (추천, 데이터베이스, 증상, 질환, 인슐린)

4. If needed, adjust hyperparameters and retrain