# v19 Training - XLM-RoBERTa-large with Similarity-Weighted Loss

This notebook trains the v19 Korean-English cross-lingual SPLADE model.

## Key Features:
- **Model**: xlm-roberta-large (560M parameters)
- **Dataset**: v19_high_quality (1:N mixed Korean/English with similarity scores)
- **Format**: `{"ko": "프로그램", "terms": [{"term": "program", "sim": 0.95}, ...]}`
- **Loss**: Similarity-weighted target loss (higher weight for more similar terms)
- **Max targets**: 8 per Korean source term
- **Learning rate**: 2e-6
- **Epochs**: 10

In [3]:
import sys
from pathlib import Path

# Find project root
def find_project_root():
    """Find project root by looking for markers like pyproject.toml or src/"""
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

Project root: /home/west/Documents/cursor-workspace/opensearch-neural-pre-train


In [4]:
import json
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import GradScaler, autocast
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from tqdm.notebook import tqdm

from src.model.splade_model import create_splade_model

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

PyTorch version: 2.10.0.dev20251109+cu130
CUDA available: True
GPU: NVIDIA GB10
GPU Memory: 128.5 GB


    Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.
    Minimum and Maximum cuda capability supported by this version of PyTorch is
    (8.0) - (12.0)
    
  queued_call()


## 1. Configuration

In [5]:
# Training Configuration
CONFIG = {
    # Model
    "model_name": "xlm-roberta-large",
    "max_length": 64,

    # Data - 1:N mixed format (Korean -> [Korean + English terms])
    "data_path": PROJECT_ROOT / "dataset" / "v19_high_quality" / "term_mappings.jsonl",

    # Training - adjusted for larger batch size
    "batch_size": 64,                # 32 → 64 (2x)
    "gradient_accumulation_steps": 2,  # 4 → 2 (effective batch: 128 → 128)
    "num_epochs": 15,                # 10 → 15 (more epochs for better convergence)
    "learning_rate": 3e-6,           # 2e-6 → 3e-6 (1.5x, conservative scaling)
    "warmup_ratio": 0.1,             # 0.2 → 0.1 (shorter warmup)
    "max_grad_norm": 1.0,

    # Loss weights
    "lambda_self": 2.0,       # Korean source preservation
    "lambda_target": 5.0,     # Target term activation (Korean + English)
    "lambda_margin": 3.0,     # Margin loss
    "lambda_negative": 1.0,   # 0.5 → 1.0 (stronger suppression for special tokens)
    "lambda_sparsity": 0.01,  # 0.005 → 0.01 (stronger sparsity)
    "target_margin": 2.0,

    # Mixed precision
    "use_fp16": True,

    # Output
    "output_dir": PROJECT_ROOT / "outputs" / "v19_xlm_large",
}

