# v19 Training - Knowledge Distillation + FLOPS Sparse Retrieval

This notebook trains the v19 Korean-English cross-lingual SPLADE model using **state-of-the-art** techniques.

## Architecture: Dense Teacher → Sparse Student

```
┌─────────────────────┐     Distillation     ┌─────────────────────┐
│    Dense Teacher    │  ─────────────────▶  │   Sparse Student    │
│ (multilingual-e5)   │    soft labels       │     (SPLADE)        │
└─────────────────────┘                      └─────────────────────┘
```

## Key Techniques:

### 1. Knowledge Distillation
- Teacher: `intfloat/multilingual-e5-base` (dense embeddings)
- Student learns semantic similarity from teacher
- No manual rules needed - teacher guides what's important

### 2. FLOPS Regularization (Automatic Noise Suppression)
- Penalizes tokens that activate frequently across batch
- "s", "the", "a" naturally get suppressed (high avg activation → high penalty)
- End-to-end learned, no manual token lists

### 3. Cross-Lingual Term Mapping
- Korean source → Korean synonyms + English translations
- Separate loss weights for Korean/English targets

## Why This Approach?
- **SPLADE-v2, ColBERT-v2, OpenSearch neural-sparse** all use similar techniques
- **Scalable**: No language-specific rules to maintain
- **Effective**: Teacher provides rich semantic signal

In [None]:
import sys
from pathlib import Path

# Find project root
def find_project_root():
    """Find project root by looking for markers like pyproject.toml or src/"""
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
import json
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import GradScaler, autocast
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from sentence_transformers import SentenceTransformer
from tqdm.notebook import tqdm

from src.model.splade_model import create_splade_model

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Configuration

In [None]:
# Training Configuration
CONFIG = {
    # Model
    "model_name": "xlm-roberta-large",
    "max_length": 64,
    
    # Teacher model for Knowledge Distillation (upgraded to BGE-M3)
    "teacher_model": "BAAI/bge-m3",  # Upgraded from e5-base (dim=1024, best cross-lingual)

    # Data - 1:N mixed format (Korean -> [Korean + English terms])
    "data_path": PROJECT_ROOT / "dataset" / "v19_high_quality" / "term_mappings.jsonl",

    # Training
    "batch_size": 64,
    "gradient_accumulation_steps": 2,
    "num_epochs": 15,
    "learning_rate": 3e-6,
    "warmup_ratio": 0.1,
    "max_grad_norm": 1.0,

    # Loss weights (tuned for better English activation)
    "lambda_self": 3.0,         # Korean source preservation
    "lambda_ko_target": 2.0,    # Korean synonym activation
    "lambda_en_target": 12.0,   # English translation (increased from 8.0)
    "lambda_margin": 2.0,       # Margin loss for minimum activation
    "lambda_distill": 1.5,      # Knowledge distillation (increased from 1.0)
    "lambda_flops": 5e-5,       # FLOPS regularization (auto noise suppression)
    "target_margin": 2.0,

    # Mixed precision
    "use_fp16": True,

    # Output
    "output_dir": PROJECT_ROOT / "outputs" / "v19_xlm_large",
}

