In [32]:
import allel
import numpy as np
import pandas as pd 
import plotly.express as px

import warnings 
warnings.filterwarnings("ignore", category=RuntimeWarning, module="pandas.core.arraylike")
warnings.filterwarnings("ignore", category=UserWarning, module="allel.io")

def _dipclust_concat_subplots(
    figures,
    width,
    height,
    row_heights,
    title,
    xaxis_range,
):
    from plotly.subplots import make_subplots  # type: ignore
    import plotly.graph_objects as go  # type: ignore

    # make subplots
    fig = make_subplots(
        rows=len(figures),
        cols=1,
        shared_xaxes=True,
        vertical_spacing=0.02,
        row_heights=row_heights,
    )

    for i, figure in enumerate(figures):
        if isinstance(figure, go.Figure):
            # This is a figure, access the traces within it.
            for trace in range(len(figure["data"])):
                fig.append_trace(figure["data"][trace], row=i + 1, col=1)
        else:
            # Assume this is a trace, add directly.
            fig.append_trace(figure, row=i + 1, col=1)

    fig.update_xaxes(visible=False)
    fig.update_layout(
        title=title,
        width=width,
        height=height,
        hovermode="closest",
        plot_bgcolor="white",
        xaxis_range=xaxis_range,
    )

    return fig

import numba
from scipy.spatial.distance import squareform  # type: ignore

@numba.njit(parallel=True)
def multiallelic_diplotype_pdist(X, metric):
    """Optimised implementation of pairwise distance between diplotypes.

    N.B., here we assume the array X provides diplotypes as genotype allele
    counts, with axes in the order (n_samples, n_sites, n_alleles).

    Computation will be faster if X is a contiguous (C order) array.

    The metric argument is the function to compute distance for a pair of
    diplotypes. This can be a numba jitted function.

    """
    n_samples = X.shape[0]
    n_pairs = (n_samples * (n_samples - 1)) // 2
    out = np.zeros(n_pairs, dtype=np.float32)

    # Loop over samples, first in pair.
    for i in range(n_samples):
        x = X[i, :, :]

        # Loop over observations again, second in pair.
        for j in numba.prange(i + 1, n_samples):
            y = X[j, :, :]

            # Compute distance for the current pair.
            d = metric(x, y)

            # Store result for the current pair.
            k = square_to_condensed(i, j, n_samples)
            out[k] = d

    return out


@numba.njit
def square_to_condensed(i, j, n):
    """Convert distance matrix coordinates from square form (i, j) to condensed form."""

    assert i != j, "no diagonal elements in condensed matrix"
    if i < j:
        i, j = j, i
    return n * j - j * (j + 1) // 2 + i - 1 - j


@numba.njit
def multiallelic_diplotype_mean_cityblock(x, y):
    """Compute the mean cityblock distance between two diplotypes x and y. The
    diplotype vectors are expected as genotype allele counts, i.e., x and y
    should have the same shape (n_sites, n_alleles).

    N.B., here we compute the mean value of the distance over sites where
    both individuals have a called genotype. This avoids computing distance
    at missing sites.

    """
    n_sites = x.shape[0]
    n_alleles = x.shape[1]
    distance = np.float32(0)
    n_sites_called = np.float32(0)

    # Loop over sites.
    for i in range(n_sites):
        x_is_called = False
        y_is_called = False
        d = np.float32(0)

        # Loop over alleles.
        for j in range(n_alleles):
            # Access allele counts.
            xc = np.float32(x[i, j])
            yc = np.float32(y[i, j])

            # Check if any alleles observed.
            x_is_called = x_is_called or (xc > 0)
            y_is_called = y_is_called or (yc > 0)

            # Compute cityblock distance (absolute difference).
            d += np.fabs(xc - yc)

        # Accumulate distance for the current pair, but only if both samples
        # have a called genotype.
        if x_is_called and y_is_called:
            distance += d
            n_sites_called += np.float32(1)

    # Compute the mean distance over sites with called genotypes.
    if n_sites_called > 0:
        mean_distance = distance / n_sites_called
    else:
        mean_distance = np.nan

    return mean_distance

