# Association study 

In [None]:
import pandas as pd
import numpy as np
import allel
import plotly.express as px


def _print_filter_stats(stage_label, total_snps, miss_mask, maf_mask, max_missing_filter, min_maf_filter, extra_removed=None):
    """Consistent SNP filtering report for any stage."""
    final_mask = miss_mask & maf_mask
    removed_missing = int((~miss_mask).sum())
    removed_maf = int((~maf_mask).sum())
    removed_overall = int((~final_mask).sum())
    max_missing_pct = max_missing_filter * 100
    min_maf_threshold = min_maf_filter

    print(f"SNP filtering stats ({stage_label}):")
    print(f"Total SNPs evaluated: {int(total_snps)}")
    if extra_removed is not None:
        print(f"Removed before MAF/missingness (non-segregating): {int(extra_removed)}")
    print(f"Removed for missingness (>{max_missing_pct}%): {removed_missing}")
    print(f"Removed for low MAF (<={min_maf_threshold}): {removed_maf}")
    print(f"Removed overall (failed either filter): {removed_overall}")
    print(f"SNPs retained after filtering: {int(final_mask.sum())}")


def vcf_to_glm_data(vcf_path, df_samples, sample_query='sample_id.str.contains("Siaya")', max_missing_filter=0.20, min_maf_filter=0.02, split_multiallelic=True, convert_genotypes=True):
    """
    Process a VCF file and prepare genotype data for GLM analysis.

    Parameters:
    vcf_path : str
        Path to the VCF file.
    df_samples : pd.DataFrame
        DataFrame containing sample metadata, including sample IDs and locations.
    max_missing_filter : float, optional
        Maximum allowed proportion of missing genotypes per SNP (default is 0.20).
    min_maf_filter : float, optional
        Minimum minor allele frequency (MAF) threshold for filtering SNPs (default is 0.02).
    split_multiallelic : bool, optional
        Whether to split multiallelic SNPs into separate rows (default is True).
    convert_genotypes : bool, optional
        Whether to convert genotypes to alternate allele counts (0, 1, 2) (default is True).

    Returns:
    pd.DataFrame
        A DataFrame containing processed genotype data, with SNPs as columns and samples as rows.
    """
    vcf_df = vcf_to_df(
        vcf_path,
        df_samples,
        sample_query=sample_query,
        max_missing_filter=max_missing_filter,
        min_maf_filter=min_maf_filter,
    )
    samples = vcf_df.columns[6:]

    if split_multiallelic:
        vcf_df = split_rows_with_multiple_alleles(vcf_df, samples)

    if convert_genotypes:
        vcf_df = convert_genotype_to_alt_allele_count(vcf_df, samples)

    vcf_df = vcf_df.assign(
        snp_id=lambda x: "snp_" + x.CHROM.astype(str) + ":" + x.POS.astype(str) + "_" + x.REF.astype(str) + ">" + x.ALT.astype(str)
    )

    vcf_df = apply_final_snp_filters(
        vcf_df,
        samples=samples,
        max_missing_filter=max_missing_filter,
        min_maf_filter=min_maf_filter,
        genotypes_are_alt_counts=convert_genotypes,
    )

    vcf_df = vcf_df.set_index('snp_id')
    vcf_df = vcf_df.drop(columns=['CHROM', 'POS', 'FILTER_PASS', 'REF', 'ALT', 'ANN']).T
    vcf_df = pd.concat([df_samples.set_index('sample_id').query(sample_query, engine='python'), vcf_df], axis=1)

    return vcf_df


