# Practical Guide to Genome Simulation with stdpopsim

This notebook provides a hands-on guide to using stdpopsim for genome data simulation, focusing on practical applications for imputation research.

## Learning Objectives
- Understand stdpopsim workflow and key concepts
- Generate realistic genomic datasets for different scenarios
- Analyze population genetic patterns in simulated data
- Export data for downstream imputation analysis

## Setup and Configuration

In [None]:
# Essential imports
import stdpopsim
import tskit
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# Configuration
np.random.seed(42)
plt.rcParams['figure.figsize'] = (12, 8)

print(f"stdpopsim version: {stdpopsim.__version__}")
print(f"tskit version: {tskit.__version__}")

## Understanding stdpopsim Components

stdpopsim consists of four main components:
1. **Species**: Organism-specific parameters (generation time, genome structure)
2. **Demographic Models**: Population history and migration patterns
3. **Genetic Maps**: Recombination rates across the genome
4. **Simulation Engine**: msprime for coalescent simulation

In [None]:
# Get human species and explore available models
species = stdpopsim.get_species("HomSap")
print(f"Species: {species.name} ({species.common_name})")
print(f"Generation time: {species.generation_time} years")

print("\nAvailable demographic models:")
for i, model in enumerate(species.demographic_models[:5]):  # Show first 5
    print(f"{i+1}. {model.id}: {model.description[:80]}...")

# Select the Out-of-Africa model
model = species.get_demographic_model("OutOfAfrica_3G09")
print(f"\nSelected model: {model.id}")
print(f"Populations: {[pop.name for pop in model.populations]}")

## Chromosome Configuration

We'll use chromosome 22 with reduced length for faster simulation.

In [None]:
# Configure chromosome 22 (10% length for speed)
contig = species.get_contig("chr22", length_multiplier=0.1)

print(f"Chromosome: {contig.chromosome.id}")
print(f"Length: {contig.length:,} bp ({contig.length/1e6:.1f} Mb)")
print(f"Mutation rate: {contig.mutation_rate:.2e} per bp per generation")
print(f"Mean recombination rate: {contig.recombination_map.mean_rate:.2e}")

## Simulation Workflow

### Step 1: Basic Single Population Simulation

In [None]:
def run_basic_simulation(population="CHB", sample_size=1000):
    """
    Run a basic single-population simulation
    """
    print(f"Simulating {sample_size} {population} individuals...")
    
    # Define samples
    samples = {population: sample_size}
    
    # Run simulation
    engine = stdpopsim.get_engine("msprime")
    ts = engine.simulate(
        demographic_model=model,
        contig=contig,
        samples=samples,
        seed=42
    )
    
    # Basic statistics
    print(f"✓ Simulation complete!")
    print(f"  Samples: {ts.num_samples:,}")
    print(f"  Trees: {ts.num_trees:,}")
    print(f"  Mutations: {ts.num_mutations:,}")
    print(f"  Diversity (π): {ts.diversity():.6f}")
    
    return ts

# Run basic simulation
ts_basic = run_basic_simulation("CHB", 1000)

### Step 2: Multi-Population Simulation

In [None]:
def run_multipop_simulation(sample_sizes=None):
    """
    Run multi-population simulation
    """
    if sample_sizes is None:
        sample_sizes = {"YRI": 1000, "CEU": 1000, "CHB": 1000}
    
    print(f"Simulating multiple populations: {sample_sizes}")
    
    # Run simulation
    engine = stdpopsim.get_engine("msprime")
    ts = engine.simulate(
        demographic_model=model,
        contig=contig,
        samples=sample_sizes,
        seed=42
    )
    
    print(f"✓ Multi-population simulation complete!")
    print(f"  Total samples: {ts.num_samples:,}")
    print(f"  Trees: {ts.num_trees:,}")
    print(f"  Mutations: {ts.num_mutations:,}")
    
    return ts

# Run multi-population simulation
ts_multipop = run_multipop_simulation()

## Data Analysis and Visualization

In [None]:
def analyze_tree_sequence(ts, title="Tree Sequence Analysis"):
    """
    Comprehensive analysis of tree sequence
    """
    print(f"\n{title}")
    print("=" * len(title))
    
    # Basic statistics
    diversity = ts.diversity()
    tajimas_d = ts.Tajimas_D()
    wattersons_theta = ts.segregating_sites() / ts.sequence_length
    
    print(f"Nucleotide diversity (π): {diversity:.6f}")
    print(f"Tajima's D: {tajimas_d:.4f}")
    print(f"Watterson's θ: {wattersons_theta:.6f}")
    print(f"Segregating sites: {ts.segregating_sites():,}")
    
    # Site frequency spectrum
    afs = ts.allele_frequency_spectrum(polarised=True)
    print(f"AFS shape: {afs.shape}")
    
    return {
        'diversity': diversity,
        'tajimas_d': tajimas_d,
        'wattersons_theta': wattersons_theta,
        'afs': afs
    }

