In [None]:
import plotly.express as px
import allel
import numpy as np
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

def vcf_to_snp_dataframe(vcf_path, metadata, platform):
    import ampseekertools as amp

    geno, pos, contig, metadata, ref, alt, ann = amp.load_vcf(vcf_path=vcf_path, metadata=metadata, platform=platform)
    
    # make dataframe of variant positions and merge with bed
    snp_df = pd.DataFrame({'contig':contig, 'pos':pos, 'ref':ref, 'alt':[list(a[a != ""]) for a in alt], 'ann':ann})

    snp_df = snp_df.explode('alt').reset_index().rename(columns={'index':'variant_index'})
    snp_df = snp_df.assign(alt_index=snp_df.groupby(['contig','pos']).cumcount() + 1) 
    snp_df = snp_df.assign(label=lambda x: x.pos.astype(str) + " | " +  x.alt.fillna('NA'))
    snp_df.head(2)

    # split and find correct annotation 
    df = snp_df.assign(ann=lambda x: x.ann.str.split(","))
    anns = []
    for i, row in df.iterrows():
        alt = row['alt']
        if row['ann'] == None:
            ann = ""
        else:
            # keep only RD Vgsc annotations
            if 'AGAP004707' in ','.join(row['ann']):
                row['ann'] = [a for a in row['ann'] if "AGAP004707-RD" in a]

            ann = ','.join([a for a in row['ann'] if a.startswith(alt)])
        anns.append(ann)

    snp_df = snp_df.assign(ann=anns)
    
    return snp_df, geno

def calculate_frequencies_cohort(snp_df, metadata, geno, cohort_col, af_filter, missense_filter):
    np.seterr(all="ignore")
    
    df = snp_df.copy()
    
    # get indices of each cohort
    coh_dict = {}
    cohs = metadata[cohort_col].unique()
    cohs = cohs[~pd.isnull(cohs)]
    for coh in cohs:
        coh_dict[coh] = np.where(metadata[cohort_col] == coh)[0]
    
    # get allele counts for each population
    ac = geno.count_alleles_subpops(coh_dict, max_allele=3)
    
    for coh in cohs:
        total_counts = []
        alt_counts = []
        for i, row in df.iterrows():
            var_idx = row['variant_index']
            alt_idx = row['alt_index']
            total_counts.append(ac[coh][var_idx,:].sum())
            alt_counts.append(ac[coh][var_idx, alt_idx])

        df.loc[:, f'count_{coh}'] = np.array(alt_counts)
        df.loc[:, f'frq_{coh}'] = np.round(np.array(alt_counts)/np.array(total_counts), 3)
    
    freq_df = df.set_index('label').filter(like='frq')
    
    ann_df = snp_df.ann.str.split("|", expand=True).iloc[:, :11].drop(columns=[0,7,8])
    ann_df.columns = ['type', 'effect', 'gene', 'geneID', 'modifier', 'transcript', 'base_change', 'aa_change']
    snp_df = pd.concat([snp_df[['contig', 'pos', 'ref', 'alt']], ann_df], axis=1)
    snp_freq_df = pd.concat([snp_df, freq_df.reset_index()], axis=1)

    snp_freq_df = snp_freq_df.assign(label=
                  lambda x: x.contig + " | " + x.gene + " | " + x.pos.astype(str) + " | " + x.aa_change.str.replace("p.", "") + " | " + x.alt.fillna(" ")
                 )
    
    if af_filter:
        af_pass = (snp_freq_df.filter(like='frq') > 0.05).any(axis=1)
        snp_freq_df = snp_freq_df[af_pass]
    
    if missense_filter:
        snp_freq_df = snp_freq_df.query("type == 'missense_variant'")
    
    return snp_freq_df.set_index('label')

def plot_allele_frequencies(df, cohort_col, colscale="Reds"):
        
    fig = px.imshow(
            img=df,
            zmin=0,
            zmax=1,
            width=np.max([800, df.shape[1] * 100]),
            height=200 + (df.shape[0] * 18),
            text_auto=True,
            aspect=1,
            color_continuous_scale=colscale,
            title=f"Allele frequencies | by {cohort_col}",
        template='simple_white'
        )
    fig.update(layout_coloraxis_showscale=False)

    return fig 

