In [7]:
import allel
import numpy as np
import pandas as pd 
import plotly.express as px

def load_vcf(vcf_path, metadata):
    """
    Load VCF and filter poor-quality samples
    """
    
    sampleIDs = metadata.sampleID.to_list()
    
    # load vcf and get genotypes and positions
    vcf = allel.read_vcf(vcf_path, fields='*')
    samples = vcf['samples']
    # keep only samples in qcpass metadata 
    sample_mask = np.isin(vcf['samples'], metadata.sampleID)
    
    # remove low quality samples 
    geno = allel.GenotypeArray(vcf['calldata/GT'])
    geno = geno.compress(sample_mask, axis=1)
    pos = vcf['variants/POS']
    contig = vcf['variants/CHROM']
    indel = vcf['variants/INDEL']
    
    # remove indels 
    geno = geno.compress(~indel, axis=0)
    pos = pos[~indel]
    contig = contig[~indel]
    
    return geno, pos, contig, samples[sample_mask]

def pca(metadata_path, vcf_path, n_components = 6):
    """
    Load genotype data and run PCA 
    """
    if metadata_path.endswith('.xlsx'):
        metadata = pd.read_excel(metadata_path, engine='openpyxl')
    elif metadata_path.endswith('.tsv'):
        metadata = pd.read_csv(metadata_path, sep="\t")
    elif metadata_path.endswith('.csv'):
        metadata = pd.read_csv(metadata_path, sep=",")
    else:
        raise ValueError("Metadata file must be .xlsx or .csv")

    geno, pos, contig, samples = load_vcf(vcf_path, metadata)
    
    ac = geno.count_alleles()
    gn_alt = geno.to_n_alt()

    print("removing any invariant sites")
    loc_var = np.any(gn_alt != gn_alt[:, 0, np.newaxis], axis=1)
    gn_var = np.compress(loc_var, gn_alt, axis=0)
    
    coords, model = allel.pca(gn_var, n_components=n_components)
    # flip axes back so PC1 is same orientation in each window 
    for i in range(n_components):
        c = coords[:, i]
    if np.abs(c.min()) > np.abs(c.max()):
        coords[:, i] = c * -1
    
    pca_df = pd.DataFrame(coords)
    pca_df.columns = [f"PC{pc+1}" for pc in range(n_components)]
    pca_df = pd.concat([metadata, pca_df], axis=1)
    
    return pca_df, model

def plot_pca(pca_df, colour_column, cohort_columns):
    fig1 = px.scatter(
        pca_df, 
        x='PC1', 
        y='PC2', 
        title=f"PCA {dataset} | PC1 vs PC2 | coloured by {colour_column}", 
        color=colour_column, 
        hover_data=cohort_columns, 
        template='simple_white'
    )
    
    fig2 = px.scatter(
        pca_df, 
        x='PC3', 
        y='PC4', 
        title=f"PCA {dataset} | PC3 vs PC4 | coloured by {colour_column}", 
        color=colour_column, 
        hover_data=cohort_columns,
        template='simple_white'
    )
    return fig1, fig2
    

In [None]:
dataset = 'ampseq-vigg-01'
vcf_path = f"../../results/vcfs/targets/{dataset}.annot.vcf"
metadata_path = "../../results/config/metadata.qcpass.tsv"
cohort_cols = 'taxon,location'

## PCA

In this notebook, we run a principal components analysis on the amplicon sequencing variant data, plotting PC1 v PC2 and PC3 v PC4, and the variance explained by the model.

In [None]:
cohort_cols = cohort_cols.split(",")
pca_df, model = pca(metadata_path, vcf_path)

### Variance explained

As a general rule of thumb, when the variance explained for each PC begins to flatten out, that is when the PCs are no longer informative.

In [None]:
fig = px.bar(model.explained_variance_ratio_ , labels={
                     "value": "Variance Explained",
                     "index": "Principal Component",
                 }, template='simple_white')
fig.update_layout(showlegend=False)

fig.show()

### PCA

In [None]:
for coh in cohort_cols:
    fig1, fig2 = plot_pca(pca_df, colour_column=coh, cohort_columns=cohort_cols)
    fig1.show()
    fig2.show()