print("Training Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

# Calculate training stats
dataset_size = 13684
batches_per_epoch = dataset_size // CONFIG["batch_size"]
opt_steps_per_epoch = batches_per_epoch // CONFIG["gradient_accumulation_steps"]
total_opt_steps = opt_steps_per_epoch * CONFIG["num_epochs"]

print(f"\nTraining Stats:")
print(f"  Effective batch size: {CONFIG['batch_size'] * CONFIG['gradient_accumulation_steps']}")
print(f"  Batches per epoch: ~{batches_per_epoch}")
print(f"  Total optimization steps: ~{total_opt_steps}")
print(f"\n*** Teacher: BGE-M3 (best cross-lingual), lambda_en_target=12.0 ***")

## 2. Helper Functions

In [None]:
def is_korean_char(c: str) -> bool:
    """Check if character is Korean."""
    return (
        "\uac00" <= c <= "\ud7a3"
        or "\u1100" <= c <= "\u11ff"
        or "\u3130" <= c <= "\u318f"
    )


def is_english_char(c: str) -> bool:
    """Check if character is English."""
    return c.isalpha() and c.isascii()


def is_non_target_token(token: str) -> bool:
    """Check if token is from non-target language (not Korean or English)."""
    clean = token.replace("\u2581", "").replace("##", "")  # Remove subword markers
    if not clean:
        return False

    has_korean = any(is_korean_char(c) for c in clean)
    has_english = any(is_english_char(c) for c in clean)

    if has_korean or has_english:
        return False

    # Check for other languages
    has_japanese = any(
        "\u3040" <= c <= "\u309f" or "\u30a0" <= c <= "\u30ff" for c in clean
    )
    has_cjk = any("\u4e00" <= c <= "\u9fff" for c in clean)
    has_cyrillic = any("\u0400" <= c <= "\u04ff" for c in clean)
    has_arabic = any("\u0600" <= c <= "\u06ff" for c in clean)
    has_thai = any("\u0e00" <= c <= "\u0e7f" for c in clean)
    has_greek = any("\u0370" <= c <= "\u03ff" for c in clean)

    return (
        has_japanese or has_cjk or has_cyrillic or has_arabic or has_thai or has_greek
    )

## 3. Dataset Class - Separate Korean/English Targets

Dataset for 1:N mixed Korean/English term mappings with similarity scores:
- Input format: `{"ko": "프로그램", "terms": [{"term": "program", "sim": 0.95}, {"term": "소프트웨어", "sim": 0.88}]}`

**Key Change**: Separates Korean and English targets into distinct lists:
- `ko_target_ids` / `ko_target_weights`: Korean synonym tokens
- `en_target_ids` / `en_target_weights`: English translation tokens

This enables separate loss computation for cross-lingual training.

In [None]:
class TermMappingDataset(Dataset):
    """Dataset for 1:N Korean to mixed Korean/English term mappings with similarity.
    
    Format: {"ko": "프로그램", "terms": [{"term": "program", "sim": 0.95}, ...]}
    
    Separates Korean and English targets for explicit cross-lingual training.
    Returns text for knowledge distillation.
    """

    def __init__(self, data_path: Path, tokenizer, max_length: int = 64):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []
        
        # Build set of special token IDs to exclude
        self.special_ids = {
            tokenizer.pad_token_id,
            tokenizer.cls_token_id,
            tokenizer.sep_token_id,
            tokenizer.unk_token_id,
            tokenizer.bos_token_id,
            tokenizer.eos_token_id,
        }
        self.special_ids = {t for t in self.special_ids if t is not None}
        
        # Add tokens by name
        for token_name in ['<s>', '</s>', '<pad>', '<unk>', '<mask>']:
            tid = tokenizer.convert_tokens_to_ids(token_name)
            if tid != tokenizer.unk_token_id:
                self.special_ids.add(tid)

        print(f"Loading dataset from {data_path}...")

        def is_korean_term(text: str) -> bool:
            """Check if term contains Korean characters."""
            return any('\uac00' <= c <= '\ud7a3' for c in text)

        total_ko_targets = 0
        total_en_targets = 0

        with open(data_path, "r", encoding="utf-8") as f:
            for line in tqdm(f, desc="Loading data"):
                item = json.loads(line.strip())

                ko_term = item.get("ko", "")
                terms_data = item.get("terms", [])

                if not ko_term or not terms_data:
                    continue

                # Tokenize Korean source term (exclude special tokens)
                ko_tokens = tokenizer.tokenize(ko_term)
                ko_token_ids = tokenizer.convert_tokens_to_ids(ko_tokens)
                ko_token_ids = [
                    tid for tid in ko_token_ids 
                    if tid != tokenizer.unk_token_id and tid not in self.special_ids
                ]

                # SEPARATE Korean and English targets
                ko_target_weights: dict = {}
                en_target_weights: dict = {}
                
                for term_info in terms_data:
                    if isinstance(term_info, dict):
                        term = term_info.get("term", "")
                        sim = term_info.get("sim", 1.0)
                    else:
                        term = term_info
                        sim = 1.0
                    
                    is_korean = is_korean_term(term)
                    term_lower = term if is_korean else term.lower()
                    tokens = tokenizer.tokenize(term_lower)
                    token_ids = tokenizer.convert_tokens_to_ids(tokens)
                    
                    for tid in token_ids:
                        if tid != tokenizer.unk_token_id and tid not in self.special_ids:
                            if is_korean:
                                ko_target_weights[tid] = max(ko_target_weights.get(tid, 0.0), sim)
                            else:
                                en_target_weights[tid] = max(en_target_weights.get(tid, 0.0), sim)

                if ko_token_ids and (ko_target_weights or en_target_weights):
                    total_ko_targets += len(ko_target_weights)
                    total_en_targets += len(en_target_weights)
                    
                    self.data.append({
                        "ko_term": ko_term,
                        "ko_token_ids": ko_token_ids,
                        "ko_target_ids": list(ko_target_weights.keys()),
                        "ko_target_weights": list(ko_target_weights.values()),
                        "en_target_ids": list(en_target_weights.keys()),
                        "en_target_weights": list(en_target_weights.values()),
                    })

        print(f"Loaded {len(self.data):,} valid term mappings")
        samples_with_english = sum(1 for d in self.data if d["en_target_ids"])
        print(f"Samples with English targets: {samples_with_english:,} ({samples_with_english/len(self.data)*100:.1f}%)")
        print(f"Total Korean target tokens: {total_ko_targets:,}")
        print(f"Total English target tokens: {total_en_targets:,}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        encoding = self.tokenizer(
            item["ko_term"],
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "text": item["ko_term"],  # For knowledge distillation
            "ko_token_ids": item["ko_token_ids"],
            "ko_target_ids": item["ko_target_ids"],
            "ko_target_weights": item["ko_target_weights"],
            "en_target_ids": item["en_target_ids"],
            "en_target_weights": item["en_target_weights"],
        }


def collate_fn(batch):
    """Custom collate function."""
    return {
        "input_ids": torch.stack([item["input_ids"] for item in batch]),
        "attention_mask": torch.stack([item["attention_mask"] for item in batch]),
        "texts": [item["text"] for item in batch],  # List of texts for distillation
        "ko_token_ids": [item["ko_token_ids"] for item in batch],
        "ko_target_ids": [item["ko_target_ids"] for item in batch],
        "ko_target_weights": [item["ko_target_weights"] for item in batch],
        "en_target_ids": [item["en_target_ids"] for item in batch],
        "en_target_weights": [item["en_target_weights"] for item in batch],
    }

## 4. Loss Function - Cross-Lingual with Separate Korean/English

**Critical Design Decision**: Separate Korean and English target losses.

Why? When the input is Korean (e.g., "머신러닝"):
- Korean targets (e.g., "딥러닝") share subword tokens → naturally high activation
- English targets (e.g., "machine", "learning") have no overlap → need explicit training

Without separation, Korean dominates the loss and English is ignored.

In [None]:
class DistillationFLOPSLoss(nn.Module):
    """
    Knowledge Distillation + FLOPS Loss for sparse retrieval.
    
    Key innovations:
    1. Knowledge Distillation: Learn semantic similarity from dense teacher
    2. FLOPS Regularization: Automatic noise suppression without manual rules
    3. Cross-lingual term mapping: Separate Korean/English target losses
    
    The FLOPS loss naturally suppresses high-frequency noise tokens (s, the, a, etc.)
    because they activate across many samples → high average activation → high penalty.
    """

    def __init__(
        self,
        teacher_model: SentenceTransformer,
        target_margin: float = 2.0,
        temperature: float = 0.05,
    ):
        super().__init__()
        self.teacher = teacher_model
        self.target_margin = target_margin
        self.temperature = temperature
        
        # Freeze teacher
        self.teacher.eval()
        for param in self.teacher.parameters():
            param.requires_grad = False

    def compute_flops_loss(self, sparse_rep: torch.Tensor) -> torch.Tensor:
        """
        FLOPS regularization - penalize frequently activated tokens.
        
        Logic: If a token activates for many samples in batch, it's likely noise.
        - "s", "the" activate for almost everything → high avg → high penalty
        - "machine", "검색" activate selectively → low avg → low penalty
        """
        # Average activation per token across batch [vocab_size]
        avg_activation = sparse_rep.mean(dim=0)
        
        # L2 penalty on average (stronger than L1)
        flops_loss = (avg_activation ** 2).sum()
        
        return flops_loss

    def compute_distillation_loss(
        self, 
        sparse_rep: torch.Tensor, 
        texts: list[str],
        device: torch.device,
    ) -> torch.Tensor:
        """
        Knowledge distillation from dense teacher.
        
        Student learns to produce similar pairwise similarities as teacher.
        This teaches the model what tokens are semantically important.
        """
        # Get teacher embeddings (already normalized)
        with torch.no_grad():
            teacher_emb = self.teacher.encode(
                texts,
                convert_to_tensor=True,
                normalize_embeddings=True,
                device=device,
            )
        
        # Normalize student sparse representations for comparison
        student_emb = F.normalize(sparse_rep.float(), p=2, dim=-1)
        
        # Compute similarity matrices
        teacher_sim = teacher_emb @ teacher_emb.T  # [batch, batch]
        student_sim = student_emb @ student_emb.T  # [batch, batch]
        
        # Scale by temperature
        teacher_sim = teacher_sim / self.temperature
        student_sim = student_sim / self.temperature
        
        # KL divergence on similarity distributions
        teacher_probs = F.softmax(teacher_sim, dim=-1)
        student_log_probs = F.log_softmax(student_sim, dim=-1)
        
        distill_loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
        
        return distill_loss

    def compute_term_losses(
        self,
        sparse_rep: torch.Tensor,
        ko_token_ids: list,
        ko_target_ids: list,
        ko_target_weights: list,
        en_target_ids: list,
        en_target_weights: list,
    ) -> dict:
        """Compute self, Korean target, English target, and margin losses."""
        batch_size = sparse_rep.shape[0]
        device = sparse_rep.device

        self_loss = torch.tensor(0.0, device=device)
        ko_target_loss = torch.tensor(0.0, device=device)
        en_target_loss = torch.tensor(0.0, device=device)
        margin_loss = torch.tensor(0.0, device=device)

        n_valid = 0
        n_with_english = 0

        for i in range(batch_size):
            rep = sparse_rep[i]

            # Self loss: preserve Korean source tokens
            if ko_token_ids[i]:
                ko_ids = torch.tensor(ko_token_ids[i], device=device)
                ko_activations = rep[ko_ids]
                self_loss = self_loss - torch.log(ko_activations + 1e-8).mean()
                
                # Margin penalty for Korean source
                ko_margin = F.relu(self.target_margin - ko_activations).mean()
                self_loss = self_loss + ko_margin

            # Korean target loss
            if ko_target_ids[i]:
                tgt_ids = torch.tensor(ko_target_ids[i], device=device)
                tgt_weights = torch.tensor(ko_target_weights[i], device=device, dtype=torch.float32)
                tgt_activations = rep[tgt_ids]
                
                weighted_log = -torch.log(tgt_activations + 1e-8) * tgt_weights
                ko_target_loss = ko_target_loss + weighted_log.sum() / (tgt_weights.sum() + 1e-8)
                
                margin_violations = F.relu(self.target_margin - tgt_activations) * tgt_weights
                margin_loss = margin_loss + margin_violations.sum() / (tgt_weights.sum() + 1e-8)

            # English target loss
            if en_target_ids[i]:
                n_with_english += 1
                tgt_ids = torch.tensor(en_target_ids[i], device=device)
                tgt_weights = torch.tensor(en_target_weights[i], device=device, dtype=torch.float32)
                tgt_activations = rep[tgt_ids]
                
                weighted_log = -torch.log(tgt_activations + 1e-8) * tgt_weights
                en_target_loss = en_target_loss + weighted_log.sum() / (tgt_weights.sum() + 1e-8)
                
                en_margin = self.target_margin * 1.5
                margin_violations = F.relu(en_margin - tgt_activations) * tgt_weights
                margin_loss = margin_loss + margin_violations.sum() / (tgt_weights.sum() + 1e-8)

            n_valid += 1

        if n_valid > 0:
            self_loss = self_loss / n_valid
            ko_target_loss = ko_target_loss / n_valid
            margin_loss = margin_loss / n_valid
        
        if n_with_english > 0:
            en_target_loss = en_target_loss / n_with_english

        return {
            "self": self_loss,
            "ko_target": ko_target_loss,
            "en_target": en_target_loss,
            "margin": margin_loss,
        }

    def forward(
        self,
        sparse_rep: torch.Tensor,
        texts: list[str],
        ko_token_ids: list,
        ko_target_ids: list,
        ko_target_weights: list,
        en_target_ids: list,
        en_target_weights: list,
    ) -> dict:
        """
        Compute all losses.
        
        Returns dict with: self, ko_target, en_target, margin, distill, flops
        """
        device = sparse_rep.device
        
        # Term-based losses
        term_losses = self.compute_term_losses(
            sparse_rep,
            ko_token_ids,
            ko_target_ids,
            ko_target_weights,
            en_target_ids,
            en_target_weights,
        )
        
        # Knowledge distillation loss
        distill_loss = self.compute_distillation_loss(sparse_rep, texts, device)
        
        # FLOPS regularization (automatic noise suppression)
        flops_loss = self.compute_flops_loss(sparse_rep)
        
        return {
            **term_losses,
            "distill": distill_loss,
            "flops": flops_loss,
        }

## 5. Evaluation Function

In [None]:
# Test pairs for evaluation (Korean source -> expected Korean synonyms + English translations)
TEST_PAIRS = [
    # (source_ko, expected_english, expected_korean_synonyms)
    ("머신러닝", ["machine", "learning"], ["머신", "러닝", "기계학습"]),
    ("딥러닝", ["deep", "learning"], ["딥", "러닝", "심층학습"]),
    ("자연어처리", ["natural", "language", "processing"], ["자연어", "처리"]),
    ("인공지능", ["artificial", "intelligence"], ["인공", "지능"]),
    ("검색엔진", ["search", "engine"], ["검색", "엔진"]),
    ("데이터베이스", ["database"], ["데이터", "베이스"]),
    ("클라우드", ["cloud"], ["클라우드"]),
    ("서버", ["server"], ["서버"]),
    ("네트워크", ["network"], ["네트워크"]),
    ("추천시스템", ["recommend", "system"], ["추천", "시스템"]),
    ("추천", ["recommend", "recommendation"], ["추천"]),
    ("신경망", ["neural", "network"], ["신경망", "신경"]),
    ("강화학습", ["reinforcement", "learning"], ["강화", "학습"]),
    ("컴퓨터비전", ["computer", "vision"], ["컴퓨터", "비전"]),
    ("음성인식", ["speech", "recognition"], ["음성", "인식"]),
]


def evaluate_model(model, tokenizer, device, top_k=50):
    """Evaluate model on test pairs - check both Korean and English activation."""
    model.eval()

    ko_activated_total = 0
    en_activated_total = 0
    ko_expected_total = 0
    en_expected_total = 0

    with torch.no_grad():
        for ko_term, en_expected, ko_expected in TEST_PAIRS:
            encoding = tokenizer(
                ko_term,
                max_length=64,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )

            with autocast("cuda", enabled=CONFIG["use_fp16"]):
                sparse_rep, _ = model(
                    encoding["input_ids"].to(device),
                    encoding["attention_mask"].to(device),
                )

            sparse_rep = sparse_rep[0].float().cpu()
            top_indices = torch.topk(sparse_rep, k=top_k).indices.tolist()
            top_tokens = tokenizer.convert_ids_to_tokens(top_indices)
            top_tokens_set = set(top_tokens)

            # Check Korean synonym/preservation activation
            for ko in ko_expected:
                ko_toks = tokenizer.tokenize(ko)
                for tok in ko_toks:
                    ko_expected_total += 1
                    if tok in top_tokens_set:
                        ko_activated_total += 1

            # Check English translation activation
            for en in en_expected:
                en_toks = tokenizer.tokenize(en.lower())
                for tok in en_toks:
                    en_expected_total += 1
                    if tok in top_tokens_set:
                        en_activated_total += 1

    model.train()

    ko_rate = (
        ko_activated_total / ko_expected_total * 100 if ko_expected_total > 0 else 0
    )
    en_rate = (
        en_activated_total / en_expected_total * 100 if en_expected_total > 0 else 0
    )

    return ko_rate, en_rate

## 6. Initialize Components

In [None]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Load tokenizer
print(f"\nLoading tokenizer: {CONFIG['model_name']}...")
tokenizer = AutoTokenizer.from_pretrained(CONFIG["model_name"])
print(f"Vocab size: {tokenizer.vocab_size:,}")

In [None]:
# Load Teacher Model for Knowledge Distillation
print(f"Loading teacher model: {CONFIG['teacher_model']}...")
teacher_model = SentenceTransformer(CONFIG["teacher_model"])
teacher_model = teacher_model.to(device)
teacher_model.eval()

print(f"Teacher model loaded: {CONFIG['teacher_model']}")
print(f"Teacher embedding dim: {teacher_model.get_sentence_embedding_dimension()}")

# Note: No manual noise token filtering needed!
# FLOPS regularization will automatically suppress frequently-activated noise tokens

In [None]:
# Load dataset (1:N mixed format)
dataset = TermMappingDataset(CONFIG["data_path"], tokenizer, CONFIG["max_length"])

dataloader = DataLoader(
    dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn,
    pin_memory=True,
)

print(f"\nDataset size: {len(dataset):,}")
print(f"Batches per epoch: {len(dataloader):,}")
print(f"Effective batch size: {CONFIG['batch_size'] * CONFIG['gradient_accumulation_steps']}")

In [None]:
# Create model
print(f"\nCreating model: {CONFIG['model_name']}...")
model = create_splade_model(
    model_name=CONFIG["model_name"],
    use_idf=False,
    use_expansion=True,
    expansion_mode="mlm",
)
model = model.to(device)

n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,} ({n_params / 1e6:.1f}M)")

