In [None]:
import plotly.express as px
import allel
import numpy as np
import pandas as pd

def sample_diversity(geno, samples, pos):
    pis = []
    for i, sample in enumerate(samples):
        ac = geno.take([i], axis=1).count_alleles()
        pis.append(allel.sequence_diversity(ac=ac, pos=np.arange(len(pos))))

    return pd.DataFrame({'sample_id':samples, 'pi':pis})

def cohort_diversity(geno, pos, samples, metadata, cohort_col):
    cohs = metadata[cohort_col].unique()
    cohs = cohs[~pd.isnull(cohs)] #remove nan cohorts
    coh_idxs = {loc:np.where(metadata[cohort_col] == loc)[0] for loc in cohs}

    pis = []
    for coh in cohs:
        ac = geno.take(coh_idxs[coh], axis=1).count_alleles()
        pis.append(allel.sequence_diversity(ac=ac, pos=np.arange(len(pos))))   

    return pd.DataFrame({'cohort':cohs, 'pi':pis}) 

In [None]:
dataset = 'vigg-01'
metadata_path = "../../results/config/metadata.qcpass.tsv"
cohort_cols = 'location,taxon'
vcf_path = "../../results/vcfs/amplicons/ampseq-vigg-01.annot.vcf"
wkdir = "../.."

In [None]:
import sys
import os
sys.path.append(os.path.join(wkdir, 'workflow'))
import ampseekertools as amp

### Genetic diversity

This page calculates genetic diversity in individuals and cohorts. Genetic diversity (π) is a measure of nucleotide diversity that quantifies the average number of nucleotide differences per site between two sequences (Nei & Li, 1979). It provides insights into population history and evolutionary forces.

*Note*: Calculating genetic diversity from Ag-vampIR amplicons is tricky because there are so many IR amplicons, results will be biased by the presence of selective sweeps. AIMs may also not be neutral. 

In [None]:
cohort_cols = cohort_cols.split(",")

# load metadata
if metadata_path.endswith('.xlsx'):
    metadata = pd.read_excel(metadata_path, engine='openpyxl')
elif metadata_path.endswith('.tsv'):
    metadata = pd.read_csv(metadata_path, sep="\t")
elif metadata_path.endswith('.csv'):
    metadata = pd.read_csv(metadata_path, sep=",")
else:
    raise ValueError("Metadata file must be .xlsx or .csv")

geno, pos, contig, metadata, ref, alt, ann = amp.load_vcf(vcf_path, metadata, platform=platform)
samples = metadata['sample_id'].values

#### By cohort

In [None]:
for coh in cohort_cols:
    df_cohort_pi = cohort_diversity(
        geno=geno, 
        pos=pos,
        samples=samples, 
        metadata=metadata, 
        cohort_col=coh
    )
    df_cohort_pi.to_csv(f"{wkdir}/results/genetic-diversity/{coh}.pi.tsv", sep="\t")
    
    fig = px.bar(df_cohort_pi, x='cohort', y='pi', template='simple_white', width=600, height=400)
    fig.show()

In [None]:
sample_pi_df = sample_diversity(geno=geno, samples=samples, pos=pos)
sample_pi_df.to_csv(f"{wkdir}/results/genetic-diversity/samples.pi.tsv", sep="\t")