In [1]:
import hashlib
import json
import pickle
import allel
import seaborn as sns
import zarr
import plotly.express as px
import dask.array as da
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import stats
from numpy.random import binomial
from Bio import SeqIO
from pathlib import Path
import numpy as np
from scipy import stats
from sklearn.linear_model import LogisticRegression

In [2]:
# Define metadata and qc bool globally to start
#define useful variables
zarr_base_path = f"/scratch/user/uqtdenni/afar_production_bunya/curation/uq-beebe-001/combined_zarr/{{contig}}.zarr"

# Let's start by converting zarrs for the 5 largest contigs - we can do QC on them...
ref_path = '/scratch/user/uqtdenni/afar_production_bunya/reference/VectorBase-54_AfarautiFAR1_Genome.fasta'
# now let's get a list of the contigs that we are going to call over
contig_lengths = {}
for record in SeqIO.parse(ref_path, "fasta"):
    seq_id = record.id
    seq_length = len(record.seq)
    contig_lengths[seq_id] = seq_length
filtered_contigs = {k: v for k, v in sorted(contig_lengths.items(), key=lambda item: item[1], reverse=True) if v > 100000}

# Because these data are unstaged, we need to faff about a bit more and load the unstaged metadata to exclude extra dud samples
df_samples_dirty = pd.read_csv('/scratch/user/uqtdenni/far_hin_1.x/work/metadata_development_20250702/sample_metadata_interim_seq_qc_pass.txt', index_col = 'derived_sample_id')
# And load the final (cleaned) metadata
df_samples = pd.read_csv('/scratch/user/uqtdenni/far_hin_1.x/work/metadata_development_20250702/sample_metadata_pass_qc.txt', index_col = 'derived_sample_id')

# Mask removing samples we removed before the staging step of QC (that I haven't done yet)
qc_bool = df_samples_dirty.index.isin(df_samples.index)

# Zarr location
zarr_base_path = f"/scratch/user/uqtdenni/afar_production_bunya/curation/uq-beebe-001/combined_zarr/{{contig}}.zarr"


In [None]:
# Define helper functions
def load_genotype_array(contig, qc_bool=qc_bool, df_samples=df_samples, sample_query = None):
    # Load gts and remove failed qc samples
    z = zarr.open(zarr_base_path.format(contig=contig))
    
    # Variant-level mask: punctulatus_group_filter_pass
    filter_mask = z[f"{contig}/punctulatus_group_filter_pass"][:]

    gt = allel.GenotypeChunkedArray(z[f"{contig}/calldata/GT"])
    
    # Apply combined variant mask
    gt = allel.GenotypeChunkedArray(z[f"{contig}/calldata/GT"])
    gt = gt.compress(qc_bool, axis=1)          # Filter samples by QC
    gt = gt.compress(filter_mask, axis=0)    # Filter variants
    
    # If an additional mask is supplied to subset the data from the finished metadata, apply, else return all samples
    if sample_query is not None:
        bool_query = np.array(df_samples.eval(sample_query))
        return gt.compress(bool_query, axis=1)
    else:
        return gt

def collapse_ac_array(biallelic_ac_array):
    #ONLY EVER USE THIS IF YOU KNOW YOUR AC ARRAY CONTAINS 2 ELEMENTS PER ROW AND MULTIALLELICS HAVE BEEN FILTERED OUT
        # Preallocate 2-column array
    collapsed_fixed = np.zeros((biallelic_ac_array.shape[0], 2), dtype=biallelic_ac_array.dtype)
    
    # Fill it with nonzero values
    for i, row in enumerate(biallelic_ac_array):
        nz = row[row > 0]
        collapsed_fixed[i, :len(nz)] = nz
    return(collapsed_fixed)

