# Cross-lingual Neural Sparse Training with Vocabulary Expansion

기존 학습된 모델을 기반으로 Cross-lingual Training을 적용합니다.

## 핵심 변경: SPLADEDocExpansion 사용
기존 SPLADEDoc은 입력 토큰만 활성화할 수 있는 구조적 한계가 있습니다.
SPLADEDocExpansion은 MLM 헤드를 사용해 전체 어휘 공간으로 투영하여 
입력에 없는 토큰도 활성화할 수 있습니다.

## 목표
- 한국어 용어 입력 시 영어 동의어 토큰도 활성화
- "머신러닝" → [머신, ##닝, machine, learning, ML]
- "학습" → [학, ##습, training, learning]

## 1. Setup

In [None]:
import sys
import os
import json
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.amp import autocast, GradScaler
import numpy as np
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel

# Project root
project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

# IMPORTANT: Use SPLADEDocExpansion instead of SPLADEDoc
# SPLADEDocExpansion uses MLM head to project to full vocabulary space,
# enabling activation of ANY token (not just input tokens)
from src.model.splade_model import SPLADEDocExpansion, create_splade_model
from src.data.synonym_dataset import SynonymDataset, SynonymCollator
from src.training.losses import (
    SynonymAlignmentLoss, 
    CrossLingualKDLoss, 
    TokenExpansionLoss,
    ExplicitNoiseTokenLoss,  # v4: Explicit noise token penalty (safer than frequency-based)
)

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Configuration

In [None]:
@dataclass
class CrossLingualConfig:
    """Configuration for cross-lingual training with vocabulary expansion."""
    # Model - Use SPLADEDocExpansion (MLM-based)
    base_model: str = "bert-base-multilingual-cased"
    expansion_mode: str = "mlm"  # Use MLM head for vocabulary expansion
    
    # Teacher model for KD guidance
    teacher_model: str = "intfloat/multilingual-e5-large"
    
    # Data - use LARGE-SCALE dataset (3.2M+ pairs from XLEnt, CCMatrix, Wikidata)
    synonym_data: str = "dataset/large_scale/ko_en_terms_merged.jsonl"
    parallel_data: str = "dataset/synonyms/ko_en_parallel.jsonl"
    
    # Training - adjusted for large-scale
    batch_size: int = 128  # Increased for large dataset
    learning_rate: float = 2e-5
    num_epochs: int = 3  # Fewer epochs with more data
    warmup_steps: int = 1000  # More warmup for larger dataset
    max_length: int = 64
    
    # Loss weights
    lambda_expansion: float = 1.0    # Token expansion loss (main loss)
    lambda_kd: float = 0.3           # Teacher KD loss  
    lambda_sparsity: float = 0.001   # L1 sparsity regularization
    lambda_noise: float = 0.3        # v4: Explicit noise token penalty
    
    # Loss types
    expansion_loss_type: str = "additive"
    expansion_top_k: int = 10
    kd_loss_type: str = "relation"
    
    # v4: Explicit noise token config (safer than frequency-based)
    noise_penalty_type: str = "sum"  # 'sum', 'max', 'softmax'
    # Custom noise tokens (observed in Top-10)
    custom_noise_tokens: tuple = (
        # Programming terms (high activation, low relevance)
        "function", "operator", "operation", "operations",
        "programming", "integration", "organization",
        "implementation", "configuration", "application",
        # Generic terms
        "system", "systems", "process", "processing",
        "method", "methods", "type", "types",
        # Subword noise
        "##ing", "##tion", "##ation", "##ment",
        # Common but uninformative
        "the", "and", "for", "with", "from",
    )
    
    # Output
    output_dir: str = "outputs/cross_lingual_expansion_v5_largescale"
    log_steps: int = 100
    save_steps: int = 5000


config = CrossLingualConfig()

# Convert to absolute paths
config.synonym_data = str(project_root / config.synonym_data)
config.parallel_data = str(project_root / config.parallel_data)
config.output_dir = str(project_root / config.output_dir)

# Create output directory
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

print("Configuration (v5 - LARGE-SCALE with 3.2M+ pairs):")
for field, value in config.__dict__.items():
    print(f"  {field}: {value}")

## 3. Create SPLADEDocExpansion Model

SPLADEDocExpansion은 MLM 헤드를 사용하여 입력 토큰 외의 토큰도 활성화할 수 있습니다.

**핵심 차이점:**
- SPLADEDoc: `sparse_repr` 생성 시 `input_ids`에 있는 토큰만 활성화 가능 (scatter 연산)
- SPLADEDocExpansion: MLM logits를 사용하여 전체 어휘(vocab_size)에 대해 활성화 점수 계산

이를 통해 "머신러닝" 입력 시 "machine", "learning" 토큰도 활성화될 수 있습니다.

In [None]:
# Create SPLADEDocExpansion model from scratch (using MLM head)
# This model can activate ANY token in vocabulary, not just input tokens

print(f"Creating SPLADEDocExpansion model...")
print(f"  Base model: {config.base_model}")
print(f"  Expansion mode: {config.expansion_mode}")

# Use factory function with use_expansion=True
model = create_splade_model(
    model_name=config.base_model,
    use_idf=False,
    use_expansion=True,
    expansion_mode=config.expansion_mode,
    dropout=0.1,
)

model = model.to(device)

print(f"\nModel created: SPLADEDocExpansion")
print(f"  Vocab size: {model.config.vocab_size}")
print(f"  Hidden size: {model.config.hidden_size}")
print(f"  Expansion mode: {model.expansion_mode}")

# Verify model can activate non-input tokens
tokenizer = AutoTokenizer.from_pretrained(config.base_model)
test_input = tokenizer("머신러닝", return_tensors="pt", max_length=64, padding="max_length")
with torch.no_grad():
    test_sparse, _ = model(
        test_input['input_ids'].to(device),
        test_input['attention_mask'].to(device),
    )

# Check if non-input tokens can be activated
non_zero_count = (test_sparse[0] > 0).sum().item()
print(f"\nInitial activation test:")
print(f"  Non-zero tokens in sparse repr: {non_zero_count}")
print(f"  (Pre-training: random activations from MLM head)")

## 4. Load Teacher Model (Optional)

Teacher 모델은 cross-lingual semantic similarity를 가이드하는 역할을 합니다.
주요 기능: KO-EN 쌍의 semantic similarity를 relation-based KD로 전달

In [None]:
print(f"Loading teacher model: {config.teacher_model}")

try:
    from sentence_transformers import SentenceTransformer
    
    teacher = SentenceTransformer(
        config.teacher_model,
        device=str(device),
        trust_remote_code=True,
    )
    teacher.eval()
    
    # Freeze teacher
    for param in teacher.parameters():
        param.requires_grad = False
    
    print(f"Teacher loaded successfully")
    print(f"  Embedding dimension: {teacher.get_sentence_embedding_dimension()}")
    USE_TEACHER = True
    
except Exception as e:
    print(f"Warning: Could not load teacher model: {e}")
    print("Will use synonym alignment only (no KD)")
    teacher = None
    USE_TEACHER = False

## 5. Load Synonym Data

In [None]:
# Check if synonym data exists
synonym_path = Path(config.synonym_data)

if not synonym_path.exists():
    print(f"Synonym data not found at: {synonym_path}")
    print("\nPlease run notebook 04_cross_lingual_data_synthesis.ipynb first.")
    print("Or provide synonym data in the expected format.")
    raise FileNotFoundError(f"Synonym data not found: {synonym_path}")

# Load dataset
synonym_dataset = SynonymDataset(str(synonym_path))
print(f"Loaded {len(synonym_dataset)} synonym pairs")

# Show samples
print("\nSample pairs:")
for i in range(min(5, len(synonym_dataset))):
    item = synonym_dataset[i]
    print(f"  {item['ko_term']} → {item['en_term']}")

## 6. Create DataLoader

In [None]:
# Tokenizer was already created in cell-6, reuse it
# If running from this cell, uncomment the line below:
# tokenizer = AutoTokenizer.from_pretrained(config.base_model)

# Create collator
collator = SynonymCollator(
    tokenizer=tokenizer,
    max_length=config.max_length,
)

# Create dataloader
train_loader = DataLoader(
    synonym_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=collator,
    num_workers=2,
    pin_memory=True,
)

print(f"DataLoader created:")
print(f"  Batch size: {config.batch_size}")
print(f"  Num batches: {len(train_loader)}")

# Test batch
sample_batch = next(iter(train_loader))
print(f"\nSample batch:")
for key, value in sample_batch.items():
    if isinstance(value, torch.Tensor):
        print(f"  {key}: {value.shape}")
    else:
        print(f"  {key}: {type(value).__name__} (len={len(value)})")

## 7. Initialize Loss Functions

In [None]:
# Token expansion loss (main loss for activating EN tokens)
expansion_loss_fn = TokenExpansionLoss(
    expansion_type=config.expansion_loss_type,
    top_k=config.expansion_top_k,
)

# KD loss (if teacher available)
if USE_TEACHER:
    kd_loss_fn = CrossLingualKDLoss(
        loss_type=config.kd_loss_type,
    )
else:
    kd_loss_fn = None

# v4: Explicit noise token loss (safer than frequency-based)
# This explicitly penalizes known noise tokens instead of using statistical patterns
noise_loss_fn = ExplicitNoiseTokenLoss(
    tokenizer=tokenizer,
    noise_tokens=list(config.custom_noise_tokens),
    lambda_noise=1.0,  # Will be scaled by config.lambda_noise in train_step
    penalty_type=config.noise_penalty_type,
)

print(f"\nLoss functions initialized:")
print(f"  Token expansion: {config.expansion_loss_type} (top_k={config.expansion_top_k})")
print(f"  KD loss: {config.kd_loss_type if USE_TEACHER else 'disabled'}")
print(f"  Noise penalty: ExplicitNoiseTokenLoss (type={config.noise_penalty_type}, lambda={config.lambda_noise})")

## 8. Initialize Optimizer

In [None]:
# Only train certain layers (optional: freeze transformer, only train projection)
# For full fine-tuning:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.learning_rate,
    weight_decay=0.01,
)

# Learning rate scheduler
total_steps = len(train_loader) * config.num_epochs
warmup_steps = config.warmup_steps

def lr_lambda(step: int) -> float:
    if step < warmup_steps:
        return step / warmup_steps
    return max(0.1, 1.0 - (step - warmup_steps) / (total_steps - warmup_steps))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Mixed precision scaler
scaler = GradScaler('cuda')

print(f"Optimizer initialized:")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Total steps: {total_steps}")
print(f"  Warmup steps: {warmup_steps}")

## 9. Training Loop

In [None]:
def train_step(
    batch: Dict[str, torch.Tensor],
    model: SPLADEDocExpansion,
    teacher: Optional['SentenceTransformer'],
    expansion_loss_fn: TokenExpansionLoss,
    kd_loss_fn: Optional[CrossLingualKDLoss],
    noise_loss_fn: ExplicitNoiseTokenLoss,
    config: CrossLingualConfig,
) -> Dict[str, float]:
    """
    Single training step for cross-lingual token expansion with noise suppression.
    
    Key insight: We encode BOTH Korean and English through the model,
    then train KO representation to also activate EN's top tokens.
    
    v4: ExplicitNoiseTokenLoss directly penalizes known noise tokens.
    This is safer than frequency-based approach which can suppress useful tokens.
    """
    # Move inputs to device
    ko_input_ids = batch['ko_input_ids'].to(device)
    ko_attention_mask = batch['ko_attention_mask'].to(device)
    en_input_ids = batch['en_input_ids'].to(device)
    en_attention_mask = batch['en_attention_mask'].to(device)
    
    # Forward pass - get sparse representations for BOTH
    ko_sparse, _ = model(ko_input_ids, ko_attention_mask)
    
    # Get EN sparse rep (as target, detach to not backprop through EN)
    with torch.no_grad():
        en_sparse, _ = model(en_input_ids, en_attention_mask)
    
    losses = {}
    
    # 1. Token expansion loss (MAIN LOSS)
    expansion_loss = expansion_loss_fn(ko_sparse, en_sparse)
    losses['expansion_loss'] = expansion_loss.item()
    
    # 2. Teacher KD loss (optional)
    if teacher is not None and kd_loss_fn is not None:
        with torch.no_grad():
            ko_terms = batch['ko_terms']
            en_terms = batch['en_terms']
            
            ko_teacher = teacher.encode(
                ko_terms, 
                convert_to_tensor=True,
                normalize_embeddings=True,
            )
            en_teacher = teacher.encode(
                en_terms,
                convert_to_tensor=True,
                normalize_embeddings=True,
            )
        
        en_sparse_grad, _ = model(en_input_ids, en_attention_mask)
        kd_loss = kd_loss_fn(ko_sparse, en_sparse_grad, ko_teacher, en_teacher)
        losses['kd_loss'] = kd_loss.item()
    else:
        kd_loss = torch.tensor(0.0, device=device)
        losses['kd_loss'] = 0.0
    
    # 3. Sparsity regularization (L1)
    sparsity_loss = ko_sparse.abs().mean()
    losses['sparsity_loss'] = sparsity_loss.item()
    
    # 4. v4: Explicit noise token penalty (suppress known noise tokens)
    # This directly penalizes tokens like 'function', 'operator', etc.
    noise_loss = noise_loss_fn(ko_sparse)
    losses['noise_loss'] = noise_loss.item()
    
    # Combined loss
    total_loss = (
        config.lambda_expansion * expansion_loss
        + config.lambda_kd * kd_loss
        + config.lambda_sparsity * sparsity_loss
        + config.lambda_noise * noise_loss  # v4: explicit noise penalty
    )
    losses['total_loss'] = total_loss.item()
    
    return total_loss, losses


print("Training step function defined (v4 with ExplicitNoiseTokenLoss)")

In [None]:
def train_epoch(
    model: SPLADEDocExpansion,
    teacher: Optional['SentenceTransformer'],
    train_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LRScheduler,
    scaler: GradScaler,
    expansion_loss_fn: TokenExpansionLoss,
    kd_loss_fn: Optional[CrossLingualKDLoss],
    noise_loss_fn: ExplicitNoiseTokenLoss,
    config: CrossLingualConfig,
    epoch: int,
) -> Dict[str, float]:
    """
    Train for one epoch with SPLADEDocExpansion and explicit noise suppression.
    """
    model.train()
    
    epoch_losses = {
        'total_loss': 0.0,
        'expansion_loss': 0.0,
        'kd_loss': 0.0,
        'sparsity_loss': 0.0,
        'noise_loss': 0.0,  # v4: explicit noise loss
    }
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.num_epochs}")
    
    for step, batch in enumerate(pbar):
        optimizer.zero_grad()
        
        # Mixed precision forward
        with autocast('cuda', dtype=torch.bfloat16):
            total_loss, losses = train_step(
                batch, model, teacher,
                expansion_loss_fn, kd_loss_fn, noise_loss_fn, config
            )
        
        # Backward
        scaler.scale(total_loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
        # Accumulate losses
        for key in epoch_losses:
            epoch_losses[key] += losses.get(key, 0.0)
        
        # Update progress bar
        if step % config.log_steps == 0:
            pbar.set_postfix({
                'loss': f"{losses['total_loss']:.4f}",
                'exp': f"{losses['expansion_loss']:.4f}",
                'noise': f"{losses['noise_loss']:.4f}",
                'lr': f"{scheduler.get_last_lr()[0]:.2e}",
            })
    
    # Average losses
    num_batches = len(train_loader)
    for key in epoch_losses:
        epoch_losses[key] /= num_batches
    
    return epoch_losses


print("Training epoch function defined (v4 with ExplicitNoiseTokenLoss)")

## 10. Run Training

In [None]:
print("="*60)
print("STARTING CROSS-LINGUAL TRAINING (v4 with ExplicitNoiseTokenLoss)")
print("="*60)

history = []
best_loss = float('inf')

for epoch in range(config.num_epochs):
    print(f"\n--- Epoch {epoch+1}/{config.num_epochs} ---")
    
    epoch_losses = train_epoch(
        model=model,
        teacher=teacher,
        train_loader=train_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        scaler=scaler,
        expansion_loss_fn=expansion_loss_fn,
        kd_loss_fn=kd_loss_fn,
        noise_loss_fn=noise_loss_fn,  # v4: explicit noise penalty
        config=config,
        epoch=epoch,
    )
    
    history.append(epoch_losses)
    
    print(f"\nEpoch {epoch+1} Results:")
    for key, value in epoch_losses.items():
        print(f"  {key}: {value:.4f}")
    
    # Save checkpoint
    if epoch_losses['total_loss'] < best_loss:
        best_loss = epoch_losses['total_loss']
        
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_loss': best_loss,
            'config': {
                'model': {'name': config.base_model},
                'cross_lingual': config.__dict__,
            },
        }
        
        save_path = Path(config.output_dir) / 'best_model' / 'checkpoint.pt'
        save_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(checkpoint, save_path)
        print(f"  Saved best model to: {save_path}")

