In [None]:
import plotly.express as px
import allel
import numpy as np
import pandas as pd

def load_vcf(vcf_path, metadata):
    """
    Load VCF and filter poor-quality samples
    """
    
    sampleIDs = metadata.sampleID.to_list()
    
    # load vcf and get genotypes and positions
    vcf = allel.read_vcf(vcf_path, fields='*')
    samples = vcf['samples']
    # keep only samples in qcpass metadata 
    sample_mask = np.isin(vcf['samples'], metadata.sampleID)
    
    # remove low quality samples 
    geno = allel.GenotypeArray(vcf['calldata/GT'])
    geno = geno.compress(sample_mask, axis=1)
    pos = vcf['variants/POS']
    contig = vcf['variants/CHROM']
    indel = vcf['variants/INDEL']
    
    # remove indels 
    geno = geno.compress(~indel, axis=0)
    pos = pos[~indel]
    contig = contig[~indel]
    ref = vcf['variants/REF'][~indel]
    alt = vcf['variants/ALT'][~indel]
    ann = read_ANN_field(vcf_path)[~indel]
    
    return geno, pos, contig, samples[sample_mask], ref, alt, ann

def read_ANN_field(vcf_file):
    anns = []
    with open(vcf_file, 'r') as f:
        for line in f:
            if line.startswith('#'):
                continue  # Skip header lines
            fields = line.strip().split('\t')
            info_field = fields[7]
            info_pairs = info_field.split(';')
            ann_value = None
            for pair in info_pairs:
                if pair.startswith('ANN='):
                    ann_value = pair.split('=')[1]
                    break
            anns.append(ann_value)

    return np.array(anns)

def vcf_to_snp_dataframe(vcf_path, metadata):

    geno, pos, contig, samples, ref, alt, ann = load_vcf(vcf_path=vcf_path, metadata=metadata)
    
    # make dataframe of variant positions and merge with bed
    snp_df = pd.DataFrame({'contig':contig, 'pos':pos, 'ref':ref, 'alt':[list(a[a != ""]) for a in alt], 'ann':ann})
    snp_df = snp_df.merge(bed_df)

    # filter out AIMs (will only do anything if ag-vampir)
    aim_mask = snp_df.eval("~target_id.str.contains('AIM')")
    geno = geno.compress(aim_mask, axis=0)
    snp_df = snp_df[aim_mask].reset_index(drop=True)
    
    # split multiallelic rows but store an index of the variant and alternate allele
    snp_df = snp_df.explode('alt').reset_index().rename(columns={'index':'variant_index'})
    snp_df = snp_df.assign(alt_index=snp_df.groupby('target_id').cumcount() + 1) 
    snp_df = snp_df.assign(label=lambda x: x.target_id + " | " +  x.alt.fillna('NA'))
    
    # split and find correct annotation 
    df = snp_df.assign(ann=lambda x: x.ann.str.split(","))
    anns = []
    for i, row in df.iterrows():
        alt = row['alt']
        if row['ann'] == None:
            ann = ""
        else:
            # keep only RD Vgsc annotations
            if 'AGAP004707' in ','.join(row['ann']):
                row['ann'] = [a for a in row['ann'] if "AGAP004707-RD" in a]

            ann = ','.join([a for a in row['ann'] if a.startswith(alt)])
        anns.append(ann)
    
    snp_df = snp_df.assign(ann=anns)
    
    return snp_df, geno

def calculate_frequencies_cohort(snp_df, metadata, geno, cohort_col):
    np.seterr(all="ignore")
    
    df = snp_df.copy()
    
    # get indices of each cohort
    coh_dict = {}
    cohs = metadata[cohort_col].unique()
    cohs = cohs[~pd.isnull(cohs)]
    for coh in cohs:
        coh_dict[coh] = np.where(metadata[cohort_col] == coh)[0]
    
    tot_ac = geno.count_alleles()
    total_calls = tot_ac.sum(axis=1)

    # get allele counts for each population
    ac = geno.count_alleles_subpops(coh_dict)
    
    for coh in cohs:
        total_calls_denominator = []
        total_counts = []
        alt_counts = []
        for i, row in df.iterrows():
            var_idx = row['variant_index']
            alt_idx = row['alt_index']
            total_counts.append(ac[coh][var_idx,:].sum())
            alt_counts.append(ac[coh][var_idx, alt_idx])

        df.loc[:, f'count_{coh}'] = np.array(alt_counts)
        df.loc[:, f'frq_{coh}'] = np.round(np.array(alt_counts)/np.array(total_counts), 3)
    
    freq_df = df.set_index('label').filter(like='frq')
    freq_df.columns = freq_df.columns.str.replace("frq_", "")
    
    return freq_df

def plot_allele_frequencies(df, cohort_col):
        
    fig = px.imshow(
            img=df,
            zmin=0,
            zmax=1,
            width=np.max([400, df.shape[1] * 100]),
            height=df.shape[0] * 20,
            text_auto=True,
            aspect=1,
            color_continuous_scale="Reds",
            title=f"Allele frequencies | by {cohort_col}",
        template='simple_white'
        )
    fig.update(layout_coloraxis_showscale=False)

    fig.show()

In [None]:
dataset = 'ampseq-vigg-01'
metadata_path = "../../results/config/metadata.qcpass.tsv"
cohort_cols = 'taxon,location'
bed_path = "../../config/ag-vampir.bed"
vcf_path = "../../results/vcfs/targets/ampseq-vigg01.annot.vcf"
wkdir = "../.."

### Plotting allele frequencies

This page shows allele frequencies in each cohort of the SNPs genotyped in the amplicon sequencing protocol.

In [None]:
cohort_cols = cohort_cols.split(",")

bed_df = pd.read_csv(bed_path, sep="\t", header=None, names=['contig', 'start', 'pos', 'amplicon_id', 'target_id'])

# load metadata
if metadata_path.endswith('.xlsx'):
    metadata = pd.read_excel(metadata_path, engine='openpyxl')
elif metadata_path.endswith('.tsv'):
    metadata = pd.read_csv(metadata_path, sep="\t")
elif metadata_path.endswith('.csv'):
    metadata = pd.read_csv(metadata_path, sep=",")
else:
    raise ValueError("Metadata file must be .xlsx or .csv")

In [None]:
snp_df, geno = vcf_to_snp_dataframe(vcf_path, metadata)

frq_dfs = []
for cohort_col in cohort_cols:
    
    freq_df = calculate_frequencies_cohort(
        snp_df=snp_df, 
        metadata=metadata,
        geno=geno, 
        cohort_col=cohort_col
    )
    frq_dfs.append(freq_df.reset_index(drop=True))

    plot_allele_frequencies(
        df=freq_df,
        cohort_col=cohort_col
    )

### SNP frequency summary table

In [None]:
pd.set_option("display.max_rows", 200)
pd.set_option('display.width', 1000)

ann_df = snp_df.ann.str.split("|", expand=True).iloc[:, :11].drop(columns=[0,7,8])
ann_df.columns = ['type', 'effect', 'gene', 'geneID', 'modifier', 'transcript', 'base_change', 'aa_change']

snp_df = pd.concat([snp_df[['contig', 'pos', 'ref', 'alt', 'target_id']], ann_df], axis=1)
snp_df = pd.concat([snp_df] + frq_dfs, axis=1)
snp_df.to_csv(f"{wkdir}/results/snp_frequencies_summary.tsv", sep="\t")
snp_df