def plot_dendrogram(
    dist,
    linkage_method,
    count_sort,
    distance_sort,
    render_mode,
    width,
    height,
    title,
    line_width,
    line_color,
    marker_size,
    leaf_data,
    leaf_hover_name,
    leaf_hover_data,
    leaf_color,
    leaf_symbol,
    leaf_y,
    leaf_color_discrete_map,
    leaf_category_orders,
    template,
    y_axis_title,
    y_axis_buffer,
):
    import scipy.cluster.hierarchy as sch
    # Hierarchical clustering.
    Z = sch.linkage(dist, method=linkage_method)

    # Compute the dendrogram but don't plot it.
    dend = sch.dendrogram(
        Z,
        count_sort=count_sort,
        distance_sort=distance_sort,
        no_plot=True,
    )

    # Compile the line coordinates into a single dataframe.
    icoord = dend["icoord"]
    dcoord = dend["dcoord"]
    line_segments_x = []
    line_segments_y = []
    for ik, dk in zip(icoord, dcoord):
        # Adding None here breaks up the lines.
        line_segments_x += ik + [None]
        line_segments_y += dk + [None]
    df_line_segments = pd.DataFrame({"x": line_segments_x, "y": line_segments_y})

    # Convert X coordinates to haplotype indices (scipy multiplies coordinates by 10).
    df_line_segments["x"] = (df_line_segments["x"] - 5) / 10

    # Plot the lines.
    fig = px.line(
        df_line_segments,
        x="x",
        y="y",
        render_mode=render_mode,
        template=template,
    )

    # Reorder leaf data to align with dendrogram.
    leaves = dend["leaves"]
    n_leaves = len(leaves)
    leaf_data = leaf_data.iloc[leaves]

    # Add scatter plot to draw the leaves.
    fig.add_traces(
        list(
            px.scatter(
                data_frame=leaf_data,
                x=np.arange(n_leaves),
                y=np.repeat(leaf_y, n_leaves),
                color=leaf_color,
                symbol=leaf_symbol,
                render_mode=render_mode,
                hover_name=leaf_hover_name,
                hover_data=leaf_hover_data,
                template=template,
                color_discrete_map=leaf_color_discrete_map,
                category_orders=leaf_category_orders,
            ).select_traces()
        )
    )

    # Style the lines and markers.
    line_props = dict(
        width=line_width,
        color=line_color,
    )
    marker_props = dict(
        size=marker_size,
    )
    fig.update_traces(line=line_props, marker=marker_props)

    # Style the figure.
    fig.update_layout(
        width=width,
        height=height,
        title=title,
        autosize=True,
        hovermode="closest",
        # I cannot get the xaxis title to appear below the plot, and when
        # it's above the plot it often overlaps the title, so hiding it
        # for now.
        xaxis_title=None,
        yaxis_title=y_axis_title,
        showlegend=True,
    )

    # Style axes.
    fig.update_xaxes(
        mirror=False,
        showgrid=False,
        showline=False,
        showticklabels=False,
        ticks="",
        range=(-2, n_leaves + 2),
    )
    fig.update_yaxes(
        mirror=False,
        showgrid=False,
        showline=False,
        showticklabels=True,
        ticks="outside",
        range=(leaf_y - y_axis_buffer, np.max(dcoord) + y_axis_buffer),
    )

    return fig, leaf_data

### Run diplotype clustering on Darlingi

In [33]:
import pandas as pd
import numpy as np

