# v21.4 Data Augmentation

## Problem Analysis from v21.3

Problem terms (추천, 데이터베이스, 증상, 질환, 인슐린) have:
- **ZERO** exact matches as source/target in training synonym pairs
- Only appear embedded in compound words
- Some appear as negatives (counterproductive)

## Solution

1. Add explicit single-term synonym pairs
2. Add identity pairs for self-reconstruction
3. Apply domain-specific filtering thresholds

In [None]:
import sys
from pathlib import Path

def find_project_root():
    current = Path.cwd()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists() or (parent / "src").exists():
            return parent
    return Path.cwd().parent.parent

PROJECT_ROOT = find_project_root()
sys.path.insert(0, str(PROJECT_ROOT))

import json
import os
from collections import defaultdict
from typing import Dict, List, Set, Tuple
from dataclasses import dataclass, asdict

print(f"Project root: {PROJECT_ROOT}")

## 1. Load v21.3 Filtered Data

In [None]:
# Paths
V21_3_DATA_DIR = PROJECT_ROOT / "dataset" / "v21.3_filtered_enhanced"
V21_4_DATA_DIR = PROJECT_ROOT / "data" / "v21.4"
V21_4_DATA_DIR.mkdir(parents=True, exist_ok=True)

# Load v21.3 filtered pairs
filtered_pairs_path = V21_3_DATA_DIR / "filtered_synonym_pairs.jsonl"

v21_3_pairs = []
if filtered_pairs_path.exists():
    with open(filtered_pairs_path, "r", encoding="utf-8") as f:
        for line in f:
            v21_3_pairs.append(json.loads(line))
    print(f"Loaded {len(v21_3_pairs)} filtered pairs from v21.3")
else:
    print(f"Warning: {filtered_pairs_path} not found")

# Show sample
if v21_3_pairs:
    print("\nSample pair:")
    print(json.dumps(v21_3_pairs[0], ensure_ascii=False, indent=2))

## 2. Analyze Problem Terms

In [None]:
# Problem terms identified from v21.3 inference test
PROBLEM_TERMS = [
    "추천",
    "데이터베이스",
    "증상",
    "질환",
    "인슐린",
]

# Check coverage in v21.3 data
def check_term_coverage(pairs: List[Dict], terms: List[str]) -> Dict[str, Dict]:
    """Check how many times each term appears as source/target."""
    coverage = {term: {"as_source": 0, "as_target": 0, "partial_source": 0, "partial_target": 0} 
                for term in terms}
    
    for pair in pairs:
        source = pair.get("source", "")
        target = pair.get("target", "")
        
        for term in terms:
            # Exact match
            if source == term:
                coverage[term]["as_source"] += 1
            elif term in source:
                coverage[term]["partial_source"] += 1
                
            if target == term:
                coverage[term]["as_target"] += 1
            elif term in target:
                coverage[term]["partial_target"] += 1
    
    return coverage

coverage = check_term_coverage(v21_3_pairs, PROBLEM_TERMS)

print("Problem Term Coverage in v21.3 Data:")
print("=" * 70)
print(f"{'Term':<15} {'As Source':>12} {'As Target':>12} {'Partial Src':>12} {'Partial Tgt':>12}")
print("-" * 70)
for term, stats in coverage.items():
    print(f"{term:<15} {stats['as_source']:>12} {stats['as_target']:>12} {stats['partial_source']:>12} {stats['partial_target']:>12}")

## 3. Define Single-term Synonym Pairs

Manually curated synonym pairs for problem terms and other common single terms.

