In [None]:
import allel
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go  # type: ignore

import warnings
warnings.filterwarnings('ignore')

In [None]:
metadata_path = '../../../results/config/metadata.qcpass.tsv'
bed_targets_path = "../../../config/ag-vampir.bed"
vcf_path = "../../..//results/vcfs/targets/lab-strains.annot.vcf"
wkdir = "../../.."
cohort_cols = 'location,taxon'
platform = 'illumina'

In [None]:
import sys
import os
sys.path.append(os.path.join(wkdir, 'workflow'))
import ampseekertools as amp

### Taxon classification

In this analysis, we use three separate methods (AIMs, PCA, XGboost) to assign taxon to amplicon samples. For the latter two methods, we integrate data from over 7000 public accessible WGS samples from the [Vector Observatory](https://www.malariagen.net/vobs/).

Accurate species identification is essential for mosquito surveillance as different members of the Anopheles gambiae complex have distinct ecological niches and vectorial capacities (Coetzee et al., 2013).

#### Ancestry informative marker heatmap

In [None]:
cohort_col = cohort_cols.split(',')[0]

metadata = pd.read_csv(metadata_path , sep="\t")

import json
with open(f"{wkdir}/results/config/metadata_colours.json", 'r') as f:
    color_mapping = json.load(f)
    
targets = pd.read_csv(bed_targets_path, sep="\t", header=None)
targets.columns = ['contig', 'start', 'end', 'amplicon', 'mutation', 'ref', 'alt']

gn_amp, pos, gn_contigs, metadata, refs, alts_amp, ann = amp.load_vcf(vcf_path=vcf_path, metadata=metadata, platform=platform)
samples = metadata['sample_id'].to_list()

alts_amp = np.concatenate([refs.reshape(refs.shape[0], -1), alts_amp], axis=1)

In [None]:
def _aims_n_alt(gt, aim_alts, data_alts):
    n_sites = gt.shape[0]
    n_samples = gt.shape[1]
    # create empty array
    aim_n_alt = np.empty((n_sites, n_samples), dtype=np.int8)

    # for every site
    for i in range(n_sites):
        # find the index of the correct tag snp allele
        tagsnp_index = np.where(aim_alts[i] == data_alts[i])[0]
        for j in range(n_samples):
            if tagsnp_index.shape[0] == 0:
                aim_n_alt[i, j] = 0
                continue

            n_tag_alleles = np.sum(gt[i, j] == tagsnp_index[0])
            n_missing = np.sum(gt[i, j] == -1)
            if n_missing != 0:
                aim_n_alt[i,j] = -1
            else:
                aim_n_alt[i, j] = n_tag_alleles

    return aim_n_alt

contigs = ['2R', '2L', '3R', '3L', 'X']
df_aims = targets.query("mutation.str.contains('AIM')", engine='python')

aim_mask = np.isin(pos, df_aims.end.to_list())
aim_gn = gn_amp.compress(aim_mask, axis=0)
aim_pos = pos[aim_mask]
aim_contigs = gn_contigs[aim_mask]
aim_alts = alts_amp[aim_mask]

aim_loc = ["aim_" + c + ":" + str(aim_pos[i]) for i, c in enumerate(aim_contigs)]
df_aims = df_aims.assign(loc=lambda x: "aim_" + x.contig + ":" + x.end.astype(str)).set_index('loc')
df_aims = df_aims.loc[aim_loc]

aim_gn_alt = _aims_n_alt(aim_gn, aim_alts=df_aims.alt.to_list(), data_alts=aim_alts)
df_aims = pd.concat([df_aims, pd.DataFrame(aim_gn_alt, columns=samples, index=aim_loc)], axis=1)
df_aims = pd.concat([df_aims.query(f"contig == '{contig}'") for contig in contigs])

# sort by cohort_col and then within that, by aim fraction 
aimplot_sample_order = []
for coh in metadata[cohort_col].unique():
    coh_samples = metadata.query(f"{cohort_col} == '{coh}'").sample_id.to_list()
    coh_samples_aim_order = df_aims.iloc[:, 7:].loc[:, coh_samples].replace({-1: np.nan}).mean().sort_values(ascending=True).index.to_list()
    aimplot_sample_order.extend(coh_samples_aim_order)

# exclude samples with missing data
# n_missing = df_aims.replace({-1: np.nan}).iloc[:, 7:].isna().sum(axis=0).sort_values(ascending=False)
# missing_samples = n_missing[n_missing > 20].index.to_list()
# aimplot_sample_order = [s for s in aimplot_sample_order if s not in missing_samples]
from plotly.subplots import make_subplots
col_widths = [
    np.count_nonzero(aim_contigs == contig)
    for contig in contigs
]

fig = make_subplots(
    rows=1,
    cols=len(contigs),
    shared_yaxes=True,
    column_titles=contigs,
    row_titles=None,
    column_widths=col_widths,
    x_title=None,
    y_title=None,
    horizontal_spacing=0.01,
    vertical_spacing=0.01,
)

species = "gamb_vs_colu".split("_vs_")
# Define a colorbar.
colorbar = dict(
    title="AIM genotype",
    tickmode="array",
    tickvals=[-1, 0, 1, 2],
    ticktext=[
        "missing",
        f"{species[0]}/{species[0]}",
        f"{species[0]}/{species[1]}",
        f"{species[1]}/{species[1]}",
    ],
    len=100,
    lenmode="pixels",
    y=1,
    yanchor="top",
    outlinewidth=1,
    outlinecolor="black",
)

# Set up default AIMs color palettes.
colors = px.colors.qualitative.T10
color_gambcolu = colors[6]
color_gamb = colors[0]
color_gamb_colu_het = colors[5]
color_colu = colors[2]
color_missing = "white"
palette = (
        color_missing,
        color_gamb,
        color_gamb_colu_het,
        color_colu,
    )

colorscale = [
    [0 / 4, palette[0]],
    [1 / 4, palette[0]],
    [1 / 4, palette[1]],
    [2 / 4, palette[1]],
    [2 / 4, palette[2]],
    [3 / 4, palette[2]],
    [3 / 4, palette[3]],
    [4 / 4, palette[3]],
]

# Create the subplots, one for each contig.
for j, contig in enumerate(contigs):

    df_aims_contig = df_aims.filter(like=contig, axis=0)
    df_aims_contig = df_aims_contig.iloc[:, 7:]  
    df_aims_contig = df_aims_contig.loc[:, aimplot_sample_order]
    df_aims_contig = df_aims_contig.T

    fig.add_trace(
        go.Heatmap(
            y=df_aims_contig.index,
            z=df_aims_contig,
            x=df_aims_contig.columns,
            colorscale=colorscale,
            zmin=-1.5,
            zmax=2.5,
            xgap=0,
            ygap=0.5,  # this creates faint lines between rows
            colorbar=colorbar,
        ),
        row=1,
        col=j + 1,
    )

fig.update_layout(
    title=f"AIMs - gambiae vs coluzzii",
    height=max(600, 1.2 * len(samples) + 300),
    width=800,
)
fig.write_image(f"{wkdir}/results/aims_gamb_vs_colu.png", scale=2)

fig.show()

#### By cohort

Ancestry Informative Markers (AIMs) are genomic variants with large allele frequency differences between populations, making them useful for species identification and detecting hybridization (Rosenberg et al., 2003).

In [None]:
df = df_aims.iloc[:, 7:]
df.columns = samples

## Use only chrom 3 and X aims for taxon 
x_3_mask = df.index.str.contains("aim_3|aim_X")
df = df[x_3_mask]

mean_aims = df.replace(-1, float('nan')).apply(np.nanmean, axis=0)
max_missing_aims = 12
mean_aims[df.replace(-1, float('nan')).isna().sum(axis=0) > max_missing_aims] = np.nan
aims = mean_aims.loc[metadata.set_index('sample_id').index]
metadata = metadata.assign(mean_aim_genotype=aims.values)

taxon = []
for i, row in metadata.iterrows():
    if row.mean_aim_genotype == np.nan:
        taxon.append('uncertain')
    elif row.mean_aim_genotype < 0.5:
        taxon.append('gambiae')
    elif row.mean_aim_genotype >= 0.5 and row.mean_aim_genotype < 1.5:
        taxon.append('unassigned')
    elif row.mean_aim_genotype >= 1.5:
        taxon.append('coluzzii')
    else: 
        taxon.append(np.nan)

new_metadata = metadata.assign(aim_taxon=taxon)

fig = px.histogram(
    new_metadata,
    nbins=100, 
    x='mean_aim_genotype', 
    color=cohort_col, 
    color_discrete_map=color_mapping[cohort_col],
    width=750, 
    height=400, 
    template='plotly_white', 
    title='AIM genotype distribution'
    )
fig.show()

In [None]:
from IPython.display import display, Markdown
new_metadata[['sample_id', 'aim_taxon', 'mean_aim_genotype']].to_csv(f"{wkdir}/results/ag-vampir/aims/taxon_aims.tsv", sep="\t", index=False)
display(Markdown(f'<a href={wkdir}/results/ag-vampir/aims/taxon_aims.tsv>Sample gambiae vs coluzzii AIMs and AIM taxon assignment (.tsv)</a>'))

### PCA-based taxon classification

In this method, we use perform PCA and then apply a classifier to predict taxon based on VObs reference data. 

Principal Component Analysis reduces the dimensionality of genetic data while preserving patterns that differentiate species, enabling classification through machine learning algorithms (Jombart et al., 2010).

In [None]:
### Load VObs PCA data
ampseq_loci = np.load(f"{wkdir}/resources/ag-vampir/agvampir-ampseq-loci.npy")
alts_wgs = np.load(f"{wkdir}/resources/ag-vampir/agvampir-alts-wgs.npy")
gn_wgs = np.load(f"{wkdir}//resources/ag-vampir/agvampir-gn-wgs.npy")

df_wgs_samples = pd.read_csv(f"{wkdir}/resources/ag-vampir/agvampir-df-samples.csv", index_col=0)
df_wgs_samples.shape
    
targets = pd.read_csv(bed_targets_path, sep="\t", header=None)
targets.columns = ['contig', 'start', 'end', 'amplicon', 'mutation', 'ref', 'alt']

gn_amp, pos, gn_contigs, metadata, refs, alts_amp, ann = amp.load_vcf(vcf_path=vcf_path, metadata=new_metadata, platform=platform)
samples = metadata['sample_id'].to_list()

alts = np.concatenate([refs.reshape(refs.shape[0], -1), alts_amp], axis=1)

wgs_pos  = np.array([int(s.split(":")[1].split("-")[0]) for s in ampseq_loci])

wgs_dict = {value: idx for idx, value in enumerate(wgs_pos)}
indices = [wgs_dict.get(value, -1) for value in pos]
valid_indices = [idx for idx in indices if idx != -1]
wgs_pos = np.array([wgs_pos[i] for i in valid_indices])
gn_wgs = gn_wgs.take(valid_indices, axis=0)
alts_wgs = alts_wgs.take(valid_indices, axis=0)

assert all(np.array(wgs_pos) == pos)

In [None]:
import allel
import numpy as np
import pandas as pd
import plotly.express as px
from sklearn.neighbors import KNeighborsClassifier

def pca_all_samples(gn_wgs, gn_amp, df_wgs_samples, df_amp_samples, n_components=6):
    """
    Perform PCA on both WGS and amplicon samples combined
    
    Parameters:
    -----------
    gn_wgs : GenotypeArray
        Genotypes of samples with known taxon
    gn_amp : GenotypeArray
        Genotypes of samples to be assigned
    df_wgs_samples : DataFrame
        Metadata for samples with known taxon, including 'taxon' column
    n_components : int
        Number of principal components to compute
        
    Returns:
    --------
    pca_df : DataFrame
        DataFrame containing PCA coordinates and metadata for all samples
    """
    # Create metadata for amplicon samples
    df_amp_samples = df_amp_samples.assign(sample_type='amplicon')
    
    # Add sample type to WGS metadata
    df_wgs_samples = df_wgs_samples.copy()
    df_wgs_samples = df_wgs_samples.assign(sample_type='WGS')
    # Combine metadata
    df_all_samples = pd.concat([df_wgs_samples.reset_index(), df_amp_samples[['sample_id', 'sample_type'] + cohort_cols.split(",")]], ignore_index=True)    
    
    # Combine genotype data
    print(f"Combining {gn_wgs.shape[1]} WGS samples and {gn_amp.shape[1]} amplicon samples")
    combined_geno = allel.GenotypeArray(np.concatenate([gn_wgs[:], gn_amp[:]], axis=1))
    
    # Perform PCA
    print("Performing PCA on combined data")
    gn_alt = combined_geno.to_n_alt()
    print("Removing any invariant sites")
    loc_var = np.any(gn_alt != gn_alt[:, 0, np.newaxis], axis=1)
    gn_var = np.compress(loc_var, gn_alt, axis=0)
    
    print(f"Running PCA with {gn_var.shape[0]} variable sites and {gn_var.shape[1]} samples")
    coords, model = allel.pca(gn_var, n_components=n_components)
    
    # Flip axes back so PC1 is same orientation in each window
    for i in range(n_components):
        c = coords[:, i]
        if np.abs(c.min()) > np.abs(c.max()):
            coords[:, i] = c * -1
    
    # Create PCA DataFrame
    pca_df = pd.DataFrame(coords)
    pca_df.columns = [f"PC{pc+1}" for pc in range(n_components)]
    pca_df = pd.concat([df_all_samples.reset_index(drop=True), pca_df], axis=1)
    
    print("PCA completed successfully")
    return pca_df

def assign_taxa(pca_df, method='knn', n_neighbors=5, probability_threshold=0.8, **kwargs):
    """
    Assign taxa to amplicon samples based on PC1-4 coordinates
    
    Parameters:
    -----------
    pca_df : DataFrame
        DataFrame with PCA coordinates and metadata for all samples
    method : str
        Classification method to use ('knn' or 'svm')
    n_neighbors : int
        Number of neighbors to use for KNN classification (ignored if method='svm')
    probability_threshold : float
        Minimum probability required for taxon assignment
    **kwargs : dict
        Additional parameters for the classifier
    
    Returns:
    --------
    assignment_df : DataFrame
        DataFrame with taxon assignments for amplicon samples
    """
    from sklearn.svm import SVC
    
    # Separate training (WGS) and test (amplicon) data
    wgs_samples = pca_df[pca_df['sample_type'] == 'WGS']
    amp_samples = pca_df[pca_df['sample_type'] == 'amplicon']
    
    print(f"Training data: {len(wgs_samples)} WGS samples")
    print(f"Test data: {len(amp_samples)} amplicon samples")
    
    # Extract features for training (PC1-4)
    pc_features = ['PC1', 'PC2', 'PC3', 'PC4']
    X_train = wgs_samples[pc_features].values
    y_train = wgs_samples['taxon'].values
    
    # Extract features for testing
    X_test = amp_samples[pc_features].values
    
    # Initialize classifier based on method
    if method.lower() == 'knn':
        print(f"Training KNN classifier with {n_neighbors} neighbors")
        classifier = KNeighborsClassifier(n_neighbors=n_neighbors, **kwargs)
    elif method.lower() == 'svm':
        print("Training SVM classifier")
        # Ensure probability=True for SVM to get prediction probabilities
        svm_kwargs = {'probability': True}
        svm_kwargs.update(kwargs)
        classifier = SVC(**svm_kwargs)
    else:
        raise ValueError(f"Unsupported method: {method}. Choose 'knn' or 'svm'")
    
    # Train classifier
    classifier.fit(X_train, y_train)
    
    # Predict taxa
    predictions = classifier.predict(X_test)
    probabilities = classifier.predict_proba(X_test)
    max_probs = np.max(probabilities, axis=1)
    
    # Create assignment DataFrame
    assignment_df = amp_samples[['sample_id']].copy()
    
    # Apply probability threshold for assignment
    assignment_df['predicted_taxon'] = ['unassigned'] * len(amp_samples)
    assignment_df['probability'] = 0.0
    assignment_df['classifier'] = method.lower()
    
    for i, (pred, prob) in enumerate(zip(predictions, max_probs)):
        if prob >= probability_threshold:
            assignment_df.loc[assignment_df.index[i], 'predicted_taxon'] = pred
            assignment_df.loc[assignment_df.index[i], 'probability'] = prob
        else:
            # For low confidence assignments, still store the prediction but mark as unassigned
            assignment_df.loc[assignment_df.index[i], 'low_confidence_prediction'] = pred
            assignment_df.loc[assignment_df.index[i], 'low_confidence_probability'] = prob
    
    # Count assignments
    assigned_count = sum(assignment_df['predicted_taxon'] != 'unassigned')
    print(f"Assigned {assigned_count} out of {len(amp_samples)} samples with confidence â‰¥ {probability_threshold}")
    
    if assigned_count < len(amp_samples):
        print(f"{len(amp_samples) - assigned_count} samples were below the confidence threshold")
    
    # Print summary of assignments
    print("\nTaxon assignment summary:")
    print(assignment_df['predicted_taxon'].value_counts())
    
    return assignment_df

def plot_pca_3d_with_assignments(pca_df, assignment_df, title="PCA with Taxon Assignments", 
                              height=800, width=800):
    """
    Create a 3D plot of PCA results with taxon assignments
    
    Parameters:
    -----------
    pca_df : DataFrame
        DataFrame with PCA coordinates and metadata
    assignment_df : DataFrame
        DataFrame with taxon assignments for amplicon samples
    title : str
        Plot title
    height, width : int
        Plot dimensions
    
    Returns:
    --------
    plotly.graph_objects.Figure
    """
    import plotly.graph_objects as go
    
    # Create a copy of the PCA DataFrame for visualization
    plot_df = pca_df.copy()
    
    # Update amplicon samples with assigned taxa
    for i, row in assignment_df.iterrows():
        sample_id = row['sample_id']
        if row['predicted_taxon'] != 'unassigned':
            # Update the taxon for this sample
            plot_df.loc[plot_df['sample_id'] == sample_id, 'display_taxon'] = row['predicted_taxon'] + " (predicted)"
        else:
            # Mark as unassigned
            plot_df.loc[plot_df['sample_id'] == sample_id, 'display_taxon'] = 'unassigned'
    
    # For WGS samples, use the known taxon but add a "reference" label
    plot_df.loc[plot_df['sample_type'] == 'WGS', 'display_taxon'] = plot_df.loc[plot_df['sample_type'] == 'WGS', 'taxon'] + " (VObs reference)"
    
    # Split the data by sample type for different marker styles
    wgs_samples = plot_df[plot_df['sample_type'] == 'WGS']
    amp_samples = plot_df[plot_df['sample_type'] == 'amplicon']
    
    # Create an empty figure
    fig = go.Figure()
    
    # Add reference (WGS) samples - use circles
    for taxon in wgs_samples['display_taxon'].unique():
        subset = wgs_samples[wgs_samples['display_taxon'] == taxon]
        fig.add_trace(go.Scatter3d(
            x=subset['PC1'],
            y=subset['PC2'],
            z=subset['PC3'],
            mode='markers',
            marker=dict(
                size=5,
                symbol='circle',
                opacity=0.7
            ),
            name=taxon,
            hovertemplate="<b>%{text}</b><br>PC1: %{x:.2f}<br>PC2: %{y:.2f}<br>PC3: %{z:.2f}<br>Sample Type: Reference<br>Taxon: " + taxon.replace(" (reference)", "") + "<extra></extra>",
            text=subset['sample_id']
        ))
    
    # Add amplicon samples - use diamonds/cross for greater visibility
    for taxon in amp_samples['display_taxon'].unique():
        subset = amp_samples[amp_samples['display_taxon'] == taxon]
        
        # Skip unassigned for a moment
        if taxon == 'unassigned':
            continue
        
        # For assigned samples, display probability
        probs = []
        for sid in subset['sample_id']:
            prob = assignment_df.loc[assignment_df['sample_id'] == sid, 'probability'].values[0]
            probs.append(f"{prob:.2f}")
        
        fig.add_trace(go.Scatter3d(
            x=subset['PC1'],
            y=subset['PC2'],
            z=subset['PC3'],
            mode='markers',
            marker=dict(
                size=7,
                symbol='diamond',
                opacity=0.9
            ),
            name=taxon,
            hovertemplate="<b>%{text}</b><br>PC1: %{x:.2f}<br>PC2: %{y:.2f}<br>PC3: %{z:.2f}<br>Sample Type: Amplicon<br>Assigned Taxon: " + taxon.replace(" (predicted)", "") + "<br>Probability: %{customdata}<extra></extra>",
            text=subset['sample_id'],
            customdata=probs
        ))
    
    # Add unassigned amplicon samples
    unassigned = amp_samples[amp_samples['display_taxon'] == 'unassigned']
    if len(unassigned) > 0:
        # For unassigned, show low confidence prediction if available
        hover_data = []
        for sid in unassigned['sample_id']:
            row = assignment_df.loc[assignment_df['sample_id'] == sid].iloc[0]
            if 'low_confidence_prediction' in row and not pd.isna(row['low_confidence_prediction']):
                hover_data.append(f"{row['low_confidence_prediction']} ({row['low_confidence_probability']:.2f})")
            else:
                hover_data.append("No prediction")
        
        fig.add_trace(go.Scatter3d(
            x=unassigned['PC1'],
            y=unassigned['PC2'],
            z=unassigned['PC3'],
            mode='markers',
            marker=dict(
                size=7,
                symbol='x',
                color='gray',
                opacity=0.7
            ),
            name='unassigned',
            hovertemplate="<b>%{text}</b><br>PC1: %{x:.2f}<br>PC2: %{y:.2f}<br>PC3: %{z:.2f}<br>Sample Type: Amplicon<br>Status: Unassigned<br>Low confidence: %{customdata}<extra></extra>",
            text=unassigned['sample_id'],
            customdata=hover_data
        ))
    
    # Improve layout
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title='PC1',
            yaxis_title='PC2',
            zaxis_title='PC3'
        ),
        height=height,
        width=width,
        margin=dict(l=0, r=0, b=0, t=40),
        legend_title_text='Taxa'
    )
    
    return fig

