# v11 Training: Term-Level KO-EN Neural Sparse Model

## Key Changes from v9/v10
- **Term-level data**: Direct term mappings (MUSE + Wikidata + IT terminology)
- **74K+ training pairs**: Much larger than v9's 10 samples
- **Expected output**: 추천 시스템 → [추천, 시스템, recommend, system, recommendation]

## Data Sources
1. MUSE bilingual dictionary (~20K pairs)
2. Wikidata entity labels (~53K pairs)
3. IT/Tech terminology (~150 pairs)

## 1. Setup

In [None]:
import sys
import json
import re
from pathlib import Path

def find_project_root():
    candidates = [
        Path.cwd(),
        Path.cwd().parent,
        Path.cwd().parent.parent,
        Path("/home/west/Documents/cursor-workspace/opensearch-neural-pre-train"),
    ]
    for candidate in candidates:
        if (candidate / "CLAUDE.md").exists() or (candidate / ".git").exists():
            return candidate
    return Path("/home/west/Documents/cursor-workspace/opensearch-neural-pre-train")

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from src.model.splade_model import create_splade_model

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Configuration

In [None]:
# v11 Training configuration
CONFIG = {
    # Model
    'model_name': 'bert-base-multilingual-cased',
    'max_length': 64,
    
    # Data - Term-level pairs
    'data_path': PROJECT_ROOT / 'dataset' / 'v11_term_pairs' / 'term_pairs.jsonl',
    
    # Training
    'batch_size': 128,
    'num_epochs': 5,
    'learning_rate': 2e-5,
    'warmup_ratio': 0.1,
    'max_grad_norm': 1.0,
    
    # Loss weights (optimized for term-level)
    'lambda_self': 1.0,       # Korean preservation
    'lambda_target': 2.0,     # English target activation
    'lambda_margin': 1.5,     # Ensure minimum English activation
    'lambda_negative': 0.5,   # Suppress non-target languages
    'lambda_sparsity': 0.01,  # Keep representations sparse
    
    # Margin threshold
    'target_margin': 1.0,
    
    # Output
    'output_dir': PROJECT_ROOT / 'outputs' / 'v11_term_level',
}

print("v11 Configuration:")
print("="*60)
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 3. Non-Target Language Token Detection

In [None]:
def is_korean_char(c: str) -> bool:
    """Check if character is Korean."""
    return '\uac00' <= c <= '\ud7a3' or '\u1100' <= c <= '\u11ff' or '\u3130' <= c <= '\u318f'

def is_english_char(c: str) -> bool:
    """Check if character is English letter."""
    return c.isalpha() and c.isascii()

def is_non_target_token(token: str) -> bool:
    """
    Check if token is from non-target language (Japanese, Chinese, etc.).
    We want to suppress these tokens.
    """
    clean = token.replace('##', '')
    if not clean:
        return False
    
    has_korean = any(is_korean_char(c) for c in clean)
    has_english = any(is_english_char(c) for c in clean)
    
    if has_korean or has_english:
        return False
    
    # Check for non-target scripts
    has_japanese = any('\u3040' <= c <= '\u309f' or '\u30a0' <= c <= '\u30ff' for c in clean)
    has_cjk = any('\u4e00' <= c <= '\u9fff' for c in clean)
    has_cyrillic = any('\u0400' <= c <= '\u04ff' for c in clean)
    has_arabic = any('\u0600' <= c <= '\u06ff' for c in clean)
    has_thai = any('\u0e00' <= c <= '\u0e7f' for c in clean)
    has_greek = any('\u0370' <= c <= '\u03ff' for c in clean)
    
    return has_japanese or has_cjk or has_cyrillic or has_arabic or has_thai or has_greek

In [None]:
# Build non-target token ID list
tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_name'])

print("Building non-target language token ID list...")
non_target_ids = []

for token_id in tqdm(range(tokenizer.vocab_size)):
    token = tokenizer.convert_ids_to_tokens(token_id)
    if is_non_target_token(token):
        non_target_ids.append(token_id)

non_target_ids_tensor = torch.tensor(non_target_ids, dtype=torch.long)
print(f"Found {len(non_target_ids):,} non-target language tokens")

## 4. Term-Level Dataset

