# v22.1 Data Preparation

Generate training triplets with curriculum learning phases optimized for SPLADELossV23.

## Key Features

1. **Length Classification**: single_term, short_phrase, sentence
2. **Hard Negative Mining**: Using BM25Scorer for quality hard negatives
3. **Curriculum Learning Splits**: 3 phases with progressive difficulty
4. **IDF Weight Computation**: For IDFAwareFLOPSLoss in SPLADELossV23

## Curriculum Phases

| Phase | Epochs | Data Focus | Description |
|-------|--------|------------|-------------|
| 1 | 1-7 | 50% single-term, 30% short, 20% sentence | Single-term focus |
| 2 | 8-14 | 33% each (balanced) | Balanced learning |
| 3 | 15-20 | Full data + hard negatives | Final refinement |

In [None]:
import sys
from pathlib import Path


def find_project_root() -> Path:
    """Find the project root directory."""
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent


PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

import json
import random
import math
import numpy as np
import torch
from collections import defaultdict, Counter
from typing import Dict, List, Set, Tuple, Optional
from dataclasses import dataclass, asdict
from tqdm import tqdm

# Set random seeds for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

print(f"Project root: {PROJECT_ROOT}")

## 1. Configuration and Paths

In [None]:
# Paths
V22_0_DATA_DIR = PROJECT_ROOT / "data" / "v22.0"
V22_1_DATA_DIR = PROJECT_ROOT / "data" / "v22.1"
HF_DATA_DIR = PROJECT_ROOT / "data" / "huggingface_korean"

V22_1_DATA_DIR.mkdir(parents=True, exist_ok=True)

# Configuration
CONFIG = {
    "val_ratio": 0.1,
    "n_negatives_per_pair": 3,
    "hard_negative_top_k": 100,
    "jaccard_filter_threshold": 0.9,
    # Curriculum phase ratios
    "phase1_ratios": {"single_term": 0.50, "short_phrase": 0.30, "sentence": 0.20},
    "phase2_ratios": {"single_term": 0.33, "short_phrase": 0.33, "sentence": 0.34},
}