def vcf_to_df(vcf_path, df_samples, sample_query, max_missing_filter=0.20, min_maf_filter=0.02, query2=None):
    """Read VCF, subset samples, and apply initial SNP filtering for performance."""
    vcf_dict = allel.read_vcf(vcf_path, fields='*')
    samples = vcf_dict['samples']
    contig = vcf_dict['variants/CHROM']
    pos = vcf_dict['variants/POS']
    filter_pass = vcf_dict['variants/FILTER_PASS']
    ref = vcf_dict['variants/REF']
    alt = [','.join([a for a in row if a != '']) for row in vcf_dict['variants/ALT']]
    alt = np.array(alt, dtype=object)
    ann = vcf_dict['variants/ANN']
    geno = allel.GenotypeArray(vcf_dict['calldata/GT'])

    print(f"Initial number of samples: {len(samples)}")
    print(f"Initial number of SNPs: {geno.shape[0]}")

    mask = df_samples.eval(sample_query, engine='python')
    sample_mask = np.isin(samples, df_samples[mask].sample_id)

    geno = geno.compress(sample_mask, axis=1)
    samples = samples[sample_mask]

    if query2 is not None:
        mask = df_samples.eval(query2, engine='python')
        geno = geno.compress(mask, axis=1)
        samples = samples[mask]

    print(f"Final number of samples after sample filtering: {len(samples)}")

    ac = geno.count_alleles(max_allele=3)
    seg_mask = ac.is_segregating()

    geno_seg = geno.compress(seg_mask, axis=0)
    ac_seg = geno_seg.count_alleles(max_allele=3)

    miss_mask = geno_seg.is_missing().mean(axis=1) <= max_missing_filter
    maf = np.minimum(ac_seg.to_frequencies()[:, 1:].sum(axis=1), 1 - ac_seg.to_frequencies()[:, 1:].sum(axis=1))
    maf_mask = maf > min_maf_filter

    _print_filter_stats(
        stage_label='initial pre-split',
        total_snps=len(geno_seg),
        miss_mask=miss_mask,
        maf_mask=maf_mask,
        max_missing_filter=max_missing_filter,
        min_maf_filter=min_maf_filter,
        extra_removed=int((~seg_mask).sum()),
    )

    keep_mask_seg = miss_mask & maf_mask

    geno_final = geno_seg.compress(keep_mask_seg, axis=0)
    contig_final = contig[seg_mask][keep_mask_seg]
    pos_final = pos[seg_mask][keep_mask_seg]
    filter_pass_final = filter_pass[seg_mask][keep_mask_seg]
    ref_final = ref[seg_mask][keep_mask_seg]
    alt_final = alt[seg_mask][keep_mask_seg]
    ann_final = ann[seg_mask][keep_mask_seg]

    vcf_df = pd.DataFrame(
        {
            'CHROM': contig_final,
            'POS': pos_final,
            'FILTER_PASS': filter_pass_final,
            'REF': ref_final,
            'ALT': alt_final,
            'ANN': ann_final,
        }
    )
    geno_df = pd.DataFrame(geno_final.to_gt().astype(str), columns=samples)
    vcf = pd.concat([vcf_df, geno_df], axis=1)

    print(f"Final DataFrame shape after initial filtering: {vcf.shape}")

    return vcf


def apply_final_snp_filters(vcf_df, samples, max_missing_filter=0.20, min_maf_filter=0.02, genotypes_are_alt_counts=True):
    """Apply post-split/post-encoding SNP-level missingness and MAF filters."""
    geno_df = vcf_df[samples].copy()

    if not genotypes_are_alt_counts:
        geno_df = geno_df.applymap(
            lambda gt: np.nan if pd.isna(gt) or gt == './.' else sum(allele != '0' for allele in str(gt).split('/'))
        )

    geno_df = geno_df.apply(pd.to_numeric, errors='coerce')

    missing_frac = geno_df.isna().mean(axis=1)
    miss_mask = missing_frac <= max_missing_filter

    called_n = geno_df.notna().sum(axis=1)
    alt_count = geno_df.sum(axis=1, skipna=True)
    denom = 2 * called_n

    alt_freq = pd.Series(np.nan, index=geno_df.index, dtype=float)
    valid = denom > 0
    alt_freq.loc[valid] = alt_count.loc[valid] / denom.loc[valid]

    maf = pd.Series(np.nan, index=geno_df.index, dtype=float)
    maf.loc[valid] = np.minimum(alt_freq.loc[valid], 1 - alt_freq.loc[valid])
    maf_mask = maf > min_maf_filter

    _print_filter_stats(
        stage_label='final post-split',
        total_snps=len(vcf_df),
        miss_mask=miss_mask,
        maf_mask=maf_mask,
        max_missing_filter=max_missing_filter,
        min_maf_filter=min_maf_filter,
        extra_removed=None,
    )

    final_mask = miss_mask & maf_mask
    return vcf_df.loc[final_mask].copy()


def split_rows_with_multiple_alleles(df, samples):
    # Create an empty list to store the new rows
    new_rows = []
    # Iterate through each row
    for index, row in df.iterrows():
        alt_alleles = row['ALT'].split(',')
        # Check if there are multiple alleles in the ALT field
        if len(alt_alleles) > 1:
            for allele_num, allele in enumerate(alt_alleles):
                # Create a new row for each allele
                new_row = row.copy()
                new_row['ALT'] = allele
                # Update genotype fields
                for col in samples:
                    genotype = row[col]
                    # Split the genotype and process it
                    if genotype != './.':
                        gt_alleles = genotype.split('/')
                        new_gt = ['0' if (int(gt) != allele_num + 1 and gt != '0') else gt for gt in gt_alleles]
                        new_row[col] = '/'.join(new_gt)
                new_rows.append(new_row)
        else:
            new_rows.append(row)

    new_df = pd.DataFrame(new_rows).reset_index(drop=True)
    return new_df


