# v22.0 Data Augmentation

## Key Improvements over v21.4

1. **Single-term Expansion**: 448 → 29,322 triplets (65x increase)
2. **Automatic Synonym Extraction**: From curated lists, Wikipedia, and training data
3. **Hard Negative Mining**: Character overlap-based difficulty scoring

## Data Sources

1. **v21.3 Filtered Pairs**: Base synonym pairs from previous version
2. **HuggingFace Korean Datasets**: Large-scale Korean NLP data
3. **v22.0 Single-term Expanded**: Auto-extracted single-term synonym pairs
4. **Identity Pairs**: Self-reconstruction pairs

In [None]:
import sys
from pathlib import Path

def find_project_root():
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

import json
from collections import defaultdict
from typing import Dict, List, Set
from dataclasses import dataclass, asdict

print(f"Project root: {PROJECT_ROOT}")

## 1. Load v21.3 Filtered Data

In [None]:
# Paths
V21_3_DATA_DIR = PROJECT_ROOT / "dataset" / "v21.3_filtered_enhanced"
V22_DATA_DIR = PROJECT_ROOT / "data" / "v22.0"
HF_DATA_DIR = PROJECT_ROOT / "data" / "huggingface_korean"
V22_DATA_DIR.mkdir(parents=True, exist_ok=True)

# Load v21.3 filtered pairs
filtered_pairs_path = V21_3_DATA_DIR / "filtered_synonym_pairs.jsonl"

v21_3_pairs = []
if filtered_pairs_path.exists():
    with open(filtered_pairs_path, "r", encoding="utf-8") as f:
        for line in f:
            v21_3_pairs.append(json.loads(line))
    print(f"Loaded {len(v21_3_pairs):,} filtered pairs from v21.3")
else:
    print(f"Warning: {filtered_pairs_path} not found")

if v21_3_pairs:
    print("\nSample pair:")
    print(json.dumps(v21_3_pairs[0], ensure_ascii=False, indent=2))

## 2. Load HuggingFace Korean Data

In [None]:
# Load HuggingFace synonym pairs
hf_pairs_path = HF_DATA_DIR / "huggingface_synonym_pairs.jsonl"

hf_pairs = []
if hf_pairs_path.exists():
    with open(hf_pairs_path, "r", encoding="utf-8") as f:
        for line in f:
            hf_pairs.append(json.loads(line))
    print(f"Loaded {len(hf_pairs):,} pairs from HuggingFace datasets")
    
    # Statistics by source
    source_counts = defaultdict(int)
    for pair in hf_pairs:
        source_counts[pair.get("pair_type", "unknown")] += 1
    
    print("\nHuggingFace pairs by source:")
    for source, count in sorted(source_counts.items(), key=lambda x: -x[1]):
        print(f"  {source}: {count:,}")
else:
    print(f"Warning: {hf_pairs_path} not found")
    print("Please run 00_data_loading.ipynb first!")

## 3. Load v22.0 Single-term Expanded Data

This is the key improvement in v22.0 - automatically extracted single-term synonym pairs.

In [None]:
# Load v22.0 single-term expanded data
single_term_path = V22_DATA_DIR / "single_term_expanded.jsonl"

single_term_expanded = []
if single_term_path.exists():
    with open(single_term_path, "r", encoding="utf-8") as f:
        for line in f:
            single_term_expanded.append(json.loads(line))
    print(f"Loaded {len(single_term_expanded):,} single-term expanded triplets")
    
    # Statistics by difficulty
    difficulty_counts = defaultdict(int)
    for item in single_term_expanded:
        difficulty_counts[item.get("difficulty", "unknown")] += 1
    
    print("\nBy difficulty:")
    for diff, count in sorted(difficulty_counts.items()):
        pct = count / len(single_term_expanded) * 100
        print(f"  {diff}: {count:,} ({pct:.1f}%)")
    
    # Show sample
    print("\nSample triplet:")
    print(json.dumps(single_term_expanded[0], ensure_ascii=False, indent=2))
else:
    print(f"Warning: {single_term_path} not found")
    print("Please run: python scripts/extract_synonyms.py")

In [None]:
# Load MS MARCO triplets
msmarco_triplets_path = HF_DATA_DIR / "msmarco_triplets.jsonl"

msmarco_triplets = []
if msmarco_triplets_path.exists():
    with open(msmarco_triplets_path, "r", encoding="utf-8") as f:
        for line in f:
            msmarco_triplets.append(json.loads(line))
    print(f"Loaded {len(msmarco_triplets):,} MS MARCO triplets")
else:
    print(f"Warning: {msmarco_triplets_path} not found")

## 4. Analyze Problem Terms Coverage