print("Training Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

# Calculate training stats
dataset_size = 1012  # approximate
batches_per_epoch = dataset_size // CONFIG["batch_size"]
opt_steps_per_epoch = batches_per_epoch // CONFIG["gradient_accumulation_steps"]
total_opt_steps = opt_steps_per_epoch * CONFIG["num_epochs"]

print(f"\nTraining Stats:")
print(f"  Effective batch size: {CONFIG['batch_size'] * CONFIG['gradient_accumulation_steps']}")
print(f"  Batches per epoch: ~{batches_per_epoch}")
print(f"  Optimization steps per epoch: ~{opt_steps_per_epoch}")
print(f"  Total optimization steps: ~{total_opt_steps}")

Training Configuration:
  model_name: xlm-roberta-large
  max_length: 64
  data_path: /home/west/Documents/cursor-workspace/opensearch-neural-pre-train/dataset/v19_high_quality/term_mappings.jsonl
  batch_size: 64
  gradient_accumulation_steps: 2
  num_epochs: 15
  learning_rate: 3e-06
  warmup_ratio: 0.1
  max_grad_norm: 1.0
  lambda_self: 2.0
  lambda_target: 5.0
  lambda_margin: 3.0
  lambda_negative: 1.0
  lambda_sparsity: 0.01
  target_margin: 2.0
  use_fp16: True
  output_dir: /home/west/Documents/cursor-workspace/opensearch-neural-pre-train/outputs/v19_xlm_large

Training Stats:
  Effective batch size: 128
  Batches per epoch: ~15
  Optimization steps per epoch: ~7
  Total optimization steps: ~105


## 2. Helper Functions

In [6]:
def is_korean_char(c: str) -> bool:
    """Check if character is Korean."""
    return (
        "\uac00" <= c <= "\ud7a3"
        or "\u1100" <= c <= "\u11ff"
        or "\u3130" <= c <= "\u318f"
    )


def is_english_char(c: str) -> bool:
    """Check if character is English."""
    return c.isalpha() and c.isascii()


def is_non_target_token(token: str) -> bool:
    """Check if token is from non-target language (not Korean or English)."""
    clean = token.replace("\u2581", "").replace("##", "")  # Remove subword markers
    if not clean:
        return False

    has_korean = any(is_korean_char(c) for c in clean)
    has_english = any(is_english_char(c) for c in clean)

    if has_korean or has_english:
        return False

    # Check for other languages
    has_japanese = any(
        "\u3040" <= c <= "\u309f" or "\u30a0" <= c <= "\u30ff" for c in clean
    )
    has_cjk = any("\u4e00" <= c <= "\u9fff" for c in clean)
    has_cyrillic = any("\u0400" <= c <= "\u04ff" for c in clean)
    has_arabic = any("\u0600" <= c <= "\u06ff" for c in clean)
    has_thai = any("\u0e00" <= c <= "\u0e7f" for c in clean)
    has_greek = any("\u0370" <= c <= "\u03ff" for c in clean)

    return (
        has_japanese or has_cjk or has_cyrillic or has_arabic or has_thai or has_greek
    )

## 3. Dataset Class

Dataset for 1:N mixed Korean/English term mappings with similarity scores:
- Input format: `{"ko": "프로그램", "terms": [{"term": "program", "sim": 0.95}, {"term": "소프트웨어", "sim": 0.88}]}`
- Tokenizes Korean source term and all target terms (mixed Korean + English)
- Preserves similarity scores for weighted loss calculation

In [7]:
class TermMappingDataset(Dataset):
    """Dataset for 1:N Korean to mixed Korean/English term mappings with similarity.
    
    Format: {"ko": "프로그램", "terms": [{"term": "program", "sim": 0.95}, ...]}
    """

    def __init__(self, data_path: Path, tokenizer, max_length: int = 64):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []
        
        # Build set of special token IDs to exclude
        self.special_ids = {
            tokenizer.pad_token_id,
            tokenizer.cls_token_id,
            tokenizer.sep_token_id,
            tokenizer.unk_token_id,
            tokenizer.bos_token_id,
            tokenizer.eos_token_id,
        }
        self.special_ids = {t for t in self.special_ids if t is not None}
        
        # Add tokens by name
        for token_name in ['<s>', '</s>', '<pad>', '<unk>', '<mask>']:
            tid = tokenizer.convert_tokens_to_ids(token_name)
            if tid != tokenizer.unk_token_id:
                self.special_ids.add(tid)

        print(f"Loading dataset from {data_path}...")
        print(f"Special token IDs to exclude: {len(self.special_ids)}")

        with open(data_path, "r", encoding="utf-8") as f:
            for line in tqdm(f, desc="Loading data"):
                item = json.loads(line.strip())

                ko_term = item.get("ko", "")
                terms_data = item.get("terms", [])

                if not ko_term or not terms_data:
                    continue

                # Tokenize Korean source term (exclude special tokens)
                ko_tokens = tokenizer.tokenize(ko_term)
                ko_token_ids = tokenizer.convert_tokens_to_ids(ko_tokens)
                ko_token_ids = [
                    tid for tid in ko_token_ids 
                    if tid != tokenizer.unk_token_id and tid not in self.special_ids
                ]

                # Tokenize all target terms with similarity weights
                # Format: {token_id: similarity_weight}
                target_weights: dict = {}
                for term_info in terms_data:
                    # Handle both old format (string) and new format (dict)
                    if isinstance(term_info, dict):
                        term = term_info.get("term", "")
                        sim = term_info.get("sim", 1.0)
                    else:
                        term = term_info
                        sim = 1.0
                    
                    # Lowercase for consistency (affects English only)
                    term_lower = term.lower() if term.isascii() else term
                    tokens = tokenizer.tokenize(term_lower)
                    token_ids = tokenizer.convert_tokens_to_ids(tokens)
                    
                    for tid in token_ids:
                        # Exclude unknown and special tokens
                        if tid != tokenizer.unk_token_id and tid not in self.special_ids:
                            # Keep maximum similarity for each token
                            target_weights[tid] = max(target_weights.get(tid, 0.0), sim)

                if ko_token_ids and target_weights:
                    self.data.append(
                        {
                            "ko_term": ko_term,
                            "ko_token_ids": ko_token_ids,
                            "target_token_ids": list(target_weights.keys()),
                            "target_weights": list(target_weights.values()),
                        }
                    )

        print(f"Loaded {len(self.data):,} valid term mappings")
        
        # Statistics
        n_targets = [len(d["target_token_ids"]) for d in self.data]
        all_weights = [w for d in self.data for w in d["target_weights"]]
        print(f"Average target tokens per source: {sum(n_targets)/len(n_targets):.2f}")
        print(f"Similarity weights: min={min(all_weights):.3f}, max={max(all_weights):.3f}, mean={sum(all_weights)/len(all_weights):.3f}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        encoding = self.tokenizer(
            item["ko_term"],
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "ko_token_ids": item["ko_token_ids"],
            "target_token_ids": item["target_token_ids"],
            "target_weights": item["target_weights"],
        }


def collate_fn(batch):
    """Custom collate function."""
    return {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
        "ko_token_ids": [item["ko_token_ids"] for item in batch],
        "target_token_ids": [item["target_token_ids"] for item in batch],
        "target_weights": [item["target_weights"] for item in batch],
    }

## 4. Loss Function

In [8]:
class TermLevelLoss(nn.Module):
    """Loss function for term-level cross-lingual training with similarity weights.
    
    Components:
    - Self loss: Preserve Korean source term tokens
    - Target loss: Activate target tokens with SIMILARITY-WEIGHTED loss
    - Margin loss: Ensure minimum activation for target tokens
    - Negative loss: Suppress non-target language tokens
    
    The target loss uses similarity scores as weights:
    - Higher similarity = stronger loss signal for that token
    - This focuses learning on the most semantically similar terms
    """

    def __init__(self, target_margin: float = 2.0, non_target_ids: torch.Tensor = None):
        super().__init__()
        self.target_margin = target_margin
        self.non_target_ids = non_target_ids

    def forward(self, sparse_rep, ko_token_ids, target_token_ids, target_weights):
        """
        Args:
            sparse_rep: Sparse representations [batch_size, vocab_size]
            ko_token_ids: List of Korean source token IDs per sample
            target_token_ids: List of target token IDs (mixed KO/EN) per sample
            target_weights: List of similarity weights for each target token
        """
        batch_size = sparse_rep.shape[0]
        device = sparse_rep.device

        self_loss = torch.tensor(0.0, device=device)
        target_loss = torch.tensor(0.0, device=device)
        margin_loss = torch.tensor(0.0, device=device)
        negative_loss = torch.tensor(0.0, device=device)

        n_valid = 0

        for i in range(batch_size):
            rep = sparse_rep[i]

            # Self loss: maximize activation of Korean source tokens
            if ko_token_ids[i]:
                ko_ids = torch.tensor(ko_token_ids[i], device=device)
                ko_activations = rep[ko_ids]
                self_loss = self_loss - torch.log(ko_activations + 1e-8).mean()

            # Target loss: SIMILARITY-WEIGHTED activation of target tokens
            if target_token_ids[i]:
                tgt_ids = torch.tensor(target_token_ids[i], device=device)
                tgt_weights = torch.tensor(target_weights[i], device=device, dtype=torch.float32)
                tgt_activations = rep[tgt_ids]
                
                # Weighted log loss: higher similarity = stronger loss
                weighted_log_loss = -torch.log(tgt_activations + 1e-8) * tgt_weights
                target_loss = target_loss + weighted_log_loss.sum() / (tgt_weights.sum() + 1e-8)
                
                # Margin loss (weighted by similarity)
                margin_violations = F.relu(self.target_margin - tgt_activations) * tgt_weights
                margin_loss = margin_loss + margin_violations.sum() / (tgt_weights.sum() + 1e-8)

            # Negative loss: suppress non-target language tokens
            if self.non_target_ids is not None:
                non_target_ids_device = self.non_target_ids.to(device)
                non_target_activations = rep[non_target_ids_device]
                negative_loss = negative_loss + F.relu(
                    non_target_activations - 0.1
                ).mean()

            n_valid += 1

        if n_valid > 0:
            self_loss = self_loss / n_valid
            target_loss = target_loss / n_valid
            margin_loss = margin_loss / n_valid
            negative_loss = negative_loss / n_valid

        return {
            "self": self_loss,
            "target": target_loss,
            "margin": margin_loss,
            "negative": negative_loss,
        }

## 5. Evaluation Function

In [9]:
# Test pairs for evaluation (Korean source -> expected Korean synonyms + English translations)
TEST_PAIRS = [
    # (source_ko, expected_english, expected_korean_synonyms)
    ("머신러닝", ["machine", "learning"], ["머신", "러닝", "기계학습"]),
    ("딥러닝", ["deep", "learning"], ["딥", "러닝", "심층학습"]),
    ("자연어처리", ["natural", "language", "processing"], ["자연어", "처리"]),
    ("인공지능", ["artificial", "intelligence"], ["인공", "지능"]),
    ("검색엔진", ["search", "engine"], ["검색", "엔진"]),
    ("데이터베이스", ["database"], ["데이터", "베이스"]),
    ("클라우드", ["cloud"], ["클라우드"]),
    ("서버", ["server"], ["서버"]),
    ("네트워크", ["network"], ["네트워크"]),
    ("추천시스템", ["recommend", "system"], ["추천", "시스템"]),
    ("추천", ["recommend", "recommendation"], ["추천"]),
    ("신경망", ["neural", "network"], ["신경망", "신경"]),
    ("강화학습", ["reinforcement", "learning"], ["강화", "학습"]),
    ("컴퓨터비전", ["computer", "vision"], ["컴퓨터", "비전"]),
    ("음성인식", ["speech", "recognition"], ["음성", "인식"]),
]


def evaluate_model(model, tokenizer, device, top_k=50):
    """Evaluate model on test pairs - check both Korean and English activation."""
    model.eval()

    ko_activated_total = 0
    en_activated_total = 0
    ko_expected_total = 0
    en_expected_total = 0

    with torch.no_grad():
        for ko_term, en_expected, ko_expected in TEST_PAIRS:
            encoding = tokenizer(
                ko_term,
                max_length=64,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )

            with autocast("cuda", enabled=CONFIG["use_fp16"]):
                sparse_rep, _ = model(
                    encoding["input_ids"].to(device),
                    encoding["attention_mask"].to(device),
                )

            sparse_rep = sparse_rep[0].float().cpu()
            top_indices = torch.topk(sparse_rep, k=top_k).indices.tolist()
            top_tokens = tokenizer.convert_ids_to_tokens(top_indices)
            top_tokens_set = set(top_tokens)

            # Check Korean synonym/preservation activation
            for ko in ko_expected:
                ko_toks = tokenizer.tokenize(ko)
                for tok in ko_toks:
                    ko_expected_total += 1
                    if tok in top_tokens_set:
                        ko_activated_total += 1

            # Check English translation activation
            for en in en_expected:
                en_toks = tokenizer.tokenize(en.lower())
                for tok in en_toks:
                    en_expected_total += 1
                    if tok in top_tokens_set:
                        en_activated_total += 1

    model.train()

    ko_rate = (
        ko_activated_total / ko_expected_total * 100 if ko_expected_total > 0 else 0
    )
    en_rate = (
        en_activated_total / en_expected_total * 100 if en_expected_total > 0 else 0
    )

    return ko_rate, en_rate

## 6. Initialize Components

In [10]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Load tokenizer
print(f"\nLoading tokenizer: {CONFIG['model_name']}...")
tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"])
print(f"Vocab size: {tokenizer.vocab_size:,}")

Device: cuda

Loading tokenizer: xlm-roberta-large...
Vocab size: 250,002


In [11]:
# Build non-target language token ID list AND special tokens to suppress
print("Building token suppression lists...")

non_target_ids = []
special_token_ids = []

# Get special token IDs
special_tokens = {
    tokenizer.pad_token_id,
    tokenizer.cls_token_id,
    tokenizer.sep_token_id,
    tokenizer.unk_token_id,
    tokenizer.bos_token_id,
    tokenizer.eos_token_id,
}
special_tokens = {t for t in special_tokens if t is not None}

# Punctuation and symbols to suppress
suppress_patterns = {
    '.', ',', '!', '?', ':', ';', '-', '_', '(', ')', '[', ']', '{', '}',
    '"', "'", '`', '/', '\\', '@', '#', '$', '%', '^', '&', '*', '+', '=',
    '<', '>', '~', '|', '▶', '►', '●', '○', '■', '□', '★', '☆', '→', '←',
    '...', '..', '--', '==', '##', '@@',
}

for token_id in tqdm(range(tokenizer.vocab_size), desc="Scanning vocab"):
    token = tokenizer.convert_ids_to_tokens(token_id)
    
    # Special tokens
    if token_id in special_tokens:
        special_token_ids.append(token_id)
        continue
    
    # Check for special token markers
    if token in ['<s>', '</s>', '<pad>', '<unk>', '<mask>', '[CLS]', '[SEP]', '[PAD]', '[UNK]', '[MASK]']:
        special_token_ids.append(token_id)
        continue
    
    # Punctuation and symbols
    clean_token = token.replace('▁', '').replace('##', '').strip()
    if clean_token in suppress_patterns or (clean_token and all(c in suppress_patterns or not c.isalnum() for c in clean_token)):
        special_token_ids.append(token_id)
        continue
    
    # Non-target language (not Korean, not English)
    if is_non_target_token(token):
        non_target_ids.append(token_id)

# Combine for suppression
all_suppress_ids = list(set(non_target_ids + special_token_ids))
suppress_ids_tensor = torch.tensor(all_suppress_ids, dtype=torch.long)

print(f"Non-target language tokens: {len(non_target_ids):,}")
print(f"Special/punctuation tokens: {len(special_token_ids):,}")
print(f"Total tokens to suppress: {len(all_suppress_ids):,}")

Building token suppression lists...


Scanning vocab:   0%|          | 0/250002 [00:00<?, ?it/s]

Non-target language tokens: 76,209
Special/punctuation tokens: 2,481
Total tokens to suppress: 78,690


In [12]:
# Load dataset (1:N mixed format)
dataset = TermMappingDataset(CONFIG["data_path"], tokenizer, CONFIG["max_length"])

dataloader = DataLoader(
    dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn,
    pin_memory=True,
)

print(f"\nDataset size: {len(dataset):,}")
print(f"Batches per epoch: {len(dataloader):,}")
print(f"Effective batch size: {CONFIG['batch_size'] * CONFIG['gradient_accumulation_steps']}")

Loading dataset from /home/west/Documents/cursor-workspace/opensearch-neural-pre-train/dataset/v19_high_quality/term_mappings.jsonl...
Special token IDs to exclude: 5


Loading data: 0it [00:00, ?it/s]

Loaded 13,684 valid term mappings
Average target tokens per source: 11.08
Similarity weights: min=0.800, max=0.995, mean=0.888

Dataset size: 13,684
Batches per epoch: 214
Effective batch size: 128


In [13]:
# Create model
print(f"\nCreating model: {CONFIG['model_name']}...")
model = create_splade_model(
    model_name=CONFIG["model_name"],
    use_idf=False,
    use_expansion=True,
    expansion_mode="mlm",
)
model = model.to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,} ({n_params / 1e6:.1f}M)")