In [None]:
dataset = 'nomads16'
metadata_path = "../../results/config/metadata.qcpass.tsv"
cohort_cols = 'library_name'
bed_path = "../../config/nomads16.bed"
vcf_path = f"../../results/vcfs/targets/{dataset}.annot.vcf"
wkdir = "../.."
platform = 'nanopore'

In [None]:
import sys
import os
sys.path.append(os.path.join(wkdir, 'workflow'))
import ampseekertools as amp

### Plotting allele frequencies

This page shows allele frequencies in each cohort of the SNPs genotyped in the amplicon sequencing protocol. Allele frequency refers to the proportion of a specific genetic variant in a population.

In [None]:
cohort_cols = cohort_cols.split(",")

df_bed = pd.read_csv(bed_path, sep="\t", header=None, names=['contig', 'start', 'end', 'amplicon_id', 'mutation', 'ref', 'alt'])

metadata = pd.read_csv(metadata_path, sep="\t")

In [None]:
non_aim_snps = df_bed.query("~mutation.str.contains('AIM')").end.to_numpy()
snp_df, geno = vcf_to_snp_dataframe(vcf_path, metadata, platform=platform)

maf_threshold = 0.02

frq_dfs = []
vgsc_frq_dfs = []
for cohort_col in cohort_cols:
    
    freq_df = calculate_frequencies_cohort(
        snp_df=snp_df, 
        metadata=metadata,
        geno=geno, 
        cohort_col=cohort_col,
        af_filter=maf_threshold,
        missense_filter=False
    )

    freq_df['contig'] = pd.Categorical(freq_df['contig'], categories=['2R', '2L', '3R', '3L', 'X'], ordered=True)
    freq_df = freq_df.sort_values(by=['contig', 'pos'])
    vgsc_freq_df = freq_df.query("gene == 'para'")
    para = True if not vgsc_freq_df.empty else False
    freq_df = freq_df.query("gene ! = 'para' and pos in @non_aim_snps")
    frq_dfs.append(freq_df.reset_index(drop=True))

    if freq_df.empty:
        print(f"No variants found after filtering for cohort {cohort_col} at maf > {maf_threshold}")
        continue

    fig = plot_allele_frequencies(
        df=freq_df.filter(like='frq_'),
        cohort_col=cohort_col,
        colscale="Reds"
    )
    fig.write_image(f"{wkdir}/results/allele_frequencies_{cohort_col}.png", scale=2)
    fig.show()

    if para:
        vgsc_frq_dfs.append(vgsc_freq_df.reset_index(drop=True))
        fig1 = plot_allele_frequencies(
            df=vgsc_freq_df.filter(like='frq_'),
            cohort_col=cohort_col,
            colscale="Oranges"
        )
        fig.write_image(f"{wkdir}/results/allele_frequencies_{cohort_col}_vgsc.png", scale=2)
        fig1.show()

#### SNP frequency summary table

This table summarizes allele frequencies across all cohorts.

In [None]:
pd.set_option("display.max_rows", 200)
pd.set_option('display.max_columns', 100)

if para:
    snp_df = pd.concat([pd.concat(vgsc_frq_dfs), pd.concat(frq_dfs)])
else:
    snp_df = pd.concat(frq_dfs)

snp_df.to_csv(f"{wkdir}/results/snp_frequencies_summary.tsv", sep="\t")
snp_df

#### Allele frequencies of any SNPs across amplicons

This heatmap visualizes missense mutations found across all amplicons, focusing on functionally relevant variants that change amino acid sequences and potentially affect protein function.

In [None]:
vcf_path = f"{wkdir}/results/vcfs/amplicons/{dataset}.annot.vcf"
cohort_col = cohort_cols[0]

snp_df, geno = vcf_to_snp_dataframe(vcf_path, metadata, platform=platform)

snp_freq_df = calculate_frequencies_cohort(
    snp_df=snp_df, 
    metadata=metadata,
    geno=geno, 
    cohort_col=cohort_col, 
    af_filter=0.05,
    missense_filter=True
)   

snp_freq_df = snp_freq_df.filter(like='frq')
snp_freq_df.columns = snp_freq_df.columns.str.replace("frq_", "")

plot_allele_frequencies(
    df=snp_freq_df,
    cohort_col=cohort_col
)