In [14]:
import allel
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go  # type: ignore

def load_vcf(vcf_path, metadata):
    """
    Load VCF and filter poor-quality samples
    """
        
    # load vcf and get genotypes and positions
    vcf = allel.read_vcf(vcf_path, fields='*')
    samples = vcf['samples']
    # keep only samples in qcpass metadata 
    sample_mask = np.isin(vcf['samples'], metadata.sampleID)
    
    # remove low quality samples 
    geno = allel.GenotypeArray(vcf['calldata/GT'])
    geno = geno.compress(sample_mask, axis=1)
    pos = vcf['variants/POS']
    contig = vcf['variants/CHROM']
    indel = vcf['variants/INDEL']
    
    refs = vcf['variants/REF']
    alts = vcf['variants/ALT']
    
    # remove indels 
    geno = geno.compress(~indel, axis=0)
    pos = pos[~indel]
    contig = contig[~indel]
    refs = refs[~indel]
    alts = alts[~indel]
    
    return geno, pos, contig, samples[sample_mask], refs, alts

In [3]:
metadata_path = '../../../results/config/metadata.qcpass.tsv'
bed_targets_path = "../../../config/ag-vampir.bed"
vcf_path = "../../../results/vcfs/targets/ag-vampir-002.annot.vcf"
wkdir = "../../.."
cohort_cols = 'taxon,location'

### Species ID

In [None]:
cohort_col = cohort_cols.split(',')[0]

metadata = pd.read_csv(metadata_path , sep="\t", index_col=0)
targets = pd.read_csv(bed_targets_path, sep="\t", header=None)
targets.columns = ['contig', 'start', 'end', 'amplicon', 'mutation', 'ref', 'alt']

geno, pos, gn_contigs, samples, refs, alts = load_vcf(vcf_path=vcf_path, metadata=metadata)
alts = np.concatenate([refs.reshape(refs.shape[0], -1), alts], axis=1)

In [21]:
def _aims_n_alt(gt, aim_alts, data_alts):
    n_sites = gt.shape[0]
    n_samples = gt.shape[1]
    # create empty array
    aim_n_alt = np.empty((n_sites, n_samples), dtype=np.int8)

    # for every site
    for i in range(n_sites):
        # find the index of the correct tag snp allele
        tagsnp_index = np.where(aim_alts[i] == data_alts[i])[0]
        for j in range(n_samples):
            n_tag_alleles = np.sum(gt[i, j] == tagsnp_index[0])
            n_missing = np.sum(gt[i, j] == -1)
            if n_missing != 0:
                aim_n_alt[i,j] = -1
            else:
                aim_n_alt[i, j] = n_tag_alleles

    return aim_n_alt

In [None]:
contigs = ['2L', '2R', '3L', '3R', 'X']
df_aims = targets.query("mutation.str.contains('AIM')", engine='python')

aim_mask = np.isin(pos, df_aims.end.to_list())
aim_gn = geno.compress(aim_mask, axis=0)
aim_pos = pos[aim_mask]
aim_contigs = gn_contigs[aim_mask]
aim_alts = alts[aim_mask]

aim_loc = ["aim_" + c + ":" + str(aim_pos[i]) for i, c in enumerate(aim_contigs)]
df_aims = df_aims.assign(loc=lambda x: "aim_" + x.contig + ":" + x.end.astype(str)).set_index('loc')
df_aims = df_aims.loc[aim_loc]

aim_gn_alt = _aims_n_alt(aim_gn, aim_alts=df_aims.alt.to_list(), data_alts=aim_alts)
df_aims = pd.concat([df_aims, pd.DataFrame(aim_gn_alt, columns=samples, index=aim_loc)], axis=1)
df_aims = pd.concat([df_aims.query(f"contig == '{contig}'") for contig in contigs])
df_aims_gn = df_aims.iloc[:, 7:].T

x_label = df_aims_gn.columns.to_list()