In [None]:
# Single-term synonym pairs - manually curated
SINGLE_TERM_SYNONYMS = {
    # Problem terms from v21.3
    "추천": ["권장", "권유", "제안", "소개", "추천서"],
    "데이터베이스": ["DB", "디비", "저장소", "데이터저장소"],
    "증상": ["증세", "징후", "양상", "현상", "이상"],
    "질환": ["질병", "병", "병증", "환", "이환"],
    "인슐린": ["insulin", "인슐린호르몬", "혈당조절호르몬"],
    
    # General terms
    "검색": ["탐색", "조회", "찾기", "서치", "search"],
    "컴퓨터": ["PC", "computer", "전산", "컴", "피씨"],
    "인공지능": ["AI", "에이아이", "기계지능", "artificial intelligence"],
    "스마트폰": ["휴대폰", "핸드폰", "모바일폰", "smartphone"],
    "프로그래밍": ["코딩", "개발", "programming", "프로그램작성"],
    
    # Legal terms
    "손해배상": ["배상", "보상", "손해보전", "피해배상"],
    "판결": ["판례", "선고", "결정", "심판", "재결"],
    "소송": ["재판", "법적분쟁", "송사", "쟁송"],
    "계약": ["약정", "협약", "협정", "계약서"],
    "위반": ["위법", "불법", "법위반", "규정위반"],
    "피고": ["피고인", "피소인", "소송상대방"],
    "원고": ["소송인", "제소자", "소제기자"],
    "변호사": ["법률가", "법조인", "변호인", "lawyer"],
    
    # Medical terms
    "진단": ["진찰", "검진", "판단", "diagnosis"],
    "치료": ["처치", "요법", "치유", "시술", "therapy"],
    "처방": ["투약", "처방전", "약처방", "prescription"],
    "당뇨병": ["당뇨", "diabetes", "혈당질환"],
    "고혈압": ["혈압높음", "hypertension", "고혈압증"],
    "두통": ["머리아픔", "headache", "편두통"],
    "감기": ["cold", "감기증상", "코감기"],
    "독감": ["인플루엔자", "flu", "influenza"],
}

print(f"Defined {len(SINGLE_TERM_SYNONYMS)} single-term synonym groups")
print(f"Total pairs: {sum(len(v) for v in SINGLE_TERM_SYNONYMS.values())}")

## 4. Generate Augmented Synonym Pairs

In [None]:
@dataclass
class SynonymPair:
    source: str
    target: str
    similarity: float
    category: str
    pair_type: str  # "original", "single_term", "identity"


def generate_single_term_pairs(synonym_dict: Dict[str, List[str]]) -> List[SynonymPair]:
    """Generate bidirectional synonym pairs from dictionary."""
    pairs = []
    
    for source, targets in synonym_dict.items():
        for target in targets:
            # Forward pair
            pairs.append(SynonymPair(
                source=source,
                target=target,
                similarity=0.9,  # High similarity for curated pairs
                category="single_term",
                pair_type="single_term",
            ))
            # Backward pair
            pairs.append(SynonymPair(
                source=target,
                target=source,
                similarity=0.9,
                category="single_term",
                pair_type="single_term",
            ))
    
    return pairs


def generate_identity_pairs(terms: List[str]) -> List[SynonymPair]:
    """Generate identity pairs (term -> term) for self-reconstruction."""
    return [
        SynonymPair(
            source=term,
            target=term,
            similarity=1.0,
            category="identity",
            pair_type="identity",
        )
        for term in terms
    ]


# Generate single-term pairs
single_term_pairs = generate_single_term_pairs(SINGLE_TERM_SYNONYMS)
print(f"Generated {len(single_term_pairs)} single-term synonym pairs")

# Generate identity pairs for all unique terms
all_terms = set(SINGLE_TERM_SYNONYMS.keys())
for targets in SINGLE_TERM_SYNONYMS.values():
    all_terms.update(targets)

identity_pairs = generate_identity_pairs(list(all_terms))
print(f"Generated {len(identity_pairs)} identity pairs")

## 5. Merge with v21.3 Data

In [None]:
def convert_v21_3_pair(pair: Dict) -> SynonymPair:
    """Convert v21.3 pair format to SynonymPair."""
    return SynonymPair(
        source=pair["source"],
        target=pair["target"],
        similarity=pair.get("similarity", 0.8),
        category=pair.get("category", "unknown"),
        pair_type="original",
    )


# Convert v21.3 pairs
original_pairs = [convert_v21_3_pair(p) for p in v21_3_pairs]
print(f"Original v21.3 pairs: {len(original_pairs)}")

