# v21.4 Comprehensive Evaluation

## Evaluation Metrics

1. **Single-term Recall**: Focus on previously problematic terms
2. **Garbage Detection**: Identify invalid token outputs
3. **Sentence-level Performance**: MRR, Recall@K
4. **Version Comparison**: v21.2 vs v21.3 vs v21.4

In [1]:
import sys
from pathlib import Path

def find_project_root():
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import AutoTokenizer, AutoModelForMaskedLM
from typing import Dict, List, Tuple, Set
from collections import defaultdict

print(f"Project root: {PROJECT_ROOT}")
print(f"CUDA available: {torch.cuda.is_available()}")

Project root: /home/west/Documents/cursor-workspace/opensearch-neural-pre-train
CUDA available: True


## 1. Load Models

In [2]:
# Paths
V21_4_MODEL_PATH = PROJECT_ROOT / "outputs" / "v21.4_korean_enhanced" / "best_model.pt"
V21_3_MODEL_PATH = PROJECT_ROOT / "outputs" / "v21.3_korean_enhanced" / "best_model.pt"
V21_2_MODEL_PATH = PROJECT_ROOT / "outputs" / "v21.2_korean_legal_medical" / "best_model.pt"

MODEL_NAME = "skt/A.X-Encoder-base"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"v21.4: {V21_4_MODEL_PATH.exists()}")
print(f"v21.3: {V21_3_MODEL_PATH.exists()}")
print(f"v21.2: {V21_2_MODEL_PATH.exists()}")

v21.4: True
v21.3: True
v21.2: True


In [3]:
class SPLADEModel(nn.Module):
    """SPLADE model for Korean sparse retrieval."""
    
    def __init__(self, model_name: str = "skt/A.X-Encoder-base"):
        super().__init__()
        self.model = AutoModelForMaskedLM.from_pretrained(model_name)
        self.config = self.model.config
        self.relu = nn.ReLU()
    
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        token_scores = torch.log1p(self.relu(logits))
        mask = attention_mask.unsqueeze(-1).float()
        token_scores = token_scores * mask
        sparse_repr, _ = token_scores.max(dim=1)
        token_weights = token_scores.max(dim=-1).values
        return sparse_repr, token_weights


def load_model(path: Path, model_name: str, device):
    """Load a trained model."""
    if not path.exists():
        return None
    
    model = SPLADEModel(model_name)
    checkpoint = torch.load(path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    model = model.to(device)
    model.eval()
    
    return model


# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Load models
model_v21_4 = load_model(V21_4_MODEL_PATH, MODEL_NAME, device)
model_v21_3 = load_model(V21_3_MODEL_PATH, MODEL_NAME, device)
model_v21_2 = load_model(V21_2_MODEL_PATH, MODEL_NAME, device)

print(f"v21.4 loaded: {model_v21_4 is not None}")
print(f"v21.3 loaded: {model_v21_3 is not None}")
print(f"v21.2 loaded: {model_v21_2 is not None}")

    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    
  queued_call()


v21.4 loaded: True
v21.3 loaded: True
v21.2 loaded: True


## 2. Token Validation Utilities

In [4]:
def is_valid_korean_token(token: str) -> bool:
    """
    Check if a token is valid Korean/English.
    
    Valid tokens contain only:
    - Korean characters (Hangul syllables, Jamo)
    - ASCII letters and digits
    - Basic CJK characters
    """
    if not token:
        return False
    
    # Remove ## prefix
    clean = token.replace('##', '').strip()
    if not clean:
        return False
    
    # Skip special tokens
    if '<' in token or '>' in token or '[' in token or ']' in token:
        return False
    
    for char in clean:
        code = ord(char)
        
        # Hangul syllables (가-힣)
        if 0xAC00 <= code <= 0xD7A3:
            continue
        # Hangul Jamo
        if 0x1100 <= code <= 0x11FF:
            continue
        # Hangul Compatibility Jamo
        if 0x3130 <= code <= 0x318F:
            continue
        # ASCII letters and digits
        if (0x0041 <= code <= 0x005A) or (0x0061 <= code <= 0x007A) or (0x0030 <= code <= 0x0039):
            continue
        # Basic CJK
        if 0x4E00 <= code <= 0x9FFF:
            continue
        
        return False
    
    return True


# Build special token set
special_token_ids = set()
for attr in ['pad_token_id', 'cls_token_id', 'sep_token_id', 
             'unk_token_id', 'mask_token_id', 'bos_token_id', 'eos_token_id']:
    token_id = getattr(tokenizer, attr, None)
    if token_id is not None:
        special_token_ids.add(token_id)

## 3. Expansion Function

In [5]:
@torch.no_grad()
def get_synonym_expansion(
    model: nn.Module,
    tokenizer,
    text: str,
    device: torch.device,
    top_k: int = 20,
    min_weight: float = 0.5,
) -> List[Tuple[str, float]]:
    """
    Get top-k activated tokens for a given text.
    """
    model.eval()
    
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=64)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    weights, _ = model(inputs["input_ids"], inputs["attention_mask"])
    weights = weights[0].cpu()
    
    # Mask special tokens
    for tid in special_token_ids:
        if tid < len(weights):
            weights[tid] = -float('inf')
    
    # Get top tokens
    top_weights, top_indices = weights.topk(min(top_k * 10, len(weights)))
    
    results = []
    for idx, weight in zip(top_indices.tolist(), top_weights.tolist()):
        if weight < min_weight:
            continue
        
        token = tokenizer.decode([idx]).strip()
        
        if not is_valid_korean_token(token):
            continue
        
        results.append((token, weight))
        
        if len(results) >= top_k:
            break
    
    return results