In [None]:
# Loss function - Knowledge Distillation + FLOPS
loss_fn = DistillationFLOPSLoss(
    teacher_model=teacher_model,
    target_margin=CONFIG["target_margin"],
    temperature=0.05,
)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=CONFIG["learning_rate"], 
    weight_decay=0.01
)

# Scheduler
total_steps = (
    len(dataloader) * CONFIG["num_epochs"] // CONFIG["gradient_accumulation_steps"]
)
warmup_steps = int(total_steps * CONFIG["warmup_ratio"])

scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=warmup_steps, 
    num_training_steps=total_steps
)

print(f"Total optimization steps: {total_steps:,}")
print(f"Warmup steps: {warmup_steps:,}")
print(f"\nLoss weights:")
print(f"  Self (Korean source): {CONFIG['lambda_self']}")
print(f"  Korean targets: {CONFIG['lambda_ko_target']}")
print(f"  English targets: {CONFIG['lambda_en_target']}")
print(f"  Margin: {CONFIG['lambda_margin']}")
print(f"  Distillation: {CONFIG['lambda_distill']} (from teacher)")
print(f"  FLOPS: {CONFIG['lambda_flops']} (auto noise suppression)")

# Mixed precision scaler
scaler = GradScaler("cuda", enabled=CONFIG["use_fp16"])