In [None]:
class TermPairDataset(Dataset):
    """Dataset for term-level KO-EN pairs."""
    
    def __init__(self, data_path: Path, tokenizer, max_length: int = 64):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []
        
        print(f"Loading dataset from {data_path}...")
        
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in tqdm(f, desc="Loading data"):
                item = json.loads(line.strip())
                
                ko_term = item['ko']
                en_term = item['en']
                
                # Tokenize Korean term
                ko_tokens = tokenizer.tokenize(ko_term)
                ko_token_ids = tokenizer.convert_tokens_to_ids(ko_tokens)
                ko_token_ids = [tid for tid in ko_token_ids if tid != tokenizer.unk_token_id]
                
                # Tokenize English term (lowercase for better matching)
                en_tokens = tokenizer.tokenize(en_term.lower())
                en_token_ids = tokenizer.convert_tokens_to_ids(en_tokens)
                en_token_ids = [tid for tid in en_token_ids if tid != tokenizer.unk_token_id]
                
                if ko_token_ids and en_token_ids:
                    self.data.append({
                        'ko_term': ko_term,
                        'en_term': en_term,
                        'ko_token_ids': ko_token_ids,
                        'en_token_ids': en_token_ids,
                    })
        
        print(f"Loaded {len(self.data):,} valid term pairs")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        encoding = self.tokenizer(
            item['ko_term'],
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'ko_token_ids': item['ko_token_ids'],
            'en_token_ids': item['en_token_ids'],
        }


def collate_fn(batch):
    return {
        'input_ids': torch.stack([item['input_ids'] for item in batch]),
        'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
        'ko_token_ids': [item['ko_token_ids'] for item in batch],
        'en_token_ids': [item['en_token_ids'] for item in batch],
    }

In [None]:
# Load dataset
dataset = TermPairDataset(
    CONFIG['data_path'],
    tokenizer,
    CONFIG['max_length']
)

dataloader = DataLoader(
    dataset,
    batch_size=CONFIG['batch_size'],
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn,
    pin_memory=True
)

print(f"\nDataset size: {len(dataset):,}")
print(f"Batches per epoch: {len(dataloader):,}")

# Show samples
print("\nSample data:")
for i in range(min(5, len(dataset.data))):
    item = dataset.data[i]
    print(f"  {item['ko_term']} -> {item['en_term']}")

## 5. Loss Function

In [None]:
class TermLevelLoss(nn.Module):
    """
    Loss function for term-level KO-EN training.
    
    Components:
    1. Self-preservation: Keep Korean input tokens activated
    2. Target activation: Activate English translation tokens
    3. Margin loss: Ensure minimum activation for English tokens
    4. Negative sampling: Suppress non-target language tokens
    """
    
    def __init__(self, target_margin: float = 1.0, non_target_ids: torch.Tensor = None):
        super().__init__()
        self.target_margin = target_margin
        self.non_target_ids = non_target_ids
    
    def forward(
        self,
        sparse_rep: torch.Tensor,
        ko_token_ids: list,
        en_token_ids: list,
    ) -> dict:
        batch_size = sparse_rep.shape[0]
        device = sparse_rep.device
        
        self_loss = torch.tensor(0.0, device=device)
        target_loss = torch.tensor(0.0, device=device)
        margin_loss = torch.tensor(0.0, device=device)
        negative_loss = torch.tensor(0.0, device=device)
        
        n_valid = 0
        
        for i in range(batch_size):
            rep = sparse_rep[i]
            
            # 1. Self-preservation loss (Korean tokens)
            if ko_token_ids[i]:
                ko_ids = torch.tensor(ko_token_ids[i], device=device)
                ko_activations = rep[ko_ids]
                self_loss = self_loss - torch.log(ko_activations + 1e-8).mean()
            
            # 2. English target loss
            if en_token_ids[i]:
                en_ids = torch.tensor(en_token_ids[i], device=device)
                en_activations = rep[en_ids]
                target_loss = target_loss - torch.log(en_activations + 1e-8).mean()
                
                # 3. Margin loss - ensure minimum activation
                margin_loss = margin_loss + F.relu(self.target_margin - en_activations).mean()
            
            # 4. Negative sampling loss
            if self.non_target_ids is not None:
                non_target_ids_device = self.non_target_ids.to(device)
                non_target_activations = rep[non_target_ids_device]
                negative_loss = negative_loss + F.relu(non_target_activations - 0.1).mean()
            
            n_valid += 1
        
        if n_valid > 0:
            self_loss = self_loss / n_valid
            target_loss = target_loss / n_valid
            margin_loss = margin_loss / n_valid
            negative_loss = negative_loss / n_valid
        
        return {
            'self': self_loss,
            'target': target_loss,
            'margin': margin_loss,
            'negative': negative_loss,
        }

## 6. Model and Training Setup

In [None]:
# Create model
model = create_splade_model(
    model_name=CONFIG['model_name'],
    use_idf=False,
    use_expansion=True,
    expansion_mode='mlm',
)
model = model.to(device)

print(f"Model created: {CONFIG['model_name']}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Loss function
loss_fn = TermLevelLoss(
    target_margin=CONFIG['target_margin'],
    non_target_ids=non_target_ids_tensor
)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=0.01
)