print("\n" + "="*60)
print("TRAINING COMPLETE")
print("="*60)

## 11. Evaluation - Cross-lingual Activation Test

In [None]:
def test_cross_lingual_activation(
    model: SPLADEDocExpansion,
    tokenizer: AutoTokenizer,
    test_pairs: List[Tuple[str, List[str]]],
    top_k: int = 30,
) -> Dict:
    """
    Test if Korean terms activate English synonym tokens.
    
    With SPLADEDocExpansion, the model should be able to activate
    English tokens even when the input is only Korean.
    
    Args:
        model: Trained SPLADEDocExpansion model
        tokenizer: Tokenizer
        test_pairs: List of (korean_term, [english_synonyms])
        top_k: Number of top tokens to check
    
    Returns:
        Evaluation results
    """
    model.eval()
    results = []
    
    for ko_term, en_synonyms in test_pairs:
        # Encode Korean term
        encoding = tokenizer(
            ko_term,
            max_length=64,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        
        with torch.no_grad():
            sparse_rep, _ = model(
                encoding['input_ids'].to(device),
                encoding['attention_mask'].to(device),
            )
        
        sparse_rep = sparse_rep[0].cpu()
        
        # Get top-k activated tokens
        top_k_values, top_k_indices = torch.topk(sparse_rep, k=top_k)
        top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices.tolist())
        
        # Check English synonym activation
        en_activated = []
        en_scores = {}
        
        for en_syn in en_synonyms:
            # Tokenize English synonym
            en_tokens = tokenizer.tokenize(en_syn.lower())
            
            # Check if any token is activated (in top-k or has positive score)
            for en_tok in en_tokens:
                # Check in top-k
                if en_tok in top_k_tokens or en_tok.lower() in [t.lower() for t in top_k_tokens]:
                    en_activated.append(en_tok)
                    # Get score
                    if en_tok in top_k_tokens:
                        idx = top_k_tokens.index(en_tok)
                        en_scores[en_tok] = top_k_values[idx].item()
                
                # Also check direct token ID score
                en_tok_ids = tokenizer.convert_tokens_to_ids([en_tok])
                if en_tok_ids and en_tok_ids[0] is not None:
                    score = sparse_rep[en_tok_ids[0]].item()
                    if score > 0 and en_tok not in en_scores:
                        en_scores[en_tok] = score
                        if en_tok not in en_activated:
                            en_activated.append(en_tok)
        
        results.append({
            'ko_term': ko_term,
            'en_synonyms': en_synonyms,
            'top_10_tokens': top_k_tokens[:10],
            'top_10_values': top_k_values[:10].tolist(),
            'en_activated': en_activated,
            'en_scores': en_scores,
            'activation_rate': len(set(en_activated)) / len(en_synonyms) if en_synonyms else 0,
        })
    
    return results