# Analyze both simulations
stats_basic = analyze_tree_sequence(ts_basic, "Basic Simulation Analysis")
stats_multipop = analyze_tree_sequence(ts_multipop, "Multi-Population Analysis")

In [None]:
def plot_comparison(ts_list, labels, stats_list):
    """
    Create comparison plots
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # 1. Diversity comparison
    diversities = [stats['diversity'] for stats in stats_list]
    bars = axes[0,0].bar(labels, diversities, color=['skyblue', 'lightcoral'])
    axes[0,0].set_ylabel('Nucleotide Diversity (π)')
    axes[0,0].set_title('Genetic Diversity Comparison')
    
    # Add value labels
    for bar, div in zip(bars, diversities):
        axes[0,0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.00001,
                      f'{div:.5f}', ha='center', va='bottom')
    
    # 2. Sample size comparison
    sample_sizes = [ts.num_samples for ts in ts_list]
    bars = axes[0,1].bar(labels, sample_sizes, color=['skyblue', 'lightcoral'])
    axes[0,1].set_ylabel('Number of Samples')
    axes[0,1].set_title('Sample Size Comparison')
    
    # 3. Mutation count comparison
    mutations = [ts.num_mutations for ts in ts_list]
    bars = axes[1,0].bar(labels, mutations, color=['skyblue', 'lightcoral'])
    axes[1,0].set_ylabel('Number of Mutations')
    axes[1,0].set_title('Mutation Count Comparison')
    
    # 4. Site frequency spectrum comparison
    for i, (stats, label) in enumerate(zip(stats_list, labels)):
        afs = stats['afs']
        # Normalize and plot first 20 frequency classes
        afs_norm = afs[:21] / np.sum(afs)
        axes[1,1].plot(range(len(afs_norm)), afs_norm, 'o-', label=label, alpha=0.7)
    
    axes[1,1].set_xlabel('Allele Frequency Class')
    axes[1,1].set_ylabel('Proportion of Sites')
    axes[1,1].set_title('Site Frequency Spectrum')
    axes[1,1].legend()
    axes[1,1].set_yscale('log')
    
    plt.tight_layout()
    plt.show()

# Create comparison plots
plot_comparison([ts_basic, ts_multipop], 
               ['Single Population', 'Multi-Population'],
               [stats_basic, stats_multipop])

## Data Export for Imputation Analysis

Export simulated data in formats suitable for imputation software.

In [None]:
def export_simulation_data(ts, output_prefix, sample_info=None):
    """
    Export tree sequence to various formats
    """
    print(f"Exporting data with prefix: {output_prefix}")
    
    # 1. Save tree sequence (native format)
    ts_file = f"{output_prefix}.trees"
    ts.dump(ts_file)
    print(f"✓ Tree sequence saved: {ts_file}")
    
    # 2. Export to VCF format
    vcf_file = f"{output_prefix}.vcf"
    with open(vcf_file, "w") as f:
        ts.write_vcf(f)
    print(f"✓ VCF file saved: {vcf_file}")
    
    # 3. Create sample information file
    if sample_info is None:
        # Create basic sample info
        sample_info = []
        for i in range(ts.num_samples):
            sample_info.append({
                'sample_id': f'sample_{i}',
                'population': 'unknown',
                'time_period': 'modern'
            })
    
    sample_df = pd.DataFrame(sample_info)
    sample_file = f"{output_prefix}_samples.csv"
    sample_df.to_csv(sample_file, index=False)
    print(f"✓ Sample info saved: {sample_file}")
    
    # 4. Basic statistics summary
    stats_data = {
        'Statistic': ['Total_Samples', 'Total_Trees', 'Total_Mutations', 
                     'Sequence_Length', 'Nucleotide_Diversity', 'Tajimas_D'],
        'Value': [ts.num_samples, ts.num_trees, ts.num_mutations, 
                 int(ts.sequence_length), ts.diversity(), ts.Tajimas_D()]
    }
    
    stats_df = pd.DataFrame(stats_data)
    stats_file = f"{output_prefix}_statistics.csv"
    stats_df.to_csv(stats_file, index=False)
    print(f"✓ Statistics saved: {stats_file}")
    
    return {
        'trees': ts_file,
        'vcf': vcf_file,
        'samples': sample_file,
        'statistics': stats_file
    }

# Export basic simulation
basic_files = export_simulation_data(ts_basic, "../data/basic_simulation")

# Create sample info for multi-population simulation
multipop_sample_info = []
samples_per_pop = ts_multipop.num_samples // 3
populations = ['YRI', 'CEU', 'CHB']

for i, pop in enumerate(populations):
    start_idx = i * samples_per_pop
    end_idx = start_idx + samples_per_pop if i < 2 else ts_multipop.num_samples
    
    for j in range(start_idx, end_idx):
        multipop_sample_info.append({
            'sample_id': f'{pop}_sample_{j-start_idx}',
            'population': pop,
            'time_period': 'modern'
        })

# Export multi-population simulation
multipop_files = export_simulation_data(ts_multipop, "../data/multipop_simulation", 
                                       multipop_sample_info)

## Advanced Applications

### Creating Missing Data for Imputation Testing

In [None]:
def create_missing_data_scenario(ts, missing_rate=0.1, output_prefix="imputation_test"):
    """
    Create datasets with missing data for imputation testing
    """
    print(f"Creating imputation test scenario with {missing_rate*100}% missing data")
    
    # Get genotype matrix
    genotype_matrix = ts.genotype_matrix()
    print(f"Original data shape: {genotype_matrix.shape} (variants x samples)")
    
    # Create missing data mask
    np.random.seed(42)
    missing_mask = np.random.random(genotype_matrix.shape) < missing_rate
    
    # Create incomplete dataset (set missing to -1)
    incomplete_genotypes = genotype_matrix.copy()
    incomplete_genotypes[missing_mask] = -1
    
    # Save complete data (truth)
    truth_file = f"../data/{output_prefix}_truth.csv"
    pd.DataFrame(genotype_matrix.T).to_csv(truth_file, index=False)
    
    # Save incomplete data (for imputation)
    incomplete_file = f"../data/{output_prefix}_incomplete.csv"
    pd.DataFrame(incomplete_genotypes.T).to_csv(incomplete_file, index=False)
    
    # Save missing data mask
    mask_file = f"../data/{output_prefix}_missing_mask.csv"
    pd.DataFrame(missing_mask.T.astype(int)).to_csv(mask_file, index=False)
    
    missing_count = np.sum(missing_mask)
    total_genotypes = genotype_matrix.size
    actual_missing_rate = missing_count / total_genotypes
    
    print(f"✓ Missing data created:")
    print(f"  Total genotypes: {total_genotypes:,}")
    print(f"  Missing genotypes: {missing_count:,}")
    print(f"  Actual missing rate: {actual_missing_rate:.3f}")
    print(f"  Truth data: {truth_file}")
    print(f"  Incomplete data: {incomplete_file}")
    print(f"  Missing mask: {mask_file}")
    
    return {
        'truth': truth_file,
        'incomplete': incomplete_file,
        'mask': mask_file,
        'missing_rate': actual_missing_rate
    }

# Create imputation test scenario
imputation_files = create_missing_data_scenario(ts_multipop, missing_rate=0.15)

## Summary and Next Steps

In [None]:
def print_summary():
    """
    Print comprehensive summary of the simulation session
    """
    print("\n" + "="*60)
    print("STDPOPSIM SIMULATION SESSION SUMMARY")
    print("="*60)
    
    print("\n📊 SIMULATIONS COMPLETED:")
    print(f"1. Basic simulation: {ts_basic.num_samples:,} CHB samples")
    print(f"   - Mutations: {ts_basic.num_mutations:,}")
    print(f"   - Diversity: {ts_basic.diversity():.6f}")
    
    print(f"\n2. Multi-population: {ts_multipop.num_samples:,} total samples")
    print(f"   - Mutations: {ts_multipop.num_mutations:,}")
    print(f"   - Diversity: {ts_multipop.diversity():.6f}")
    
    print("\n📁 FILES GENERATED:")
    all_files = []
    all_files.extend(basic_files.values())
    all_files.extend(multipop_files.values())
    all_files.extend(imputation_files.values())
    
    for i, file_path in enumerate(all_files, 1):
        if isinstance(file_path, str):
            print(f"{i:2d}. {file_path}")
    
    print("\n🔬 READY FOR ANALYSIS:")
    print("✓ VCF files for standard genomics tools")
    print("✓ CSV files for custom analysis")
    print("✓ Missing data scenarios for imputation testing")
    print("✓ Sample metadata for population structure analysis")
    
    print("\n🚀 NEXT STEPS:")
    print("1. Load data into imputation software (BEAGLE, IMPUTE, etc.)")
    print("2. Test different imputation algorithms")
    print("3. Evaluate imputation accuracy using truth data")
    print("4. Experiment with different missing data patterns")
    print("5. Try different demographic models or parameters")
    
    print("\n💡 TIPS:")
    print("- Increase length_multiplier for larger datasets")
    print("- Try different chromosomes for varied LD patterns")
    print("- Experiment with ancient DNA sampling times")
    print("- Use different demographic models for diverse scenarios")
    
    print("\n" + "="*60)

# Print session summary
print_summary()