def convert_genotype_to_alt_allele_count(df, samples):
    """
    Convert genotype data to alternate allele counts (0, 1, or 2).

    Parameters:
    df : pd.DataFrame
        DataFrame containing genotype data, with genotypes in the format '0/0', '0/1', etc.
    samples : list
        List of sample IDs corresponding to genotype columns in the DataFrame.

    Returns:
    pd.DataFrame
        A DataFrame with genotypes converted to alternate allele counts.
    """
    nalt_df = df.copy()
    # Iterate through each row
    for index, row in df.iterrows():
        # Update genotype fields
        for col in samples:
            genotype = row[col]
            if genotype != './.':
                # Split the genotype and count non-zero alleles
                alleles = genotype.split('/')
                alt_allele_count = sum([1 for allele in alleles if allele != '0'])
                nalt_df.at[index, col] = alt_allele_count
            else:
                nalt_df.at[index, col] = np.nan

    return nalt_df



#### Code -  run GLM and process results

In [None]:
def calculate_pseudo_r2(results):
    """
    Calculate different pseudo R² measures for a fitted GLM model
    
    Parameters:
    results : statsmodels GLMResults object
        The fitted model results
        
    Returns:
    dict : Dictionary containing different pseudo R² measures
    """
    ll_null = results.null_deviance / -2
    ll_model = results.deviance / -2
    n = results.nobs
    
    r2_mcfadden = 1 - (ll_model / ll_null)
    r2_coxsnell = 1 - np.exp((2/n) * (ll_null - ll_model))
    r2_nagelkerke = r2_coxsnell / (1 - np.exp((2/n) * ll_null))
    
    return {
        'McFadden R²': r2_mcfadden,
        'Nagelkerke R²': r2_nagelkerke
    }

def glm_all_snps(snp_df):
    import statsmodels.formula.api as smf
    import statsmodels.api as sm
    from statsmodels.stats.multitest import fdrcorrection
    
    dfs = []
    for snp in snp_df.filter(like="snp_").columns:
        glm_data = snp_df[['phenotype', snp]].dropna()
        glm_data[snp] = glm_data[snp].astype(int)
        glm_data.columns = ['phenotype', 'snp']
    
        # df_contingency = pd.crosstab(glm_data['phenotype'], glm_data['snp'])
        
        formula = f'phenotype ~ snp'
        logit_model = smf.glm(formula=formula, data=glm_data, family=sm.families.Binomial())
        log_results = logit_model.fit()
        
        # Calculate pseudo R² measures
        r2_values = calculate_pseudo_r2(log_results)
    
        pval_data = results_summary_to_dataframe(log_results)
        # Add R² values to the results
        for r2_name, r2_value in r2_values.items():
            pval_data[r2_name] = r2_value
            
        dfs.append(pval_data.assign(snp=snp))
        
    df_eff = pd.concat(dfs).query("index != 'Intercept'")
    res = fdrcorrection(df_eff['pvals'], alpha=0.05)
    df_eff['fdr'] = res[1]
    df_eff['fdr_sig'] = res[0]
    return process_effect_sizes(df_eff)


