# v12 Inference Test: Improved Term-Level KO-EN Model

Test the trained v12 model on various Korean queries.

**Key Improvements from v11:**
- Filtered Wikidata (removed person names)
- Oversampled MUSE dictionary (3x)
- Added IT terminology and common vocabulary
- Increased loss weights for stronger English activation

In [1]:
import sys
from pathlib import Path

def find_project_root():
    candidates = [
        Path.cwd(),
        Path.cwd().parent,
        Path.cwd().parent.parent,
        Path("/home/west/Documents/cursor-workspace/opensearch-neural-pre-train"),
    ]
    for candidate in candidates:
        if (candidate / "CLAUDE.md").exists() or (candidate / ".git").exists():
            return candidate
    return Path("/home/west/Documents/cursor-workspace/opensearch-neural-pre-train")

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

Project root: /home/west/Documents/cursor-workspace/opensearch-neural-pre-train


In [2]:
import torch
from transformers import AutoTokenizer
from src.model.splade_model import create_splade_model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


## Load Model

In [3]:
# Load best model checkpoint
checkpoint_path = PROJECT_ROOT / 'outputs' / 'v12_improved' / 'best_model.pt'

if not checkpoint_path.exists():
    print(f"Checkpoint not found at {checkpoint_path}")
    print("Please run scripts/train_v12.py first!")
else:
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    config = checkpoint['config']
    
    print(f"Loaded checkpoint from: {checkpoint_path}")
    print(f"Model: {config['model_name']}")
    print(f"EN Rate: {checkpoint.get('en_rate', 'N/A')}%")
    print(f"KO Rate: {checkpoint.get('ko_rate', 'N/A')}%")

Loaded checkpoint from: /home/west/Documents/cursor-workspace/opensearch-neural-pre-train/outputs/v12_improved/best_model.pt
Model: bert-base-multilingual-cased
EN Rate: 78.94736842105263%
KO Rate: 65.78947368421053%


    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    
  queued_call()


In [4]:
# Create and load model
tokenizer = AutoTokenizer.from_pretrained(config['model_name'])

model = create_splade_model(
    model_name=config['model_name'],
    use_idf=False,
    use_expansion=True,
    expansion_mode='mlm',
)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()

print("Model loaded successfully!")

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model loaded successfully!


## Inference Functions

In [5]:
def is_korean_char(c: str) -> bool:
    return '\uac00' <= c <= '\ud7a3' or '\u1100' <= c <= '\u11ff' or '\u3130' <= c <= '\u318f'

def is_english_char(c: str) -> bool:
    return c.isalpha() and c.isascii()

def classify_token(token: str) -> str:
    """Classify token as Korean, English, or Other."""
    clean = token.replace('##', '')
    if not clean:
        return 'other'
    
    has_korean = any(is_korean_char(c) for c in clean)
    has_english = any(is_english_char(c) for c in clean)
    
    if has_korean:
        return 'korean'
    elif has_english:
        return 'english'
    else:
        return 'other'