In [None]:
PROBLEM_TERMS = [
    "추천",
    "데이터베이스",
    "증상",
    "질환",
    "인슐린",
]

def check_term_coverage(pairs: List[Dict], terms: List[str]) -> Dict[str, Dict]:
    """Check how many times each term appears."""
    coverage = {term: {"as_anchor": 0, "as_positive": 0, "partial": 0} for term in terms}
    
    for pair in pairs:
        anchor = pair.get("anchor", pair.get("source", ""))
        positive = pair.get("positive", pair.get("target", ""))
        
        for term in terms:
            if anchor == term:
                coverage[term]["as_anchor"] += 1
            elif term in anchor:
                coverage[term]["partial"] += 1
            
            if positive == term:
                coverage[term]["as_positive"] += 1
    
    return coverage

# Check v21.3 coverage
v21_3_coverage = check_term_coverage(v21_3_pairs, PROBLEM_TERMS)

print("Problem Term Coverage in v21.3 Data:")
print("=" * 60)
print(f"{'Term':<15} {'As Anchor':>12} {'As Positive':>12} {'Partial':>12}")
print("-" * 60)
for term, stats in v21_3_coverage.items():
    print(f"{term:<15} {stats['as_anchor']:>12} {stats['as_positive']:>12} {stats['partial']:>12}")

In [None]:
# Check v22.0 single-term expanded coverage
if single_term_expanded:
    v22_coverage = check_term_coverage(single_term_expanded, PROBLEM_TERMS)
    
    print("\nProblem Term Coverage in v22.0 Single-term Expanded Data:")
    print("=" * 60)
    print(f"{'Term':<15} {'As Anchor':>12} {'As Positive':>12} {'Partial':>12}")
    print("-" * 60)
    for term in PROBLEM_TERMS:
        old = v21_3_coverage[term]
        new = v22_coverage[term]
        anchor_change = new['as_anchor'] - old['as_anchor']
        pos_change = new['as_positive'] - old['as_positive']
        
        anchor_str = f"{new['as_anchor']} (+{anchor_change})" if anchor_change > 0 else str(new['as_anchor'])
        pos_str = f"{new['as_positive']} (+{pos_change})" if pos_change > 0 else str(new['as_positive'])
        
        print(f"{term:<15} {anchor_str:>12} {pos_str:>12} {new['partial']:>12}")

## 5. Generate Augmented Pairs

In [None]:
@dataclass
class SynonymPair:
    source: str
    target: str
    similarity: float
    category: str
    pair_type: str


def convert_v21_3_pair(pair: Dict) -> SynonymPair:
    """Convert v21.3 pair format to SynonymPair."""
    return SynonymPair(
        source=pair["source"],
        target=pair["target"],
        similarity=pair.get("similarity", 0.8),
        category=pair.get("category", "unknown"),
        pair_type="original",
    )


def convert_hf_pair(pair: Dict) -> SynonymPair:
    """Convert HuggingFace pair format to SynonymPair."""
    return SynonymPair(
        source=pair["source"],
        target=pair["target"],
        similarity=pair.get("similarity", 0.8),
        category=pair.get("category", "huggingface"),
        pair_type=pair.get("pair_type", "huggingface"),
    )


def convert_single_term_triplet(triplet: Dict) -> SynonymPair:
    """Convert single-term triplet to SynonymPair."""
    return SynonymPair(
        source=triplet["anchor"],
        target=triplet["positive"],
        similarity=0.9,
        category="single_term_expanded",
        pair_type="single_term_expanded",
    )

In [None]:
# Convert all pairs
original_pairs = [convert_v21_3_pair(p) for p in v21_3_pairs]
print(f"Original v21.3 pairs: {len(original_pairs):,}")

huggingface_pairs = [convert_hf_pair(p) for p in hf_pairs]
print(f"HuggingFace pairs: {len(huggingface_pairs):,}")

# Convert single-term expanded (extract unique pairs)
single_term_pairs = []
seen_pairs = set()
for triplet in single_term_expanded:
    key = (triplet["anchor"], triplet["positive"])
    if key not in seen_pairs:
        seen_pairs.add(key)
        single_term_pairs.append(convert_single_term_triplet(triplet))
print(f"Single-term expanded pairs: {len(single_term_pairs):,}")

In [None]:
# Merge all pairs
all_pairs = original_pairs + huggingface_pairs + single_term_pairs

# Remove duplicates
seen = set()
unique_pairs = []
for pair in all_pairs:
    key = (pair.source, pair.target)
    if key not in seen:
        seen.add(key)
        unique_pairs.append(pair)

print(f"\nTotal pairs after merge: {len(all_pairs):,}")
print(f"Unique pairs: {len(unique_pairs):,}")