def compute_ac(
    contig, 
    qc_bool=qc_bool, 
    df_samples=df_samples, 
    segregating=False, 
    biallelics_only=False, 
    remove_singletons=False,
    sample_query=None, 
    seg_mask = None, # Provide precomputed seg mask
):
    # Load genotypes
    gt = load_genotype_array(contig, qc_bool, df_samples, sample_query)
    ac = gt.count_alleles()

    # Filter for segregating sites if requested
    if segregating:
        ac = ac.compress(ac.is_segregating())
    
    # Or use a presupplied mask
    if seg_mask:
        ac = ac.compress(seg_mask)

    # Filter for biallelic sites if requested
    if biallelics_only:
        ac = ac.compress(ac.is_biallelic())

    if remove_singletons:
        mask_singleton = np.asarray(ac.is_singleton())
        ac = ac.compress(~mask_singleton)

    return ac

# Set up causal snp architecture - is it a single snp with a massive effect size, like in Aedes, or multiple snps with modest effect sizes, or polygenic (aggrgrhgh)
def setup_causal_architecture(architecture, maf):
    """Setup causal SNPs and effects based on architecture type"""
    
    # Get SNPs with intermediate frequencies
    intermediate_freq_mask = (maf > 0.2) & (maf < 0.5)
    available_snps = np.where(intermediate_freq_mask)[0]
    
    if architecture == "monogenic":
        # Single SNP with massive effect
        n_causal = 1
        causal_snp_indices = np.random.choice(available_snps, size=n_causal, replace=False)
        causal_effects = np.array([np.log(7)])  # OR = 7, large effect
        
    elif architecture == "oligogenic":
        # Few SNPs with moderate effects
        n_causal = 4
        causal_snp_indices = np.random.choice(available_snps, size=n_causal, replace=False)
        causal_effects = np.array([np.log(2.5), np.log(2.0), np.log(1.8)])  # OR = 1.8-2.5
        
    elif architecture == "polygenic":
        # Many SNPs with small effects
        n_causal = 15
        causal_snp_indices = np.random.choice(available_snps, size=n_causal, replace=False)
        causal_effects = np.random.normal(np.log(1.2), 0.1, n_causal)  # OR ~1.1-1.3
        
    return causal_snp_indices, causal_effects

# Simulate larger population (assuming HWE)
def simulate_genotypes(allele_freqs, n_inds):
    n_snps = len(allele_freqs)
    genotypes = np.zeros((n_inds, n_snps))
    
    for i, freq in enumerate(allele_freqs):
        # Binomial sampling: 2 alleles per individual
        genotypes[:, i] = np.random.binomial(2, freq, n_inds)
    
    return genotypes

def run_gwas(genotypes, phenotype):
    """Run GWAS with robust handling of sparse contingency tables"""
    n_snps = genotypes.shape[1]
    p_values = np.ones(n_snps)  # Default to p=1 (no association)
    
    for i in range(n_snps):
        # Create contingency table
        contingency_table = np.array([
            [np.sum((genotypes[:, i] == 0) & (phenotype == 0)),  # AA controls
             np.sum((genotypes[:, i] == 0) & (phenotype == 1))], # AA cases
            [np.sum((genotypes[:, i] == 1) & (phenotype == 0)),  # AB controls  
             np.sum((genotypes[:, i] == 1) & (phenotype == 1))], # AB cases
            [np.sum((genotypes[:, i] == 2) & (phenotype == 0)),  # BB controls
             np.sum((genotypes[:, i] == 2) & (phenotype == 1))]  # BB cases
        ])
        
        # Check for sparse table (zeros in expected frequencies)
        row_sums = contingency_table.sum(axis=1)
        col_sums = contingency_table.sum(axis=0)
        
        # Remove rows/cols with all zeros
        non_zero_rows = row_sums > 0
        non_zero_cols = col_sums > 0
        
        if np.sum(non_zero_rows) < 2 or np.sum(non_zero_cols) < 2:
            # Not enough data for test
            continue
            
        filtered_table = contingency_table[non_zero_rows][:, non_zero_cols]
        
        try:
            chi2, p_val = stats.chi2_contingency(filtered_table)[:2]
            p_values[i] = p_val
        except ValueError:
            # Still problematic - skip this SNP
            continue
    
    return p_values