mapping = allel.create_allele_mapping(alts_wgs[:, 0], alt=alts_amp, alleles=alts_wgs)
gn_amp_remap = gn_amp.map_alleles(mapping)
gn_amp_remap.shape

snp_miss_mask = gn_amp_remap.is_missing().sum(axis=1) > 40
snp_miss_mask.sum()

gn_amp_flt = gn_amp_remap.compress(~snp_miss_mask, axis=0)
gn_wgs_flt = gn_wgs.compress(~snp_miss_mask, axis=0)

mask = df_wgs_samples.eval("taxon not in ['unassigned', 'fontenillei', 'gcx4', 'quadriannulatus', 'merus', 'melas']")
df_wgs_samples_flt = df_wgs_samples[mask]
gn_wgs_flt = gn_wgs_flt.compress(mask, axis=1)

In [None]:
pca_df = pca_all_samples(gn_wgs_flt, gn_amp_flt, df_wgs_samples_flt, metadata, n_components=6)

df_assignments = assign_taxa(pca_df, method="svm", n_neighbors=5, 
                           probability_threshold=0.8)

metadata = metadata.assign(pca_taxon=df_assignments.predicted_taxon.to_numpy())

pca_3d = plot_pca_3d_with_assignments(
        pca_df, 
        df_assignments, 
        title="3D PCA with Taxon Assignments"
)