Creating model: xlm-roberta-large...


Some weights of the model checkpoint at xlm-roberta-large were not used when initializing XLMRobertaForMaskedLM: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing XLMRobertaForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Parameters: 560,142,482 (560.1M)


In [14]:
# Loss function - use combined suppression list
loss_fn = TermLevelLoss(
    target_margin=CONFIG["target_margin"], 
    non_target_ids=suppress_ids_tensor  # Now includes special tokens + punctuation
)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=CONFIG["learning_rate"], 
    weight_decay=0.01
)

# Scheduler
total_steps = (
    len(dataloader) * CONFIG["num_epochs"] // CONFIG["gradient_accumulation_steps"]
)
warmup_steps = int(total_steps * CONFIG["warmup_ratio"])

scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=warmup_steps, 
    num_training_steps=total_steps
)

print(f"Total optimization steps: {total_steps:,}")
print(f"Warmup steps: {warmup_steps:,}")
print(f"Tokens to suppress: {len(suppress_ids_tensor):,}")

# Mixed precision scaler
scaler = GradScaler("cuda", enabled=CONFIG["use_fp16"])

# Create output directory
CONFIG["output_dir"].mkdir(parents=True, exist_ok=True)
print(f"Output directory: {CONFIG['output_dir']}")

Total optimization steps: 1,605
Warmup steps: 160
Tokens to suppress: 78,690
Output directory: /home/west/Documents/cursor-workspace/opensearch-neural-pre-train/outputs/v19_xlm_large


