In [None]:
import allel
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go  # type: ignore

import warnings
warnings.filterwarnings('ignore')


In [None]:
metadata_path = '../../../results/config/metadata.qcpass.tsv'
bed_targets_path = "../../../config/ag-vampir.bed"
vcf_path = "../../..//results/vcfs/targets/lab-strains.annot.vcf"
wkdir = "../../.."
cohort_cols = 'location,taxon'
platform = 'illumina'


In [None]:
import os
import sys

sys.path.append(os.path.join(wkdir, 'workflow/lib'))
import ampseeker as amp


### Taxon classification

In this analysis, we use three separate methods (AIMs, PCA, XGboost) to assign taxon to amplicon samples. For the latter two methods, we integrate data from over 7000 public accessible WGS samples from the [Vector Observatory](https://www.malariagen.net/vobs/).

Accurate species identification is essential for mosquito surveillance as different members of the Anopheles gambiae complex have distinct ecological niches and vectorial capacities (Coetzee et al., 2013).

#### Ancestry informative marker heatmap

In [None]:
cohort_cols_list = [c.strip() for c in cohort_cols.split(',') if c.strip()]
cohort_col = cohort_cols_list[0]

metadata = pd.read_csv(metadata_path , sep="\t")

# Avoid grouping and counting on the same taxon column in summary tables.
summary_cohort_col = cohort_col
if summary_cohort_col == 'taxon':
    summary_cohort_col = next((c for c in cohort_cols_list if c != 'taxon' and c in metadata.columns), 'sample_id')

import json
with open(f"{wkdir}/results/config/metadata_colours.json", 'r') as f:
    color_mapping = json.load(f)
    
targets = pd.read_csv(bed_targets_path, sep="\t", header=None)
targets.columns = ['contig', 'start', 'end', 'amplicon', 'mutation', 'ref', 'alt']

gn_amp, pos, gn_contigs, metadata, refs, alts_amp, ann = amp.load_variants(vcf_path=vcf_path, metadata=metadata, platform=platform, filter_indel=True)
samples = metadata['sample_id'].to_list()

alts_amp = np.concatenate([refs.reshape(refs.shape[0], -1), alts_amp], axis=1)


In [None]:
contigs = ['2R', '2L', '3R', '3L', 'X']
df_aims = targets.query("mutation.str.contains('AIM')", engine='python')

aim_mask = np.isin(pos, df_aims.end.to_list())
aim_gn = gn_amp.compress(aim_mask, axis=0)
aim_pos = pos[aim_mask]
aim_contigs = gn_contigs[aim_mask]
aim_alts = alts_amp[aim_mask]

aim_loc = ["aim_" + c + ":" + str(aim_pos[i]) for i, c in enumerate(aim_contigs)]
df_aims = df_aims.assign(loc=lambda x: "aim_" + x.contig + ":" + x.end.astype(str)).set_index('loc')
df_aims = df_aims.loc[aim_loc]

aim_gn_alt = amp._aims_n_alt(aim_gn, aim_alts=df_aims.alt.to_list(), data_alts=aim_alts)
df_aims = pd.concat([df_aims, pd.DataFrame(aim_gn_alt, columns=samples, index=aim_loc)], axis=1)
df_aims = pd.concat([df_aims.query(f"contig == '{contig}'") for contig in contigs])

# sort by cohort_col and then within that, by aim fraction 
aimplot_sample_order = []
for coh in metadata[cohort_col].unique():
    coh_samples = metadata.query(f"{cohort_col} == '{coh}'").sample_id.to_list()
    coh_samples_aim_order = df_aims.iloc[:, 7:].loc[:, coh_samples].replace({-1: np.nan}).mean().sort_values(ascending=True).index.to_list()
    aimplot_sample_order.extend(coh_samples_aim_order)

# exclude samples with missing data
# n_missing = df_aims.replace({-1: np.nan}).iloc[:, 7:].isna().sum(axis=0).sort_values(ascending=False)
# missing_samples = n_missing[n_missing > 20].index.to_list()
# aimplot_sample_order = [s for s in aimplot_sample_order if s not in missing_samples]
from plotly.subplots import make_subplots
col_widths = [
    np.count_nonzero(aim_contigs == contig)
    for contig in contigs
]

fig = make_subplots(
    rows=1,
    cols=len(contigs),
    shared_yaxes=True,
    column_titles=contigs,
    row_titles=None,
    column_widths=col_widths,
    x_title=None,
    y_title=None,
    horizontal_spacing=0.01,
    vertical_spacing=0.01,
)

species = "gamb_vs_colu".split("_vs_")
# Define a colorbar.
colorbar = dict(
    title="AIM genotype",
    tickmode="array",
    tickvals=[-1, 0, 1, 2],
    ticktext=[
        "missing",
        f"{species[0]}/{species[0]}",
        f"{species[0]}/{species[1]}",
        f"{species[1]}/{species[1]}",
    ],
    len=100,
    lenmode="pixels",
    y=1,
    yanchor="top",
    outlinewidth=1,
    outlinecolor="black",
)

# Set up default AIMs color palettes.
colors = px.colors.qualitative.T10
color_gambcolu = colors[6]
color_gamb = colors[0]
color_gamb_colu_het = colors[5]
color_colu = colors[2]
color_missing = "white"
palette = (
        color_missing,
        color_gamb,
        color_gamb_colu_het,
        color_colu,
    )

colorscale = [
    [0 / 4, palette[0]],
    [1 / 4, palette[0]],
    [1 / 4, palette[1]],
    [2 / 4, palette[1]],
    [2 / 4, palette[2]],
    [3 / 4, palette[2]],
    [3 / 4, palette[3]],
    [4 / 4, palette[3]],
]

# Create the subplots, one for each contig.
for j, contig in enumerate(contigs):

    df_aims_contig = df_aims.filter(like=contig, axis=0)
    df_aims_contig = df_aims_contig.iloc[:, 7:]  
    df_aims_contig = df_aims_contig.loc[:, aimplot_sample_order]
    df_aims_contig = df_aims_contig.T

    fig.add_trace(
        go.Heatmap(
            y=df_aims_contig.index,
            z=df_aims_contig,
            x=df_aims_contig.columns,
            colorscale=colorscale,
            zmin=-1.5,
            zmax=2.5,
            xgap=0,
            ygap=0.5,  # this creates faint lines between rows
            colorbar=colorbar,
        ),
        row=1,
        col=j + 1,
    )

fig.update_layout(
    title=f"AIMs - gambiae vs coluzzii",
    height=max(600, 1.2 * len(samples) + 300),
    width=800,
)
fig.write_image(f"{wkdir}/results/aims_gamb_vs_colu.png", scale=2)

fig.show()


#### By cohort

Ancestry Informative Markers (AIMs) are genomic variants with large allele frequency differences between populations, making them useful for species identification and detecting hybridization (Rosenberg et al., 2003).

In [None]:
df = df_aims.iloc[:, 7:]
df.columns = samples

## Use only chrom 3 and X aims for taxon 
x_3_mask = df.index.str.contains("aim_3|aim_X")
df = df[x_3_mask]

mean_aims = df.replace(-1, float('nan')).apply(np.nanmean, axis=0)
max_missing_aims = 12
mean_aims[df.replace(-1, float('nan')).isna().sum(axis=0) > max_missing_aims] = np.nan
aims = mean_aims.loc[metadata.set_index('sample_id').index]
metadata = metadata.assign(mean_aim_genotype=aims.values)

taxon = []
for i, row in metadata.iterrows():
    if row.mean_aim_genotype == np.nan:
        taxon.append('uncertain')
    elif row.mean_aim_genotype < 0.5:
        taxon.append('gambiae')
    elif row.mean_aim_genotype >= 0.5 and row.mean_aim_genotype < 1.5:
        taxon.append('unassigned')
    elif row.mean_aim_genotype >= 1.5:
        taxon.append('coluzzii')
    else: 
        taxon.append(np.nan)

new_metadata = metadata.assign(aim_taxon=taxon)

fig = px.histogram(
    new_metadata,
    nbins=100, 
    x='mean_aim_genotype', 
    color=cohort_col, 
    color_discrete_map=color_mapping[cohort_col],
    width=750, 
    height=400, 
    template='plotly_white', 
    title='AIM genotype distribution'
    )
fig.show()


In [None]:
from IPython.display import display, Markdown
new_metadata[['sample_id', 'aim_taxon', 'mean_aim_genotype']].to_csv(f"{wkdir}/results/ag-vampir/aims/taxon_aims.tsv", sep="\t", index=False)
display(Markdown(f'<a href={wkdir}/results/ag-vampir/aims/taxon_aims.tsv>Sample gambiae vs coluzzii AIMs and AIM taxon assignment (.tsv)</a>'))


### PCA-based taxon classification

In this method, we use perform PCA and then apply a classifier to predict taxon based on VObs reference data. 

Principal Component Analysis reduces the dimensionality of genetic data while preserving patterns that differentiate species, enabling classification through machine learning algorithms (Jombart et al., 2010).

In [None]:
### Load VObs PCA data
ampseq_loci = np.load(f"{wkdir}/resources/ag-vampir/agvampir-ampseq-loci.npy")
alts_wgs = np.load(f"{wkdir}/resources/ag-vampir/agvampir-alts-wgs.npy")
gn_wgs = np.load(f"{wkdir}//resources/ag-vampir/agvampir-gn-wgs.npy")

df_wgs_samples = pd.read_csv(f"{wkdir}/resources/ag-vampir/agvampir-df-samples.csv", index_col=0)
df_wgs_samples.shape
    
targets = pd.read_csv(bed_targets_path, sep="\t", header=None)
targets.columns = ['contig', 'start', 'end', 'amplicon', 'mutation', 'ref', 'alt']

gn_amp, pos, gn_contigs, metadata, refs, alts_amp, ann = amp.load_variants(vcf_path=vcf_path, metadata=new_metadata, platform=platform, filter_indel=True)
samples = metadata['sample_id'].to_list()

alts = np.concatenate([refs.reshape(refs.shape[0], -1), alts_amp], axis=1)

wgs_pos  = np.array([int(s.split(":")[1].split("-")[0]) for s in ampseq_loci])

wgs_dict = {value: idx for idx, value in enumerate(wgs_pos)}
indices = [wgs_dict.get(value, -1) for value in pos]
valid_indices = [idx for idx in indices if idx != -1]
wgs_pos = np.array([wgs_pos[i] for i in valid_indices])
gn_wgs = gn_wgs.take(valid_indices, axis=0)
alts_wgs = alts_wgs.take(valid_indices, axis=0)

assert all(np.array(wgs_pos) == pos)


In [None]:
import allel
import numpy as np
import pandas as pd
import plotly.express as px


In [None]:
pca_df = amp.pca_all_samples(gn_wgs_flt, gn_amp_flt, df_wgs_samples_flt, metadata, n_components=6, cohort_cols=cohort_cols)

df_assignments = amp.assign_taxa(pca_df, method="svm", n_neighbors=5, 
                           probability_threshold=0.8)

metadata = metadata.assign(pca_taxon=df_assignments.predicted_taxon.to_numpy())

pca_3d = amp.plot_pca_3d_with_assignments(
        pca_df, 
        df_assignments, 
        title="3D PCA with Taxon Assignments"
)

pca_3d


### Machine learning based taxon classification

In this method, we use an XGBoost algorithm to predict species based on the VObs reference data. 

XGBoost is a gradient boosting framework that uses decision trees to identify the most informative features for classification, offering high accuracy for species assignment (Chen & Guestrin, 2016).

In [None]:
import pickle
import xgboost 

with open(f"{wkdir}/resources/ag-vampir/xgboost_taxon.pickle", 'rb') as file:
        tree = pickle.load(file)
taxon_tree_features = tree['feature_importance'].sort_index().SNP.to_list()
taxon_tree_feat_pos = np.unique([int(t.split(":")[1].split("-")[1]) for t in taxon_tree_features])

aim_gn = gn_amp.copy()
aim_pos = pos.copy()


In [None]:
#All nucleotides that should be in each row
all_nucleotides = set(['A', 'C', 'G', 'T'])

arr = alts.copy()

# Fill in the missing values
for i in range(arr.shape[0]):
    # Get the non-empty values in the current row
    present = set([val for val in arr[i] if val != ''])
    
    # Find the missing nucleotides
    missing = all_nucleotides - present
    
    # Fill in the missing values
    j = 0
    for k in range(arr.shape[1]):
        if arr[i, k] == '':
            if j < len(missing):
                arr[i, k] = list(missing)[j]
                j += 1


In [None]:
gn_counts = amp._melt_gt_counts(gn_amp.to_allele_counts(max_allele=3).values)

alts_flat = arr[:, 1:].flatten()
pos_flat = np.repeat(pos, 3)

labels = np.array([f"{np.repeat(gn_contigs, 3)[i]}:{pos_flat[i]}-{pos_flat[i]}:{alts_flat[i]}" for i in range(len(alts_flat))])

df_vampir = pd.DataFrame(gn_counts, index=labels, columns=samples)
missing_features = [feature for feature in taxon_tree_features if feature not in df_vampir.index]
if missing_features:
    missing_df = pd.DataFrame(0, index=missing_features, columns=df_vampir.columns)
    df_vampir = pd.concat([df_vampir, missing_df])
    print(f"WARNING: The following features were missing and artificially added with zeros: {missing_features}")
    print("Taxon assignment using XGBoost may not be accurate due to these missing features.")

df_vampir = df_vampir.loc[taxon_tree_features, :]

taxon_labels = list(tree['per_taxon_results'].keys())
pred = tree['model'].predict(df_vampir.T)
metadata = metadata.assign(tree_taxon=tree['label_encoder'].inverse_transform(pred))

pivot_table = metadata.pivot_table(
    index=summary_cohort_col,
    columns='tree_taxon',
    values='sample_id',  
    aggfunc='count',
    fill_value=0
)
pivot_table


### Consensus taxon calls summary

Consensus taxonomic assignment combines predictions from multiple methods, reducing errors and improving confidence in species identification.

In [None]:
# Apply the function to create the new column
metadata['taxon'] = metadata.apply(amp.get_consensus_taxon, axis=1)

metadata.to_csv(f"{wkdir}/results/config/metadata.qcpass.tsv", sep="\t", index=False)

pivot_table = metadata.pivot_table(
    index=summary_cohort_col,
    columns='taxon',
    values='sample_id',  
    aggfunc='count',
    fill_value=0
)
pivot_table


In [None]:
pd.set_option("display.max_rows", 500)

cohort_cols_header = cohort_cols.split(',')
if 'taxon' in cohort_cols_header:
    cohort_cols_header.remove('taxon')

show_cols = ['sample_id'] + cohort_cols_header + ['mean_aim_genotype', 'aim_taxon', 'pca_taxon', 'tree_taxon', 'taxon']

metadata[show_cols]