pca_3d

### Machine learning based taxon classification

In this method, we use an XGBoost algorithm to predict species based on the VObs reference data. 

XGBoost is a gradient boosting framework that uses decision trees to identify the most informative features for classification, offering high accuracy for species assignment (Chen & Guestrin, 2016).

In [None]:
import numba
@numba.jit(nopython=True)
def _melt_gt_counts(gt_counts):
    n_snps, n_samples, n_alleles = gt_counts.shape
    # Use a float array to allow NaN values
    melted_counts = np.full((n_snps * (n_alleles - 1), n_samples), np.nan, dtype=np.float64)

    for i in range(n_snps):
        for j in range(n_samples):
            for k in range(n_alleles - 1):
                # Check if the genotype count is valid (== 2)
                if gt_counts[i][j].sum() == 2:
                    melted_counts[(i * (n_alleles - 1)) + k][j] = gt_counts[i][j][k + 1]
                else:
                    # Assign NaN for missing or invalid data
                    melted_counts[(i * (n_alleles - 1)) + k][j] = np.nan

    return melted_counts

import pickle
import xgboost 

with open(f"{wkdir}/resources/ag-vampir/xgboost_taxon.pickle", 'rb') as file:
        tree = pickle.load(file)
taxon_tree_features = tree['feature_importance'].sort_index().SNP.to_list()
taxon_tree_feat_pos = np.unique([int(t.split(":")[1].split("-")[1]) for t in taxon_tree_features])