# More realistic approach - simulate from logistic model
def simulate_phenotype_logistic(risk_scores, baseline_prob=0.5):
    """Simulate phenotype using logistic regression"""
    # Adjust intercept to get desired baseline probability
    intercept = np.log(baseline_prob / (1 - baseline_prob)) - np.mean(risk_scores)
    
    # Calculate probabilities
    log_odds = intercept + risk_scores
    probabilities = 1 / (1 + np.exp(-log_odds))
    
    # Sample binary outcomes
    binary_trait = np.random.binomial(1, probabilities)
    
    return binary_trait

def create_balanced_cases_controls(risk_scores, desired_cases):
    sorted_indices = np.argsort(risk_scores)[::-1]
    binary_trait = np.zeros(len(risk_scores), dtype=int)
    binary_trait[sorted_indices[:desired_cases]] = 1
    return binary_trait

### Simulating GWAS using allele frequencies from a population of interest

We are going to simulate our GWAS using simulated genotypes, generated from the SFS of a smaller sequenced population (that we will return to to study)

Let's start by setting up the genotype and allele frequency data. Compute minor allele frequencies.

In [4]:
# Compute allele_counts                
wp_ac = compute_ac(contig = "KI915040", 
           qc_bool = qc_bool, 
           df_samples=df_samples, 
           sample_query = 'species == "hinesorum" & admin1_iso == "SB-WE"',
           segregating=True,
           biallelics_only=True,
           remove_singletons=True)

# convert to frequencies, select maf (eg column 1)
wp_maf = wp_ac.to_frequencies()[:,1]

Taking the mafs from the population (Solomon Islands Western Province), let's run a simulated GWAS. For simplicity's sake we'll start with a single SNP with a massive odds ratio.

In [5]:
# Set up - small sample size for now, but massive OR of the SNP so should be ok?
n_individuals = 100
arch = "monogenic"
desired_cases = 30 #30 % trait frequency  

# Let's extract causal SNPs of an intermediate frequency for a monogenic trait
causal_snp_indices, causal_effects = setup_causal_architecture(arch, wp_maf)
simulated_genos = simulate_genotypes(wp_maf, n_individuals)

# Calculate risk scores 
risk_scores = np.zeros(n_individuals)
for i, snp_idx in enumerate(causal_snp_indices):
    risk_scores += causal_effects[i] * simulated_genos[:, snp_idx]

# Create balanced case/control split
binary_trait = create_balanced_cases_controls(risk_scores, desired_cases)

print("=== Experiment Setup & Debug===")
print(f"Cases: {binary_trait.sum()}, Controls: {(1-binary_trait).sum()}")
# 3. Verify immediately after assignment
cases_indices = np.where(binary_trait == 1)[0]
print(f"Cases indices: {cases_indices}")
print(f"Cases risk scores: {risk_scores[cases_indices]}")
print(f"All cases have high risk scores: {np.all(risk_scores[cases_indices] >= 2.0)}")

=== Experiment Setup & Debug===
Cases: 30, Controls: 70
Cases indices: [ 6  9 10 11 15 19 20 25 27 28 30 32 35 37 38 42 43 44 46 49 60 66 67 69
 72 76 77 79 80 83]
Cases risk scores: [1.94591015 1.94591015 1.94591015 3.8918203  1.94591015 1.94591015
 3.8918203  1.94591015 1.94591015 1.94591015 1.94591015 1.94591015
 1.94591015 3.8918203  3.8918203  1.94591015 1.94591015 1.94591015
 3.8918203  1.94591015 1.94591015 1.94591015 1.94591015 1.94591015
 1.94591015 1.94591015 1.94591015 1.94591015 1.94591015 1.94591015]
All cases have high risk scores: False


In [6]:
# Run GWAS! fly my pretties!
print("=== Running GWAS, FLY MY PRETTIES, FLY!==")
p_values = run_gwas(simulated_genos, binary_trait)
print(f"P-value at causal SNP: {p_values[causal_snp_indices[0]]:.2e}")

=== Running GWAS, FLY MY PRETTIES, FLY!==


KeyboardInterrupt: 

In [None]:
detected = np.sum(p_values[causal_snp_indices] < 5e-8)
power = detected / len(causal_snp_indices)
power