def parse_snpeff_annotations(ann):
    columns = [
        'allele', 'effect', 'impact', 'gene_name', 'gene_id',
        'feature_type', 'feature_id', 'transcript_biotype', 'rank_total',
        'hgvs_c', 'hgvs_p', 'cdna_pos', 'cds_pos', 'protein_pos',
        'distance', 'errors'
    ]
    parsed_data = [a.split('|') for a in ann.flatten()]
    df_ann = pd.DataFrame(parsed_data, columns=columns)[['allele', 'effect', 'gene_name', 'gene_id', 'hgvs_c', 'hgvs_p']]
    return df_ann.assign(label=lambda x: x.gene_name + "::" + x.hgvs_p.str.lstrip("p."))

def load_genotypes(vcf_path, region, missingness_filter_proportion=0.1):
    """
    Load VCF and filter poor-quality samples
    """
        
    # load vcf and get genotypes and positions
    vcf = allel.read_vcf(vcf_path, fields="*")
    
    # load data
    geno = allel.GenotypeArray(vcf['calldata/GT'])
    pos = vcf['variants/POS']
    contigs = vcf['variants/CHROM']
    samples = vcf['samples']

    if region:
        contig, start, end = [region.split(":")[0]] + region.split(":")[1].split("-")
        start, end = int(start), int(end)
        # subset to region of interest
        locus_mask = np.logical_and(contigs == contig, np.logical_and(pos >= start, pos <= end))
        assert np.sum(locus_mask) != 0, "No SNPs found in specified region"
    
        geno = geno.compress(locus_mask, axis=0)

    # remove invariant sites 
    ac = geno.count_alleles()
    is_seg = ac.is_segregating()
    geno_locus = geno.compress(is_seg, axis=0)

    if missingness_filter_proportion:
        # remove highly missing sites
        missing_mask = geno_locus.is_missing().mean(axis=1) > missingness_filter_proportion
        geno_locus = geno_locus.compress(~missing_mask, axis=0)

        # remove highly missing samples
        missing_mask = geno_locus.is_missing().mean(axis=0) > missingness_filter_proportion
        geno_locus = geno_locus.compress(~missing_mask, axis=1)
        samples = samples[~missing_mask]

    return geno_locus, samples

@numba.jit(nopython=True)
def _melt_gt_counts(gt_counts):
    n_snps, n_samples, n_alleles = gt_counts.shape
    melted_counts = np.zeros((n_snps * (n_alleles - 1), n_samples), dtype=np.int32)

    for i in range(n_snps):
        for j in range(n_samples):
            for k in range(n_alleles - 1):
                melted_counts[(i * 3) + k][j] = gt_counts[i][j][k + 1]

    return melted_counts

def prepare_snp_allele_counts(vcf_path, vcf_annot_path, region, leaf_data, non_synonymous=True, snp_filter_min_maf=0.05):

    vcf = allel.read_vcf(vcf_path, fields="*")
    gn = allel.GenotypeArray(vcf['calldata/GT'])
    contigs = vcf['variants/CHROM']
    pos_gn = vcf['variants/POS']
    samples = vcf['samples']

    # load annotations
    ann_vcf = allel.read_vcf(vcf_annot_path, fields="*", numbers={'variants/ANN': 3})
    ann = ann_vcf['variants/ANN']
    pos_ann = ann_vcf['variants/POS']

    # first intersect ann and gn positions
    gn_mask = np.isin(pos_gn, pos_ann)
    ann_mask = np.isin(pos_ann, pos_gn)

    gn = gn.compress(gn_mask, axis=0)
    pos_gn = pos_gn[gn_mask]
    ann = ann.compress(ann_mask, axis=0)
    pos_ann = pos_ann[ann_mask]
    contigs = contigs[gn_mask]
    assert gn.shape[0] == ann.shape[0]

    # restrict to region
    if region:
        contig, start, end = [region.split(":")[0]] + region.split(":")[1].split("-")
        start, end = int(start), int(end)
        # subset to region of interest
        locus_mask = np.logical_and(contigs == contig, np.logical_and(pos_gn >= start, pos_gn <= end))
        assert np.sum(locus_mask) != 0, "No SNPs found in specified region"

        gn = gn.compress(locus_mask, axis=0)
        ann = ann.compress(locus_mask, axis=0)
        pos = pos_gn[locus_mask]

    gt_long = _melt_gt_counts(gn.to_allele_counts().values)
    # df_snps[gn.is_missing()] = -1
    df_ann = parse_snpeff_annotations(ann)
    assert gt_long.shape[0] == df_ann.shape[0]

    df_snps = pd.concat([df_ann, pd.DataFrame(gt_long, columns=samples)], axis=1)

    if non_synonymous:
        df_snps = df_snps.query("effect == 'missense_variant'")

    df_snps = df_snps.drop(columns=['allele', 'effect', 'gene_name', 'gene_id', 'hgvs_c', 'hgvs_p'])
    df_snps = df_snps.set_index('label')

    df_snps = df_snps.loc[:, leaf_data.sample_id.tolist()]
    df_snps = df_snps.replace(-1, np.nan)

    if snp_filter_min_maf:
        df_snps = df_snps.assign(af=lambda x: x.sum(axis=1) / (x.shape[1] * 2))
        df_snps = df_snps.query("af > @snp_filter_min_maf").drop(columns="af")  
    
    return df_snps