# Scheduler
total_steps = len(dataloader) * CONFIG['num_epochs']
warmup_steps = int(total_steps * CONFIG['warmup_ratio'])

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

print(f"Total steps: {total_steps:,}")
print(f"Warmup steps: {warmup_steps:,}")

## 7. Evaluation Function

In [None]:
TEST_PAIRS = [
    ("머신러닝", ["machine", "learning"], ["머신", "러닝"]),
    ("딥러닝", ["deep", "learning"], ["딥", "러닝"]),
    ("자연어처리", ["natural", "language", "processing"], ["자연어", "처리"]),
    ("인공지능", ["artificial", "intelligence"], ["인공", "지능"]),
    ("데이터베이스", ["database"], ["데이터베이스"]),
    ("추천시스템", ["recommend", "system"], ["추천", "시스템"]),
    ("검색엔진", ["search", "engine"], ["검색", "엔진"]),
    ("클라우드", ["cloud"], ["클라우드"]),
    ("서버", ["server"], ["서버"]),
    ("네트워크", ["network"], ["네트워크"]),
]

def evaluate_model(model, tokenizer, device, top_k=50):
    """Evaluate model on test pairs."""
    model.eval()
    
    results = []
    ko_activated_total = 0
    en_activated_total = 0
    ko_expected_total = 0
    en_expected_total = 0
    non_target_in_top_total = 0
    
    with torch.no_grad():
        for ko_term, en_expected, ko_expected in TEST_PAIRS:
            encoding = tokenizer(
                ko_term,
                max_length=64,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            sparse_rep, _ = model(
                encoding['input_ids'].to(device),
                encoding['attention_mask'].to(device)
            )
            
            sparse_rep = sparse_rep[0].cpu()
            top_indices = torch.topk(sparse_rep, k=top_k).indices.tolist()
            top_tokens = tokenizer.convert_ids_to_tokens(top_indices)
            top_tokens_set = set(top_tokens)
            
            # Count non-target tokens
            non_target_count = sum(1 for t in top_tokens if is_non_target_token(t))
            non_target_in_top_total += non_target_count
            
            # Check Korean preservation
            ko_activated = []
            for ko in ko_expected:
                ko_toks = tokenizer.tokenize(ko)
                for tok in ko_toks:
                    ko_expected_total += 1
                    if tok in top_tokens_set:
                        ko_activated_total += 1
                        ko_activated.append(tok)
            
            # Check English activation
            en_activated = []
            for en in en_expected:
                en_toks = tokenizer.tokenize(en.lower())
                for tok in en_toks:
                    en_expected_total += 1
                    if tok in top_tokens_set:
                        en_activated_total += 1
                        en_activated.append(tok)
            
            results.append({
                'input': ko_term,
                'ko_activated': ko_activated,
                'en_activated': en_activated,
                'top_10': top_tokens[:10],
                'non_target_count': non_target_count,
            })
    
    model.train()
    
    ko_rate = ko_activated_total / ko_expected_total * 100 if ko_expected_total > 0 else 0
    en_rate = en_activated_total / en_expected_total * 100 if en_expected_total > 0 else 0
    avg_non_target = non_target_in_top_total / len(TEST_PAIRS)
    
    return {
        'ko_rate': ko_rate,
        'en_rate': en_rate,
        'avg_non_target': avg_non_target,
        'details': results
    }

In [None]:
# Initial evaluation
print("Initial evaluation (before training):")
init_eval = evaluate_model(model, tokenizer, device)
print(f"  Korean preservation rate: {init_eval['ko_rate']:.1f}%")
print(f"  English activation rate: {init_eval['en_rate']:.1f}%")
print(f"  Avg non-target tokens in top-50: {init_eval['avg_non_target']:.1f}")

## 8. Training Loop

In [None]:
# Create output directory
CONFIG['output_dir'].mkdir(parents=True, exist_ok=True)

# Training history
history = []

print("=" * 70)
print("STARTING TRAINING (v11 - Term Level)")
print(f"Dataset: {len(dataset):,} term pairs")
print(f"Epochs: {CONFIG['num_epochs']}")
print(f"Batch size: {CONFIG['batch_size']}")
print("=" * 70)

In [None]:
for epoch in range(CONFIG['num_epochs']):
    print(f"\n--- Epoch {epoch + 1}/{CONFIG['num_epochs']} ---")
    model.train()
    
    epoch_losses = {
        'total': 0.0,
        'self': 0.0,
        'target': 0.0,
        'margin': 0.0,
        'negative': 0.0,
        'sparsity': 0.0,
    }
    
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}")
    
    for batch_idx, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        # Forward pass
        sparse_rep, _ = model(input_ids, attention_mask)
        
        # Compute losses
        losses = loss_fn(
            sparse_rep,
            batch['ko_token_ids'],
            batch['en_token_ids'],
        )
        
        # Sparsity loss
        sparsity_loss = sparse_rep.mean()
        
        # Total loss
        total_loss = (
            CONFIG['lambda_self'] * losses['self'] +
            CONFIG['lambda_target'] * losses['target'] +
            CONFIG['lambda_margin'] * losses['margin'] +
            CONFIG['lambda_negative'] * losses['negative'] +
            CONFIG['lambda_sparsity'] * sparsity_loss
        )
        
        # Backward pass
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['max_grad_norm'])
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # Accumulate losses
        epoch_losses['total'] += total_loss.item()
        epoch_losses['self'] += losses['self'].item()
        epoch_losses['target'] += losses['target'].item()
        epoch_losses['margin'] += losses['margin'].item()
        epoch_losses['negative'] += losses['negative'].item()
        epoch_losses['sparsity'] += sparsity_loss.item()
        
        # Update progress bar
        if (batch_idx + 1) % 50 == 0:
            progress_bar.set_postfix({
                'loss': f"{epoch_losses['total'] / (batch_idx + 1):.4f}",
                'tgt': f"{epoch_losses['target'] / (batch_idx + 1):.4f}",
            })
    
    # Average losses
    n_batches = len(dataloader)
    for key in epoch_losses:
        epoch_losses[key] /= n_batches
    
    history.append(epoch_losses)
    
    # Evaluate
    eval_result = evaluate_model(model, tokenizer, device)
    
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"  Total Loss: {epoch_losses['total']:.4f}")
    print(f"  Self Loss: {epoch_losses['self']:.4f}")
    print(f"  Target Loss: {epoch_losses['target']:.4f}")
    print(f"  Korean Preservation: {eval_result['ko_rate']:.1f}%")
    print(f"  English Activation: {eval_result['en_rate']:.1f}%")
    print(f"  Avg Non-target in Top-50: {eval_result['avg_non_target']:.1f}")
    
    # Save checkpoint
    checkpoint_path = CONFIG['output_dir'] / f'checkpoint_epoch{epoch + 1}.pt'
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'losses': epoch_losses,
        'eval': eval_result,
        'config': {k: str(v) if isinstance(v, Path) else v for k, v in CONFIG.items()},
    }, checkpoint_path)
    print(f"  Saved: {checkpoint_path}")