OK, now we've run our simulated GWAS for a monogenic trait, with a large effect. It looks like we have good power with 100 samples to detect this, so let's run some more simulations with some more complex trait architectures, a range of sample and effect sizes

## Full power analysis

In [7]:
import hashlib
import json
import pickle
import pandas as pd
from pathlib import Path
import numpy as np

def create_parameter_hash(architecture, trait_freq, effect_size, sample_size, n_replicates):
    """Create a hash from simulation parameters for unique file naming"""
    params = f"{architecture}_{trait_freq}_{effect_size}_{sample_size}_{n_replicates}"
    return hashlib.md5(params.encode()).hexdigest()[:8]

def check_simulation_exists(output_path, param_hash):
    """Check if simulation already exists"""
    pickle_file = output_path / f"sim_{param_hash}.pkl"
    return pickle_file.exists()

def load_existing_simulation(output_path, param_hash):
    """Load existing simulation result"""
    pickle_file = output_path / f"sim_{param_hash}.pkl"
    with open(pickle_file, 'rb') as f:
        return pickle.load(f)

def save_simulation_result(result, output_path, param_hash):
    """Save simulation result"""
    pickle_file = output_path / f"sim_{param_hash}.pkl"
    with open(pickle_file, 'wb') as f:
        pickle.dump(result, f)
    return pickle_file

def setup_causal_architecture(architecture_type, allele_freqs, effect_size=None):
    """Setup causal SNPs and effects based on architecture type"""
    intermediate_freq_mask = (allele_freqs > 0.2) & (allele_freqs < 0.5)
    available_snps = np.where(intermediate_freq_mask)[0]
    
    if architecture_type == "monogenic":
        n_causal = 1
        causal_snp_indices = np.random.choice(available_snps, size=n_causal, replace=False)
        or_effect = effect_size if effect_size is not None else 7
        causal_effects = np.array([np.log(or_effect)])
        
    elif architecture_type == "oligogenic":
        n_causal = 3
        causal_snp_indices = np.random.choice(available_snps, size=n_causal, replace=False)
        base_effects = [2.5, 2.0, 1.8]
        if effect_size is not None:
            scale_factor = effect_size / 2.0
            base_effects = [e * scale_factor for e in base_effects]
        causal_effects = np.array([np.log(e) for e in base_effects])
        
    return causal_snp_indices, causal_effects