In [34]:
def calculate_distances(gn, samples):
    # calculate distances
    from scipy.spatial.distance import squareform
    ac = allel.GenotypeArray(gn).to_allele_counts(max_allele=3)
    X = np.ascontiguousarray(np.swapaxes(ac.values, 0, 1))
    dists = multiallelic_diplotype_pdist(X, metric=multiallelic_diplotype_mean_cityblock)
    dist_matrix = squareform(dists)
    df_dists = pd.DataFrame(dist_matrix, index=samples, columns=samples)

    # remove na comparisons TODO 
    # TODO sort 
    na_mask = df_dists.isna().any()
    df_dists = df_dists.loc[~na_mask, ~na_mask]
    # median_val = np.nanmedian(df_dists.values.flatten())
    # df_dists[df_dists.isna()] = median_val
    return df_dists


def plot_diplotype_clustering(vcf_path, vcf_annot_path, region, df_samples, distance_metric='cityblock', leaf_color='location', non_synonymous=True, snp_filter_min_maf=0.05):

    # load vcf, geno, pos, contigs
    gn, samples = load_genotypes(vcf_path, region=region)

    if df_samples is None:
        df_samples = pd.DataFrame({'sampleID':samples})['sampleID'].str.split("_", expand=True).iloc[:, :2].assign(sampleID=samples)
        df_samples.columns = ['location', 'taxon', 'sample_id']

    df_dists = calculate_distances(gn, samples)

    import plotly.graph_objects as go
    leaf_data = df_samples#[~na_mask.to_numpy()]
    fig_dendro, leaf_data = plot_dendrogram(
        dist=squareform(df_dists.values),
        linkage_method="complete",
        count_sort=True,
        distance_sort=False,
        render_mode="svg",
        width=800,
        height=500,
        title=f"{region} diplotype clustering",
        line_width=0.6,
        line_color='black',
        marker_size=5,
        leaf_data=leaf_data,
        leaf_hover_name="sample_id",
        leaf_hover_data=leaf_data.columns,
        leaf_color=leaf_color,
        leaf_symbol=None,
        leaf_y=-0.01,
        leaf_color_discrete_map=None,
        leaf_category_orders=None,
        template="simple_white",
        y_axis_title=f"Distance ({distance_metric})",
        y_axis_buffer=0.1,
    )
    
    figures = [fig_dendro]
    subplot_heights = [300]
    snp_row_height  = 30
    width = 800


    # het bar
    df_het = pd.DataFrame(
        {"sample_id": samples, "Sample Heterozygosity": gn.is_het().sum(axis=0) / gn.is_called().sum(axis=0)}
    ).set_index("sample_id")

    # order according to dendrogram and transpose
    df_het = df_het.loc[leaf_data.sample_id, :].T
    het_trace = go.Heatmap(
        z=df_het,
        y=["Heterozygosity"],
        colorscale="Greys",
        showlegend=False,
        showscale=False,
    )

    figures.append(het_trace)
    subplot_heights.append(30)

    # prepare snp data
    df_snps_ordered = prepare_snp_allele_counts(vcf_path, vcf_annot_path, region, leaf_data, non_synonymous=non_synonymous, snp_filter_min_maf=snp_filter_min_maf)

    snp_trace = go.Heatmap(
        z=df_snps_ordered[::-1].values,
        y=df_snps_ordered[::-1].index.to_list(),
        colorscale="Greys",
        showlegend=False,
        showscale=False,
    )
    
    figures.append(snp_trace)
    subplot_heights.append(snp_row_height * df_snps_ordered.shape[0])
    
    height = sum(subplot_heights) + 50
    fig = _dipclust_concat_subplots(
        figures=figures,
        width=width,
        height=height,
        row_heights=subplot_heights,
        title=f"{region} | diplotype clustering",
        xaxis_range=(0, df_dists.shape[0]),
    )
    
    fig["layout"]["yaxis"]["title"] = f"Distance ({distance_metric})"
    
    aa_idx = len(figures)
    fig.add_hline(y=-0.5, line_width=1, line_color="grey", row=aa_idx, col=1)
    for i, y in enumerate(df_snps_ordered.index.to_list()):
        fig.add_hline(y=i+0.5, line_width=1, line_color="grey", row=aa_idx, col=1)
    
    fig['layout'][f'yaxis{aa_idx}']['title']=f'mutations'
    fig.update_xaxes(showline = True, linecolor = 'grey', linewidth = 1, row = aa_idx, col = 1, mirror = True)
    fig.update_yaxes(showline = True, linecolor = 'grey', linewidth = 1, row = aa_idx, col = 1, mirror = True)
    return fig

