In [43]:
import allel
import numpy as np
import pandas as pd 
import plotly.express as px

def load_vcf(vcf_path, metadata, query=None):
    """
    Load VCF and filter poor-quality samples
    """
       
    # load vcf and get genotypes and positions
    vcf = allel.read_vcf(vcf_path, fields="*")
    samples = vcf['samples']
    # keep only samples in qcpass metadata 
    sample_mask = np.isin(vcf['samples'], metadata.sample_id)
    
    # remove low quality samples 
    geno = allel.GenotypeArray(vcf['calldata/GT'])
    geno = geno.compress(sample_mask, axis=1)
    pos = vcf['variants/POS']
    contig = vcf['variants/CHROM']
    indel = vcf['variants/INDEL']
    
    # remove any indels 
    geno = geno.compress(~indel, axis=0)
    pos = pos[~indel]
    contig = contig[~indel]

    metadata = metadata.set_index('sample_id')
    samples = samples[sample_mask]

    if query:
        mask = metadata.eval(query)
        metadata = metadata[mask]
        samples = samples[mask]
        geno = geno.compress(mask, axis=1)
    
    return geno, pos, contig, metadata.loc[samples, :].reset_index()

def pca(metadata, vcf_path, n_components = 6, query=None):
    """
    Load genotype data and run PCA 
    """

    geno, pos, contig, metadata = load_vcf(vcf_path, metadata, query)
    
    gn_alt = geno.to_n_alt()

    print("removing any invariant sites")
    loc_var = np.any(gn_alt != gn_alt[:, 0, np.newaxis], axis=1)
    gn_var = np.compress(loc_var, gn_alt, axis=0)
    
    coords, model = allel.pca(gn_var, n_components=n_components)
    # flip axes back so PC1 is same orientation in each window 
    for i in range(n_components):
        c = coords[:, i]
    if np.abs(c.min()) > np.abs(c.max()):
        coords[:, i] = c * -1
    
    pca_df = pd.DataFrame(coords)
    pca_df.columns = [f"PC{pc+1}" for pc in range(n_components)]
    pca_df = pd.concat([metadata, pca_df], axis=1)
    
    return pca_df, model

def plot_pca(pca_df, colour_column, cohort_columns, dataset, n_components=6, height=500, width=750):
    fig1 = px.scatter(
        pca_df, 
        x='PC1', 
        y='PC2', 
        title=f"PCA {dataset} | PC1 vs PC2 | coloured by {colour_column}", 
        color=colour_column, 
        hover_data=cohort_columns, 
        template='simple_white',
        height=height,
        width=width
    )
    
    if n_components < 4:
        return fig1, None

    fig2 = px.scatter(
        pca_df, 
        x='PC3', 
        y='PC4', 
        title=f"PCA {dataset} | PC3 vs PC4 | coloured by {colour_column}", 
        color=colour_column, 
        hover_data=cohort_columns,
        template='simple_white',
        height=height,
        width=width,
    )
    return fig1, fig2
    

In [50]:
dataset = 'ag-vampir-002'
vcf_path = f"../../results/vcfs/amplicons/{dataset}.annot.vcf"
metadata_path = "../../results/config/metadata.qcpass.tsv"
cohort_cols = 'taxon,location'

## Population structure

In this notebook, we run a principal components analysis and build a neighbour joining tree on the amplicon sequencing variant data. For the PCA, we will plot PC1 v PC2 and PC3 v PC4, and the variance explained by the model.

In [None]:
cohort_cols = cohort_cols.split(",")

if metadata_path.endswith('.xlsx'):
    metadata = pd.read_excel(metadata_path, engine='openpyxl')
elif metadata_path.endswith('.tsv'):
    metadata = pd.read_csv(metadata_path, sep="\t")
elif metadata_path.endswith('.csv'):
    metadata = pd.read_csv(metadata_path, sep=",")
else:
    raise ValueError("Metadata file must be .xlsx or .csv")

df_pca, model = pca(metadata, vcf_path, n_components=4)

#### Variance explained

As a general rule of thumb, when the variance explained for each PC begins to flatten out, that is when the PCs are no longer informative.

In [None]:
fig = px.bar(model.explained_variance_ratio_ , labels={
                     "value": "Variance Explained",
                     "index": "Principal Component",
                 }, template='simple_white', height=250, width=600)
fig.update_layout(showlegend=False)

fig.show()

### PCA

In [None]:
for coh in cohort_cols:
    fig1, fig2 = plot_pca(df_pca, colour_column=coh, cohort_columns=cohort_cols, dataset=dataset, n_components=4)
    fig1.show()
    fig2.show()

## NJT

In [54]:
import numba
from scipy.spatial.distance import squareform  # type: ignore

@numba.njit(parallel=True)
def multiallelic_diplotype_pdist(X, metric):
    """Optimised implementation of pairwise distance between diplotypes.

    N.B., here we assume the array X provides diplotypes as genotype allele
    counts, with axes in the order (n_samples, n_sites, n_alleles).

    Computation will be faster if X is a contiguous (C order) array.

    The metric argument is the function to compute distance for a pair of
    diplotypes. This can be a numba jitted function.

    """
    n_samples = X.shape[0]
    n_pairs = (n_samples * (n_samples - 1)) // 2
    out = np.zeros(n_pairs, dtype=np.float32)

    # Loop over samples, first in pair.
    for i in range(n_samples):
        x = X[i, :, :]

        # Loop over observations again, second in pair.
        for j in numba.prange(i + 1, n_samples):
            y = X[j, :, :]

            # Compute distance for the current pair.
            d = metric(x, y)

            # Store result for the current pair.
            k = square_to_condensed(i, j, n_samples)
            out[k] = d

    return out