# Merge all pairs
all_pairs = original_pairs + single_term_pairs + identity_pairs

# Remove duplicates
seen = set()
unique_pairs = []
for pair in all_pairs:
    key = (pair.source, pair.target)
    if key not in seen:
        seen.add(key)
        unique_pairs.append(pair)

print(f"\nTotal pairs after merge: {len(all_pairs)}")
print(f"Unique pairs: {len(unique_pairs)}")

# Statistics by type
type_counts = defaultdict(int)
for pair in unique_pairs:
    type_counts[pair.pair_type] += 1

print(f"\nPairs by type:")
for pair_type, count in sorted(type_counts.items()):
    print(f"  {pair_type}: {count}")

## 6. Verify Problem Term Coverage

In [None]:
# Verify that problem terms are now covered
pair_dicts = [asdict(p) for p in unique_pairs]
new_coverage = check_term_coverage(pair_dicts, PROBLEM_TERMS)

print("Problem Term Coverage After Augmentation:")
print("=" * 70)
print(f"{'Term':<15} {'As Source':>12} {'As Target':>12} {'Partial Src':>12} {'Partial Tgt':>12}")
print("-" * 70)
for term, stats in new_coverage.items():
    src_change = stats['as_source'] - coverage[term]['as_source']
    tgt_change = stats['as_target'] - coverage[term]['as_target']
    src_str = f"{stats['as_source']} (+{src_change})" if src_change > 0 else str(stats['as_source'])
    tgt_str = f"{stats['as_target']} (+{tgt_change})" if tgt_change > 0 else str(stats['as_target'])
    print(f"{term:<15} {src_str:>12} {tgt_str:>12} {stats['partial_source']:>12} {stats['partial_target']:>12}")

## 7. Save Augmented Data

In [None]:
# Save augmented pairs
output_path = V21_4_DATA_DIR / "augmented_synonym_pairs.jsonl"

with open(output_path, "w", encoding="utf-8") as f:
    for pair in unique_pairs:
        f.write(json.dumps(asdict(pair), ensure_ascii=False) + "\n")

print(f"Saved {len(unique_pairs)} pairs to {output_path}")

# Save single-term pairs separately for reference
single_term_path = V21_4_DATA_DIR / "single_term_pairs.jsonl"
with open(single_term_path, "w", encoding="utf-8") as f:
    for pair in single_term_pairs + identity_pairs:
        f.write(json.dumps(asdict(pair), ensure_ascii=False) + "\n")

print(f"Saved {len(single_term_pairs) + len(identity_pairs)} single-term pairs to {single_term_path}")

## 8. Summary Statistics

In [None]:
# Category distribution
category_counts = defaultdict(int)
for pair in unique_pairs:
    category_counts[pair.category] += 1

print("Category Distribution:")
print("=" * 40)
for category, count in sorted(category_counts.items(), key=lambda x: -x[1]):
    pct = count / len(unique_pairs) * 100
    print(f"  {category:<20}: {count:>6} ({pct:.1f}%)")

# Source length distribution
length_bins = {"1 token": 0, "2-3 tokens": 0, "4-5 tokens": 0, "6+ tokens": 0}
for pair in unique_pairs:
    src_len = len(pair.source)
    if src_len <= 2:
        length_bins["1 token"] += 1
    elif src_len <= 6:
        length_bins["2-3 tokens"] += 1
    elif src_len <= 10:
        length_bins["4-5 tokens"] += 1
    else:
        length_bins["6+ tokens"] += 1

print("\nSource Length Distribution:")
print("=" * 40)
for length, count in length_bins.items():
    pct = count / len(unique_pairs) * 100
    print(f"  {length:<15}: {count:>6} ({pct:.1f}%)")

## Summary

### Data Augmentation Results

| Metric | v21.3 | v21.4 |
|--------|-------|-------|
| Original pairs | - | From v21.3 |
| Single-term pairs | 0 | Added |
| Identity pairs | 0 | Added |
| Problem terms covered | 0 | All 5 |

### Next Steps

1. Run `02_data_preparation.ipynb` to generate training triplets
2. Apply length-balanced sampling for curriculum learning