## 4. Problem Terms Evaluation

In [6]:
# Problem terms from v21.3
PROBLEM_TERMS = {
    "추천": ["권장", "권유", "제안", "소개"],
    "데이터베이스": ["DB", "디비", "저장소", "데이터"],
    "증상": ["증세", "징후", "현상", "이상"],
    "질환": ["질병", "병", "병증", "이환"],
    "인슐린": ["insulin", "호르몬", "혈당", "당뇨"],
}


def evaluate_problem_terms(model, name: str) -> Dict:
    """Evaluate model on problem terms."""
    if model is None:
        return None
    
    results = {}
    total_recall = 0
    total_garbage = 0
    
    print(f"\n{name} - Problem Terms:")
    print("=" * 60)
    
    for term, expected in PROBLEM_TERMS.items():
        expansions = get_synonym_expansion(model, tokenizer, term, device, top_k=15)
        
        # Check recall
        top_tokens = [t for t, _ in expansions]
        hits = 0
        for exp in expected:
            for tok in top_tokens:
                if exp.lower() in tok.lower() or tok.lower() in exp.lower():
                    hits += 1
                    break
        recall = hits / len(expected) * 100
        total_recall += recall
        
        # Check garbage ratio
        if len(expansions) == 0:
            garbage_ratio = 100
            total_garbage += 1
        else:
            garbage_ratio = 0
        
        print(f"\n{term}:")
        print(f"  Expected: {expected}")
        print(f"  Recall: {recall:.0f}%")
        if expansions:
            print(f"  Top expansions: {[(t, f'{w:.2f}') for t, w in expansions[:5]]}")
        else:
            print(f"  Top expansions: (GARBAGE OUTPUT)")
        
        results[term] = {
            "recall": recall,
            "top_tokens": top_tokens[:5],
            "is_garbage": len(expansions) == 0,
        }
    
    avg_recall = total_recall / len(PROBLEM_TERMS)
    garbage_count = total_garbage
    
    print(f"\nSummary:")
    print(f"  Average Recall: {avg_recall:.1f}%")
    print(f"  Garbage outputs: {garbage_count}/{len(PROBLEM_TERMS)}")
    
    return {
        "terms": results,
        "avg_recall": avg_recall,
        "garbage_count": garbage_count,
    }

In [7]:
# Evaluate all models
results_v21_4 = evaluate_problem_terms(model_v21_4, "v21.4")
results_v21_3 = evaluate_problem_terms(model_v21_3, "v21.3")
results_v21_2 = evaluate_problem_terms(model_v21_2, "v21.2")


v21.4 - Problem Terms:

추천:
  Expected: ['권장', '권유', '제안', '소개']
  Recall: 100%
  Top expansions: [('추천', '3.53'), ('권장', '2.99'), ('권유', '2.92'), ('##천', '2.87'), ('소개', '2.87')]

데이터베이스:
  Expected: ['DB', '디비', '저장소', '데이터']
  Recall: 50%
  Top expansions: [('데이터베이스', '3.45'), ('DB', '3.09'), ('데이터', '2.98'), ('##베이스', '2.75'), ('컴퓨터', '2.67')]

증상:
  Expected: ['증세', '징후', '현상', '이상']
  Recall: 75%
  Top expansions: [('증상', '3.54'), ('증세', '3.19'), ('증후', '3.03'), ('질환', '2.99'), ('##증', '2.97')]

질환:
  Expected: ['질병', '병', '병증', '이환']
  Recall: 75%
  Top expansions: [('질환', '3.49'), ('질병', '3.28'), ('증상', '3.07'), ('병', '3.07'), ('##병', '2.98')]

