# üß¨ Genomic Data Augmentation with OmniGenBench

Welcome to this comprehensive tutorial where we'll explore how to perform **intelligent genomic data augmentation** using **OmniGenBench**. This guide will walk you through the process of generating high-quality synthetic genomic sequences that preserve biological patterns and improve model training performance.

### 1. The Machine Learning Challenge: Why Genomic Data Augmentation?

**Genomic data augmentation** is a critical technique in computational biology that addresses several fundamental challenges in genomic machine learning:

- **Limited Training Data**: High-quality labeled genomic datasets are often small and expensive to generate
- **Class Imbalance**: Rare genomic variants and functions are underrepresented in datasets
- **Overfitting Prevention**: Augmentation increases dataset diversity and improves generalization
- **Domain Adaptation**: Bridging gaps between different experimental conditions or species

The power of genomic augmentation lies in its ability to:
- **Generate Realistic Sequences**: Create biologically plausible variants while preserving functional patterns
- **Expand Dataset Size**: Multiply available training data without additional experimental costs
- **Improve Model Robustness**: Enhance model performance on unseen genomic variations
- **Balance Datasets**: Address class imbalance issues in genomic classification tasks

Applications across computational biology:
- **Rare Variant Analysis**: Augment underrepresented mutation patterns for disease prediction
- **Cross-Species Learning**: Generate bridge sequences for evolutionary studies
- **Functional Annotation**: Create training data for poorly characterized genomic regions
- **Model Validation**: Generate test sequences for robustness evaluation

### 2. The Challenge: Biologically-Informed Sequence Generation

Unlike random mutations, intelligent genomic augmentation must preserve:

- **Functional Motifs**: Critical regulatory and coding sequences
- **Structural Constraints**: Secondary structures and folding patterns
- **Evolutionary Patterns**: Codon usage bias and phylogenetic relationships  
- **Statistical Properties**: Nucleotide composition and k-mer frequencies

**Augmentation Process:**

| Original Sequence | Random Mutation | Intelligent Augmentation |
|------------------|-----------------|-------------------------|
| `ATGCGATCG` | `ATGCTATCG` (random) | `ATGCGATCC` (codon-aware) |
| Functional | May break function | Preserves function |

### 3. The Tool: Masked Language Models for Genomic Augmentation

#### Foundation Model Understanding
**OmniGenome** uses masked language modeling (MLM) for intelligent sequence augmentation. This approach:

1. **Masks Strategic Positions**: Selectively mask nucleotides while preserving critical patterns
2. **Predicts Biologically Plausible Alternatives**: Use pre-trained understanding to suggest realistic substitutions
3. **Maintains Sequence Integrity**: Ensure augmented sequences remain biologically valid
4. **Preserves Functional Patterns**: Keep important motifs and regulatory elements intact

### 4. The Workflow: A 4-Step Guide to Genomic Augmentation

```mermaid
flowchart TD
    subgraph "4-Step Workflow for Genomic Data Augmentation"
        A["üì• Step 1: Setup and Configuration<br/>Configure augmentation parameters and models"] --> B["üîß Step 2: Model Initialization<br/>Load pre-trained genomic foundation models"]
        B --> C["üéì Step 3: Sequence Augmentation<br/>Generate diverse, biologically-valid variants"]
        C --> D["üîÆ Step 4: Quality Assessment<br/>Validate and analyze augmented sequences"]
    end

    style A fill:#e1f5fe,stroke:#333,stroke-width:2px
    style B fill:#f3e5f5,stroke:#333,stroke-width:2px
    style C fill:#e8f5e8,stroke:#333,stroke-width:2px
    style D fill:#fff3e0,stroke:#333,stroke-width:2px
```

Let's start generating high-quality genomic training data!

## üöÄ Step 1: Setup and Configuration

This first step focuses on setting up our genomic data augmentation environment and understanding the key parameters that control sequence generation quality.

### 1.1: Environment Setup

First, let's install the required packages for intelligent genomic data augmentation.

In [None]:
# Install required packages
# Note: omnigenbench requires Python 3.8+ and PyTorch 1.12+
!pip install omnigenbench torch transformers tqdm scikit-learn matplotlib seaborn -U

# Verify installation
import sys
print(f"Python version: {sys.version.split()[0]}")
try:
    import omnigenbench
    print(f"‚úÖ OmniGenBench version: {omnigenbench.__version__}")
except ImportError as e:
    print(f"‚ùå Failed to import omnigenbench: {e}")

### 1.2: Import Required Libraries

Next, we import the essential libraries for genomic data augmentation, including specialized tools for sequence analysis and quality assessment.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
from sklearn.metrics import jaccard_score
import json
from tqdm import tqdm
from pathlib import Path
import warnings

from omnigenbench import (
    OmniModelForAugmentation,
    ModelHub,
)