## 7. Initial Evaluation

In [15]:
# Evaluate before training
ko_rate, en_rate = evaluate_model(model, tokenizer, device)
print(f"Initial Performance:")
print(f"  Korean Preservation: {ko_rate:.1f}%")
print(f"  English Activation: {en_rate:.1f}%")
print(f"  Combined Score: {ko_rate + en_rate:.1f}")

Initial Performance:
  Korean Preservation: 48.8%
  English Activation: 9.1%
  Combined Score: 57.9


## 8. Training Loop

In [16]:
# Training variables
history = []
best_score = 0
global_step = 0

print("=" * 70)
print("STARTING TRAINING")
print("=" * 70)

STARTING TRAINING


In [None]:
for epoch in range(CONFIG["num_epochs"]):
    print(f"\n{'='*70}")
    print(f"Epoch {epoch + 1}/{CONFIG['num_epochs']}")
    print(f"{'='*70}")
    
    model.train()
    epoch_losses = defaultdict(float)
    optimizer.zero_grad()

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}")

    for batch_idx, batch in enumerate(progress_bar):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        with autocast("cuda", enabled=CONFIG["use_fp16"]):
            sparse_rep, _ = model(input_ids, attention_mask)

            # Pass target_weights for similarity-weighted loss
            losses = loss_fn(
                sparse_rep,
                batch["ko_token_ids"],
                batch["target_token_ids"],
                batch["target_weights"],  # Similarity weights
            )

            sparsity_loss = sparse_rep.mean()

            total_loss = (
                CONFIG["lambda_self"] * losses["self"]
                + CONFIG["lambda_target"] * losses["target"]
                + CONFIG["lambda_margin"] * losses["margin"]
                + CONFIG["lambda_negative"] * losses["negative"]
                + CONFIG["lambda_sparsity"] * sparsity_loss
            )

            total_loss = total_loss / CONFIG["gradient_accumulation_steps"]

        scaler.scale(total_loss).backward()

        epoch_losses["total"] += total_loss.item() * CONFIG["gradient_accumulation_steps"]
        epoch_losses["self"] += losses["self"].item()
        epoch_losses["target"] += losses["target"].item()
        epoch_losses["margin"] += losses["margin"].item()
        epoch_losses["negative"] += losses["negative"].item()

        if (batch_idx + 1) % CONFIG["gradient_accumulation_steps"] == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), CONFIG["max_grad_norm"]
            )
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1

        if (batch_idx + 1) % 100 == 0:
            progress_bar.set_postfix(
                {
                    "loss": f"{epoch_losses['total'] / (batch_idx + 1):.4f}",
                    "tgt": f"{epoch_losses['target'] / (batch_idx + 1):.4f}",
                    "step": global_step,
                }
            )

    # Calculate average losses
    n_batches = len(dataloader)
    for key in epoch_losses:
        epoch_losses[key] /= n_batches

    history.append(dict(epoch_losses))

    # Evaluate
    ko_rate, en_rate = evaluate_model(model, tokenizer, device)
    combined_score = ko_rate + en_rate

    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"  Total Loss: {epoch_losses['total']:.4f}")
    print(f"  Self Loss: {epoch_losses['self']:.4f}")
    print(f"  Target Loss (weighted): {epoch_losses['target']:.4f}")
    print(f"  Korean Activation: {ko_rate:.1f}%")
    print(f"  English Activation: {en_rate:.1f}%")
    print(f"  Combined Score: {combined_score:.1f}")

    # Save checkpoint
    checkpoint_path = CONFIG["output_dir"] / f"checkpoint_epoch{epoch + 1}.pt"
    torch.save(
        {
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "losses": dict(epoch_losses),
            "ko_rate": ko_rate,
            "en_rate": en_rate,
            "config": {
                k: str(v) if isinstance(v, Path) else v for k, v in CONFIG.items()
            },
        },
        checkpoint_path,
    )
    print(f"  Saved: {checkpoint_path.name}")

    # Save best model
    if combined_score > best_score:
        best_score = combined_score
        best_path = CONFIG["output_dir"] / "best_model.pt"
        torch.save(
            {
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "ko_rate": ko_rate,
                "en_rate": en_rate,
                "combined_score": combined_score,
                "config": {
                    k: str(v) if isinstance(v, Path) else v
                    for k, v in CONFIG.items()
                },
            },
            best_path,
        )
        print(f"  *** New best model! Score: {combined_score:.1f} (KO:{ko_rate:.1f}% + EN:{en_rate:.1f}%) ***")