# Statistics by type
type_counts = defaultdict(int)
for pair in unique_pairs:
    type_counts[pair.pair_type] += 1

print(f"\nPairs by type:")
for pair_type, count in sorted(type_counts.items(), key=lambda x: -x[1]):
    pct = count / len(unique_pairs) * 100
    print(f"  {pair_type}: {count:,} ({pct:.1f}%)")

## 6. Verify Problem Term Coverage

In [None]:
# Verify that problem terms are now covered
pair_dicts = [{"source": p.source, "target": p.target} for p in unique_pairs]
new_coverage = check_term_coverage(pair_dicts, PROBLEM_TERMS)

print("Problem Term Coverage After Augmentation:")
print("=" * 70)
print(f"{'Term':<15} {'As Source':>15} {'As Target':>15} {'Improvement':>15}")
print("-" * 70)
for term in PROBLEM_TERMS:
    old_total = v21_3_coverage[term]['as_anchor'] + v21_3_coverage[term]['as_positive']
    new_total = new_coverage[term]['as_anchor'] + new_coverage[term]['as_positive']
    improvement = new_total - old_total
    
    print(f"{term:<15} {new_coverage[term]['as_anchor']:>15} {new_coverage[term]['as_positive']:>15} {'+' + str(improvement):>15}")

## 7. Save Augmented Data

In [None]:
# Save augmented pairs
output_path = V22_DATA_DIR / "augmented_synonym_pairs.jsonl"

with open(output_path, "w", encoding="utf-8") as f:
    for pair in unique_pairs:
        f.write(json.dumps(asdict(pair), ensure_ascii=False) + "\n")

print(f"Saved {len(unique_pairs):,} pairs to {output_path}")

In [None]:
# Save MS MARCO triplets for direct training use
if msmarco_triplets:
    msmarco_output_path = V22_DATA_DIR / "msmarco_direct_triplets.jsonl"
    
    with open(msmarco_output_path, "w", encoding="utf-8") as f:
        for triplet in msmarco_triplets:
            output_triplet = {
                "anchor": triplet["anchor"],
                "positive": triplet["positive"],
                "negative": triplet.get("negative", ""),
                "difficulty": "medium",
                "length_class": "sentence",
                "pair_type": "msmarco_direct",
            }
            if output_triplet["negative"]:
                f.write(json.dumps(output_triplet, ensure_ascii=False) + "\n")
    
    print(f"Saved MS MARCO triplets to {msmarco_output_path}")

## 8. Summary

In [None]:
# Category distribution
category_counts = defaultdict(int)
for pair in unique_pairs:
    category_counts[pair.category] += 1

print("Category Distribution:")
print("=" * 50)
for category, count in sorted(category_counts.items(), key=lambda x: -x[1])[:15]:
    pct = count / len(unique_pairs) * 100
    print(f"  {category:<25}: {count:>8} ({pct:.1f}%)")

# Source length distribution
length_bins = {"1-3 chars": 0, "4-8 chars": 0, "9-20 chars": 0, "20+ chars": 0}
for pair in unique_pairs:
    src_len = len(pair.source)
    if src_len <= 3:
        length_bins["1-3 chars"] += 1
    elif src_len <= 8:
        length_bins["4-8 chars"] += 1
    elif src_len <= 20:
        length_bins["9-20 chars"] += 1
    else:
        length_bins["20+ chars"] += 1

print("\nSource Length Distribution:")
print("=" * 50)
for length, count in length_bins.items():
    pct = count / len(unique_pairs) * 100
    print(f"  {length:<15}: {count:>8} ({pct:.1f}%)")

In [None]:
print("\n" + "=" * 60)
print("v22.0 Data Augmentation Summary")
print("=" * 60)

print(f"\nData Sources:")
print(f"  v21.3 Original Pairs: {len(original_pairs):,}")
print(f"  HuggingFace Pairs: {len(huggingface_pairs):,}")
print(f"  Single-term Expanded: {len(single_term_pairs):,} (NEW in v22.0)")
print(f"  MS MARCO Triplets: {len(msmarco_triplets):,}")

print(f"\nKey Improvement:")
print(f"  Single-term pairs: 448 (v21.4) -> {len(single_term_pairs):,} (v22.0)")
print(f"  Increase: {len(single_term_pairs) / 448:.0f}x")

print(f"\nFinal Output:")
print(f"  Total Unique Pairs: {len(unique_pairs):,}")

print(f"\nOutput Files:")
for f in V22_DATA_DIR.glob("*.jsonl"):
    size_mb = f.stat().st_size / 1024 / 1024
    print(f"  {f.name}: {size_mb:.2f} MB")

## Next Steps

1. Run `02_data_preparation.ipynb` to generate curriculum learning splits
2. The expanded single-term data will be used for Phase 1 training