# Create output directory
CONFIG["output_dir"].mkdir(parents=True, exist_ok=True)
print(f"\nOutput directory: {CONFIG['output_dir']}")

## 7. Initial Evaluation

In [None]:
# Evaluate before training
ko_rate, en_rate = evaluate_model(model, tokenizer, device)
print(f"Initial Performance:")
print(f"  Korean Preservation: {ko_rate:.1f}%")
print(f"  English Activation: {en_rate:.1f}%")
print(f"  Combined Score: {ko_rate + en_rate:.1f}")

## 8. Training Loop

In [None]:
# Training variables
history = []
best_score = 0
global_step = 0

print("=" * 70)
print("STARTING TRAINING")
print("=" * 70)

In [None]:
for epoch in range(CONFIG["num_epochs"]):
    print(f"\n{'='*70}")
    print(f"Epoch {epoch + 1}/{CONFIG['num_epochs']}")
    print(f"{'='*70}")
    
    model.train()
    epoch_losses = defaultdict(float)
    optimizer.zero_grad()

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}")

    for batch_idx, batch in enumerate(progress_bar):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        texts = batch["texts"]  # For distillation

        with autocast("cuda", enabled=CONFIG["use_fp16"]):
            sparse_rep, _ = model(input_ids, attention_mask)

            # Compute all losses (including distillation and FLOPS)
            losses = loss_fn(
                sparse_rep,
                texts,
                batch["ko_token_ids"],
                batch["ko_target_ids"],
                batch["ko_target_weights"],
                batch["en_target_ids"],
                batch["en_target_weights"],
            )

            # Total loss with all components
            total_loss = (
                CONFIG["lambda_self"] * losses["self"]
                + CONFIG["lambda_ko_target"] * losses["ko_target"]
                + CONFIG["lambda_en_target"] * losses["en_target"]
                + CONFIG["lambda_margin"] * losses["margin"]
                + CONFIG["lambda_distill"] * losses["distill"]
                + CONFIG["lambda_flops"] * losses["flops"]
            )

            total_loss = total_loss / CONFIG["gradient_accumulation_steps"]

        scaler.scale(total_loss).backward()

        # Track losses
        epoch_losses["total"] += total_loss.item() * CONFIG["gradient_accumulation_steps"]
        epoch_losses["self"] += losses["self"].item()
        epoch_losses["ko_target"] += losses["ko_target"].item()
        epoch_losses["en_target"] += losses["en_target"].item()
        epoch_losses["margin"] += losses["margin"].item()
        epoch_losses["distill"] += losses["distill"].item()
        epoch_losses["flops"] += losses["flops"].item()

        if (batch_idx + 1) % CONFIG["gradient_accumulation_steps"] == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG["max_grad_norm"])
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1

        if (batch_idx + 1) % 50 == 0:
            progress_bar.set_postfix({
                "loss": f"{epoch_losses['total'] / (batch_idx + 1):.4f}",
                "distill": f"{epoch_losses['distill'] / (batch_idx + 1):.4f}",
                "flops": f"{epoch_losses['flops'] / (batch_idx + 1):.2f}",
            })

    # Calculate average losses
    n_batches = len(dataloader)
    for key in epoch_losses:
        epoch_losses[key] /= n_batches

    history.append(dict(epoch_losses))

    # Evaluate
    ko_rate, en_rate = evaluate_model(model, tokenizer, device)
    combined_score = ko_rate + en_rate

    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"  Total Loss: {epoch_losses['total']:.4f}")
    print(f"  Self Loss: {epoch_losses['self']:.4f}")
    print(f"  KO Target: {epoch_losses['ko_target']:.4f}")
    print(f"  EN Target: {epoch_losses['en_target']:.4f}")
    print(f"  Distillation: {epoch_losses['distill']:.4f}")
    print(f"  FLOPS: {epoch_losses['flops']:.2f}")
    print(f"  Korean Activation: {ko_rate:.1f}%")
    print(f"  English Activation: {en_rate:.1f}%")

    # Save checkpoint
    checkpoint_path = CONFIG["output_dir"] / f"checkpoint_epoch{epoch + 1}.pt"
    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "losses": dict(epoch_losses),
        "ko_rate": ko_rate,
        "en_rate": en_rate,
        "config": {k: str(v) if isinstance(v, Path) else v for k, v in CONFIG.items()},
    }, checkpoint_path)
    print(f"  Saved: {checkpoint_path.name}")

    # Save best model (weighted: KO + 2*EN)
    weighted_score = ko_rate + 2 * en_rate
    if weighted_score > best_score:
        best_score = weighted_score
        best_path = CONFIG["output_dir"] / "best_model.pt"
        torch.save({
            "epoch": epoch + 1,
            "model_state_dict": model.state_dict(),
            "ko_rate": ko_rate,
            "en_rate": en_rate,
            "combined_score": combined_score,
            "config": {k: str(v) if isinstance(v, Path) else v for k, v in CONFIG.items()},
        }, best_path)
        print(f"  *** New best! Score: {weighted_score:.1f} (KO:{ko_rate:.1f}% + EN:{en_rate:.1f}%) ***")

