# v21.2 Ranking-Based Evaluation

This notebook demonstrates comprehensive ranking-based evaluation metrics for SPLADE synonym expansion models.

## Problem Statement

Current metrics (source preservation rate, synonym activation rate) saturate at 100% from epoch 1,
making it impossible to differentiate between model quality improvements across training.

## Solution: Ranking Metrics

We implement three ranking-aware metrics with multi-grade relevance:

1. **Recall@K** (K=5, 10, 20): Fraction of relevant items in top-K
2. **MRR (Mean Reciprocal Rank)**: How early the first relevant item appears
3. **nDCG@K**: Quality of ranking with graded relevance

### Relevance Grades

| Grade | Description | Example for "손해배상" |
|-------|-------------|----------------------|
| 3 | Exact synonym | 배상, 보상 |
| 2 | Partial match | 손해, 피해 |
| 1 | Related term | 사고, 책임 |
| 0 | Not relevant | (implicit) |

In [None]:
import sys
from pathlib import Path

def find_project_root():
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForMaskedLM
import matplotlib.pyplot as plt
import numpy as np

# Import our evaluation module
from src.evaluation import (
    RankingMetrics,
    EvaluationDataset,
    GradedRelevance,
    ModelComparison,
    create_korean_legal_medical_eval_dataset,
)

print(f"Project root: {PROJECT_ROOT}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Mathematical Formulations

### 1.1 Recall@K

Measures the fraction of relevant items retrieved in top-K positions.

$$\text{Recall@K} = \frac{|\{\text{relevant items in top-K}\}|}{|\{\text{all relevant items}\}|}$$

For graded relevance with threshold:

$$\text{Recall@K}(\theta) = \frac{|\{i \in \text{top-K} : \text{rel}_i \geq \theta\}|}{|\{j : \text{rel}_j \geq \theta\}|}$$

### 1.2 Mean Reciprocal Rank (MRR)

Measures how early the first relevant item appears in the ranking.

$$\text{MRR} = \frac{1}{|Q|} \sum_{q \in Q} \frac{1}{\text{rank}_q}$$

where $\text{rank}_q$ is the position of the first relevant item for query $q$.

### 1.3 Normalized Discounted Cumulative Gain (nDCG@K)

Measures ranking quality with graded relevance and position discount.

$$\text{DCG@K} = \sum_{i=1}^{K} \frac{2^{\text{rel}_i} - 1}{\log_2(i + 1)}$$

$$\text{IDCG@K} = \text{DCG@K of ideal ranking}$$

$$\text{nDCG@K} = \frac{\text{DCG@K}}{\text{IDCG@K}}$$

Key properties:
- Grade 3 (exact synonym) contributes $2^3 - 1 = 7$ to gain
- Grade 2 (partial match) contributes $2^2 - 1 = 3$ to gain
- Grade 1 (related term) contributes $2^1 - 1 = 1$ to gain
- Higher positions have larger logarithmic discount factors

## 2. Load Model and Create Evaluation Dataset

In [None]:
# SPLADE Model definition (same as training)
class SPLADEModel(nn.Module):
    """SPLADE model for Korean sparse retrieval."""

    def __init__(self, model_name: str = "skt/A.X-Encoder-base"):
        super().__init__()
        self.model = AutoModelForMaskedLM.from_pretrained(model_name)
        self.config = self.model.config
        self.relu = nn.ReLU()

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        token_scores = torch.log1p(self.relu(logits))
        mask = attention_mask.unsqueeze(-1).float()
        token_scores = token_scores * mask
        sparse_repr, _ = token_scores.max(dim=1)
        token_weights = token_scores.max(dim=-1).values
        return sparse_repr, token_weights

In [None]:
# Configuration
MODEL_NAME = "skt/A.X-Encoder-base"
MODEL_PATH = PROJECT_ROOT / "outputs" / "v21.2_korean_legal_medical" / "best_model.pt"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f"Tokenizer vocab size: {tokenizer.vocab_size:,}")

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SPLADEModel(MODEL_NAME)