def process_effect_sizes(df_eff):
    """Validate, annotate, and index per-SNP GLM effect sizes for downstream LD pruning/PGS."""
    # We expect one modeled row per SNP at this stage.
    assert 'snp' in df_eff.columns, "Expected 'snp' column in effect size dataframe."
    assert df_eff['snp'].is_unique, "Expected one GLM result per SNP; found duplicate SNP IDs before annotation merge."

    # Parse SNP IDs (e.g. snp_2L:12345_A>G) back into merge keys.
    parsed = df_eff['snp'].str.extract(r"^snp_(?P<CHROM>[^:]+):(?P<POS>\d+)_(?P<REF>[^>]+)>(?P<ALT>.+)$")
    assert parsed.notna().all().all(), (
        "Unable to parse one or more SNP IDs. Expected format: snp_<CHROM>:<POS>_<REF>><ALT>."
    )

    # Attach parsed keys to effect-size rows.
    df_eff = pd.concat([df_eff, parsed], axis=1)
    df_eff['POS'] = df_eff['POS'].astype(int)

    # Ensure the annotation table has the columns needed to reconstruct SNP IDs.
    required_ann_cols = {'CHROM', 'POS', 'REF', 'ALT', 'ANN'}
    missing_ann_cols = required_ann_cols.difference(snp_data.columns)
    assert not missing_ann_cols, f"snp_data is missing required columns: {sorted(missing_ann_cols)}"

    # Build annotation SNP IDs using the same convention as GLM results.
    ann_df = snp_data[list(required_ann_cols)].copy()
    ann_df = ann_df.assign(
        snp=lambda x: "snp_" + x.CHROM.astype(str) + ":" + x.POS.astype(int).astype(str) + "_" + x.REF.astype(str) + ">" + x.ALT.astype(str)
    )

    # Guard against ambiguous annotation lookups.
    duplicate_ann = ann_df['snp'].duplicated(keep=False)
    assert not duplicate_ann.any(), (
        "Annotation table contains duplicate SNP IDs even after using CHROM+POS+REF+ALT. "
        f"Example duplicates: {ann_df.loc[duplicate_ann, 'snp'].head(5).tolist()}"
    )

    # Merge ANN onto effect sizes and verify it is strictly one-to-one.
    pre_merge_rows = len(df_eff)
    df_eff = df_eff.merge(
        ann_df[['snp', 'ANN']],
        on='snp',
        how='left',
        validate='one_to_one'
    )
    assert len(df_eff) == pre_merge_rows, "Annotation merge changed row count unexpectedly."

    # Fail fast if any scored SNP lacks annotation.
    missing_ann = df_eff['ANN'].isna()
    assert not missing_ann.any(), (
        "Missing annotation for one or more SNPs after merge. "
        f"Example SNPs: {df_eff.loc[missing_ann, 'snp'].head(5).tolist()}"
    )

    # Final shape for downstream functions: SNP ID index + effect/annotation columns.
    cols = ['snp'] + [col for col in df_eff.columns if col != 'snp']
    df_eff = df_eff[cols].drop(columns=['CHROM', 'POS', 'REF', 'ALT'])
    df_eff = df_eff.set_index('snp')
    assert df_eff.index.is_unique, "Effect size index must be unique before LD pruning."
    return df_eff


def results_summary_to_dataframe(results):
    '''take the result of an statsmodel results table and transforms it into a dataframe'''
    pvals = results.pvalues
    coeff = results.params
    conf_lower = results.conf_int()[0]
    conf_higher = results.conf_int()[1]

    results_df = pd.DataFrame({"pvals":pvals,
                               "odds_ratio":np.exp(coeff),
                               "conf_lower":np.exp(conf_lower),
                               "conf_higher":np.exp(conf_higher)
                                })
    
    results_df.loc[:, 'sig'] = [True if pval <= 0.05 else False for pval in pvals]
    results_df = results_df[["odds_ratio","pvals","conf_lower","conf_higher", "sig"]]
    return results_df

In [None]:
df_amp_samples = pd.read_csv("../../../results/config/metadata.qcpass.tsv", sep="\t")
df_amp_samples = df_amp_samples.assign(phenotype=df_amp_samples.sample_id.str.extract(r'_(Dead|Alive)_'))
df_amp_samples = df_amp_samples.rename(columns={'sampleID':'sample_id'})

In [None]:
df_amp_samples.shape

### Perform association tests with separate binomial GLMs

In [None]:
# load snp dataframe for snp annotations 
snp_data = pd.read_excel("../../../results/vcf/amplicons/natcomms-snps.xlsx")

In [None]:
snp_data.head(2)

In [None]:
import warnings 
warnings.filterwarnings("ignore", category=RuntimeWarning, module="pandas.core.arraylike")
warnings.filterwarnings("ignore", category=UserWarning, module="allel.io")

df_genos = {}
df_effs = {}
# methods = ['targets', 'amplicons']
methods  =  ['amplicons']
for method in methods:
    df_genos[method] = vcf_to_glm_data(
        vcf_path=f"../../../results/vcfs/{method}/natcomms.annot.vcf", 
        df_samples=df_amp_samples, 
        sample_query='sample_id.str.contains("Siaya")',
        min_maf_filter=0.02, 
        max_missing_filter=0.20
    )
    print("\n")

    # run a glm on each snp on the input data 
    df_effs[method] = glm_all_snps(df_genos[method])
    df_effs[method].to_csv(f"../../../results/glm-siaya-effect-size-{method}.csv")