aim_gn = gn_amp.copy()
aim_pos = pos.copy()

In [None]:
#All nucleotides that should be in each row
all_nucleotides = set(['A', 'C', 'G', 'T'])

arr = alts.copy()

# Fill in the missing values
for i in range(arr.shape[0]):
    # Get the non-empty values in the current row
    present = set([val for val in arr[i] if val != ''])
    
    # Find the missing nucleotides
    missing = all_nucleotides - present
    
    # Fill in the missing values
    j = 0
    for k in range(arr.shape[1]):
        if arr[i, k] == '':
            if j < len(missing):
                arr[i, k] = list(missing)[j]
                j += 1

In [None]:
gn_counts = _melt_gt_counts(gn_amp.to_allele_counts(max_allele=3).values)

alts_flat = arr[:, 1:].flatten()
pos_flat = np.repeat(pos, 3)

labels = np.array([f"{np.repeat(gn_contigs, 3)[i]}:{pos_flat[i]}-{pos_flat[i]}:{alts_flat[i]}" for i in range(len(alts_flat))])

df_vampir = pd.DataFrame(gn_counts, index=labels, columns=samples)
missing_features = [feature for feature in taxon_tree_features if feature not in df_vampir.index]
if missing_features:
    missing_df = pd.DataFrame(0, index=missing_features, columns=df_vampir.columns)
    df_vampir = pd.concat([df_vampir, missing_df])
    print(f"WARNING: The following features were missing and artificially added with zeros: {missing_features}")
    print("Taxon assignment using XGBoost may not be accurate due to these missing features.")