@numba.njit
def square_to_condensed(i, j, n):
    """Convert distance matrix coordinates from square form (i, j) to condensed form."""

    assert i != j, "no diagonal elements in condensed matrix"
    if i < j:
        i, j = j, i
    return n * j - j * (j + 1) // 2 + i - 1 - j


@numba.njit
def multiallelic_diplotype_mean_cityblock(x, y):
    """Compute the mean cityblock distance between two diplotypes x and y. The
    diplotype vectors are expected as genotype allele counts, i.e., x and y
    should have the same shape (n_sites, n_alleles).

    N.B., here we compute the mean value of the distance over sites where
    both individuals have a called genotype. This avoids computing distance
    at missing sites.

    """
    n_sites = x.shape[0]
    n_alleles = x.shape[1]
    distance = np.float32(0)
    n_sites_called = np.float32(0)

    # Loop over sites.
    for i in range(n_sites):
        x_is_called = False
        y_is_called = False
        d = np.float32(0)

        # Loop over alleles.
        for j in range(n_alleles):
            # Access allele counts.
            xc = np.float32(x[i, j])
            yc = np.float32(y[i, j])

            # Check if any alleles observed.
            x_is_called = x_is_called or (xc > 0)
            y_is_called = y_is_called or (yc > 0)

            # Compute cityblock distance (absolute difference).
            d += np.fabs(xc - yc)

        # Accumulate distance for the current pair, but only if both samples
        # have a called genotype.
        if x_is_called and y_is_called:
            distance += d
            n_sites_called += np.float32(1)

    # Compute the mean distance over sites with called genotypes.
    if n_sites_called > 0:
        mean_distance = distance / n_sites_called
    else:
        mean_distance = np.nan

    return mean_distance

In [None]:
vcf_path = f"results/vcfs/amplicons/{dataset}.annot.vcf"
geno, pos, contig, df_samples = load_vcf(vcf_path, metadata=metadata)

import anjl

ac = allel.GenotypeArray(geno).to_allele_counts(max_allele=3)
X = np.ascontiguousarray(np.swapaxes(ac.values, 0, 1))

dists = multiallelic_diplotype_pdist(X, metric=multiallelic_diplotype_mean_cityblock)
dists = squareform(dists)
df_samples = df_samples.set_index('sample_id')
df = df_samples[['taxon', 'location']]

df_dist_matrix = pd.DataFrame(dists, index=df_samples.index.to_list(), columns=df_samples.index.to_list())
# pivot long 
df_dists = df_dist_matrix.stack().reset_index().set_axis('sample_id_x sample_id_y distance'.split(), axis=1)
# merge with metadata
df_dists = df_dists.merge(df, left_on='sample_id_x', right_index=True).merge(df, left_on='sample_id_y', right_index=True, suffixes=('_x', '_y'))
# remove self comparisons
df_dists = df_dists[df_dists['sample_id_x'] != df_dists['sample_id_y']]
# dedup
df_dists = df_dists.assign(dedup=np.array([''.join(sorted([a,b])) for a,b in zip(df_dists.sample_id_x, df_dists.sample_id_y)]).astype(str))
df_dists = df_dists.sort_values('sample_id_x').drop_duplicates('dedup').drop('dedup', axis=1)
# normalise distances
df_dists = df_dists.assign(location=lambda x: x.location_x + " | " + x.location_y).drop(['location_x', 'location_y'], axis=1)
df_grp_dists = df_dists.groupby('location').agg({'distance': 'mean'}).sort_values('distance').rename(columns={'distance': 'mean_distance'}).reset_index()
df_dists = df_dists.merge(df_grp_dists, on='location').assign(normalised_dist=lambda x: x.distance - x.mean_distance).sort_values('normalised_dist')
pd.set_option('display.max_rows', 200)

# get the 500 most distant samples and exclude highly irregular ones 
far_samples = df_dists.sort_values('normalised_dist', ascending=False)[:int(df_dists.shape[0] * 0.005)][['sample_id_x', 'sample_id_y']].values.flatten()
far_samples, far_counts = np.unique(far_samples, return_counts=True)
exclude_outliers = far_samples[far_counts > int(df_dists.shape[0] * 0.0005)]
print(f"excluding extreme outliers from NJT", exclude_outliers)

dists = df_dist_matrix.drop(exclude_outliers, axis=0).drop(exclude_outliers, axis=1).values
leaf_data = df_samples.query("sample_id not in @exclude_outliers").reset_index()

In [None]:
Z = anjl.dynamic_nj(dists)

fig = anjl.plot(
    Z,
    leaf_data=leaf_data,
    color="location",
    hover_name="sample_id",
    hover_data=cohort_cols,
    marker_size=8
)
fig.write_image("../../njt.png", scale=2)
fig