print(f"Input directory: {V22_0_DATA_DIR}")
print(f"Output directory: {V22_1_DATA_DIR}")
print(f"\nConfiguration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## 2. Load Source Data

In [None]:
def load_jsonl(path: Path) -> List[Dict]:
    """Load JSONL file into list of dictionaries."""
    data = []
    if path.exists():
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                data.append(json.loads(line))
    return data


# Load v22.0 augmented pairs
augmented_pairs_path = V22_0_DATA_DIR / "augmented_synonym_pairs.jsonl"
pairs = load_jsonl(augmented_pairs_path)
print(f"Loaded {len(pairs):,} augmented pairs from v22.0")

# Load single-term expanded triplets
single_term_path = V22_0_DATA_DIR / "single_term_expanded.jsonl"
single_term_triplets = load_jsonl(single_term_path)
print(f"Loaded {len(single_term_triplets):,} single-term expanded triplets")

# Load MS MARCO triplets for Phase 3
msmarco_path = V22_0_DATA_DIR / "msmarco_direct_triplets.jsonl"
msmarco_triplets = load_jsonl(msmarco_path)
print(f"Loaded {len(msmarco_triplets):,} MS MARCO direct triplets")

# Build source -> targets mapping for negative filtering
source_to_targets: Dict[str, Set[str]] = defaultdict(set)
for pair in pairs:
    source_to_targets[pair["source"]].add(pair["target"])

print(f"\nUnique sources: {len(source_to_targets):,}")

In [None]:
# Show sample data
if pairs:
    print("Sample augmented pair:")
    print(json.dumps(pairs[0], ensure_ascii=False, indent=2))

if single_term_triplets:
    print("\nSample single-term triplet:")
    print(json.dumps(single_term_triplets[0], ensure_ascii=False, indent=2))

## 3. Length Classification

In [None]:
def get_length_class(text: str) -> str:
    """
    Classify text by character length.
    
    Args:
        text: Input text string
        
    Returns:
        Length class: 'single_term', 'short_phrase', or 'sentence'
    """
    length = len(text)
    if length <= 3:
        return "single_term"
    elif length <= 8:
        return "short_phrase"
    else:
        return "sentence"


# Classify all pairs by source length
pairs_by_length: Dict[str, List[Dict]] = defaultdict(list)
for pair in pairs:
    source_text = pair.get("source", "")
    length_class = get_length_class(source_text)
    pairs_by_length[length_class].append(pair)

print("Pairs by length class:")
print("=" * 50)
for length_class in ["single_term", "short_phrase", "sentence"]:
    count = len(pairs_by_length[length_class])
    pct = count / len(pairs) * 100 if pairs else 0
    print(f"  {length_class:<15}: {count:>8,} ({pct:.1f}%)")

In [None]:
# Show examples for each length class
print("\nExamples by length class:")
print("=" * 60)

for length_class in ["single_term", "short_phrase", "sentence"]:
    examples = pairs_by_length[length_class][:3]
    print(f"\n{length_class}:")
    for ex in examples:
        source = ex.get("source", "")
        target = ex.get("target", "")
        print(f"  '{source}' -> '{target}'")

## 4. Build Corpus for Negative Mining

In [None]:
# Build vocabulary from all available sources
all_terms: Set[str] = set()

# Add terms from augmented pairs
for pair in pairs:
    all_terms.add(pair.get("source", ""))
    all_terms.add(pair.get("target", ""))

# Add terms from single-term triplets
for triplet in single_term_triplets:
    all_terms.add(triplet.get("anchor", ""))
    all_terms.add(triplet.get("positive", ""))
    negative = triplet.get("negative", "")
    if negative:
        all_terms.add(negative)

# Remove empty strings
all_terms.discard("")
all_terms_list = list(all_terms)

print(f"Total unique terms in corpus: {len(all_terms_list):,}")

## 5. Hard Negative Mining with BM25

In [None]:
from src.data.hard_negative_miner import BM25Scorer, HardNegativeMiner

# Initialize BM25 scorer for hard negative mining
print("Fitting BM25 scorer on corpus...")
bm25_scorer = BM25Scorer(k1=1.5, b=0.75, epsilon=0.25)
bm25_scorer.fit(all_terms_list)

print(f"BM25 fitted on {bm25_scorer.corpus_size:,} documents")
print(f"Average document length: {bm25_scorer.avgdl:.2f}")
print(f"Vocabulary size: {len(bm25_scorer.idf):,}")

In [None]:
@dataclass
class TrainingTriplet:
    """Training triplet for contrastive learning."""
    anchor: str
    positive: str
    negative: str
    difficulty: str  # "easy", "medium", "hard"
    length_class: str  # "single_term", "short_phrase", "sentence"
    pair_type: str
    source: str  # dataset source


def compute_char_overlap(text1: str, text2: str) -> float:
    """Compute character overlap ratio between two texts."""
    chars1 = set(text1.lower())
    chars2 = set(text2.lower())
    if not chars1 or not chars2:
        return 0.0
    intersection = len(chars1 & chars2)
    union = len(chars1 | chars2)
    return intersection / union if union > 0 else 0.0


def classify_difficulty(source: str, negative: str) -> str:
    """
    Classify difficulty based on character overlap and length similarity.
    
    Args:
        source: Anchor text
        negative: Negative text
        
    Returns:
        Difficulty level: 'easy', 'medium', or 'hard'
    """
    char_overlap = compute_char_overlap(source, negative)
    len_diff = abs(len(source) - len(negative))
    
    # Hard: High character overlap and similar length
    if char_overlap >= 0.4 and len_diff <= 2:
        return "hard"
    # Medium: Some overlap or similar length
    elif char_overlap >= 0.2 or len_diff <= 3:
        return "medium"
    # Easy: Low overlap and different length
    else:
        return "easy"

In [None]:
def find_hard_negatives_bm25(
    source: str,
    positives: Set[str],
    bm25: BM25Scorer,
    corpus: List[str],
    n: int = 10,
    top_k: int = 100,
) -> List[Tuple[str, str]]:
    """
    Find hard negatives using BM25 scoring.
    
    Args:
        source: Query/anchor text
        positives: Set of positive texts to exclude
        bm25: Fitted BM25Scorer
        corpus: Full corpus list
        n: Number of negatives to return
        top_k: Number of candidates to consider
        
    Returns:
        List of (negative, difficulty) tuples
    """
    # Get top-k candidates by BM25 score
    candidates = bm25.get_top_k(source, k=top_k)
    
    negatives = []
    for idx, score in candidates:
        candidate = corpus[idx]
        
        # Skip if candidate is source or positive
        if candidate == source or candidate in positives:
            continue
        
        # Compute difficulty
        difficulty = classify_difficulty(source, candidate)
        negatives.append((candidate, difficulty))
        
        if len(negatives) >= n:
            break
    
    return negatives


def find_random_negatives(
    source: str,
    positives: Set[str],
    corpus: List[str],
    n: int = 10,
) -> List[Tuple[str, str]]:
    """
    Find random negatives with difficulty classification.
    
    Args:
        source: Query/anchor text
        positives: Set of positive texts to exclude
        corpus: Full corpus list
        n: Number of negatives to return
        
    Returns:
        List of (negative, difficulty) tuples
    """
    # Sample random candidates
    sample_size = min(len(corpus), n * 10)
    candidates = random.sample(corpus, sample_size)
    
    negatives = []
    for candidate in candidates:
        if candidate == source or candidate in positives:
            continue
        
        difficulty = classify_difficulty(source, candidate)
        negatives.append((candidate, difficulty))
        
        if len(negatives) >= n:
            break
    
    return negatives

In [None]:
def generate_triplets_for_pair(
    pair: Dict,
    bm25: BM25Scorer,
    corpus: List[str],
    source_to_targets: Dict[str, Set[str]],
    n_negatives: int = 3,
    use_bm25: bool = True,
) -> List[TrainingTriplet]:
    """
    Generate training triplets for a synonym pair.
    
    Args:
        pair: Dictionary with source and target
        bm25: Fitted BM25Scorer
        corpus: Full corpus list
        source_to_targets: Mapping from source to all its targets
        n_negatives: Number of negatives per pair
        use_bm25: Use BM25 for hard negative mining
        
    Returns:
        List of TrainingTriplet objects
    """
    source = pair.get("source", "")
    target = pair.get("target", "")
    positives = source_to_targets.get(source, {target})
    length_class = get_length_class(source)
    pair_type = pair.get("pair_type", "original")
    data_source = pair.get("category", "unknown")
    
    # Find negatives
    if use_bm25:
        negatives = find_hard_negatives_bm25(
            source, positives, bm25, corpus, n=n_negatives * 2
        )
    else:
        negatives = find_random_negatives(
            source, positives, corpus, n=n_negatives * 2
        )
    
    # Balance by difficulty
    by_difficulty: Dict[str, List[str]] = defaultdict(list)
    for neg, diff in negatives:
        by_difficulty[diff].append(neg)
    
    triplets = []
    for difficulty in ["easy", "medium", "hard"]:
        candidates = by_difficulty[difficulty]
        n_select = min(len(candidates), max(1, n_negatives // 3))
        selected = random.sample(candidates, n_select) if candidates else []
        
        for neg in selected:
            triplets.append(TrainingTriplet(
                anchor=source,
                positive=target,
                negative=neg,
                difficulty=difficulty,
                length_class=length_class,
                pair_type=pair_type,
                source=data_source,
            ))
    
    return triplets

In [None]:
# Generate triplets from augmented pairs
print("Generating triplets from augmented pairs...")
augmented_triplets: List[TrainingTriplet] = []

for pair in tqdm(pairs, desc="Generating triplets"):
    triplets = generate_triplets_for_pair(
        pair=pair,
        bm25=bm25_scorer,
        corpus=all_terms_list,
        source_to_targets=source_to_targets,
        n_negatives=CONFIG["n_negatives_per_pair"],
        use_bm25=True,
    )
    augmented_triplets.extend(triplets)

print(f"\nGenerated {len(augmented_triplets):,} triplets from augmented pairs")

In [None]:
# Convert existing single-term expanded triplets
single_term_training: List[TrainingTriplet] = []

for triplet in single_term_triplets:
    single_term_training.append(TrainingTriplet(
        anchor=triplet["anchor"],
        positive=triplet["positive"],
        negative=triplet.get("negative", ""),
        difficulty=triplet.get("difficulty", "medium"),
        length_class="single_term",
        pair_type="single_term_expanded",
        source="single_term_expanded",
    ))

print(f"Converted {len(single_term_training):,} single-term expanded triplets")

# Merge all triplets
all_triplets = augmented_triplets + single_term_training
print(f"\nTotal triplets: {len(all_triplets):,}")

## 6. Triplet Statistics

In [None]:
def print_triplet_statistics(triplets: List[TrainingTriplet], title: str = "Triplet Statistics"):
    """Print statistics for a list of triplets."""
    print(f"\n{title}")
    print("=" * 60)
    
    # By difficulty
    difficulty_counts: Dict[str, int] = defaultdict(int)
    for t in triplets:
        difficulty_counts[t.difficulty] += 1
    
    print("\nBy Difficulty:")
    for diff in ["easy", "medium", "hard"]:
        count = difficulty_counts[diff]
        pct = count / len(triplets) * 100 if triplets else 0
        print(f"  {diff:<10}: {count:>8,} ({pct:.1f}%)")
    
    # By length class
    length_counts: Dict[str, int] = defaultdict(int)
    for t in triplets:
        length_counts[t.length_class] += 1
    
    print("\nBy Length Class:")
    for lc in ["single_term", "short_phrase", "sentence"]:
        count = length_counts[lc]
        pct = count / len(triplets) * 100 if triplets else 0
        print(f"  {lc:<15}: {count:>8,} ({pct:.1f}%)")
    
    # By pair type (top 10)
    type_counts: Dict[str, int] = defaultdict(int)
    for t in triplets:
        type_counts[t.pair_type] += 1
    
    print("\nBy Pair Type (top 10):")
    for pt, count in sorted(type_counts.items(), key=lambda x: -x[1])[:10]:
        pct = count / len(triplets) * 100 if triplets else 0
        print(f"  {pt:<25}: {count:>8,} ({pct:.1f}%)")


print_triplet_statistics(all_triplets, "All Triplets Statistics")

## 7. Create Curriculum Learning Splits

In [None]:
def create_curriculum_splits(
    triplets: List[TrainingTriplet],
    phase1_ratios: Dict[str, float],
    phase2_ratios: Dict[str, float],
) -> Dict[str, List[TrainingTriplet]]:
    """
    Create curriculum learning splits for SPLADELossV23.
    
    - Phase 1 (epochs 1-7): Single-term focus (50% single, 30% short, 20% sentence)
    - Phase 2 (epochs 8-14): Balanced (33% each)
    - Phase 3 (epochs 15-20): Full data with hard negatives
    
    Args:
        triplets: All training triplets
        phase1_ratios: Ratios for phase 1
        phase2_ratios: Ratios for phase 2
        
    Returns:
        Dictionary of phase name to triplet list
    """
    # Group by length class
    by_length: Dict[str, List[TrainingTriplet]] = defaultdict(list)
    for t in triplets:
        by_length[t.length_class].append(t)
    
    single_term = by_length["single_term"]
    short_phrase = by_length["short_phrase"]
    sentence = by_length["sentence"]
    
    print(f"Available data:")
    print(f"  single_term: {len(single_term):,}")
    print(f"  short_phrase: {len(short_phrase):,}")
    print(f"  sentence: {len(sentence):,}")
    
    # Phase 1: Single-term focus
    # Calculate target sizes based on single_term being 50%
    phase1_base = len(single_term)
    phase1_short_target = int(phase1_base * phase1_ratios["short_phrase"] / phase1_ratios["single_term"])
    phase1_sentence_target = int(phase1_base * phase1_ratios["sentence"] / phase1_ratios["single_term"])
    
    phase1 = single_term.copy()
    phase1 += random.sample(short_phrase, min(len(short_phrase), phase1_short_target))
    phase1 += random.sample(sentence, min(len(sentence), phase1_sentence_target))
    random.shuffle(phase1)
    
    # Phase 2: Balanced learning
    min_class_size = min(len(single_term), len(short_phrase), len(sentence))
    phase2 = []
    phase2 += random.sample(single_term, min(len(single_term), min_class_size))
    phase2 += random.sample(short_phrase, min(len(short_phrase), min_class_size))
    phase2 += random.sample(sentence, min(len(sentence), min_class_size))
    random.shuffle(phase2)
    
    # Phase 3: Full data (all triplets)
    phase3 = triplets.copy()
    random.shuffle(phase3)
    
    return {
        "phase1_single_term_focus": phase1,
        "phase2_balanced": phase2,
        "phase3_full": phase3,
    }


curriculum_splits = create_curriculum_splits(
    all_triplets,
    CONFIG["phase1_ratios"],
    CONFIG["phase2_ratios"],
)

print("\nCurriculum Splits:")
for phase, data in curriculum_splits.items():
    print(f"  {phase}: {len(data):,} triplets")

In [None]:
# Verify length class distribution in each phase
print("\nLength Class Distribution by Phase:")
print("=" * 70)

for phase_name, phase_data in curriculum_splits.items():
    length_dist: Dict[str, int] = defaultdict(int)
    for t in phase_data:
        length_dist[t.length_class] += 1
    
    print(f"\n{phase_name}:")
    for lc in ["single_term", "short_phrase", "sentence"]:
        count = length_dist[lc]
        pct = count / len(phase_data) * 100 if phase_data else 0
        print(f"  {lc:<15}: {count:>8,} ({pct:.1f}%)")

## 8. Add MS MARCO Triplets to Phase 3

In [None]:
# Convert MS MARCO triplets and add to Phase 3
msmarco_training: List[TrainingTriplet] = []

for triplet in msmarco_triplets:
    negative = triplet.get("negative", "")
    if not negative:
        continue
        
    msmarco_training.append(TrainingTriplet(
        anchor=triplet["anchor"],
        positive=triplet["positive"],
        negative=negative,
        difficulty=triplet.get("difficulty", "medium"),
        length_class=triplet.get("length_class", "sentence"),
        pair_type=triplet.get("pair_type", "msmarco_direct"),
        source="msmarco",
    ))

print(f"Converted {len(msmarco_training):,} MS MARCO triplets")

# Add to Phase 3
original_phase3_size = len(curriculum_splits["phase3_full"])
curriculum_splits["phase3_full"].extend(msmarco_training)
random.shuffle(curriculum_splits["phase3_full"])

print(f"Phase 3 size: {original_phase3_size:,} -> {len(curriculum_splits['phase3_full']):,}")

## 9. Train/Validation Split

In [None]:
def train_val_split(
    triplets: List[TrainingTriplet],
    val_ratio: float = 0.1,
) -> Tuple[List[TrainingTriplet], List[TrainingTriplet]]:
    """
    Split triplets into train and validation sets by anchor.
    
    Splits by anchor to prevent data leakage between train and validation.
    
    Args:
        triplets: List of training triplets
        val_ratio: Ratio of anchors for validation
        
    Returns:
        Tuple of (train_triplets, val_triplets)
    """
    # Group by anchor
    by_anchor: Dict[str, List[TrainingTriplet]] = defaultdict(list)
    for t in triplets:
        by_anchor[t.anchor].append(t)
    
    # Shuffle anchors
    anchors = list(by_anchor.keys())
    random.shuffle(anchors)
    
    # Split anchors
    val_size = int(len(anchors) * val_ratio)
    val_anchors = set(anchors[:val_size])
    
    train_triplets = []
    val_triplets = []
    
    for anchor, anchor_triplets in by_anchor.items():
        if anchor in val_anchors:
            val_triplets.extend(anchor_triplets)
        else:
            train_triplets.extend(anchor_triplets)
    
    return train_triplets, val_triplets


# Split full dataset
train_triplets, val_triplets = train_val_split(
    all_triplets,
    val_ratio=CONFIG["val_ratio"],
)

print(f"Train triplets: {len(train_triplets):,}")
print(f"Validation triplets: {len(val_triplets):,}")
print(f"Validation ratio: {len(val_triplets) / (len(train_triplets) + len(val_triplets)) * 100:.1f}%")

## 10. Compute IDF Weights for SPLADELossV23

IDF weights are used in IDFAwareFLOPSLoss to weight token importance.

Formula: `idf[token] = log(1 + (N - df + 0.5) / (df + 0.5))`

In [None]:
def compute_idf_weights(
    triplets: List[TrainingTriplet],
    tokenizer_name: str = "opensearch-project/opensearch-neural-sparse-encoding-multilingual-v1",
) -> torch.Tensor:
    """
    Compute IDF weights for vocabulary based on training data.
    
    Uses BM25-style IDF: log(1 + (N - df + 0.5) / (df + 0.5))
    
    Args:
        triplets: Training triplets
        tokenizer_name: Name of tokenizer to use
        
    Returns:
        IDF weights tensor of shape (vocab_size,)
    """
    from transformers import AutoTokenizer
    
    print(f"Loading tokenizer: {tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    vocab_size = tokenizer.vocab_size
    
    print(f"Vocabulary size: {vocab_size:,}")
    
    # Collect all texts
    all_texts: List[str] = []
    for t in triplets:
        all_texts.append(t.anchor)
        all_texts.append(t.positive)
    
    # Remove duplicates
    unique_texts = list(set(all_texts))
    N = len(unique_texts)
    print(f"Number of unique documents: {N:,}")
    
    # Compute document frequencies
    df = Counter()
    
    print("Computing document frequencies...")
    for text in tqdm(unique_texts, desc="Tokenizing"):
        tokens = tokenizer.encode(text, add_special_tokens=False)
        unique_tokens = set(tokens)
        df.update(unique_tokens)
    
    # Compute IDF for each token
    print("Computing IDF weights...")
    idf_weights = torch.zeros(vocab_size, dtype=torch.float32)
    
    for token_id in range(vocab_size):
        doc_freq = df.get(token_id, 0)
        # BM25-style IDF
        idf = math.log(1 + (N - doc_freq + 0.5) / (doc_freq + 0.5))
        idf_weights[token_id] = idf
    
    # Statistics
    nonzero_count = (idf_weights > 0).sum().item()
    print(f"\nIDF Statistics:")
    print(f"  Non-zero IDF tokens: {nonzero_count:,}")
    print(f"  Min IDF: {idf_weights.min().item():.4f}")
    print(f"  Max IDF: {idf_weights.max().item():.4f}")
    print(f"  Mean IDF: {idf_weights.mean().item():.4f}")
    
    return idf_weights

In [None]:
# Compute IDF weights using training triplets
idf_weights = compute_idf_weights(train_triplets)

In [None]:
# Analyze IDF distribution
import matplotlib.pyplot as plt

# Get non-zero IDF values
nonzero_idf = idf_weights[idf_weights > 0].numpy()

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Histogram of IDF values
axes[0].hist(nonzero_idf, bins=50, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('IDF Value')
axes[0].set_ylabel('Frequency')
axes[0].set_title('IDF Distribution (Non-zero values)')
axes[0].axvline(nonzero_idf.mean(), color='r', linestyle='--', label=f'Mean: {nonzero_idf.mean():.2f}')
axes[0].legend()

# Log-scale histogram
axes[1].hist(nonzero_idf, bins=50, edgecolor='black', alpha=0.7)
axes[1].set_xlabel('IDF Value')
axes[1].set_ylabel('Frequency (log scale)')
axes[1].set_title('IDF Distribution (Log scale)')
axes[1].set_yscale('log')

plt.tight_layout()
plt.show()

# Percentile statistics
percentiles = [10, 25, 50, 75, 90, 95, 99]
print("\nIDF Percentiles:")
for p in percentiles:
    value = np.percentile(nonzero_idf, p)
    print(f"  {p}th percentile: {value:.4f}")

## 11. Save Training Data

In [None]:
def save_triplets(triplets: List[TrainingTriplet], path: Path) -> None:
    """
    Save triplets to JSONL file.
    
    Args:
        triplets: List of training triplets
        path: Output file path
    """
    with open(path, "w", encoding="utf-8") as f:
        for t in triplets:
            f.write(json.dumps(asdict(t), ensure_ascii=False) + "\n")
    print(f"Saved {len(triplets):,} triplets to {path}")


# Save main train/val splits
save_triplets(train_triplets, V22_1_DATA_DIR / "training_triplets.jsonl")
save_triplets(val_triplets, V22_1_DATA_DIR / "validation_triplets.jsonl")

# Save curriculum learning splits
for phase, data in curriculum_splits.items():
    save_triplets(data, V22_1_DATA_DIR / f"{phase}_triplets.jsonl")

# Save IDF weights
idf_path = V22_1_DATA_DIR / "idf_weights.pt"
torch.save(idf_weights, idf_path)
print(f"Saved IDF weights to {idf_path}")

## 12. Verify Problem Term Coverage

In [None]:
PROBLEM_TERMS = ["\ucd94\ucc9c", "\ub370\uc774\ud130\ubca0\uc774\uc2a4", "\uc99d\uc0c1", "\uc9c8\ud658", "\uc778\uc290\ub9b0"]

print("Problem Term Coverage in Training Triplets:")
print("=" * 70)
print(f"{'Term':<15} {'As Anchor':>12} {'As Positive':>12} {'Total':>12}")
print("-" * 70)

for term in PROBLEM_TERMS:
    as_anchor = sum(1 for t in train_triplets if t.anchor == term)
    as_positive = sum(1 for t in train_triplets if t.positive == term)
    total = as_anchor + as_positive
    print(f"{term:<15} {as_anchor:>12} {as_positive:>12} {total:>12}")

## 13. Summary

In [None]:
print("\n" + "=" * 70)
print("v22.1 Data Preparation Summary")
print("=" * 70)

print(f"\nInput Data:")
print(f"  Augmented pairs: {len(pairs):,}")
print(f"  Single-term expanded: {len(single_term_triplets):,}")
print(f"  MS MARCO triplets: {len(msmarco_triplets):,}")

print(f"\nOutput Data:")
print(f"  Total triplets: {len(all_triplets):,}")
print(f"  Training triplets: {len(train_triplets):,}")
print(f"  Validation triplets: {len(val_triplets):,}")

print(f"\nCurriculum Phases (optimized for SPLADELossV23):")
for phase, data in curriculum_splits.items():
    length_dist: Dict[str, int] = defaultdict(int)
    for t in data:
        length_dist[t.length_class] += 1
    
    print(f"\n  {phase} ({len(data):,} triplets):")
    for lc in ["single_term", "short_phrase", "sentence"]:
        count = length_dist[lc]
        pct = count / len(data) * 100 if data else 0
        print(f"    {lc}: {count:,} ({pct:.1f}%)")

print(f"\nIDF Weights:")
print(f"  Shape: {idf_weights.shape}")
print(f"  Non-zero tokens: {(idf_weights > 0).sum().item():,}")
print(f"  Mean IDF: {idf_weights.mean().item():.4f}")

print(f"\nOutput Files:")
for f in sorted(V22_1_DATA_DIR.glob("*")):
    size_mb = f.stat().st_size / 1024 / 1024
    print(f"  {f.name}: {size_mb:.2f} MB")

## 14. Sample Output Verification

In [None]:
# Verify output file format
print("Sample Training Triplet:")
print("=" * 60)
if train_triplets:
    sample = train_triplets[0]
    print(json.dumps(asdict(sample), ensure_ascii=False, indent=2))

print("\nSample from each phase:")
print("=" * 60)
for phase_name, phase_data in curriculum_splits.items():
    if phase_data:
        print(f"\n{phase_name}:")
        sample = phase_data[0]
        print(f"  Anchor: {sample.anchor}")
        print(f"  Positive: {sample.positive}")
        print(f"  Negative: {sample.negative}")
        print(f"  Difficulty: {sample.difficulty}")
        print(f"  Length class: {sample.length_class}")

In [None]:
# Verify IDF weights can be loaded
loaded_idf = torch.load(V22_1_DATA_DIR / "idf_weights.pt")
print(f"IDF weights loaded successfully")
print(f"  Shape: {loaded_idf.shape}")
print(f"  Dtype: {loaded_idf.dtype}")
print(f"  Device: {loaded_idf.device}")

# Verify values match
assert torch.allclose(idf_weights, loaded_idf), "IDF weights mismatch!"
print("  Verification: PASSED")

## Next Steps

1. Run `03_training.ipynb` with SPLADELossV23
2. Use phase-specific data for curriculum learning:
   - Phase 1 (epochs 1-7): `phase1_single_term_focus_triplets.jsonl`
   - Phase 2 (epochs 8-14): `phase2_balanced_triplets.jsonl`
   - Phase 3 (epochs 15-20): `phase3_full_triplets.jsonl`
3. Load `idf_weights.pt` for IDFAwareFLOPSLoss in SPLADELossV23