# v19 Inference Test - XLM-RoBERTa-large with High-Quality Data

This notebook tests the v19 model trained with high-quality MUSE data (excluding wikidata).

## v19 Key Features:
- **Dataset**: v19_high_quality (~18K pairs, MUSE only, no wikidata)
- **Model**: xlm-roberta-large
- **Learning rate**: 2e-6 (same as v17)
- **Epochs**: 10
- **Loss weights**: self=2.0, target=5.0, margin=3.0, negative=0.5, sparsity=0.005

In [None]:
import sys
from pathlib import Path

# Find project root
def find_project_root():
    """Find project root by looking for markers like pyproject.toml or src/"""
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

import torch
import json
from transformers import AutoTokenizer
from src.model.splade_model import create_splade_model

print(f"Project root: {PROJECT_ROOT}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Load v19 Model

In [None]:
# Load v19 model
model_path = PROJECT_ROOT / "outputs" / "v19_xlm_large" / "best_model.pt"
print(f"Loading model from: {model_path}")
print(f"Model exists: {model_path.exists()}")

checkpoint = torch.load(model_path, map_location="cpu", weights_only=False)
config = checkpoint["config"]

print(f"\nModel Configuration:")
for key, value in config.items():
    if not isinstance(value, Path):
        print(f"  {key}: {value}")

print(f"\nBest Model Info:")
print(f"  Epoch: {checkpoint.get('epoch', 'N/A')}")
print(f"  Korean Rate: {checkpoint.get('ko_rate', 'N/A'):.1f}%")
print(f"  English Rate: {checkpoint.get('en_rate', 'N/A'):.1f}%")
print(f"  Combined Score: {checkpoint.get('combined_score', 'N/A'):.1f}")

In [None]:
# Load tokenizer
model_name = config.get("model_name", "xlm-roberta-large")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print(f"Tokenizer: {model_name}")
print(f"Vocab size: {tokenizer.vocab_size:,}")

# Create model
model = create_splade_model(
    model_name=model_name,
    use_expansion=True,
    expansion_mode="mlm",
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Device: {device}")

## 2. Define Inference Helper Functions

In [None]:
def encode_term(term: str, top_k: int = 20) -> dict:
    """Encode a term and return top-k tokens with weights."""
    inputs = tokenizer(
        term, 
        return_tensors="pt", 
        padding=True, 
        truncation=True, 
        max_length=64
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        weights, _ = model(inputs["input_ids"], inputs["attention_mask"])
    
    # Get top-k tokens
    top_indices = weights[0].topk(top_k).indices.tolist()
    top_values = weights[0].topk(top_k).values.tolist()
    top_tokens = [tokenizer.decode([idx]).strip() for idx in top_indices]
    
    return {
        "term": term,
        "tokens": list(zip(top_tokens, top_values)),
        "top_indices": top_indices,
    }


def display_result(result: dict, expected_en: list = None):
    """Display encoding result with formatting."""
    print(f"\n{'='*60}")
    print(f"Input: {result['term']}")
    print(f"{'='*60}")
    
    # Check if Korean is preserved
    top_tokens = [t[0] for t in result['tokens'][:10]]
    ko_preserved = any(result['term'] in tok or tok in result['term'] for tok in top_tokens if tok)
    
    # Check if English is activated
    en_found = []
    if expected_en:
        for en in expected_en:
            for tok in top_tokens:
                if en.lower() in tok.lower() or tok.lower() in en.lower():
                    en_found.append(en)
                    break
    
    print(f"Korean preserved: {'Yes' if ko_preserved else 'No'}")
    if expected_en:
        print(f"English activated: {en_found if en_found else 'None'}")
    
    print(f"\nTop 10 tokens:")
    for i, (token, weight) in enumerate(result['tokens'][:10]):
        marker = ""
        if result['term'] in token or token in result['term']:
            marker = " [KO]"
        elif expected_en and any(en.lower() in token.lower() for en in expected_en):
            marker = " [EN]"
        print(f"  {i+1:2}. {token:20} {weight:.4f}{marker}")

## 3. Test Korean-English Term Expansion

In [None]:
# Test cases: (Korean term, expected English tokens)
test_cases = [
    ("추천", ["recommend", "recommendation", "suggest"]),
    ("검색", ["search", "retrieval", "query"]),
    ("인공지능", ["artificial", "intelligence", "AI"]),
    ("신경망", ["neural", "network", "deep"]),
    ("기계학습", ["machine", "learning", "ML"]),
    ("강화학습", ["reinforcement", "learning", "RL"]),
]

for ko_term, en_terms in test_cases:
    result = encode_term(ko_term)
    display_result(result, en_terms)

## 4. Comprehensive Test Suite

In [None]:
# Extended test cases
comprehensive_tests = [
    # NLP/ML terms
    ("자연어처리", ["natural", "language", "NLP", "processing"]),
    ("딥러닝", ["deep", "learning"]),
    ("트랜스포머", ["transformer", "attention"]),
    ("임베딩", ["embedding", "vector"]),
    
    # Software Engineering terms
    ("데이터베이스", ["database", "DB", "data"]),
    ("클라우드", ["cloud", "computing"]),
    ("서버", ["server", "servers"]),
    ("클라이언트", ["client", "clients"]),
    ("프레임워크", ["framework", "frameworks"]),
    ("라이브러리", ["library", "libraries"]),
    
    # DevOps terms
    ("컨테이너", ["container", "docker", "kubernetes"]),
    ("마이크로서비스", ["microservice", "micro", "service"]),
    ("모니터링", ["monitoring", "monitor"]),
    ("배포", ["deployment", "deploy"]),
    
    # Development terms
    ("테스트", ["test", "testing"]),
    ("디버깅", ["debug", "debugging"]),
    ("리팩토링", ["refactoring", "refactor"]),
    ("아키텍처", ["architecture", "architect"]),
    
    # System terms
    ("네트워크", ["network", "networking"]),
    ("운영체제", ["operating", "system", "OS"]),
    ("컴파일러", ["compiler", "compile"]),
    ("알고리즘", ["algorithm", "algorithms"]),
    ("최적화", ["optimization", "optimize"]),
    
    # Security terms
    ("보안", ["security", "secure", "protection"]),
    ("암호화", ["encryption", "encrypt", "crypto"]),
    ("인증", ["authentication", "auth"]),
    
    # Data terms
    ("분석", ["analysis", "analytics", "analyze"]),
    ("인덱싱", ["indexing", "index"]),
    ("쿼리", ["query", "queries"]),
    ("캐싱", ["caching", "cache"]),
    ("스케일링", ["scaling", "scale"]),
]

In [None]:
# Run comprehensive tests and calculate metrics
print("=" * 80)
print("Comprehensive Test Results")
print("=" * 80)

ko_preserved_count = 0
en_activated_count = 0
total = len(comprehensive_tests)

results_table = []

for ko_term, en_terms in comprehensive_tests:
    result = encode_term(ko_term, top_k=10)
    top_tokens = [t[0] for t in result['tokens']]
    top_values = [t[1] for t in result['tokens']]
    
    # Check Korean preservation
    ko_preserved = any(ko_term in tok or tok in ko_term for tok in top_tokens if tok)
    if ko_preserved:
        ko_preserved_count += 1
    
    # Check English activation
    en_found = []
    for en in en_terms:
        for tok in top_tokens:
            if en.lower() in tok.lower() or tok.lower() in en.lower():
                en_found.append(en)
                break
    en_activated = len(en_found) > 0
    if en_activated:
        en_activated_count += 1
    
    # Store result
    ko_mark = "o" if ko_preserved else "x"
    en_mark = "o" if en_activated else "x"
    top_3 = [f"{t}({v:.2f})" for t, v in zip(top_tokens[:3], top_values[:3])]
    
    results_table.append({
        "term": ko_term,
        "ko": ko_mark,
        "en": en_mark,
        "en_found": en_found,
        "top_3": top_3,
    })
    
    print(f"{ko_term:12} | KO:{ko_mark} EN:{en_mark} | {', '.join(top_3)}")

# Summary
print("\n" + "=" * 80)
print("Summary")
print("=" * 80)
ko_rate = ko_preserved_count / total * 100
en_rate = en_activated_count / total * 100
combined = ko_rate + en_rate

print(f"Korean Preservation: {ko_preserved_count}/{total} ({ko_rate:.1f}%)")
print(f"English Activation:  {en_activated_count}/{total} ({en_rate:.1f}%)")
print(f"Combined Score:      {combined:.1f}")

## 5. Training History Analysis

In [None]:
import matplotlib.pyplot as plt

# Load training history
history_path = PROJECT_ROOT / "outputs" / "v19_xlm_large" / "training_history.json"

if history_path.exists():
    with open(history_path, "r") as f:
        history = json.load(f)

    print(f"Training epochs: {len(history)}")
    print("\nLoss components per epoch:")
    for i, epoch in enumerate(history):
        print(f"  Epoch {i+1}: total={epoch['total']:.4f}, self={epoch['self']:.4f}, "
              f"target={epoch['target']:.4f}, margin={epoch['margin']:.6f}, "
              f"negative={epoch['negative']:.4f}")
else:
    print(f"Training history not found at: {history_path}")
    history = None

In [None]:
# Plot training curves
if history:
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))

    epochs = range(1, len(history) + 1)

    # Total loss
    axes[0, 0].plot(epochs, [-h['total'] for h in history], '-o', color='#3498db')
    axes[0, 0].set_title('Total Loss (negated for visualization)')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].grid(True, alpha=0.3)

    # Self loss
    axes[0, 1].plot(epochs, [-h['self'] for h in history], '-o', color='#2ecc71')
    axes[0, 1].set_title('Self Loss (Korean Preservation)')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].grid(True, alpha=0.3)

    # Target loss
    axes[1, 0].plot(epochs, [-h['target'] for h in history], '-o', color='#e74c3c')
    axes[1, 0].set_title('Target Loss (English Activation)')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].grid(True, alpha=0.3)

    # Negative loss
    axes[1, 1].plot(epochs, [h['negative'] for h in history], '-o', color='#9b59b6')
    axes[1, 1].set_title('Negative Loss')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