def analyze_query(text: str, top_k: int = 50) -> dict:
    """Analyze a query and return top activated tokens."""
    encoding = tokenizer(
        text,
        max_length=64,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    with torch.no_grad():
        sparse_rep, _ = model(
            encoding['input_ids'].to(device),
            encoding['attention_mask'].to(device)
        )
    
    sparse_rep = sparse_rep[0].cpu()
    
    # Get top-k tokens
    top_values, top_indices = torch.topk(sparse_rep, k=top_k)
    top_tokens = tokenizer.convert_ids_to_tokens(top_indices.tolist())
    
    # Classify tokens
    korean_tokens = []
    english_tokens = []
    other_tokens = []
    
    for token, value in zip(top_tokens, top_values.tolist()):
        token_type = classify_token(token)
        if token_type == 'korean':
            korean_tokens.append((token, value))
        elif token_type == 'english':
            english_tokens.append((token, value))
        else:
            other_tokens.append((token, value))
    
    return {
        'input': text,
        'korean': korean_tokens,
        'english': english_tokens,
        'other': other_tokens,
        'top_10': list(zip(top_tokens[:10], top_values[:10].tolist())),
    }

## Test Queries

In [6]:
test_queries = [
    # IT/Tech
    "머신러닝",
    "딥러닝",
    "자연어처리",
    "인공지능",
    "데이터베이스",
    "추천시스템",
    "검색엔진",
    "클라우드",
    "서버",
    "네트워크",
    
    # General
    "컴퓨터",
    "인터넷",
    "프로그래밍",
    "알고리즘",
    "데이터",
]

print("=" * 80)
print("v12 INFERENCE TEST RESULTS")
print("=" * 80)

for query in test_queries:
    result = analyze_query(query)
    
    print(f"\n입력: {result['input']}")
    
    # Korean tokens
    ko_tokens = [t for t, v in result['korean'][:5]]
    print(f"  한글 토큰: {ko_tokens}")
    
    # English tokens
    en_tokens = [t for t, v in result['english'][:10]]
    print(f"  영어 토큰: {en_tokens}")

v12 INFERENCE TEST RESULTS

입력: 머신러닝
  한글 토큰: ['##러', '머신']
  영어 토큰: ['##ing', '##ng', '##s', 're', 'c', '##er', 'bu', 'machine', 'mon', 'sp']

입력: 딥러닝
  한글 토큰: ['딥', '##러']
  영어 토큰: ['##ing', '##ng', '##s', 're', 'c', 'in', '##er', 'bu', 'sp', '##ning']

입력: 자연어처리
  한글 토큰: ['자', '##리', '##어']
  영어 토큰: ['##s', '##y', '##ing', 're', '##ry', 'c', '##er', '##e', '##es', '##r']

입력: 인공지능
  한글 토큰: ['인', '##지', '##공']
  영어 토큰: ['##s', 're', 'in', '##ing', 'c', '##tion', 'sp', 'con', 's', '##y']

입력: 데이터베이스
  한글 토큰: ['##스', '데', '##이터', '##베', '##이스']
  영어 토큰: ['##s', 'data', '##y', '##es', 're', '##ing', 'de', 'bu', '##e', 'c']

입력: 추천시스템
  한글 토큰: ['##스', '추', '##시', '##이']
  영어 토큰: ['##s', 're', 'c', 'sp', 'in', 's', 'con', 'bu', 'system', 'co']

입력: 검색엔진
  한글 토큰: ['검', '##스', '##색']
  영어 토큰: ['##s', 're', 'c', 'search', '##r', '##y', 'sp', '##ing', '##n', '##es']

입력: 클라우드
  한글 토큰: ['##라', '##드', '클', '##우', '##스']
  영어 토큰: ['##s', 'c', '##d', '##ing', '##ed', '##y', '##er', 'sp', '##r', '

## Detailed Analysis

In [7]:
# Detailed analysis for key queries
key_queries = [
    ("추천시스템", ["추천", "시스템"], ["recommend", "system", "recommendation"]),
    ("검색엔진", ["검색", "엔진"], ["search", "engine"]),
    ("머신러닝", ["머신", "러닝"], ["machine", "learning"]),
    ("딥러닝", ["딥", "러닝"], ["deep", "learning"]),
    ("인공지능", ["인공", "지능"], ["artificial", "intelligence"]),
    ("자연어처리", ["자연어", "처리"], ["natural", "language", "processing"]),
]

print("\n" + "=" * 80)
print("DETAILED ANALYSIS")
print("=" * 80)

for query, expected_ko, expected_en in key_queries:
    result = analyze_query(query, top_k=100)
    
    print(f"\n입력: {query}")
    print(f"  기대 한글: {expected_ko}")
    print(f"  기대 영어: {expected_en}")
    
    # Check English matches
    all_en = [t for t, v in result['english']]
    found_en = []
    for exp in expected_en:
        exp_tokens = tokenizer.tokenize(exp.lower())
        for tok in exp_tokens:
            if tok in all_en:
                found_en.append(tok)
    
    print(f"  발견된 영어: {found_en}")
    print(f"  전체 한글 ({len(result['korean'])}): {[t for t, v in result['korean'][:10]]}")
    print(f"  전체 영어 ({len(result['english'])}): {all_en[:15]}")


DETAILED ANALYSIS

입력: 추천시스템
  기대 한글: ['추천', '시스템']
  기대 영어: ['recommend', 'system', 'recommendation']
  발견된 영어: ['re', '##com', '##mend', 'system', 're', '##com', '##mend', '##ation']
  전체 한글 (8): ['##스', '추', '##시', '##이', '##템', '##드', '##트', '##리']
  전체 영어 (92): ['##s', 're', 'c', 'sp', 'in', 's', 'con', 'bu', 'system', 'co', 'dis', '##es', 'ex', '##y', 'sh']

입력: 검색엔진
  기대 한글: ['검색', '엔진']
  기대 영어: ['search', 'engine']
  발견된 영어: ['search']
  전체 한글 (4): ['검', '##스', '##색', '##이']
  전체 영어 (96): ['##s', 're', 'c', 'search', '##r', '##y', 'sp', '##ing', '##n', '##es', 'con', 'bu', 's', '##er', 'in']

입력: 머신러닝
  기대 한글: ['머신', '러닝']
  기대 영어: ['machine', 'learning']
  발견된 영어: ['machine', 'learning']
  전체 한글 (6): ['##러', '머신', '##라', '##닝', '##리', '##레']
  전체 영어 (94): ['##ing', '##ng', '##s', 're', 'c', '##er', 'bu', 'machine', 'mon', 'sp', 'me', '##ting', '##ning', 'in', 'computing']

입력: 딥러닝
  기대 한글: ['딥', '러닝']
  기대 영어: ['deep', 'learning']
  발견된 영어: ['deep', 'learning']
  전체 한글 (4): 

## Summary Statistics

In [8]:
# Calculate statistics
total_ko_activated = 0
total_en_activated = 0
total_other = 0

for query in test_queries:
    result = analyze_query(query, top_k=50)
    total_ko_activated += len(result['korean'])
    total_en_activated += len(result['english'])
    total_other += len(result['other'])

n_queries = len(test_queries)

print("\n" + "=" * 80)
print("SUMMARY STATISTICS (Top-50 tokens per query)")
print("=" * 80)
print(f"\n  Queries tested: {n_queries}")
print(f"  Avg Korean tokens: {total_ko_activated / n_queries:.1f}")
print(f"  Avg English tokens: {total_en_activated / n_queries:.1f}")
print(f"  Avg Other tokens: {total_other / n_queries:.1f}")
print(f"\n  Korean ratio: {total_ko_activated / (total_ko_activated + total_en_activated + total_other) * 100:.1f}%")
print(f"  English ratio: {total_en_activated / (total_ko_activated + total_en_activated + total_other) * 100:.1f}%")


SUMMARY STATISTICS (Top-50 tokens per query)

  Queries tested: 15
  Avg Korean tokens: 4.2
  Avg English tokens: 45.8
  Avg Other tokens: 0.0

  Korean ratio: 8.4%
  English ratio: 91.6%


In [9]:
print("\n" + "=" * 80)
print("v12 INFERENCE TEST COMPLETE")
print("=" * 80)


v12 INFERENCE TEST COMPLETE