Epoch 1/15


Epoch 1:   0%|          | 0/214 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av


Epoch 1 Summary:
  Total Loss: -6.4263
  Self Loss: -1.3848
  Target Loss (weighted): -1.3229
  Korean Activation: 62.8%
  English Activation: 0.0%
  Combined Score: 62.8
  Saved: checkpoint_epoch1.pt
  *** New best model! Score: 62.8 (KO:62.8% + EN:0.0%) ***

Epoch 2/15


Epoch 2:   0%|          | 0/214 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av


Epoch 2 Summary:
  Total Loss: -9.0627
  Self Loss: -1.4000
  Target Loss (weighted): -1.3385
  Korean Activation: 51.2%
  English Activation: 0.0%
  Combined Score: 51.2
  Saved: checkpoint_epoch2.pt

Epoch 3/15


Epoch 3:   0%|          | 0/214 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av


Epoch 3 Summary:
  Total Loss: -9.5651
  Self Loss: -1.4321
  Target Loss (weighted): -1.3803
  Korean Activation: 41.9%
  English Activation: 0.0%
  Combined Score: 41.9
  Saved: checkpoint_epoch3.pt

Epoch 4/15


Epoch 4:   0%|          | 0/214 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av


Epoch 4 Summary:
  Total Loss: -9.6967
  Self Loss: -1.4433
  Target Loss (weighted): -1.3972
  Korean Activation: 41.9%
  English Activation: 0.0%
  Combined Score: 41.9
  Saved: checkpoint_epoch4.pt