def power_analysis(output_dir="gwas_results", trait_frequencies=None, effect_sizes=None, n_replicates=10):
    """Run power analysis with caching"""
    
    # Create output directory
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True)
    
    # Default parameters
    if trait_frequencies is None:
        trait_frequencies = [0.1, 0.2, 0.3, 0.4, 0.5]
    
    if effect_sizes is None:
        effect_sizes = {
            "monogenic": [3, 5, 7, 10],
            "oligogenic": [1.5, 2.0, 2.5, 3.0]
        }
    
    all_results = []
    
    for architecture in ["monogenic", "oligogenic"]:
        for trait_freq in trait_frequencies:
            for effect_size in effect_sizes[architecture]:
                for sample_size in [50, 100, 200, 500, 1000]:
                    n_cases = int(sample_size * trait_freq)
                    n_controls = sample_size - n_cases
                    
                    # Skip if too few cases or controls
                    if n_cases < 5 or n_controls < 5:
                        continue
                    
                    # Create parameter hash
                    param_hash = create_parameter_hash(
                        architecture, trait_freq, effect_size, sample_size, n_replicates
                    )
                    
                    # Check if simulation already exists
                    if check_simulation_exists(output_path, param_hash):
                        print(f"Loading cached {architecture}, OR={effect_size}, trait_freq={trait_freq:.1f}, N={sample_size} ({param_hash})")
                        cached_result = load_existing_simulation(output_path, param_hash)
                        
                        # Extract summary for combined results
                        summary_result = {
                            'architecture': architecture,
                            'trait_frequency': trait_freq,
                            'effect_size_OR': effect_size,
                            'sample_size': sample_size,
                            'n_cases': n_cases,
                            'n_controls': n_controls,
                            'mean_power': cached_result['mean_power'],
                            'se_power': cached_result['se_power'],
                            'n_replicates': cached_result['n_replicates'],
                            'param_hash': param_hash
                        }
                        all_results.append(summary_result)
                        continue
                    
                    print(f"Running {architecture}, OR={effect_size}, trait_freq={trait_freq:.1f}, N={sample_size} ({n_cases} cases, {n_controls} controls)...")
                    
                    powers = []
                    
                    # Run replicates
                    for replicate in range(n_replicates):
                        print(f"  Replicate {replicate}")

                        # Generate new population each time
                        simulated_genos = simulate_genotypes(wp_maf, sample_size)
                        causal_snp_indices, causal_effects = setup_causal_architecture(
                            architecture, wp_maf, effect_size
                        )
                        
                        # Calculate risk scores
                        risk_scores = np.zeros(sample_size)
                        for i, snp_idx in enumerate(causal_snp_indices):
                            risk_scores += causal_effects[i] * simulated_genos[:, snp_idx]
                        
                        # Create cases/controls based on trait frequency
                        binary_trait = create_balanced_cases_controls(risk_scores, n_cases)
                        
                        # Run GWAS
                        p_values = run_gwas(simulated_genos, binary_trait)
                        
                        # Check power
                        detected = np.sum(p_values[causal_snp_indices] < 5e-8)
                        power = detected / len(causal_snp_indices)
                        powers.append(power)
                    
                    mean_power = np.mean(powers)
                    se_power = np.std(powers) / np.sqrt(len(powers))
                    
                    # Create result object
                    simulation_result = {
                        'architecture': architecture,
                        'trait_frequency': trait_freq,
                        'effect_size_OR': effect_size,
                        'sample_size': sample_size,
                        'n_cases': n_cases,
                        'n_controls': n_controls,
                        'mean_power': mean_power,
                        'se_power': se_power,
                        'n_replicates': len(powers),
                        'power_values': powers,
                        'param_hash': param_hash
                    }
                    
                    # Save simulation result
                    save_simulation_result(simulation_result, output_path, param_hash)
                    
                    # Add to summary results
                    all_results.append(simulation_result)
                    print(f"  Power = {mean_power:.3f} ± {se_power:.3f} (saved as {param_hash})")
    
    # Save combined results
    global_hash = hashlib.md5(str(len(all_results)).encode()).hexdigest()[:8]
    df = pd.DataFrame(all_results)
    csv_file = output_path / f"power_results_{global_hash}.csv"
    df.to_csv(csv_file, index=False)
    print(f"Combined results saved to {csv_file}")
    
    return all_results, df

def load_simulation_result(output_dir, param_hash):
    """Load a specific simulation result by parameter hash"""
    output_path = Path(output_dir)
    pickle_file = output_path / f"sim_{param_hash}.pkl"
    
    with open(pickle_file, 'rb') as f:
        return pickle.load(f)

# Plotting functions
def plot_power_results(df, save_path=None):
    """Plot power results"""
    import matplotlib.pyplot as plt
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    for i, arch in enumerate(["monogenic", "oligogenic"]):
        arch_data = df[df['architecture'] == arch]
        
        for effect_size in sorted(arch_data['effect_size_OR'].unique()):
            effect_data = arch_data[arch_data['effect_size_OR'] == effect_size]
            
            for trait_freq in sorted(effect_data['trait_frequency'].unique()):
                freq_data = effect_data[effect_data['trait_frequency'] == trait_freq]
                
                axes[i].plot(
                    freq_data['sample_size'], 
                    freq_data['mean_power'],
                    marker='o', 
                    label=f'OR={effect_size}, freq={trait_freq}',
                    alpha=0.7
                )
        
        axes[i].set_title(f'{arch.title()} Architecture')
        axes[i].set_xlabel('Sample Size')
        axes[i].set_ylabel('Power')
        axes[i].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        axes[i].grid(True, alpha=0.3)
        axes[i].set_ylim(0, 1.05)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to {save_path}")
    
    plt.show()
    return fig