# Load trained weights if available
if MODEL_PATH.exists():
    checkpoint = torch.load(MODEL_PATH, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    print(f"Loaded model from: {MODEL_PATH}")
else:
    print(f"No trained model found at {MODEL_PATH}, using pretrained weights")

model = model.to(device)
model.eval()
print(f"Model loaded on {device}")

In [None]:
# Create evaluation dataset
eval_dataset = create_korean_legal_medical_eval_dataset()

print("Evaluation Dataset Statistics:")
stats = eval_dataset.statistics()
for key, value in stats.items():
    print(f"  {key}: {value}")

In [None]:
# Show sample queries with graded relevance
print("\nSample Queries with Graded Relevance:")
print("=" * 70)

for i, query in enumerate(eval_dataset[:3]):
    print(f"\nQuery {i+1}: '{query.query}' (domain: {query.domain})")
    print("-" * 50)
    
    for grade in [3, 2, 1]:
        tokens = query.get_tokens_by_grade(grade)
        grade_names = {3: "Exact synonym", 2: "Partial match", 1: "Related term"}
        if tokens:
            print(f"  Grade {grade} ({grade_names[grade]}): {', '.join(sorted(tokens))}")

## 3. Initialize Ranking Metrics and Evaluate

In [None]:
# Initialize ranking metrics calculator
ranking_metrics = RankingMetrics(
    tokenizer=tokenizer,
    k_values=[5, 10, 20, 50],
)

print(f"Configured K values: {ranking_metrics.k_values}")
print(f"Special token IDs excluded: {ranking_metrics.special_token_ids}")

In [None]:
# Evaluate model
print("Evaluating model...")
result = ranking_metrics.evaluate(
    model=model,
    eval_dataset=eval_dataset,
    batch_size=16,
    device=device,
)

print(result.summary())

In [None]:
# Detailed per-query analysis
print("\nPer-Query Metrics (first 5 queries):")
print("=" * 80)

for i, metrics in enumerate(result.per_query_metrics[:5]):
    print(f"\nQuery: '{metrics['query']}' (domain: {metrics['domain']})")
    print(f"  MRR: {metrics['mrr']:.4f}")
    print(f"  First relevant rank: {metrics['first_relevant_rank']}")
    print(f"  Recall@5: {metrics['recall@5']:.4f}, Recall@10: {metrics['recall@10']:.4f}")
    print(f"  nDCG@5: {metrics['ndcg@5']:.4f}, nDCG@10: {metrics['ndcg@10']:.4f}")

## 4. Visualize Results

In [None]:
# Plot metrics across K values
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

k_values = sorted(result.recall_at_k.keys())

# Recall@K
recall_values = [result.recall_at_k[k] for k in k_values]
axes[0].bar(range(len(k_values)), recall_values, color='#3498db', alpha=0.8)
axes[0].set_xticks(range(len(k_values)))
axes[0].set_xticklabels([f'@{k}' for k in k_values])
axes[0].set_ylabel('Recall')
axes[0].set_title('Recall@K')
axes[0].set_ylim(0, 1)
for i, v in enumerate(recall_values):
    axes[0].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=10)

# nDCG@K
ndcg_values = [result.ndcg_at_k[k] for k in k_values]
axes[1].bar(range(len(k_values)), ndcg_values, color='#2ecc71', alpha=0.8)
axes[1].set_xticks(range(len(k_values)))
axes[1].set_xticklabels([f'@{k}' for k in k_values])
axes[1].set_ylabel('nDCG')
axes[1].set_title('nDCG@K')
axes[1].set_ylim(0, 1)
for i, v in enumerate(ndcg_values):
    axes[1].text(i, v + 0.02, f'{v:.3f}', ha='center', fontsize=10)

# MRR and First Relevant Rank
metrics_names = ['MRR', 'Mean 1st Rank', 'Median 1st Rank']
metrics_values = [
    result.mrr,
    result.mean_first_relevant_rank / 100 if result.mean_first_relevant_rank else 0,
    result.median_first_relevant_rank / 100 if result.median_first_relevant_rank else 0,
]
colors = ['#e74c3c', '#9b59b6', '#f39c12']
axes[2].bar(range(len(metrics_names)), metrics_values, color=colors, alpha=0.8)
axes[2].set_xticks(range(len(metrics_names)))
axes[2].set_xticklabels(metrics_names, rotation=15)
axes[2].set_title('MRR and First Relevant Rank')
axes[2].set_ylim(0, 1)

plt.tight_layout()
plt.savefig(PROJECT_ROOT / "outputs" / "v21.2_korean_legal_medical" / "ranking_metrics.png", dpi=150)
plt.show()