## 6. Detailed Analysis of Specific Terms

In [None]:
# Analyze specific terms in detail
print("=" * 80)
print("Detailed Analysis: Sample Terms")
print("=" * 80)

detail_terms = ["추천", "검색", "인공지능", "데이터베이스", "보안"]

for term in detail_terms:
    result = encode_term(term, top_k=15)
    print(f"\n{term}:")
    for i, (token, weight) in enumerate(result["tokens"][:10]):
        print(f"  {i+1:2}. {token:20} {weight:.4f}")

## 7. Dataset Analysis

In [None]:
# Analyze v19 dataset
dataset_path = PROJECT_ROOT / "dataset" / "v19_high_quality" / "term_pairs.jsonl"

if dataset_path.exists():
    data = []
    sources = {}
    
    with open(dataset_path, "r", encoding="utf-8") as f:
        for line in f:
            item = json.loads(line.strip())
            data.append(item)
            source = item.get("source", "unknown")
            sources[source] = sources.get(source, 0) + 1
    
    print(f"Total pairs: {len(data):,}")
    print(f"\nData sources:")
    for source, count in sorted(sources.items(), key=lambda x: -x[1]):
        print(f"  {source}: {count:,} ({count/len(data)*100:.1f}%)")
    
    print(f"\nSample pairs:")
    for item in data[:10]:
        print(f"  {item['ko']} -> {item['en']} ({item.get('source', 'unknown')})")
else:
    print(f"Dataset not found: {dataset_path}")

## 8. Conclusion

### v19 Results Summary

v19 was trained with the following goals:

1. **Remove wikidata noise** from v18 dataset
2. **Focus on quality** over quantity (~18K vs 33K pairs)
3. **Use same hyperparameters** as successful v17 model

### Key Observations

- Korean preservation rate shows how well the model maintains the input Korean term
- English activation rate shows cross-lingual term expansion capability
- Combined score = Korean rate + English rate

### Next Steps

1. **Analyze data quality**: Investigate what makes certain translation pairs effective
2. **Loss function tuning**: Consider adjusting target loss weight for better English activation
3. **Data filtering**: Apply stricter quality filters to identify high-impact pairs