def plot_power_heatmap(df, architecture="monogenic", save_path=None):
    """Create heatmap of power by trait frequency vs effect size"""
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    arch_data = df[(df['architecture'] == architecture) & (df['sample_size'] == 200)]
    
    heatmap_data = arch_data.pivot(
        index='trait_frequency', 
        columns='effect_size_OR', 
        values='mean_power'
    )
    
    plt.figure(figsize=(10, 6))
    sns.heatmap(
        heatmap_data, 
        annot=True, 
        fmt='.2f', 
        cmap='viridis',
        cbar_kws={'label': 'Power'}
    )
    plt.title(f'{architecture.title()} Power (N=200)')
    plt.xlabel('Effect Size (OR)')
    plt.ylabel('Trait Frequency')
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Heatmap saved to {save_path}")
    
    plt.show()


In [8]:
# Example usage:
"""
# Run analysis with default parameters (will skip cached simulations)
results, df = power_analysis()

# Run analysis with custom parameters
results, df = power_analysis(
    trait_frequencies=[0.2, 0.5], 
    effect_sizes={"monogenic": [7], "oligogenic": [2.0]},
    n_replicates=5
)

# Run with different output directory
results, df = power_analysis(
    output_dir="my_gwas_results",
    trait_frequencies=[0.1, 0.3, 0.5],
    effect_sizes={"monogenic": [3, 5, 10], "oligogenic": [1.5, 2.5]},
    n_replicates=20
)

# Load a specific simulation by hash
result = load_simulation_result("gwas_results", "abc12345")

# Plot results
plot_power_results(df)
plot_power_heatmap(df, architecture="monogenic")
"""

'\n# Run analysis with default parameters (will skip cached simulations)\nresults, df = power_analysis()\n\n# Run analysis with custom parameters\nresults, df = power_analysis(\n    trait_frequencies=[0.2, 0.5], \n    effect_sizes={"monogenic": [7], "oligogenic": [2.0]},\n    n_replicates=5\n)\n\n# Run with different output directory\nresults, df = power_analysis(\n    output_dir="my_gwas_results",\n    trait_frequencies=[0.1, 0.3, 0.5],\n    effect_sizes={"monogenic": [3, 5, 10], "oligogenic": [1.5, 2.5]},\n    n_replicates=20\n)\n\n# Load a specific simulation by hash\nresult = load_simulation_result("gwas_results", "abc12345")\n\n# Plot results\nplot_power_results(df)\nplot_power_heatmap(df, architecture="monogenic")\n'

In [9]:
results, df = power_analysis(
    trait_frequencies=[0.2, 0.5], 
    effect_sizes={"monogenic": [1, 3, 5, 7, 10], "oligogenic": [1.5, 2, 2.5]},
    n_replicates=5
)

Loading cached monogenic, OR=1, trait_freq=0.2, N=50 (bd400ae4)
Loading cached monogenic, OR=1, trait_freq=0.2, N=100 (2f25a4b2)
Loading cached monogenic, OR=1, trait_freq=0.2, N=200 (f544545e)
Loading cached monogenic, OR=1, trait_freq=0.2, N=500 (3ab267be)
Loading cached monogenic, OR=1, trait_freq=0.2, N=1000 (d8fcf3ad)
Loading cached monogenic, OR=3, trait_freq=0.2, N=50 (aad1a2da)
Loading cached monogenic, OR=3, trait_freq=0.2, N=100 (688ee4e8)
Loading cached monogenic, OR=3, trait_freq=0.2, N=200 (a4856db7)
Loading cached monogenic, OR=3, trait_freq=0.2, N=500 (c425a95b)
Loading cached monogenic, OR=3, trait_freq=0.2, N=1000 (f09262ec)
Loading cached monogenic, OR=5, trait_freq=0.2, N=50 (ecaf991d)
Loading cached monogenic, OR=5, trait_freq=0.2, N=100 (ada34ec3)
Loading cached monogenic, OR=5, trait_freq=0.2, N=200 (0b61b5ef)
Loading cached monogenic, OR=5, trait_freq=0.2, N=500 (d3ae7489)
Loading cached monogenic, OR=5, trait_freq=0.2, N=1000 (57ce8f91)
Loading cached monogenic,

In [None]:
# TO-DO - get this to run in parallel or in dask