# Test pairs - UPDATED: Added abbreviations for testing
TEST_PAIRS = [
    ("머신러닝", ["machine", "learning", "ML"]),
    ("딥러닝", ["deep", "learning", "DL"]),
    ("자연어처리", ["natural", "language", "processing", "NLP"]),
    ("학습", ["training", "learning"]),
    ("모델", ["model"]),
    ("데이터", ["data"]),
    ("알고리즘", ["algorithm"]),
    ("신경망", ["neural", "network", "NN"]),
    ("분류", ["classification", "classify"]),
    ("회귀", ["regression"]),
    # New abbreviation tests
    ("강화학습", ["reinforcement", "learning", "RL"]),
    ("컴퓨터비전", ["computer", "vision", "CV"]),
    ("개체명인식", ["named", "entity", "recognition", "NER"]),
    ("순환신경망", ["recurrent", "neural", "network", "RNN"]),
]

print("Test pairs defined (with abbreviations):")
for ko, en_list in TEST_PAIRS:
    print(f"  {ko} → {en_list}")

In [None]:
# Run evaluation
print("="*60)
print("CROSS-LINGUAL ACTIVATION TEST (SPLADEDocExpansion)")
print("="*60)

results = test_cross_lingual_activation(model, tokenizer, TEST_PAIRS, top_k=30)