## 9. Save Final Model

In [None]:
# Save final model
final_path = CONFIG['output_dir'] / 'final_model.pt'
torch.save({
    'model_state_dict': model.state_dict(),
    'config': {k: str(v) if isinstance(v, Path) else v for k, v in CONFIG.items()},
    'history': history,
}, final_path)

print(f"Final model saved: {final_path}")

# Save training history
with open(CONFIG['output_dir'] / 'training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print("Training history saved")

## 10. Final Evaluation

In [None]:
print("\n" + "=" * 70)
print("FINAL EVALUATION")
print("=" * 70)

final_eval = evaluate_model(model, tokenizer, device)

print(f"\nKorean Preservation Rate: {final_eval['ko_rate']:.1f}%")
print(f"English Activation Rate: {final_eval['en_rate']:.1f}%")
print(f"Avg Non-target in Top-50: {final_eval['avg_non_target']:.1f}")

print("\nDetailed Results:")
for result in final_eval['details']:
    print(f"\n  Input: {result['input']}")
    print(f"    Korean activated: {result['ko_activated']}")
    print(f"    English activated: {result['en_activated']}")
    print(f"    Top-10: {result['top_10']}")

In [None]:
# Plot training history
if len(history) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    epochs = range(1, len(history) + 1)

    # Main losses
    axes[0].plot(epochs, [h['total'] for h in history], 'b-', label='Total', linewidth=2)
    axes[0].plot(epochs, [h['self'] for h in history], 'g--', label='Self')
    axes[0].plot(epochs, [h['target'] for h in history], 'r--', label='Target')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('v11 Training Loss')
    axes[0].legend()
    axes[0].grid(True)

    # Component losses
    axes[1].plot(epochs, [h['margin'] for h in history], 'c-', label='Margin')
    axes[1].plot(epochs, [h['negative'] for h in history], 'm-', label='Negative')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Loss')
    axes[1].set_title('v11 Component Losses')
    axes[1].legend()
    axes[1].grid(True)

    plt.tight_layout()
    plt.savefig(CONFIG['output_dir'] / 'training_curves.png', dpi=150)
    plt.show()

In [None]:
print("\n" + "=" * 70)
print("v11 TRAINING COMPLETE")
print("=" * 70)
print(f"\nOutput directory: {CONFIG['output_dir']}")
print(f"\nNext step: Run 02_inference_test.ipynb for detailed evaluation")