## 9. Save Final Model

In [None]:
# Save final model
final_path = CONFIG["output_dir"] / "final_model.pt"
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "config": {
            k: str(v) if isinstance(v, Path) else v for k, v in CONFIG.items()
        },
        "history": history,
    },
    final_path,
)

# Save training history
with open(CONFIG["output_dir"] / "training_history.json", "w") as f:
    json.dump(history, f, indent=2)

print("\n" + "=" * 70)
print("TRAINING COMPLETE")
print("=" * 70)
print(f"Final model saved: {final_path}")
print(f"Best combined score: {best_score:.1f}")

## 10. Training Summary

In [None]:
import matplotlib.pyplot as plt

# Plot training curves
if history:
    fig, axes = plt.subplots(2, 4, figsize=(18, 8))

    epochs = range(1, len(history) + 1)

    # Total loss
    axes[0, 0].plot(epochs, [h['total'] for h in history], '-o', color='#3498db')
    axes[0, 0].set_title('Total Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].grid(True, alpha=0.3)

    # Self loss
    axes[0, 1].plot(epochs, [h['self'] for h in history], '-o', color='#2ecc71')
    axes[0, 1].set_title('Self Loss (Korean Source)')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].grid(True, alpha=0.3)

    # Korean target loss
    axes[0, 2].plot(epochs, [h['ko_target'] for h in history], '-o', color='#f39c12')
    axes[0, 2].set_title('Korean Target Loss')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].grid(True, alpha=0.3)

    # English target loss
    axes[0, 3].plot(epochs, [h['en_target'] for h in history], '-o', color='#e74c3c')
    axes[0, 3].set_title('English Target Loss')
    axes[0, 3].set_xlabel('Epoch')
    axes[0, 3].grid(True, alpha=0.3)

    # Distillation loss
    axes[1, 0].plot(epochs, [h['distill'] for h in history], '-o', color='#9b59b6')
    axes[1, 0].set_title('Distillation Loss (Teacher)')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].grid(True, alpha=0.3)

    # FLOPS loss
    axes[1, 1].plot(epochs, [h['flops'] for h in history], '-o', color='#1abc9c')
    axes[1, 1].set_title('FLOPS Loss (Noise Suppression)')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].grid(True, alpha=0.3)

    # Margin loss
    axes[1, 2].plot(epochs, [h['margin'] for h in history], '-o', color='#e67e22')
    axes[1, 2].set_title('Margin Loss')
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].grid(True, alpha=0.3)

    # Hide empty subplot
    axes[1, 3].axis('off')

    plt.tight_layout()
    plt.savefig(CONFIG["output_dir"] / "training_curves.png", dpi=150)
    plt.show()
    
    print(f"Training curves saved to: {CONFIG['output_dir'] / 'training_curves.png'}")