from plotly.subplots import make_subplots
col_widths = [
    np.count_nonzero(aim_contigs == contig)
    for contig in contigs
]

fig = make_subplots(
    rows=1,
    cols=len(contigs),
    shared_yaxes=True,
    column_titles=contigs,
    row_titles=None,
    column_widths=col_widths,
    x_title=None,
    y_title=None,
    horizontal_spacing=0.01,
    vertical_spacing=0.01,
)

species = "gamb_vs_colu".split("_vs_")
# Define a colorbar.
colorbar = dict(
    title="AIM genotype",
    tickmode="array",
    tickvals=[-1, 0, 1, 2],
    ticktext=[
        "missing",
        f"{species[0]}/{species[0]}",
        f"{species[0]}/{species[1]}",
        f"{species[1]}/{species[1]}",
    ],
    len=100,
    lenmode="pixels",
    y=1,
    yanchor="top",
    outlinewidth=1,
    outlinecolor="black",
)

# Set up default AIMs color palettes.
colors = px.colors.qualitative.T10
color_gambcolu = colors[6]
color_gamb = colors[0]
color_gamb_colu_het = colors[5]
color_colu = colors[2]
color_missing = "white"
palette = (
        color_missing,
        color_gamb,
        color_gamb_colu_het,
        color_colu,
    )

colorscale = [
    [0 / 4, palette[0]],
    [1 / 4, palette[0]],
    [1 / 4, palette[1]],
    [2 / 4, palette[1]],
    [2 / 4, palette[2]],
    [3 / 4, palette[2]],
    [3 / 4, palette[3]],
    [4 / 4, palette[3]],
]

# Create the subplots, one for each contig.
for j, contig in enumerate(contigs):

    df_aims_contig = df_aims.filter(like=contig, axis=0)

    fig.add_trace(
        go.Heatmap(
            y=samples,
            z=df_aims_contig.T,
            x=df_aims_contig.index,
            colorscale=colorscale,
            zmin=-1.5,
            zmax=2.5,
            xgap=0,
            ygap=0.5,  # this creates faint lines between rows
            colorbar=colorbar,
            # hovertemplate=hovertemplate,
        ),
        row=1,
        col=j + 1,
    )

fig.update_layout(
    title=f"AIMs - gambiae vs coluzzii",
    height=max(300, 1 * len(samples) + 100),
)

fig.show()

#### Species assignments by cohorts

In [None]:
# solely based on X chromosome, all the others look unreliable
df = df_aims.query("contig == 'X'").iloc[:, 7:]
df.columns = samples
mean_aims = df.replace(-1, float('nan')).apply(np.nanmean, axis=0)
aims = mean_aims.loc[metadata.set_index('sampleID').index]
metadata = metadata.assign(mean_aim_genotype=aims.values)

taxon = []
for i, row in metadata.iterrows():
    if row.mean_aim_genotype < 0.5:
        taxon.append('gambiae')
    elif row.mean_aim_genotype >= 0.5 and row.mean_aim_genotype < 1.5:
        taxon.append('uncertain')
    elif row.mean_aim_genotype >= 1.5:
        taxon.append('coluzzii')
    else: 
        taxon.append(np.nan)
        
new_metadata = metadata.assign(taxon=taxon)

fig = px.histogram(new_metadata, nbins=100, x='mean_aim_genotype', color=cohort_col, width=750, height=400, template='plotly_white', title='AIM genotype distribution')
fig.show()

In [None]:
from IPython.display import display, Markdown
new_metadata.to_csv(f"{wkdir}/results/config/metadata.qcpass.tsv", sep="\t")
new_metadata[['sampleID', 'taxon', 'mean_aim_genotype']].to_csv(f"{wkdir}/results/ag-vampir/aims/taxon_aims.tsv", sep="\t")
display(Markdown(f'<a href={wkdir}/results/ag-vampir/aims/taxon_aims.tsv>Sample aims and taxon assignment (.tsv)</a>'))