In [None]:
# Per-domain comparison
if result.domain_metrics:
    domains = list(result.domain_metrics.keys())
    n_domains = len(domains)
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    x = np.arange(n_domains)
    width = 0.25
    
    mrr_values = [result.domain_metrics[d]['mrr'] for d in domains]
    ndcg10_values = [result.domain_metrics[d]['ndcg_at_k'].get(10, 0) for d in domains]
    recall10_values = [result.domain_metrics[d]['recall_at_k'].get(10, 0) for d in domains]
    
    ax.bar(x - width, mrr_values, width, label='MRR', color='#3498db')
    ax.bar(x, ndcg10_values, width, label='nDCG@10', color='#2ecc71')
    ax.bar(x + width, recall10_values, width, label='Recall@10', color='#e74c3c')
    
    ax.set_xlabel('Domain')
    ax.set_ylabel('Score')
    ax.set_title('Metrics by Domain')
    ax.set_xticks(x)
    ax.set_xticklabels(domains)
    ax.legend()
    ax.set_ylim(0, 1)
    
    plt.tight_layout()
    plt.savefig(PROJECT_ROOT / "outputs" / "v21.2_korean_legal_medical" / "domain_metrics.png", dpi=150)
    plt.show()

## 5. Analyze Individual Query Results

In [None]:
def analyze_query(
    query: str,
    model: nn.Module,
    tokenizer,
    ground_truth: GradedRelevance,
    top_k: int = 20,
    device=None,
):
    """
    Analyze a single query showing ranking vs ground truth.
    """
    if device is None:
        device = next(model.parameters()).device
    
    # Encode query
    inputs = tokenizer(
        query,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=64,
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        sparse_repr, _ = model(inputs["input_ids"], inputs["attention_mask"])
    
    # Get special token IDs
    special_ids = {tokenizer.pad_token_id, tokenizer.cls_token_id,
                   tokenizer.sep_token_id, tokenizer.unk_token_id}
    special_ids = {t for t in special_ids if t is not None}
    
    # Mask special tokens
    weights = sparse_repr[0].clone()
    for tid in special_ids:
        weights[tid] = -float('inf')
    
    # Get top-k
    top_values, top_indices = torch.topk(weights, k=top_k)
    
    # Build relevance map
    relevance_map = {}
    for token, grade in ground_truth.relevance_judgments.items():
        ids = tokenizer.encode(token, add_special_tokens=False)
        if len(ids) == 1:
            relevance_map[ids[0]] = (token, grade)
    
    print(f"\nQuery: '{query}'")
    print(f"Domain: {ground_truth.domain}")
    print(f"Ground Truth: {len(ground_truth.relevance_judgments)} judged tokens")
    print("=" * 70)
    print(f"{'Rank':<6}{'Token':<15}{'Weight':<12}{'Relevance':<12}{'Grade'}")
    print("-" * 70)
    
    for rank, (idx, weight) in enumerate(zip(top_indices.tolist(), top_values.tolist()), 1):
        token = tokenizer.decode([idx]).strip()
        if idx in relevance_map:
            gt_token, grade = relevance_map[idx]
            grade_labels = {3: "Exact", 2: "Partial", 1: "Related"}
            rel_label = grade_labels.get(grade, "?")
            marker = "***" if grade == 3 else ("**" if grade == 2 else "*")
        else:
            grade = 0
            rel_label = "-"
            marker = ""
        
        print(f"{rank:<6}{token:<15}{weight:<12.4f}{rel_label:<12}{marker}")
    
    # Compute metrics for this query
    ranking = [(idx.item(), val.item()) for idx, val in zip(top_indices, top_values) if val > 0]
    metrics = RankingMetrics(tokenizer)
    
    print("\nMetrics:")
    print(f"  MRR: {metrics.compute_mrr(ranking, ground_truth):.4f}")
    for k in [5, 10, 20]:
        recall = metrics.compute_recall_at_k(ranking, ground_truth, k)
        ndcg = metrics.compute_ndcg(ranking, ground_truth, k)
        print(f"  Recall@{k}: {recall:.4f}, nDCG@{k}: {ndcg:.4f}")

In [None]:
# Analyze sample queries
for query_data in eval_dataset[:3]:
    analyze_query(
        query=query_data.query,
        model=model,
        tokenizer=tokenizer,
        ground_truth=query_data,
        top_k=20,
        device=device,
    )
    print()

## 6. Compare Models Across Epochs

This section demonstrates how to compare model quality across training epochs using statistical significance testing.

In [None]:
# Function to load model from checkpoint
def load_model_from_checkpoint(checkpoint_path, model_name, device):
    """Load model from checkpoint file."""
    model = SPLADEModel(model_name)
    if checkpoint_path.exists():
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)
    model.eval()
    return model

# Example: Compare epoch 5 vs epoch 25
output_dir = PROJECT_ROOT / "outputs" / "v21.2_korean_legal_medical"
epoch5_path = output_dir / "checkpoint_epoch5.pt"
epoch25_path = output_dir / "checkpoint_epoch25.pt"

