In [None]:
import plotly.express as px
import allel
import numpy as np
import pandas as pd

def load_vcf(vcf_path, metadata):
    """
    Load VCF and filter poor-quality samples
    """
    
    sampleIDs = metadata.sampleID.to_list()
    
    # load vcf and get genotypes and positions
    vcf = allel.read_vcf(vcf_path, fields='*')
    samples = vcf['samples']
    # keep only samples in qcpass metadata 
    sample_mask = np.isin(vcf['samples'], metadata.sampleID)
    
    # remove low quality samples 
    geno = allel.GenotypeArray(vcf['calldata/GT'])
    geno = geno.compress(sample_mask, axis=1)
    pos = vcf['variants/POS']
    contig = vcf['variants/CHROM']
    indel = vcf['variants/INDEL']
    
    # remove indels 
    geno = geno.compress(~indel, axis=0)
    pos = pos[~indel]
    contig = contig[~indel]
    
    return geno, pos, contig, samples[sample_mask]


def vcf_data_to_frequencies(metadata, bed_df, geno, contig, pos, cohort_col):
    # make dataframe of variant positions and merge with bed
    vcf_var_df = pd.DataFrame({'contig':contig, 'pos':pos})
    vcf_var_df = vcf_var_df.merge(bed_df)
    
    # filter out AIMs (will only do anything if ag-vampir)
    aim_mask = vcf_var_df.eval("~target_id.str.contains('AIM')")
    geno = geno.compress(aim_mask, axis=0)
    vcf_var_df = vcf_var_df[aim_mask]
    
    # get indices of each population
    pop_dict = {}
    pops = metadata[cohort_col].unique()
    for pop in pops:
        pop_dict[pop] = np.where(metadata[cohort_col] == pop)[0]

    # get allele counts for each population
    ac = geno.count_alleles_subpops(pop_dict)

    # convert to frequencies
    for pop in pops:
        pop_dict[pop] = ac[pop].to_frequencies()

    # make dataframe of allele frequencies
    freq_dfs = []
    for pop in pops:
        df = pd.DataFrame({'cohort':pop,
                           'mutation': vcf_var_df['target_id'],
                           'ref':pop_dict[pop][:, 0], 
                           'alt':pop_dict[pop][:, 1]})
        freq_dfs.append(df)

    # concatenate dataframes
    return pd.concat(freq_dfs)

def plot_allele_frequencies(metadata, vcf_path, bed_df, cohort_col):
    
    geno, pos, contig, samples = load_vcf(vcf_path=vcf_path, metadata=metadata)
    
    freq_df = vcf_data_to_frequencies(
        metadata=metadata, 
        bed_df=bed_df, 
        geno=geno,
        contig=contig, 
        pos=pos, 
        cohort_col=cohort_col
    )
    
    df = freq_df.drop(columns='ref').pivot(columns='cohort', index='mutation', values='alt').round(2)

    fig = px.imshow(
            img=df,
            zmin=0,
            zmax=1,
            width=df.shape[1] * 100,
            height=df.shape[0] * 20,
            text_auto=True,
            aspect=1,
            color_continuous_scale="Reds",
            title=f"{dataset} allele frequencies | by {cohort_col}",
        )
    fig.update(layout_coloraxis_showscale=False)

    fig.show()

In [None]:
dataset = 'ampseq-vigg-01'
metadata_path = "../../results/config/metadata.qcpass.tsv"
cohort_cols = 'taxon,location'
bed_path = "../../config/ag-vampir.bed"
vcf_path = "../../results/vcfs/targets/ampseq-vigg-01.annot.vcf"

### Plotting allele frequencies

This page shows allele frequencies in each cohort of the SNPs genotyped in the amplicon sequencing protocol.

In [None]:
cohort_cols = cohort_cols.split(",")

bed_df = pd.read_csv(bed_path, sep="\t", header=None)
bed_df.columns = ['contig', 'start', 'pos', 'amplicon_id', 'target_id']

# load metadata
if metadata_path.endswith('.xlsx'):
    metadata = pd.read_excel(metadata_path, engine='openpyxl')
elif metadata_path.endswith('.tsv'):
    metadata = pd.read_csv(metadata_path, sep="\t")
elif metadata_path.endswith('.csv'):
    metadata = pd.read_csv(metadata_path, sep=",")
else:
    raise ValueError("Metadata file must be .xlsx or .csv")

In [None]:
for cohort_col in cohort_cols:
    plot_allele_frequencies(
        metadata=metadata, 
        vcf_path=vcf_path, 
        bed_df=bed_df, 
        cohort_col=cohort_col
    )