df_vampir = df_vampir.loc[taxon_tree_features, :]

taxon_labels = list(tree['per_taxon_results'].keys())
pred = tree['model'].predict(df_vampir.T)
metadata = metadata.assign(tree_taxon=tree['label_encoder'].inverse_transform(pred))

pivot_table = metadata.pivot_table(
    index=cohort_col,
    columns='tree_taxon',
    values='sample_id',  
    aggfunc='count',
    fill_value=0
)
pivot_table

### Consensus taxon calls summary

Consensus taxonomic assignment combines predictions from multiple methods, reducing errors and improving confidence in species identification.

In [None]:
def get_consensus_taxon(row):
    # Extract the taxon values
    aim_taxon = row['aim_taxon']
    pca_taxon = row['pca_taxon']
    tree_taxon = row['tree_taxon']
    
    # Count occurrences of each taxon
    taxon_counts = {}
    for taxon in [aim_taxon, pca_taxon, tree_taxon]:
        if pd.notna(taxon):  # Skip NaN values
            taxon_counts[taxon] = taxon_counts.get(taxon, 0) + 1
    
    # Find the most common taxon
    max_count = 0
    max_taxon = None
    for taxon, count in taxon_counts.items():
        if count > max_count:
            max_count = count
            max_taxon = taxon
    
    # If at least two columns agree, return that taxon, otherwise 'unassigned'
    if max_count >= 2:
        return max_taxon
    else:
        return 'unassigned'

# Apply the function to create the new column
metadata['taxon'] = metadata.apply(get_consensus_taxon, axis=1)

metadata.to_csv(f"{wkdir}/results/config/metadata.qcpass.tsv", sep="\t", index=False)

pivot_table = metadata.pivot_table(
    index=cohort_col,
    columns='taxon',
    values='sample_id',  
    aggfunc='count',
    fill_value=0
)
pivot_table

In [None]:
pd.set_option("display.max_rows", 500)

cohort_cols_header = cohort_cols.split(',')
if 'taxon' in cohort_cols_header:
    cohort_cols_header.remove('taxon')

show_cols = ['sample_id'] + cohort_cols_header + ['mean_aim_genotype', 'aim_taxon', 'pca_taxon', 'tree_taxon', 'taxon']

metadata[show_cols]