total_activation_rate = 0

for result in results:
    print(f"\n[{result['ko_term']}]")
    print(f"  Expected EN: {result['en_synonyms']}")
    print(f"  Top-10 activated: {result['top_10_tokens']}")
    
    if result['en_scores']:
        print(f"  EN token scores:")
        for tok, score in sorted(result['en_scores'].items(), key=lambda x: -x[1]):
            print(f"    ✓ {tok}: {score:.4f}")
    else:
        print(f"  EN activated: None (need more training)")
    
    print(f"  Activation rate: {result['activation_rate']:.1%}")
    total_activation_rate += result['activation_rate']

avg_rate = total_activation_rate / len(results)
print(f"\n" + "="*60)
print(f"Average English Activation Rate: {avg_rate:.1%}")
print("="*60)

if avg_rate < 0.3:
    print("\n⚠️  Low activation rate - model needs training!")
    print("    Run the training loop above to improve cross-lingual activation.")

## 12. Save Final Results

In [None]:
# Save training history
history_path = Path(config.output_dir) / 'training_history.json'
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)
print(f"Saved training history to: {history_path}")

# Save evaluation results
eval_path = Path(config.output_dir) / 'evaluation_results.json'
eval_results = {
    'model_type': 'SPLADEDocExpansion',
    'expansion_mode': config.expansion_mode,
    'noise_suppression': 'ExplicitNoiseTokenLoss',  # v4
    'noise_tokens': list(config.custom_noise_tokens),
    'test_pairs': TEST_PAIRS,
    'results': results,
    'average_activation_rate': avg_rate,
}
with open(eval_path, 'w', encoding='utf-8') as f:
    json.dump(eval_results, f, ensure_ascii=False, indent=2)
print(f"Saved evaluation results to: {eval_path}")

# Save tokenizer for later use
tokenizer.save_pretrained(Path(config.output_dir) / 'best_model')
print(f"Saved tokenizer to: {config.output_dir}/best_model")

print("\n" + "="*60)
print("CROSS-LINGUAL TRAINING COMPLETE (v4)")
print("="*60)
print(f"\nModel: SPLADEDocExpansion (MLM-based vocabulary expansion)")
print(f"Noise suppression: ExplicitNoiseTokenLoss")
print(f"Output directory: {config.output_dir}")
print(f"Best model: {config.output_dir}/best_model/checkpoint.pt")
print(f"Average EN activation: {avg_rate:.1%}")