# Environment validation
print("=" * 60)
print("Environment Validation")
print("=" * 60)
print(f"Python version: {__import__('sys').version.split()[0]}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("Note: Running on CPU. Augmentation will be slower but functional.")
print("=" * 60)

# Set up plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Verify data directory exists
data_dir = Path("toy_datasets")
if not data_dir.exists():
    raise FileNotFoundError(
        f"Data directory '{data_dir}' not found. "
        f"Please ensure you're running this notebook from the correct directory."
    )

### 1.3: Understanding Augmentation Parameters

Before we start augmentation, let's understand the key parameters that control the quality and diversity of generated sequences:

#### Critical Parameters
- **noise_ratio**: Proportion of tokens to mask and predict (0.15-0.25 typically optimal)
  - Higher values = more variation but risk losing biological patterns
  - Lower values = safer but less diversity
- **instance_num**: Number of variants per original sequence
  - Each variant uses a different random masking pattern
- **max_length**: Maximum sequence length for processing
  - Longer sequences require more GPU memory
- **model_selection**: Choice of pre-trained genomic foundation model
  - Larger models (186M) produce higher quality but slower
  - Smaller models (52M) are faster but may be less accurate

**Design Philosophy:**
- **Reproducibility**: Set random seeds for deterministic results
- **Biological Validity**: Augmentation preserves sequence properties
- **Efficiency**: Batch processing for speed
- **Transparency**: All parameters are explicit and documented

In [None]:
# Single Source of Truth (SSoT) for all augmentation configuration
# Modify these values to customize augmentation behavior

AUGMENTATION_CONFIG = {
    # Model configuration
    "model_name": "yangheng/OmniGenome-52M",
    
    # Augmentation hyperparameters
    "noise_ratio": 0.20,         # Mask 20% of tokens per sequence
    "max_length": 512,           # Maximum sequence length (adjust based on your data)
    "instance_num": 3,           # Generate 3 variants per original sequence
    "batch_size": 8,             # Batch size for MLM forward pass (reduce if OOM)
    
    # Reproducibility
    "seed": 42,                  # Random seed for reproducible results
    
    # File paths (relative to notebook directory)
    "data_dir": "toy_datasets",
    "train_file": "train.json",
    "test_file": "test.json",
    "output_file": "augmented_sequences.json",
}

# Set random seed for reproducibility
torch.manual_seed(AUGMENTATION_CONFIG["seed"])
np.random.seed(AUGMENTATION_CONFIG["seed"])

# Display configuration
print("üéØ Genomic Data Augmentation Configuration (SSoT):")
print("=" * 60)
for key, value in AUGMENTATION_CONFIG.items():
    print(f"  {key:20s}: {value}")
print("=" * 60)
print(f"\nExpected output per sequence: {AUGMENTATION_CONFIG['instance_num']} variants")
print(f"Approximate masking per variant: ~{int(AUGMENTATION_CONFIG['noise_ratio'] * 100)} nucleotides per 100nt")
print(f"Model parameter count: ~52 million (efficient for CPU/GPU)")

## üöÄ Step 2: Model Initialization

Now let's initialize the genomic augmentation model. The `OmniModelForAugmentation` leverages pre-trained genomic foundation models to generate biologically-informed sequence variants.

### Augmentation Model Features
- **Intelligent Masking**: Strategic selection of positions for variation
- **Contextual Prediction**: Uses surrounding sequence context for realistic substitutions  
- **Batch Processing**: Efficient handling of multiple sequences
- **Quality Control**: Built-in validation of augmented sequences

In [None]:
# Initialize the genomic augmentation model
print("üîß Initializing Genomic Data Augmentation Model...")
print(f"Loading: {AUGMENTATION_CONFIG['model_name']}")

try:
    augmentation_model = OmniModelForAugmentation(
        config_or_model=AUGMENTATION_CONFIG["model_name"],
        noise_ratio=AUGMENTATION_CONFIG["noise_ratio"],
        max_length=AUGMENTATION_CONFIG["max_length"],
        instance_num=AUGMENTATION_CONFIG["instance_num"],
        batch_size=AUGMENTATION_CONFIG["batch_size"]
    )
    
    print("‚úÖ Augmentation model initialized successfully!")
    print("\nüéØ Model Capabilities:")
    print("  [x] Intelligent sequence masking based on genomic patterns")
    print("  [x] Context-aware nucleotide prediction via MLM")
    print("  [x] Batch processing for computational efficiency")
    print("  [x] Preservation of biological sequence properties")
    print(f"  [x] Configured for {AUGMENTATION_CONFIG['instance_num']} variants per sequence")
    
    # Verify model is on correct device
    device = next(augmentation_model.model.parameters()).device
    print(f"\n‚úì Model device: {device}")
    print(f"‚úì AMP enabled: {augmentation_model.use_amp}")
    
except Exception as e:
    print(f"‚ùå Model initialization failed: {str(e)}")
    print("\nTroubleshooting steps:")
    print("  1. Check internet connection (model downloads from HuggingFace Hub)")
    print("  2. Verify transformers version: pip install transformers>=4.25.0")
    print("  3. If behind proxy, set HF_ENDPOINT environment variable")
    raise

## üöÄ Step 3: Sequence Augmentation

Now comes the exciting part! We'll demonstrate different approaches to genomic data augmentation, from single sequences to batch processing of entire datasets.

### Our Augmentation Strategy

We'll explore multiple augmentation scenarios:

1. **Single Sequence Augmentation**: Generate variants for individual sequences
2. **Batch Augmentation**: Process multiple sequences efficiently
3. **File-based Augmentation**: Handle large datasets from files
4. **Quality-controlled Augmentation**: Ensure biological validity of outputs

Let's start with single sequence augmentation to understand the process:

In [None]:
# Demonstrate single sequence augmentation with proper API usage
test_sequences = {
    "Coding sequence": "ATGAAAGCCATTGAGAAGGCAAAACCCCGATGGTCCTTCGCGAA",
    "UTR region": "AUUGAGAUGUUUGCCAUUUUGACCAUCUGACCUUUGCCAUC",
    "Regulatory motif": "TATAAGCCGCGGTGACCTGCAG",
    "Random sequence": "ATCGATCGATCGATCGATCG"
}

print("üéì Demonstrating Single Sequence Augmentation")
print("=" * 70)
print("Using augment() method - generates k variants per sequence")
print("Each variant uses different random masking patterns")
print("=" * 70)

for seq_name, sequence in test_sequences.items():
    print(f"\nüìä Augmenting: {seq_name}")
    print(f"  Original ({len(sequence)}nt): {sequence}")
    
    try:
        # CORRECT API USAGE: augment(seq, k) automatically handles masking and prediction
        # Generates k variants (default k=1, or uses instance_num from config)
        augmented_sequences = augmentation_model.augment(seq=sequence, k=1)
        
        # augment() returns a list, even for k=1
        if augmented_sequences and len(augmented_sequences) > 0:
            augmented_seq = augmented_sequences[0]
            print(f"  Augmented:     {augmented_seq}")
            
            # Analyze differences
            min_len = min(len(sequence), len(augmented_seq))
            differences = sum(1 for a, b in zip(sequence[:min_len], augmented_seq[:min_len]) if a != b)
            
            # Handle length differences (rare but possible)
            len_diff = abs(len(sequence) - len(augmented_seq))
            total_diff = differences + len_diff
            
            similarity = 1 - (total_diff / max(len(sequence), len(augmented_seq)))
            
            print(f"\n  üìà Analysis:")
            print(f"    Changed positions: {differences}/{min_len}")
            if len_diff > 0:
                print(f"    Length difference: {len_diff} (‚ö†Ô∏è unusual)")
            print(f"    Sequence similarity: {similarity:.1%}")
            print(f"    Effective mutation rate: {differences/min_len:.1%}")
            
            # GC content analysis
            def gc_content(seq):
                seq_upper = seq.upper()
                gc_count = seq_upper.count('G') + seq_upper.count('C')
                return gc_count / len(seq) if len(seq) > 0 else 0
            
            orig_gc = gc_content(sequence)
            aug_gc = gc_content(augmented_seq)
            gc_diff = aug_gc - orig_gc
            
            print(f"    GC content: {orig_gc:.1%} ‚Üí {aug_gc:.1%} (Œî{gc_diff:+.1%})")
            
            # Biological validity check
            valid_nucs = set('ATCGUatcgu')
            invalid_orig = sum(1 for c in sequence if c not in valid_nucs)
            invalid_aug = sum(1 for c in augmented_seq if c not in valid_nucs)
            
            if invalid_orig > 0 or invalid_aug > 0:
                print(f"    ‚ö†Ô∏è  Invalid nucleotides: Orig={invalid_orig}, Aug={invalid_aug}")
            else:
                print(f"    ‚úì All nucleotides valid (ATCGU only)")
        else:
            print("  ‚ö†Ô∏è  Warning: No augmented sequence returned")
            print("     This may indicate an issue with the model or input sequence")
            
    except Exception as e:
        print(f"  ‚ùå Augmentation failed: {str(e)}")
        print(f"     Error type: {type(e).__name__}")
        # Uncomment for debugging:
        # import traceback
        # traceback.print_exc()
    
    print("‚îÄ" * 70)

print("\n‚úÖ Single sequence augmentation demonstration complete!")
print("\nüí° Key Takeaways:")
print("  1. Use augment(seq, k) method for proper augmentation workflow")
print("  2. Method automatically handles masking and MLM prediction")
print("  3. Returns list of k augmented sequences (even for k=1)")
print("  4. GC content is generally preserved (¬±5% typical)")
print("  5. Mutation rate depends on noise_ratio parameter")
print(f"  6. Current noise_ratio: {AUGMENTATION_CONFIG['noise_ratio']:.1%}")

### Batch Augmentation for Dataset Expansion

Now let's demonstrate batch augmentation for processing multiple sequences efficiently. This is particularly useful for expanding training datasets.

In [None]:
# Demonstrate file-based batch augmentation
input_file = Path(AUGMENTATION_CONFIG["data_dir"]) / AUGMENTATION_CONFIG["train_file"]
output_file = Path(AUGMENTATION_CONFIG["data_dir"]) / AUGMENTATION_CONFIG["output_file"]

print("üèóÔ∏è Demonstrating File-Based Batch Augmentation")
print("=" * 70)
print(f"üìÇ Input:  {input_file}")
print(f"üìÇ Output: {output_file}")
print("=" * 70)

# Validation: Check input file exists
if not input_file.exists():
    raise FileNotFoundError(f"Input file not found: {input_file}")

try:
    # Load original dataset to understand structure
    print("\n1Ô∏è‚É£ Loading original dataset...")
    with open(input_file, 'r') as f:
        original_data = [json.loads(line.strip()) for line in f if line.strip()]
    
    print(f"   ‚úì Loaded {len(original_data)} sequences")
    
    # Validate data format
    required_keys = {"seq"}  # Must have "seq" key
    if original_data:
        sample = original_data[0]
        if not required_keys.issubset(sample.keys()):
            raise ValueError(
                f"Invalid JSON format. Expected keys: {required_keys}, "
                f"Found: {set(sample.keys())}"
            )
        print(f"   ‚úì Data format validated")
        
        # Show sample
        sample_seq = sample["seq"]
        print(f"\n   üìù Sample original sequence:")
        print(f"      seq: {sample_seq[:60]}..." if len(sample_seq) > 60 else f"      seq: {sample_seq}")
        if "label" in sample:
            print(f"      label: {sample['label'][:60]}..." if len(str(sample['label'])) > 60 else f"      label: {sample['label']}")
    
    # Perform augmentation
    print(f"\n2Ô∏è‚É£ Starting batch augmentation...")
    print(f"   - Input sequences: {len(original_data)}")
    print(f"   - Variants per sequence: {AUGMENTATION_CONFIG['instance_num']}")
    print(f"   - Expected output: {len(original_data) * AUGMENTATION_CONFIG['instance_num']} sequences")
    print(f"   - Batch size: {AUGMENTATION_CONFIG['batch_size']} (for MLM forward pass)")
    
    # Call the augmentation method (with progress bar from tqdm)
    augmentation_model.augment_from_file(
        input_file=str(input_file),
        output_file=str(output_file)
    )
    
    # Verify output
    print(f"\n3Ô∏è‚É£ Verifying output...")
    with open(output_file, 'r') as f:
        augmented_data = [json.loads(line.strip()) for line in f if line.strip()]
    
    print(f"   ‚úì Output file created: {output_file}")
    print(f"   ‚úì Augmented sequences written: {len(augmented_data)}")
    
    # Results summary
    print(f"\n‚úÖ Augmentation Completed Successfully!")
    print("=" * 70)
    print(f"üìä Results Summary:")
    print(f"   Original sequences:    {len(original_data):>6}")
    print(f"   Augmented sequences:   {len(augmented_data):>6}")
    print(f"   Expansion ratio:       {len(augmented_data)/len(original_data):>6.1f}x")
    print(f"   Output file size:      {output_file.stat().st_size / 1024:>6.1f} KB")
    print("=" * 70)
    
    # Show sample augmented sequence
    if augmented_data:
        sample_aug = augmented_data[0]
        print(f"\nüìù Sample augmented sequence:")
        aug_seq = sample_aug.get("aug_seq", sample_aug.get("seq", ""))
        print(f"   aug_seq: {aug_seq[:60]}..." if len(aug_seq) > 60 else f"   aug_seq: {aug_seq}")
    
    # Quick quality check
    print(f"\nüî¨ Quick Quality Check:")
    aug_lengths = [len(item.get("aug_seq", item.get("seq", ""))) for item in augmented_data[:100]]
    orig_lengths = [len(item["seq"]) for item in original_data[:100]]
    
    print(f"   Original avg length: {np.mean(orig_lengths):.1f} ¬± {np.std(orig_lengths):.1f} nt")
    print(f"   Augmented avg length: {np.mean(aug_lengths):.1f} ¬± {np.std(aug_lengths):.1f} nt")
    print(f"   Length preservation: {'‚úì Good' if abs(np.mean(aug_lengths) - np.mean(orig_lengths)) < 5 else '‚ö†Ô∏è Check'}")
    
except FileNotFoundError as e:
    print(f"‚ùå File not found: {str(e)}")
    print("   Ensure toy_datasets directory exists with train.json")
except json.JSONDecodeError as e:
    print(f"‚ùå JSON parsing error: {str(e)}")
    print("   Check input file format - expecting one JSON object per line")
except Exception as e:
    print(f"‚ùå Batch augmentation failed: {str(e)}")
    print(f"   Error type: {type(e).__name__}")
    import traceback
    traceback.print_exc()
    
print("\nüí° Usage Notes:")
print("  - augment_from_file() expects JSON with 'seq' key")
print("  - Output uses 'aug_seq' key to distinguish from originals")
print("  - Original labels/metadata are NOT preserved automatically")
print("  - For training, you may need to merge original + augmented data")

## üîÆ Step 4: Quality Assessment and Analysis

The final step involves comprehensive analysis of our augmented sequences to ensure they maintain biological validity while providing useful diversity for training.

### Quality Assessment Pipeline

Our assessment includes:
1. **Sequence Diversity Analysis**: Measure how different augmented sequences are from originals
2. **Biological Property Conservation**: Check if important sequence characteristics are preserved
3. **Statistical Validation**: Ensure augmented sequences follow expected genomic patterns
4. **Functional Motif Preservation**: Verify that critical sequence elements remain intact

In [None]:
# Comprehensive quality assessment of augmented sequences
def analyze_sequence_properties(sequences, labels=None):
    """Analyze statistical properties of genomic sequences"""
    
    if not sequences or len(sequences) == 0:
        return {
            'num_sequences': 0,
            'avg_length': 0,
            'length_std': 0,
            'gc_content': [],
            'avg_gc_content': 0,
            'gc_content_std': 0,
            'nucleotide_composition': {}
        }
    
    analysis = {
        'num_sequences': len(sequences),
        'avg_length': np.mean([len(seq) for seq in sequences]),
        'length_std': np.std([len(seq) for seq in sequences]),
        'gc_content': [],
        'nucleotide_composition': {'A': [], 'T': [], 'G': [], 'C': [], 'U': []},
    }
    
    for seq in sequences:
        seq_upper = seq.upper()
        length = len(seq_upper)
        
        if length == 0:
            continue
            
        # GC content
        gc = (seq_upper.count('G') + seq_upper.count('C')) / length
        analysis['gc_content'].append(gc)
        
        # Nucleotide composition
        for nuc in ['A', 'T', 'G', 'C', 'U']:
            freq = seq_upper.count(nuc) / length
            analysis['nucleotide_composition'][nuc].append(freq)
    
    # Convert to means and stds
    if analysis['gc_content']:
        analysis['avg_gc_content'] = np.mean(analysis['gc_content'])
        analysis['gc_content_std'] = np.std(analysis['gc_content'])
    else:
        analysis['avg_gc_content'] = 0
        analysis['gc_content_std'] = 0
    
    for nuc in analysis['nucleotide_composition']:
        freqs = analysis['nucleotide_composition'][nuc]
        if freqs:
            analysis['nucleotide_composition'][nuc] = {
                'mean': np.mean(freqs),
                'std': np.std(freqs)
            }
        else:
            analysis['nucleotide_composition'][nuc] = {'mean': 0, 'std': 0}
    
    return analysis

print("üî¨ Performing comprehensive quality assessment...")
print("=" * 70)

# Load both original and augmented data for comparison
input_file_path = Path(AUGMENTATION_CONFIG["data_dir"]) / AUGMENTATION_CONFIG["train_file"]
output_file_path = Path(AUGMENTATION_CONFIG["data_dir"]) / AUGMENTATION_CONFIG["output_file"]

try:
    # Original sequences
    print("\n1Ô∏è‚É£ Loading original sequences...")
    with open(input_file_path, 'r') as f:
        original_data = [json.loads(line.strip()) for line in f if line.strip()]
    original_sequences = [item.get('seq', '') for item in original_data]
    print(f"   ‚úì Loaded {len(original_sequences)} original sequences")
    
    # Augmented sequences
    print("\n2Ô∏è‚É£ Loading augmented sequences...")
    if output_file_path.exists():
        with open(output_file_path, 'r') as f:
            augmented_data = [json.loads(line.strip()) for line in f if line.strip()]
        augmented_sequences = [item.get('aug_seq', item.get('seq', '')) for item in augmented_data]
        print(f"   ‚úì Loaded {len(augmented_sequences)} augmented sequences")
    else:
        print(f"   ‚ö†Ô∏è Augmented file not found: {output_file_path}")
        print("   Generating sample augmented sequences for demonstration...")
        augmented_sequences = []
        for seq in original_sequences[:5]:  # Augment first 5 sequences
            try:
                aug_seqs = augmentation_model.augment(seq, k=1)
                if aug_seqs:
                    augmented_sequences.extend(aug_seqs)
            except Exception as e:
                print(f"   ‚ö†Ô∏è Failed to augment sequence: {e}")
                augmented_sequences.append(seq)  # Fallback to original
    
    # Analyze both datasets
    print("\n3Ô∏è‚É£ Analyzing sequence properties...")
    original_analysis = analyze_sequence_properties(original_sequences)
    augmented_analysis = analyze_sequence_properties(augmented_sequences)
    
    # Comparative analysis
    print("\n" + "=" * 70)
    print("üéØ QUALITY ASSESSMENT RESULTS")
    print("=" * 70)
    
    print("\nüìà Dataset Size Comparison:")
    print(f"  Original sequences:    {original_analysis['num_sequences']:>6}")
    print(f"  Augmented sequences:   {augmented_analysis['num_sequences']:>6}")
    if original_analysis['num_sequences'] > 0:
        expansion_ratio = augmented_analysis['num_sequences'] / original_analysis['num_sequences']
        print(f"  Dataset expansion:     {expansion_ratio:>6.1f}x")
        print(f"  Expected expansion:    {AUGMENTATION_CONFIG['instance_num']:>6}x (from config)")
    
    print("\nüìè Sequence Length Statistics:")
    print(f"  Original:   {original_analysis['avg_length']:>6.1f} ¬± {original_analysis['length_std']:.1f} nt")
    print(f"  Augmented:  {augmented_analysis['avg_length']:>6.1f} ¬± {augmented_analysis['length_std']:.1f} nt")
    len_preservation = abs(original_analysis['avg_length'] - augmented_analysis['avg_length'])
    status = '‚úÖ Excellent' if len_preservation < 1.0 else '‚úì Good' if len_preservation < 5.0 else '‚ö†Ô∏è Check'
    print(f"  Preservation: {status} (Œî={len_preservation:.1f}nt)")
    
    print("\nüß¨ GC Content Analysis:")
    print(f"  Original:   {original_analysis['avg_gc_content']:>6.1%} ¬± {original_analysis['gc_content_std']:.1%}")
    print(f"  Augmented:  {augmented_analysis['avg_gc_content']:>6.1%} ¬± {augmented_analysis['gc_content_std']:.1%}")
    gc_diff = abs(original_analysis['avg_gc_content'] - augmented_analysis['avg_gc_content'])
    gc_status = '‚úÖ Excellent' if gc_diff < 0.02 else '‚úì Good' if gc_diff < 0.05 else '‚ö†Ô∏è Check' if gc_diff < 0.1 else '‚ùå Large'
    print(f"  Preservation: {gc_status} (Œî={gc_diff:.1%})")
    
    print("\nüî§ Nucleotide Composition Comparison:")
    print("  Nucleotide  Original    Augmented   Difference  Status")
    print("  " + "‚îÄ" * 60)
    for nuc in ['A', 'T', 'U', 'G', 'C']:
        orig_freq = original_analysis['nucleotide_composition'][nuc]['mean']
        aug_freq = augmented_analysis['nucleotide_composition'][nuc]['mean']
        diff = aug_freq - orig_freq
        if orig_freq > 0.01 or aug_freq > 0.01:  # Only show significant nucleotides
            status = '‚úÖ' if abs(diff) < 0.02 else '‚úì' if abs(diff) < 0.05 else '‚ö†Ô∏è' if abs(diff) < 0.10 else '‚ùå'
            print(f"      {nuc}       {orig_freq:>6.1%}      {aug_freq:>6.1%}      {diff:>+6.1%}     {status}")
    
    # Sequence diversity analysis
    if len(original_sequences) > 0 and len(augmented_sequences) > 0:
        print("\nüé≤ Sequence Diversity Assessment:")
        
        # Sample sequences for comparison
        sample_size = min(10, len(original_sequences), len(augmented_sequences))
        orig_sample = original_sequences[:sample_size]
        aug_sample = augmented_sequences[:sample_size]
        
        # Calculate pairwise similarities within each set
        def pairwise_similarity(sequences):
            similarities = []
            for i in range(len(sequences)):
                for j in range(i+1, len(sequences)):
                    seq1, seq2 = sequences[i], sequences[j]
                    min_len = min(len(seq1), len(seq2))
                    if min_len > 0:
                        matches = sum(1 for a, b in zip(seq1[:min_len], seq2[:min_len]) if a == b)
                        similarity = matches / min_len
                        similarities.append(similarity)
            return np.mean(similarities) if similarities else 0
        
        orig_diversity = 1 - pairwise_similarity(orig_sample)
        aug_diversity = 1 - pairwise_similarity(aug_sample)
        
        print(f"  Original set diversity:   {orig_diversity:>6.1%}")
        print(f"  Augmented set diversity:  {aug_diversity:>6.1%}")
        
        if aug_diversity > orig_diversity * 0.95:
            print("  ‚úÖ Augmentation maintained or increased diversity")
        elif aug_diversity > orig_diversity * 0.8:
            print("  ‚úì Augmentation maintained reasonable diversity") 
        else:
            print("  ‚ö†Ô∏è Augmentation may have reduced diversity")
    
    print("\n" + "=" * 70)

except FileNotFoundError as e:
    print(f"‚ùå File not found: {str(e)}")
    print("   Ensure the data directory and files exist.")
except Exception as e:
    print(f"‚ùå Quality assessment failed: {str(e)}")
    print(f"   Error type: {type(e).__name__}")

print(f"\nüéâ Quality assessment completed!")
print("\nüöÄ Your augmented dataset is ready for:")
print("  ‚Ä¢ Training data expansion and class balancing")
print("  ‚Ä¢ Model robustness improvement and regularization")
print("  ‚Ä¢ Cross-validation and generalization testing")
print("  ‚Ä¢ Domain adaptation and transfer learning")
print("  ‚Ä¢ Rare variant analysis and representation")

## üéâ Tutorial Summary and Next Steps

Congratulations! You have successfully completed this comprehensive tutorial on genomic data augmentation with OmniGenBench.

### What You've Learned

You've walked through a complete, end-to-end workflow for intelligent genomic data augmentation. Specifically, you have:

1. **Understood the "Why"**: Gained appreciation for the importance of data augmentation in genomic machine learning and how intelligent augmentation preserves biological patterns while increasing diversity.

2. **Mastered the 4-Step Workflow**:
   - **Step 1: Setup and Configuration**: You learned how to configure augmentation parameters and understand their impact on sequence generation quality.
   - **Step 2: Model Initialization**: You saw how to leverage pre-trained genomic foundation models for context-aware sequence augmentation.
   - **Step 3: Sequence Augmentation**: You implemented both single sequence and batch augmentation strategies for different use cases.
   - **Step 4: Quality Assessment**: You performed comprehensive analysis to validate the biological validity and diversity of augmented sequences.

3. **Advanced Capabilities**: You explored:
   - Intelligent masking strategies that preserve important sequence patterns
   - Context-aware nucleotide prediction for realistic variations
   - Batch processing for efficient dataset expansion
   - Quality control metrics for validating augmented sequences
   - Statistical analysis of sequence properties and diversity

### Next Steps and Applications

Your augmented datasets can now be applied to:
- **Training Data Enhancement**: Expand small or imbalanced genomic datasets
- **Model Robustness**: Improve generalization through increased sequence diversity
- **Rare Variant Analysis**: Generate synthetic examples of underrepresented patterns
- **Cross-Domain Learning**: Bridge gaps between different genomic contexts
- **Validation Studies**: Create test sets for evaluating model robustness

### Best Practices for Genomic Augmentation

1. **Parameter Tuning**: Start with noise_ratio=0.15-0.25 and adjust based on your specific application
2. **Quality Validation**: Always assess biological property conservation in augmented sequences
3. **Diversity Balance**: Ensure augmentation increases diversity without breaking biological constraints
4. **Domain Specificity**: Consider the specific genomic context (coding, regulatory, etc.) when setting parameters
5. **Iterative Refinement**: Use validation metrics to fine-tune augmentation strategies

### Further Learning

Explore our other tutorials to expand your genomic AI toolkit:
- **[mRNA Degradation Prediction](../mRNA_degrad_rate_regression/)**: Apply augmented data to stability prediction
- **[RNA Secondary Structure Prediction](../rna_secondary_structure_prediction/)**: Use augmentation for structure modeling
- **[Translation Efficiency Prediction](../translation_efficiency_prediction/)**: Enhance training data for efficiency prediction

Thank you for following along. We hope this tutorial has provided you with the knowledge and tools to effectively augment genomic datasets for your machine learning research. Intelligent data augmentation is a powerful technique for advancing genomic AI!

**Happy augmenting and discovering! üß¨‚ú®**

### Advanced Usage: Understanding the Augmentation API

The `OmniModelForAugmentation` provides three levels of API for different use cases:

#### High-Level API (Recommended)
```python
# Generate k augmented variants (automatic masking + prediction)
augmented_seqs = augmentation_model.augment(seq="AUGCGAUCGAG", k=3)
# Returns: list of 3 augmented sequences
```

#### Batch Processing API
```python
# Process multiple sequences efficiently
sequences = ["ATCGATCG", "GCTAGCTA", "CGATCGAT"]
all_augmented = augmentation_model.augment_sequences(sequences)
# Returns: list of len(sequences) * k augmented sequences
```

#### File-Based API (For Large Datasets)
```python
# Process entire datasets from JSON files
augmentation_model.augment_from_file(
    input_file="train.json",    # Each line: {"seq": "ATCG..."}
    output_file="augmented.json"  # Each line: {"aug_seq": "ATCG..."}
)
```

#### Low-Level API (Advanced Users Only)
```python
# Step 1: Manually apply masking
masked_seq = augmentation_model.apply_noise_to_sequence("AUGCGAUCGAG")
print(f"Masked: {masked_seq}")  # e.g., "AUG[MASK]GA[MASK]CGAG"

# Step 2: Predict masked positions (expects already-masked sequence)
augmented_seq = augmentation_model.augment_sequence(masked_seq)
print(f"Predicted: {augmented_seq}")
```

**‚ö†Ô∏è Important API Notes:**
- `augment(seq, k)` - **Recommended** high-level API (auto-masks + predicts)
- `augment_sequences(seqs)` - Batch processing with automatic buffering
- `augment_from_file(input, output)` - File-based processing with progress bar
- `apply_noise_to_sequence(seq)` - Low-level: masking only
- `augment_sequence(masked_seq)` - Low-level: prediction only (expects pre-masked input)

**Expected Input/Output Formats:**

Input JSON (one object per line):
```json
{"seq": "ATCTTGCATTGAAG"}
{"seq": "GGTTTACAGTCCAA"}
```

Output JSON (one object per line):
```json
{"aug_seq": "ATCTTGCATAGAAG"}
{"aug_seq": "GGTTTACAATCCAA"}
```

**Note:** The low-level API (`apply_noise_to_sequence` + `augment_sequence`) is for advanced users who need custom masking strategies. Most users should use the high-level `augment(seq, k)` method.

### Configuration Parameter Reference

**Key hyperparameters and their effects:**

| Parameter | Type | Default | Range | Effect |
|-----------|------|---------|-------|--------|
| `noise_ratio` | float | 0.15 | 0.05-0.30 | Proportion of tokens masked per sequence |
| `max_length` | int | 1026 | 50-2048 | Maximum sequence length (longer = more GPU memory) |
| `instance_num` | int | 1 | 1-10 | Variants generated per input sequence |
| `batch_size` | int | 32 | 1-128 | MLM forward pass batch size (reduce if OOM) |
| `use_amp` | bool | auto | - | Automatic mixed precision (auto-detects CUDA) |

**Tuning Guidelines:**
- **For diversity**: Increase `noise_ratio` to 0.25-0.30
- **For safety**: Decrease `noise_ratio` to 0.10-0.15  
- **For speed**: Increase `batch_size` (if GPU memory allows)
- **For rare variants**: Increase `instance_num` to 5-10

### Integrating Augmentation into Training Pipelines

**Option 1: Pre-augment dataset (recommended for static datasets)**
```python
# Augment once, save to disk, reuse across training runs
augmentation_model.augment_from_file("train.json", "train_augmented.json")

# Later in training code
train_dataset = OmniDatasetForSequenceClassification(
    data_file="train_augmented.json",
    tokenizer=tokenizer
)
```

**Option 2: On-the-fly augmentation (recommended for large datasets)**
```python
class AugmentedDataset(torch.utils.data.Dataset):
    def __init__(self, base_data, augmentor, aug_ratio=0.5):
        self.data = base_data
        self.augmentor = augmentor
        self.aug_ratio = aug_ratio  # 50% augmented, 50% original
    
    def __getitem__(self, idx):
        sample = self.data[idx]
        if np.random.rand() < self.aug_ratio:
            aug_seq = self.augmentor.augment(sample["seq"], k=1)[0]
            return {"seq": aug_seq, "label": sample["label"]}
        return sample
```

**Option 3: Hybrid approach (balanced augmentation)**
```python
# Keep originals + add augmented variants
original_data = load_data("train.json")
augmented_seqs = augmentor.augment_sequences([d["seq"] for d in original_data])

combined_data = original_data + [
    {"seq": aug_seq, "label": original_data[i // k]["label"]}
    for i, aug_seq in enumerate(augmented_seqs)
]
```

## üìö References and Further Reading

### Academic Foundation
1. **Masked Language Modeling**: Devlin et al. (2018). "BERT: Pre-training of Deep Bidirectional Transformers"
2. **Genomic Foundation Models**: Ji et al. (2021). "DNABERT: pre-trained Bidirectional Encoder Representations from Transformers model for DNA-language in genome"
3. **Data Augmentation Theory**: Shorten & Khoshgoftaar (2019). "A survey on Image Data Augmentation for Deep Learning"

### OmniGenBench Documentation
- **API Reference**: [omnigenbench.readthedocs.io](https://omnigenbench.readthedocs.io)
- **Model Hub**: [huggingface.co/yangheng](https://huggingface.co/yangheng)
- **GitHub Repository**: [github.com/yangheng95/OmniGenBench](https://github.com/yangheng95/OmniGenBench)

### Related Tutorials
- `mRNA_degrad_rate_regression/` - Applying augmented data to regression tasks
- `rna_secondary_structure_prediction/` - Structure-aware augmentation
- `translation_efficiency_prediction/` - Functional sequence augmentation

---

**Reproducibility Checklist:**
- ‚úÖ Environment validated (Python/PyTorch/CUDA versions)
- ‚úÖ Random seeds set for deterministic results  
- ‚úÖ Configuration centralized in SSoT (AUGMENTATION_CONFIG)
- ‚úÖ All file paths validated before processing
- ‚úÖ Error handling with informative messages
- ‚úÖ Output verification with statistical checks
- ‚úÖ API usage documented with examples

**Troubleshooting Common Issues:**
- **"CUDA out of memory"**: Reduce `batch_size` or `max_length`
- **"Model download timeout"**: Check HuggingFace Hub connectivity
- **"Invalid nucleotides"**: Verify input sequences contain only ATCGU
- **"Empty output file"**: Check input JSON format (must have "seq" key)

---

¬© 2025 OmniGenBench Project. Licensed under MIT License.