## Run diplotype clustering

In [None]:
### specify parameters 
region = '2R:28480000-28520000' # chr:start-end
distance_metric = 'cityblock'  
leaf_color = 'location' # column in df_samples to use for dendrogram leaf color
vcf_path = "../amplicon-seq/ampseq-agvampir002/results/vcfs/amplicons/ag-vampir-002.annot.vcf"
vcf_annot_path = "../amplicon-seq/ampseq-agvampir002/results/vcfs/amplicons/ag-vampir-002.annot.vcf" # different path for annotation vcf
snp_filter_min_maf = 0.05 # minimum minor allele frequency for amino acid plot
non_synonymous = True       # plot only non-synonymous mutations

# df_samples should be a metadata file in the same order as the VCF. these columns can be used to color the leaves of the dendrogram and provide hover data for the interactive plot
df_samples = None
# df_samples = pd.read_csv("./my_metadata.csv", sep=",") 

In [36]:
plot_diplotype_clustering(vcf_path=vcf_path, 
                          vcf_annot_path=vcf_annot_path,
                          region=region, 
                          df_samples=df_samples, 
                          distance_metric=distance_metric, 
                          leaf_color=leaf_color,
                          non_synonymous=True, 
                          snp_filter_min_maf=0.05
                          )

In [None]:
plot_diplotype_clustering(
    vcf_path=vcf_path,
    vcf_annot_path=vcf_annot_path, 
    region="2L:2000000-3000000", 
    df_samples=None, 
    distance_metric=distance_metric, 
    leaf_color=leaf_color, 
    non_synonymous=True, 
    snp_filter_min_maf=0.05
    )


invalid INFO header: '##INFO=<ID=VDB,Number=1,Type=Float,Description="Variant Distance Bias for filtering splice-site artefacts in RNA-seq data (bigger is better)",Version="3">\n'


invalid INFO header: '##INFO=<ID=VDB,Number=1,Type=Float,Description="Variant Distance Bias for filtering splice-site artefacts in RNA-seq data (bigger is better)",Version="3">\n'

