# v21.3 Inference Test - Korean Synonym Expansion

This notebook tests the trained model's Korean synonym expansion capability.

## Tests

1. **General terms**: Common Korean words
2. **Legal terms**: 법률 용어
3. **Medical terms**: 의료 용어
4. **Comparison with v21.2**: Performance difference

In [None]:
import sys
from pathlib import Path

def find_project_root():
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM
from typing import Dict, List, Tuple
import numpy as np
from collections import defaultdict

print(f"Project root: {PROJECT_ROOT}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Paths
V21_3_MODEL_PATH = PROJECT_ROOT / "outputs" / "v21.3_korean_enhanced" / "best_model.pt"
V21_2_MODEL_PATH = PROJECT_ROOT / "outputs" / "v21.2_korean_legal_medical" / "best_model.pt"

print(f"v21.3 model: {V21_3_MODEL_PATH}")
print(f"v21.2 model: {V21_2_MODEL_PATH}")
print(f"v21.3 exists: {V21_3_MODEL_PATH.exists()}")
print(f"v21.2 exists: {V21_2_MODEL_PATH.exists()}")

## 1. Define SPLADE Model

In [None]:
class SPLADEModel(nn.Module):
    """SPLADE model for Korean sparse retrieval."""

    def __init__(self, model_name: str = "skt/A.X-Encoder-base"):
        super().__init__()
        self.model = AutoModelForMaskedLM.from_pretrained(model_name)
        self.config = self.model.config
        self.relu = nn.ReLU()

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass."""
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        token_scores = torch.log1p(self.relu(logits))
        mask = attention_mask.unsqueeze(-1).float()
        token_scores = token_scores * mask
        sparse_repr, _ = token_scores.max(dim=1)
        token_weights = token_scores.max(dim=-1).values
        return sparse_repr, token_weights

## 2. Load Models

In [None]:
# Load tokenizer
MODEL_NAME = "skt/A.X-Encoder-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Tokenizer loaded: {MODEL_NAME}")
print(f"Device: {device}")

In [None]:
# Load v21.3 model
model_v21_3 = None
if V21_3_MODEL_PATH.exists():
    model_v21_3 = SPLADEModel(MODEL_NAME)
    checkpoint = torch.load(V21_3_MODEL_PATH, map_location=device)
    model_v21_3.load_state_dict(checkpoint["model_state_dict"])
    model_v21_3 = model_v21_3.to(device)
    model_v21_3.eval()
    print(f"v21.3 model loaded")
    if "eval_results" in checkpoint:
        print(f"  Eval results: {checkpoint['eval_results']}")
else:
    print("v21.3 model not found - please run 03_training.ipynb first")

In [None]:
# Load v21.2 model for comparison
model_v21_2 = None
if V21_2_MODEL_PATH.exists():
    model_v21_2 = SPLADEModel(MODEL_NAME)
    checkpoint = torch.load(V21_2_MODEL_PATH, map_location=device)
    model_v21_2.load_state_dict(checkpoint["model_state_dict"])
    model_v21_2 = model_v21_2.to(device)
    model_v21_2.eval()
    print(f"v21.2 model loaded for comparison")
else:
    print("v21.2 model not found - comparison will be skipped")

## 3. Define Synonym Expansion Function

In [None]:
def get_synonym_expansion(
    model: SPLADEModel,
    tokenizer,
    text: str,
    device: torch.device,
    top_k: int = 15,
) -> List[Tuple[str, float]]:
    """
    Get top-k activated tokens for a given text.
    
    Returns:
        List of (token, weight) tuples sorted by weight descending.
    """
    model.eval()
    
    # Build special token set
    special_ids = {tokenizer.pad_token_id, tokenizer.cls_token_id,
                   tokenizer.sep_token_id, tokenizer.unk_token_id}
    special_ids = {t for t in special_ids if t is not None}
    
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        weights, _ = model(inputs["input_ids"], inputs["attention_mask"])
        weights = weights[0]
        
        # Mask special tokens
        for tid in special_ids:
            weights[tid] = -float('inf')
        
        # Get top tokens
        top_weights, top_indices = weights.topk(top_k)
        
        results = []
        for idx, weight in zip(top_indices.tolist(), top_weights.tolist()):
            token = tokenizer.decode([idx]).strip()
            if token and len(token) >= 1:
                results.append((token, weight))
        
        return results

## 4. Test General Terms

In [None]:
# General Korean terms
general_terms = [
    "추천",
    "검색",
    "인공지능",
    "기계학습",
    "데이터베이스",
    "컴퓨터",
    "스마트폰",
    "프로그래밍",
]

if model_v21_3:
    print("=" * 80)
    print("General Terms - v21.3")
    print("=" * 80)
    
    for term in general_terms:
        expansions = get_synonym_expansion(model_v21_3, tokenizer, term, device)
        print(f"\n{term}:")
        for token, weight in expansions[:10]:
            print(f"  {token}: {weight:.4f}")

## 5. Test Legal Terms (법률 용어)

In [None]:
# Legal terms
legal_terms = [
    "손해배상",
    "판결",
    "소송",
    "계약",
    "위반",
    "피고",
    "원고",
    "변호사",
]

if model_v21_3:
    print("=" * 80)
    print("Legal Terms (법률) - v21.3")
    print("=" * 80)
    
    for term in legal_terms:
        expansions = get_synonym_expansion(model_v21_3, tokenizer, term, device)
        print(f"\n{term}:")
        for token, weight in expansions[:10]:
            print(f"  {token}: {weight:.4f}")

## 6. Test Medical Terms (의료 용어)

In [None]:
# Medical terms
medical_terms = [
    "진단",
    "치료",
    "처방",
    "증상",
    "질환",
    "당뇨병",
    "고혈압",
    "인슐린",
]

if model_v21_3:
    print("=" * 80)
    print("Medical Terms (의료) - v21.3")
    print("=" * 80)
    
    for term in medical_terms:
        expansions = get_synonym_expansion(model_v21_3, tokenizer, term, device)
        print(f"\n{term}:")
        for token, weight in expansions[:10]:
            print(f"  {token}: {weight:.4f}")

## 7. Compare v21.3 vs v21.2

In [None]:
# Test cases with expected synonyms
test_cases = [
    # General
    ("추천", ["권장", "제안", "권유", "소개"]),
    ("검색", ["탐색", "조회", "찾기", "서치"]),
    ("인공지능", ["AI", "에이아이", "기계지능"]),
    # Legal
    ("손해배상", ["배상", "보상", "책임", "손해"]),
    ("판결", ["판례", "선고", "결정", "심판"]),
    # Medical
    ("진단", ["진찰", "검진", "소견", "판단"]),
    ("치료", ["처치", "요법", "치유", "시술"]),
]

def evaluate_model(model, model_name):
    """Evaluate model on test cases."""
    if model is None:
        return None
    
    results = []
    for source, expected in test_cases:
        expansions = get_synonym_expansion(model, tokenizer, source, device, top_k=20)
        top_tokens = [t for t, _ in expansions]
        
        # Check how many expected synonyms are in top tokens
        hits = 0
        for syn in expected:
            for tok in top_tokens:
                if syn.lower() in tok.lower() or tok.lower() in syn.lower():
                    hits += 1
                    break
        
        recall = hits / len(expected) * 100
        results.append({
            "source": source,
            "expected": expected,
            "top_tokens": top_tokens[:10],
            "hits": hits,
            "recall": recall,
        })
    
    return results

# Evaluate both models
results_v21_3 = evaluate_model(model_v21_3, "v21.3")
results_v21_2 = evaluate_model(model_v21_2, "v21.2")

In [None]:
# Compare results
if results_v21_3 and results_v21_2:
    print("=" * 100)
    print("Comparison: v21.3 vs v21.2")
    print("=" * 100)
    
    print(f"\n{'Source':<15} {'Expected':<30} {'v21.3 Recall':>12} {'v21.2 Recall':>12} {'Diff':>8}")
    print("-" * 100)
    
    total_v21_3 = 0
    total_v21_2 = 0
    
    for r3, r2 in zip(results_v21_3, results_v21_2):
        diff = r3["recall"] - r2["recall"]
        diff_str = f"+{diff:.0f}%" if diff > 0 else f"{diff:.0f}%"
        
        print(f"{r3['source']:<15} {str(r3['expected'][:3]):<30} {r3['recall']:>10.0f}% {r2['recall']:>10.0f}% {diff_str:>8}")
        
        total_v21_3 += r3["recall"]
        total_v21_2 += r2["recall"]
    
    avg_v21_3 = total_v21_3 / len(results_v21_3)
    avg_v21_2 = total_v21_2 / len(results_v21_2)
    avg_diff = avg_v21_3 - avg_v21_2
    
    print("-" * 100)
    print(f"{'AVERAGE':<45} {avg_v21_3:>10.1f}% {avg_v21_2:>10.1f}% {'+' if avg_diff > 0 else ''}{avg_diff:.1f}%")
    
elif results_v21_3:
    print("=" * 80)
    print("v21.3 Results (v21.2 not available for comparison)")
    print("=" * 80)
    
    for result in results_v21_3:
        print(f"\n{result['source']}:")
        print(f"  Expected: {result['expected']}")
        print(f"  Top tokens: {result['top_tokens']}")
        print(f"  Recall: {result['recall']:.0f}%")

## 8. Sparsity Analysis

In [None]:
def analyze_sparsity(model, tokenizer, texts, device):
    """Analyze sparsity of model outputs."""
    if model is None:
        return None
    
    sparsity_stats = []
    
    model.eval()
    with torch.no_grad():
        for text in texts:
            inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            weights, _ = model(inputs["input_ids"], inputs["attention_mask"])
            weights = weights[0].cpu().numpy()
            
            # Count non-zero activations
            non_zero = np.sum(weights > 0.01)
            total = len(weights)
            sparsity = (total - non_zero) / total * 100
            
            sparsity_stats.append({
                "text": text,
                "non_zero": non_zero,
                "sparsity": sparsity,
            })
    
    return sparsity_stats

# Analyze sparsity
test_texts = general_terms + legal_terms[:4] + medical_terms[:4]

sparsity_v21_3 = analyze_sparsity(model_v21_3, tokenizer, test_texts, device)
sparsity_v21_2 = analyze_sparsity(model_v21_2, tokenizer, test_texts, device)

if sparsity_v21_3:
    avg_sparsity = np.mean([s["sparsity"] for s in sparsity_v21_3])
    avg_nonzero = np.mean([s["non_zero"] for s in sparsity_v21_3])
    print(f"\nv21.3 Sparsity Analysis:")
    print(f"  Average sparsity: {avg_sparsity:.2f}%")
    print(f"  Average non-zero activations: {avg_nonzero:.1f}")

if sparsity_v21_2:
    avg_sparsity = np.mean([s["sparsity"] for s in sparsity_v21_2])
    avg_nonzero = np.mean([s["non_zero"] for s in sparsity_v21_2])
    print(f"\nv21.2 Sparsity Analysis:")
    print(f"  Average sparsity: {avg_sparsity:.2f}%")
    print(f"  Average non-zero activations: {avg_nonzero:.1f}")

## Summary

### Key Findings

| Aspect | v21.2 | v21.3 |
|--------|-------|-------|
| Data Quality | ~50% noise | < 10% noise |
| Evaluation | Saturated at 100% | Recall@K, MRR |
| Medical Domain | Load failed | All 4 configs |
| Hard Negatives | Random | Difficulty-balanced |

### Next Steps

1. Upload to HuggingFace Hub
2. Deploy to OpenSearch
3. Run end-to-end retrieval evaluation