## Summary: Knowledge Distillation + FLOPS Approach

### Architecture

```
                    ┌─────────────────┐
                    │  Teacher Model  │
                    │ (multilingual-  │
                    │     e5-base)    │
                    └────────┬────────┘
                             │ Distillation
                             ▼
┌──────────┐         ┌─────────────────┐         ┌──────────┐
│  Korean  │ ──────▶ │  Student Model  │ ──────▶ │  Sparse  │
│  Input   │         │ (xlm-roberta-   │         │  Output  │
│          │         │     large)      │         │          │
└──────────┘         └─────────────────┘         └──────────┘
                             │
                             ▼
                    ┌─────────────────┐
                    │   FLOPS Loss    │
                    │ (Auto Noise     │
                    │  Suppression)   │
                    └─────────────────┘
```

### Loss Components

| Loss | Weight | Purpose |
|------|--------|---------|
| `lambda_self` | 3.0 | Preserve Korean source tokens |
| `lambda_ko_target` | 2.0 | Activate Korean synonym tokens |
| `lambda_en_target` | 8.0 | Activate English translation tokens |
| `lambda_margin` | 2.0 | Ensure minimum activation |
| `lambda_distill` | 1.0 | Learn from teacher model |
| `lambda_flops` | 5e-5 | Automatic noise suppression |

### Why This Works

1. **Knowledge Distillation**: Teacher model provides semantic similarity signal
   - Student learns what tokens are semantically important
   - No manual rules needed

2. **FLOPS Regularization**: Automatic noise token suppression
   - Tokens that activate frequently (s, the, a) get high penalty
   - Meaningful tokens (machine, 검색) activate selectively → low penalty

3. **Scalable**: Works across languages without language-specific rules

### Next Steps

1. Run inference tests with `03_inference_test.ipynb`
2. Check if "s" noise is suppressed
3. Verify Korean preservation rate improved
4. Compare with previous rule-based approach