인슐린:
  Expected: ['insulin', '호르몬', '혈당', '당뇨']
  Recall: 75%
  Top expansions: [('인슐린', '3.50'), ('##슐린', '3.07'), ('호르몬', '3.01'), ('혈당', '2.98'), ('당뇨병', '2.83')]

Summary:
  Average Recall: 75.0%
  Garbage outputs: 0/5

v21.3 - Problem Terms:

추천:
  Expected: ['권장', '권유', '제안', '소개']
  Recall: 0%
  Top expansions: [('##헟', '1.58'), 

## 5. General Terms Evaluation

In [8]:
# General evaluation test cases
GENERAL_TEST_CASES = [
    # General
    ("검색", ["탐색", "조회", "찾기", "서치"]),
    ("인공지능", ["AI", "에이아이", "기계지능"]),
    ("컴퓨터", ["PC", "전산", "컴퓨팅"]),
    # Legal
    ("손해배상", ["배상", "보상", "손해"]),
    ("판결", ["판례", "선고", "결정"]),
    ("소송", ["재판", "법정", "송사"]),
    # Medical
    ("진단", ["진찰", "검진", "판단"]),
    ("치료", ["처치", "요법", "치유"]),
    ("당뇨병", ["당뇨", "혈당", "diabetes"]),
]


def evaluate_general_terms(model, name: str) -> Dict:
    """Evaluate model on general test cases."""
    if model is None:
        return None
    
    total_recall = 0
    results = []
    
    for source, expected in GENERAL_TEST_CASES:
        expansions = get_synonym_expansion(model, tokenizer, source, device, top_k=20)
        top_tokens = [t for t, _ in expansions]
        
        hits = 0
        for exp in expected:
            for tok in top_tokens:
                if exp.lower() in tok.lower() or tok.lower() in exp.lower():
                    hits += 1
                    break
        
        recall = hits / len(expected) * 100
        total_recall += recall
        results.append({"source": source, "recall": recall})
    
    avg_recall = total_recall / len(GENERAL_TEST_CASES)
    return {"results": results, "avg_recall": avg_recall}

In [9]:
# Evaluate all models
general_v21_4 = evaluate_general_terms(model_v21_4, "v21.4")
general_v21_3 = evaluate_general_terms(model_v21_3, "v21.3")
general_v21_2 = evaluate_general_terms(model_v21_2, "v21.2")

print("\nGeneral Terms Recall Comparison:")
print("=" * 80)
print(f"{'Source':<15} {'Expected':<25} {'v21.4':>10} {'v21.3':>10} {'v21.2':>10}")
print("-" * 80)

for i, (source, expected) in enumerate(GENERAL_TEST_CASES):
    r4 = general_v21_4["results"][i]["recall"] if general_v21_4 else 0
    r3 = general_v21_3["results"][i]["recall"] if general_v21_3 else 0
    r2 = general_v21_2["results"][i]["recall"] if general_v21_2 else 0
    
    print(f"{source:<15} {str(expected[:2]):<25} {r4:>9.0f}% {r3:>9.0f}% {r2:>9.0f}%")

print("-" * 80)
avg_4 = general_v21_4["avg_recall"] if general_v21_4 else 0
avg_3 = general_v21_3["avg_recall"] if general_v21_3 else 0
avg_2 = general_v21_2["avg_recall"] if general_v21_2 else 0
print(f"{'AVERAGE':<40} {avg_4:>9.1f}% {avg_3:>9.1f}% {avg_2:>9.1f}%")


General Terms Recall Comparison:
Source          Expected                       v21.4      v21.3      v21.2
--------------------------------------------------------------------------------
검색              ['탐색', '조회']                     75%        75%        75%
인공지능            ['AI', '에이아이']                   67%        67%        67%
컴퓨터             ['PC', '전산']                    100%       100%       100%
손해배상            ['배상', '보상']                    100%       100%       100%
판결              ['판례', '선고']                    100%       100%       100%
소송              ['재판', '법정']                     33%        67%        33%
진단              ['진찰', '검진']                     67%        67%       100%
치료              ['처치', '요법']                    100%        67%        67%
당뇨병             ['당뇨', '혈당']                     67%        67%        67%
--------------------------------------------------------------------------------
AVERAGE                                       78.7%   

## 6. Natural Language Queries

In [10]:
# Natural language queries
NL_QUERIES = [
    "손해배상 청구 소송을 제기하려면 어떻게 해야 하나요?",
    "고혈압 환자의 식이요법에 대해 알려주세요.",
    "인공지능 기술의 최신 발전 동향",
    "서울에서 부산까지 KTX 소요 시간",
]


def evaluate_nl_queries(model, name: str):
    """Evaluate model on natural language queries."""
    if model is None:
        return
    
    print(f"\n{name} - Natural Language Queries:")
    print("=" * 70)
    
    for query in NL_QUERIES:
        expansions = get_synonym_expansion(model, tokenizer, query, device, top_k=10)
        print(f"\nQuery: {query}")
        if expansions:
            print(f"Top expansions: {[(t, f'{w:.2f}') for t, w in expansions[:6]]}")
        else:
            print("Top expansions: (GARBAGE OUTPUT)")


evaluate_nl_queries(model_v21_4, "v21.4")
evaluate_nl_queries(model_v21_3, "v21.3")


v21.4 - Natural Language Queries:

Query: 손해배상 청구 소송을 제기하려면 어떻게 해야 하나요?
Top expansions: [('##을', '3.64'), ('##상', '3.48'), ('##하', '3.44'), ('##를', '3.42'), ('##는', '3.41'), ('##면', '3.34')]

Query: 고혈압 환자의 식이요법에 대해 알려주세요.
Top expansions: [('##에', '3.66'), ('##의', '3.62'), ('##는', '3.39'), ('##을', '3.36'), ('고', '3.35'), ('##법', '3.33')]

Query: 인공지능 기술의 최신 발전 동향
Top expansions: [('##의', '3.63'), ('##는', '3.41'), ('기술', '3.35'), ('인공지능', '3.34'), ('##을', '3.26'), ('##에', '3.26')]

Query: 서울에서 부산까지 KTX 소요 시간
Top expansions: [('##에', '3.64'), ('##서', '3.54'), ('##지', '3.54'), ('시간', '3.38'), ('##까', '3.35'), ('서울', '3.30')]

v21.3 - Natural Language Queries:

Query: 손해배상 청구 소송을 제기하려면 어떻게 해야 하나요?
Top expansions: [('##하', '3.27'), ('##상', '3.24'), ('##면', '3.23'), ('##요', '3.18'), ('소송', '3.14'), ('하나', '3.10')]

Query: 고혈압 환자의 식이요법에 대해 알려주세요.
Top expansions: [('##법', '3.41'), ('##주', '3.31'), ('고', '3.24'), ('##요', '3.14'), ('##의', '3.02'), ('식이', '2.99')]

Query: 인공지능 기술의 최신 발전 동향
Top e

## 7. Final Summary

In [11]:
print("\n" + "=" * 70)
print("v21.4 Evaluation Summary")
print("=" * 70)

print("\n| Metric | v21.2 | v21.3 | v21.4 | Target |")
print("|--------|-------|-------|-------|--------|")

# Problem terms recall
prob_4 = results_v21_4["avg_recall"] if results_v21_4 else 0
prob_3 = results_v21_3["avg_recall"] if results_v21_3 else 0
prob_2 = results_v21_2["avg_recall"] if results_v21_2 else 0
print(f"| Problem Terms Recall | {prob_2:.1f}% | {prob_3:.1f}% | {prob_4:.1f}% | 80%+ |")

# General terms recall
gen_4 = general_v21_4["avg_recall"] if general_v21_4 else 0
gen_3 = general_v21_3["avg_recall"] if general_v21_3 else 0
gen_2 = general_v21_2["avg_recall"] if general_v21_2 else 0
print(f"| General Terms Recall | {gen_2:.1f}% | {gen_3:.1f}% | {gen_4:.1f}% | 80%+ |")

# Garbage outputs
garb_4 = results_v21_4["garbage_count"] if results_v21_4 else 0
garb_3 = results_v21_3["garbage_count"] if results_v21_3 else 0
garb_2 = results_v21_2["garbage_count"] if results_v21_2 else 0
print(f"| Garbage Outputs | {garb_2}/5 | {garb_3}/5 | {garb_4}/5 | 0/5 |")

print("\n" + "=" * 70)

# Improvement analysis
if results_v21_4 and results_v21_3:
    improvement = prob_4 - prob_3
    if improvement > 0:
        print(f"\nImprovement over v21.3: +{improvement:.1f}%")
    else:
        print(f"\nRegression from v21.3: {improvement:.1f}%")


v21.4 Evaluation Summary

| Metric | v21.2 | v21.3 | v21.4 | Target |
|--------|-------|-------|-------|--------|
| Problem Terms Recall | 75.0% | 0.0% | 75.0% | 80%+ |
| General Terms Recall | 78.7% | 78.7% | 78.7% | 80%+ |
| Garbage Outputs | 0/5 | 0/5 | 0/5 | 0/5 |


Improvement over v21.3: +75.0%