Epoch 5/15


Epoch 5:   0%|          | 0/214 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av


Epoch 5 Summary:
  Total Loss: -9.7549
  Self Loss: -1.4500
  Target Loss (weighted): -1.4040
  Korean Activation: 41.9%
  English Activation: 0.0%
  Combined Score: 41.9
  Saved: checkpoint_epoch5.pt

Epoch 6/15


Epoch 6:   0%|          | 0/214 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

## 9. Save Final Model

In [None]:
# Save final model
final_path = CONFIG["output_dir"] / "final_model.pt"
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "config": {
            k: str(v) if isinstance(v, Path) else v for k, v in CONFIG.items()
        },
        "history": history,
    },
    final_path,
)

# Save training history
with open(CONFIG["output_dir"] / "training_history.json", "w") as f:
    json.dump(history, f, indent=2)

print("\n" + "=" * 70)
print("TRAINING COMPLETE")
print("=" * 70)
print(f"Final model saved: {final_path}")
print(f"Best combined score: {best_score:.1f}")

## 10. Training Summary

In [None]:
import matplotlib.pyplot as plt

# Plot training curves
if history:
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))

    epochs = range(1, len(history) + 1)

    # Total loss
    axes[0, 0].plot(epochs, [-h['total'] for h in history], '-o', color='#3498db')
    axes[0, 0].set_title('Total Loss (negated)')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].grid(True, alpha=0.3)

    # Self loss
    axes[0, 1].plot(epochs, [-h['self'] for h in history], '-o', color='#2ecc71')
    axes[0, 1].set_title('Self Loss (Korean Preservation)')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].grid(True, alpha=0.3)

    # Target loss
    axes[1, 0].plot(epochs, [-h['target'] for h in history], '-o', color='#e74c3c')
    axes[1, 0].set_title('Target Loss (English Activation)')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].grid(True, alpha=0.3)

    # Negative loss
    axes[1, 1].plot(epochs, [h['negative'] for h in history], '-o', color='#9b59b6')
    axes[1, 1].set_title('Negative Loss')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Loss')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(CONFIG["output_dir"] / "training_curves.png", dpi=150)
    plt.show()
    
    print(f"Training curves saved to: {CONFIG['output_dir'] / 'training_curves.png'}")

## Next Steps

After training completes:

1. **Run inference tests** using `02_inference_test.ipynb`
2. **Analyze results** and compare with previous versions
3. **Fine-tune hyperparameters** if needed

### Data Format Used
This model was trained with **1:N mixed Korean/English mappings with similarity weights**:
```json
{"ko": "프로그램", "terms": [{"term": "program", "sim": 0.95}, {"term": "소프트웨어", "sim": 0.88}]}
```

### Key Improvements in v19
- **Max 8 targets**: Limits each Korean term to top 8 most similar targets
- **Similarity threshold 0.8**: Only high-quality pairs (cosine sim >= 0.8)
- **Similarity-weighted loss**: Higher weight for more similar terms
  - Focuses learning on the most semantically relevant pairs
  - Reduces noise from borderline matches

The model learns to activate both Korean synonyms and English translations, with stronger activation for more similar terms.