if epoch5_path.exists() and epoch25_path.exists():
    print("Loading models from epochs 5 and 25...")
    
    model_epoch5 = load_model_from_checkpoint(epoch5_path, MODEL_NAME, device)
    model_epoch25 = load_model_from_checkpoint(epoch25_path, MODEL_NAME, device)
    
    # Evaluate both models
    print("Evaluating epoch 5 model...")
    result_epoch5 = ranking_metrics.evaluate(model_epoch5, eval_dataset, device=device)
    
    print("Evaluating epoch 25 model...")
    result_epoch25 = ranking_metrics.evaluate(model_epoch25, eval_dataset, device=device)
    
    # Statistical comparison
    comparison = ModelComparison.compare_models(
        result_epoch5,
        result_epoch25,
        model_a_name="Epoch 5",
        model_b_name="Epoch 25",
    )
    
    print(comparison.summary())
else:
    print(f"Checkpoint files not found. Expected:")
    print(f"  - {epoch5_path}")
    print(f"  - {epoch25_path}")
    print("\nRun training first to generate checkpoints.")

## 7. Creating Custom Evaluation Datasets

This section shows how to create your own evaluation dataset with graded relevance judgments.

In [None]:
# Method 1: From synonym pairs (convenient format)
custom_synonym_data = [
    {
        "source": "기업",
        "synonyms": {
            "exact": ["회사", "법인", "사업체"],
            "partial": ["기업체", "업체", "상사"],
            "related": ["사업", "경영", "산업"],
        },
        "domain": "business",
    },
    {
        "source": "투자",
        "synonyms": {
            "exact": ["출자", "투입", "자본투자"],
            "partial": ["투자금", "자금", "투자액"],
            "related": ["수익", "이익", "자산"],
        },
        "domain": "finance",
    },
]

custom_dataset = EvaluationDataset.from_synonym_pairs(
    custom_synonym_data,
    name="custom_business_finance",
)

print("Custom Dataset Statistics:")
print(custom_dataset.statistics())

In [None]:
# Method 2: Direct GradedRelevance construction (full control)
manual_queries = [
    GradedRelevance(
        query="프로그래밍",
        relevance_judgments={
            "코딩": 3,
            "개발": 3,
            "프로그램": 2,
            "소프트웨어": 2,
            "컴퓨터": 1,
            "언어": 1,
        },
        domain="tech",
    ),
]

manual_dataset = EvaluationDataset(
    queries=manual_queries,
    name="manual_tech_dataset",
)

print("\nManual Dataset:")
print(f"  Queries: {len(manual_dataset)}")
print(f"  First query ideal ranking: {manual_dataset[0].ideal_ranking()}")

In [None]:
# Save and load datasets
save_path = PROJECT_ROOT / "dataset" / "eval_datasets" / "custom_eval.json"
save_path.parent.mkdir(parents=True, exist_ok=True)

custom_dataset.save(save_path)
print(f"Saved dataset to: {save_path}")

# Reload
reloaded = EvaluationDataset.load(save_path)
print(f"Reloaded dataset: {len(reloaded)} queries")

## 8. Summary and Recommendations

### Key Takeaways

1. **nDCG@K is the primary metric** for synonym expansion quality:
   - Captures graded relevance (exact > partial > related)
   - Rewards putting exact synonyms at higher ranks
   - Normalized to [0, 1] for easy interpretation

2. **MRR complements nDCG**:
   - Measures how quickly the model finds ANY relevant synonym
   - Useful for applications where only the first result matters

3. **Recall@K for coverage**:
   - Measures completeness of synonym retrieval
   - Use multiple K values (5, 10, 20) to understand recall curve

### Recommended Evaluation Protocol

1. Create domain-specific evaluation datasets with 3-grade relevance
2. Report nDCG@10 as primary metric
3. Use MRR and Recall@10 as secondary metrics
4. Use paired t-test or bootstrap CI for model comparison
5. Report per-domain metrics for specialized applications

### Integration with Training

Use these ranking metrics during training validation to:
- Detect when model quality actually improves (not just loss decreases)
- Enable early stopping based on nDCG@K plateau
- Compare different hyperparameter configurations objectively

In [None]:
# Save final results
import json

results_dict = result.to_dict()
results_path = PROJECT_ROOT / "outputs" / "v21.2_korean_legal_medical" / "ranking_evaluation_results.json"

with open(results_path, "w", encoding="utf-8") as f:
    json.dump(results_dict, f, ensure_ascii=False, indent=2, default=str)

print(f"Results saved to: {results_path}")