In [None]:
2 %pip install malariagen_data==7.15 kaleido -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m135.0/135.0 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.9/79.9 MB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m39.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m61.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m65.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.5/302.5 kB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m138.7/138.7 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.9/20.9 MB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━

In [None]:
import numpy as np
import pandas as pd
import malariagen_data
import allel
import numba
import kaleido
import scipy
import plotly.express as px

from itertools import product, combinations
from scipy.special import comb
from scipy.spatial.distance import squareform

### Diplotype clustering and amino acid variation

When we are interested in a locus that is under selection, we often would like to know how many sweeps there are and how large each sweep is, but also we would like to know which amino acid variants are present on each swept haplotype.

Usually we would use haplotype clustering or haplotype networks for this purpose, but the use of haplotypes means that multi-allelic sites are omitted from the data. Performing additional phasing of multi-allelic sites is a pain, so we want to avoid this where possible.

In this notebook, I present code to perform clustering on diplotypes from unphased genomic data. Along with the diplotype dendrogram, we plot the frequency of amino acid variants for a given transcript. It should also be straightforward to include CNVs in this as well as amino acid mutations, though this is not implemented yet.

**Work in progress! (WIP)**  

<br></br>
  
TODO   

- Split into two steps
  - dendrogram support functions
  - advanced one

- Bug with two aa changes at same position L324K and L324Q with same genotypes. need to fix.
- Tidy and rethink code.
- Figure out a neat way to cut the tree in `diplotype_pairwise_distances` function.

- Hover data
  - amino acid, position and effect data
  - heterozygosity - quantile ranking
  - cnvs - copy number
  - diplotype cluster


 <br></br>

We will first define all the functions we are going to need!

## dist functions

Functions from alimanfoo to calculate genetic distance from diplotypes.

## dip pairwise distances

here are the functions for pairwise distance + diplotype clustering, which have been partially malariagen_data-ized.

In [None]:
from malariagen_data.anopheles import base_params, hap_params
from malariagen_data.util import CacheMiss, Region
from typing import Optional, Tuple

In [None]:

 def diplotype_pairwise_distances(
  self,
  region:base_params.regions,
  site_mask:base_params.site_mask,
  sample_sets: Optional[base_params.sample_sets] = None,
  sample_query: Optional[base_params.sample_query] = None,
  cohort_size: Optional[base_params.cohort_size] = None,
  random_seed: base_params.random_seed = 42,
  heterozygosity=False,
  ) -> Tuple[np.ndarray, np.ndarray, int]:

  # Change this name if you ever change the behaviour of this function, to
  # invalidate any previously cached data.
  name = "diplotype_pairwise_distances"

  # Normalize params for consistent hash value.
  sample_sets_prepped = self._prep_sample_sets_param(sample_sets=sample_sets)
  region_prepped = self._prep_region_cache_param(region=region)
  params = dict(
      self=self,
      region=region_prepped,
      site_mask=site_mask,
      sample_sets=sample_sets_prepped,
      sample_query=sample_query,
      cohort_size=cohort_size,
      random_seed=random_seed,
      heterozygosity=heterozygosity,
  )

  # Try to retrieve results from the cache.
  try:
      results = self.results_cache_get(name=name, params=params)

  except CacheMiss:
      results = _diplotype_pairwise_distances(**params)
      self.results_cache_set(name=name, params=params, results=results)

  return results


def _diplotype_pairwise_distances(
        self,
        *,
        region,
        site_mask,
        sample_sets,
        sample_query,
        cohort_size,
        random_seed,
        heterozygosity=False,
    ):
    from scipy.spatial.distance import squareform
    import dask.array as da

    ds_snps = self.snp_calls(
        region=region,
        sample_query=sample_query,
        sample_sets=sample_sets,
        site_mask=site_mask,
        random_seed=random_seed,
    )

    gt = allel.GenotypeDaskArray(ds_snps["call_genotype"].data).compute()
    pos = ds_snps["variant_position"].values

    # Compute allele count, remove non-segregating sites.
    ac = allel.GenotypeArray(gt).count_alleles(max_allele=3)
    gt_seg = gt.compress(ac.is_segregating(), axis=0)
    ac_seg = gt_seg.to_allele_counts(max_allele=3)
    X = da.swapaxes(ac_seg.values, 0, 1).compute()

    # calculate distance
    dist = diplotype_pdist(X, metric=diplotype_mean_cityblock)
   # dist = squareform(dist)

    # Extract IDs of samples. Convert to "U" dtype here
    # to allow these to be saved to the results cache.
    samples = ds_snps["sample_id"].values.astype("U")

    if heterozygosity:
      return dict(
          dist=dist,
          samples=samples,
          n_snps=gt.shape[0],
          het=calc_heterozygosity(gt=gt, gt_samples=samples)
      )
    else:
      return dict(
          dist=dist,
          samples=samples,
          n_snps=gt.shape[0],
      )

def plot_diplotype_clustering_dendro(
    self,
    region,
    sample_sets=None,
    sample_query=None,
    site_mask=None,
    color="location",
    symbol=None,
    linkage_method='complete',
    count_sort=True,
    distance_sort=False,
    random_seed=42,
    width=800,
    height=600,
    heterozygosity=False,
    **kwargs,
):
    import plotly.express as px
    from scipy.cluster.hierarchy import linkage
    from malariagen_data.plotly_dendrogram import create_dendrogram
    debug = self._log.debug

    resolved_region = malariagen_data.util.parse_single_region(self, region)
    del region

    debug("load sample metadata")
    df_samples = self.sample_metadata(
        sample_sets=sample_sets, sample_query=sample_query
    )

    debug("set up plotting options")
    hover_data = [
        "sample_id",
        "partner_sample_id",
        "sample_set",
        "taxon",
        "country",
        "admin1_iso",
        "admin1_name",
        "admin2_name",
        "location",
        "year",
        "month",
    ]

    if color and color not in hover_data:
        hover_data.append(color)
    if symbol and symbol not in hover_data:
        hover_data.append(symbol)

    plot_kwargs = dict(
        template="simple_white",
        hover_name="sample_id",
        hover_data=hover_data,
        render_mode="svg",
    )

    out_dict = diplotype_pairwise_distances(
          self=self,
          region=resolved_region,
          site_mask=site_mask,
          sample_sets=sample_sets,
          sample_query=sample_query,
          random_seed=random_seed,
          heterozygosity=heterozygosity
    )

    dist = out_dict['dist']
    gt_samples = out_dict['samples']
    n_snps = out_dict['n_snps']

    debug("align sample metadata with diplotypes")
    df_samples = (
        df_samples.set_index("sample_id").loc[gt_samples].reset_index()
    )

    # set labels as the index which we extract to reorder metadata
    leaf_labels = np.arange(len(gt_samples))

    fig = create_dendrogram(
        dist,
        linkagefun=lambda x: linkage(x, method=linkage_method),
        labels=leaf_labels,
        color_threshold=0,
        count_sort=count_sort,
        distance_sort=distance_sort,
    )

    fig.update_traces(
        hoverinfo="y",
        line=dict(width=0.5, color="black"),
    )

    fig.update_layout(
        width=width,
        height=height,
        autosize=True,
        margin_autoexpand=False,
        hovermode="closest",
        plot_bgcolor="white",
        yaxis_title="Distance (manhattan)",
        xaxis_title="Diplotypes",
        showlegend=True,
    )
    # select only columns in hover_data
    df_samples = df_samples[hover_data]
    diplotype_order = fig.layout.xaxis["ticktext"]
    out_dict['order'] = diplotype_order
    out_dict['samples'] = gt_samples[diplotype_order]
    debug("Reorder diplotype metadata to align with diplotype clustering")
    df_samples = df_samples.loc[
        diplotype_order
    ]

    fig.update_xaxes(mirror=False, showgrid=True, showticklabels=False, ticks="")
    fig.update_yaxes(
        mirror=False, showgrid=True, showline=True)#, range=[-2, fullfig.layout.yaxis.range[-1]]

    color_discrete_map = {'gambiae':'dodgerblue', 'coluzzii':'darkorange', 'arabiensis':'limegreen', 'unassigned':'grey'}

    fig.add_traces(
        list(
            px.scatter(
                df_samples,
                x=fig.layout.xaxis["tickvals"],
                y=np.repeat(-0.00001, df_samples.shape[0]),
                color=color,
                color_discrete_map=color_discrete_map,
                symbol=symbol,
                **plot_kwargs,
            ).select_traces()
        )
    )
    return fig, out_dict


def concat_subplots(figures, x_range, width, height, row_heights, sample_sets, sample_query, resolved_region, n_snps):
    from plotly.subplots import make_subplots

    title_lines = []
    if sample_sets is not None:
        title_lines.append(f"sample sets: {sample_sets}")
    if sample_query is not None:
        title_lines.append(f"sample query: {sample_query}")
    title_lines.append(f"genomic region: {resolved_region} ({n_snps} SNPs)")
    title = "<br>".join(title_lines)

    # make subplots
    fig = make_subplots(rows=len(figures),
                        cols=1,
                        shared_xaxes=True,
                        vertical_spacing=0.02,
                        row_heights=row_heights)

    for i, figure in enumerate(figures):
        for trace in range(len(figure["data"])):
            fig.append_trace(figure["data"][trace], row=i+1, col=1)

    fig.update_xaxes(visible=False)
    fig.update_layout(
        title=title,
        width=width,
        height=height,
        hovermode="closest",
        plot_bgcolor="white",
        xaxis_range=(x_range[0], x_range[-1])
    )

    return fig



In [None]:
def plot_diplotype_clustering_advanced(
    self,
    transcript,
    region,
    sample_sets=None,
    sample_query=None,
    site_mask=None,
    heterozygosity=False,
    karyo=False,
    color='country',
    symbol=None,
    linkage_method='single',
    cut_height=None,
    min_cluster_size=None,
    count_sort=True,
    distance_sort=False,
    reverse_strand=False,
    width=1000,
    height=1000,
  ):
    resolved_region = malariagen_data.util.parse_single_region(self, region)
    del region

    assert cut_height and min_cluster_size or not cut_height and not min_cluster_size, "To cut the tree and label clusters, values for both cut_height and min_cluster_size must be provided"

    fig_dendro, out_dict = plot_diplotype_clustering_dendro(
        self=self,
        region=resolved_region,
        sample_sets=sample_sets,
        sample_query=sample_query,
        site_mask=site_mask,
        count_sort=count_sort,
        distance_sort=distance_sort,
        linkage_method=linkage_method,
        color=color,
        symbol=symbol,
        heterozygosity=True,
        )
    x_range = fig_dendro.layout.xaxis['tickvals']
    n_snps = out_dict['n_snps']

    figures = []
    row_heights = [0.2]
    figures.append(fig_dendro)

    if heterozygosity:
      het_df = out_dict['het'].iloc[out_dict['order'], :]
      het_df = het_df.rename(columns={'het':'Sample Heterozygosity'}).T
      het_df.columns = x_range
      fig_het = plotly_track(het_df, "Greys")
      figures.append(fig_het)
      row_heights.append(0.012)

    # if karyo:
    #   karyo_df = pd.read_csv("karyotypes.tsv", sep="\t")
    #   karyo_df = karyo_df[['partner_sample_id', 'mean_genotype']].set_index('partner_sample_id').loc[out_dict['samples'], :].T
    #   karyo_df.columns = x_range
    #   fig_karyo = plotly_track(df=karyo_df, colour='Tempo')
    #   figures.append(fig_karyo)
    #   row_heights.append(0.012)

    if cut_height and min_cluster_size:
      cluster_col_list = px.colors.sequential.Turbo.copy()
      cluster_col_list.insert(0, 'white')

      clusters_df = cut_dist_tree(dist=out_dict['dist'], linkage_method=linkage_method, cut_height=cut_height, min_cluster_size=min_cluster_size, count_sort=count_sort, distance_sort=distance_sort)
      clusters_df = clusters_df.rename(columns={'cluster_labels':'Diplotype Cluster'}).T
      clusters_df.columns = x_range
      fig_cluster = plotly_track(df=clusters_df, colour=cluster_col_list)
      figures.append(fig_cluster)
      row_heights.append(0.012)



    if transcript:
      # cnv_col_list = px.colors.sequential.tempo.copy()
      # cnv_col_list.insert(0, 'white')
      # cnv_col_list.append('white')

      cnv_df = calc_copynumber(self=self, transcript=transcript, sample_sets=sample_sets, sample_query=sample_query)
      cnv_order = cnv_df.index.to_list()
      mask  = np.array([i in cnv_order for i in out_dict['samples']])
      missing_samples = out_dict['samples'][~mask]
      extra_cnvs = pd.DataFrame({'sample_id':missing_samples, 'cn_mode':0}).set_index('sample_id')
      cnv_df = pd.concat([cnv_df, extra_cnvs])
      cnv_df = cnv_df.loc[out_dict['samples']]#.iloc[out_dict['order'], :]
      cnv_df = cnv_df.rename(columns={'cn_mode':'CNV'})
      cnv_df.loc[:, 'CNV'] = (cnv_df.loc[:, 'CNV'] - 2).apply(lambda x: np.max([x, 0]))
      cnv_df = cnv_df.T
      cnv_df.columns = x_range
      fig_cnv = plotly_track(df=cnv_df, colour="Greys")
      figures.append(fig_cnv)
      row_heights.append(0.012)

      # load genotypes at amino acid variants for each sample
      df_snps = calc_aa_genotypes(self, transcript=transcript, sample_query=sample_query, sample_sets=sample_sets, site_mask=site_mask, reverse_strand=reverse_strand)
      # set to diplotype cluster order
      df_snps = df_snps.filter(like='frq').fillna(0).iloc[:, out_dict['order']]
      df_snps.columns = x_range
      aa_height = df_snps.shape[0]/100
      fig_aa = plotly_track(df_snps)
      figures.append(fig_aa)
      row_heights.append(aa_height)

    fig = concat_subplots(
        figures = figures,
        x_range=x_range,
        width=width,
        height=height,
        row_heights=row_heights,
        sample_sets=sample_sets,
        sample_query=sample_query,
        resolved_region=resolved_region,
        n_snps=n_snps,
    )

    if transcript:
      # add lines to aa plot
      aa_idx = len(figures)
      fig.add_hline(y=-0.5, line_width=1, line_color="grey", row=aa_idx, col=1)
      for i, y in enumerate(df_snps.index.to_list()):
        fig.add_hline(y=i+0.5, line_width=1, line_color="grey", row=aa_idx, col=1)

      fig['layout'][f'yaxis{aa_idx}']['title']=f'{transcript} amino acids'
      fig.update_xaxes(showline = True, linecolor = 'grey', linewidth = 1, row = aa_idx, col = 1, mirror = True)
      fig.update_yaxes(showline = True, linecolor = 'grey', linewidth = 1, row = aa_idx, col = 1, mirror = True)

    fig['layout']['yaxis']['title']=f'Distance (manhattan)'

    return fig

In [None]:
def cluster_summary(
    self,
    region,
    sample_sets=None,
    sample_query=None,
    site_mask=None,
    linkage_method='single',
    cut_height=None,
    min_cluster_size=None,
    count_sort=True,
    distance_sort=False
):

    resolved_region = malariagen_data.util.parse_single_region(self, region)
    del region

    assert cut_height and min_cluster_size or not cut_height and not min_cluster_size, "To cut the tree and label clusters, values for both cut_height and min_cluster_size must be provided"

    fig_dendro, out_dict = plot_diplotype_clustering_dendro(
        self=self,
        region=resolved_region,
        sample_sets=sample_sets,
        sample_query=sample_query,
        site_mask=site_mask,
        count_sort=count_sort,
        distance_sort=distance_sort,
        linkage_method=linkage_method,
        color=None,
        symbol=None,
        heterozygosity=False,
        )

    df_samples = self.sample_metadata(sample_sets=sample_sets, sample_query=sample_query)

    clusters_df = cut_dist_tree(dist=out_dict['dist'], linkage_method=linkage_method, cut_height=cut_height, min_cluster_size=min_cluster_size, count_sort=count_sort, distance_sort=distance_sort)
    clusters_df = clusters_df.reset_index()

    return df_samples, clusters_df

In [None]:
def calc_aa_genotypes(self, transcript, site_mask=None, sample_sets=None, sample_query=None, maf=0.05, reverse_strand=False):
    region, df_snps = self._snp_df(transcript=transcript)

    gt = self.snp_genotypes(
        region=region,
        sample_sets=sample_sets,
        sample_query=sample_query,
        site_mask=None,
        field="GT",
    ).compute()

    freq_cols = {}
    for idx in range(gt.shape[1]):
        gt_coh = gt[:, [idx], :]
        ac_coh = allel.GenotypeArray(gt_coh).count_alleles(max_allele=3)
        af_coh = ac_coh.to_frequencies()
        freq_cols["frq_" + str(idx)] = af_coh[:, 1:].flatten()

    df_freqs = pd.DataFrame(freq_cols)
    df_max_af = pd.DataFrame({"max_af": df_freqs.max(axis=1)})

    df_snps.reset_index(drop=True, inplace=True)
    df_snps = pd.concat([df_snps, df_freqs, df_max_af], axis=1)

    if site_mask is not None:
        loc_sites = df_snps[f"pass_{site_mask}"]
        df_snps = df_snps.loc[loc_sites]

    ann = self._annotator()
    ann.get_effects(
        transcript=transcript, variants=df_snps
    )
    df_snps["label"] = self._pandas_apply(
        self._make_snp_label_effect,
        df_snps,
        columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"],
    )
    df_snps.set_index(
        ["contig", "position", "ref_allele", "alt_allele", "aa_change"],
        inplace=True)

    df_snps = df_snps.query("~aa_change.isnull()", engine='python')
    df_snps = df_snps.assign(af=df_snps.filter(like="frq").sum(axis=1) / df_snps.filter(like='frq').shape[1])
    df_snps = df_snps.reset_index()

    freq_cols = [col for col in df_snps if col.startswith("frq")] + ['af']
    agg = {c: np.nansum for c in freq_cols}
    keep_cols = (
          "contig",
          "transcript",
          "aa_pos",
          "ref_allele",
          "ref_aa",
          "alt_aa",
          "effect",
          "impact",
      )
    for c in keep_cols:
        agg[c] = "first"
    agg["alt_allele"] = lambda v: "{" + ",".join(v) + "}" if len(v) > 1 else v
    df_aaf = df_snps.groupby(["position", "aa_change"]).agg(agg).reset_index()

    df_aaf["label"] = self._pandas_apply(
        self._make_snp_label_aa,
        df_aaf,
        columns=["aa_change", "contig", "position", "ref_allele", "alt_allele"],
    )

    df_aaf = df_aaf.sort_values(["position", "aa_change"])
    df_aaf.set_index(["aa_change", "contig", "position"], inplace=True)

    df_aaf = df_aaf.query(f"effect == 'NON_SYNONYMOUS_CODING' and af > {maf}").reset_index().set_index('aa_change')
    if reverse_strand == True:
      df_aaf = df_aaf[::-1] #flip df cos reverse strand so aa are backwards

    return df_aaf

def calc_heterozygosity(gt, gt_samples):
  from tqdm.notebook import tqdm
  het_per_sample = [np.nanmean(allel.heterozygosity_observed(gt[:, [i], :])) for i in tqdm(range(gt.shape[1]))]
  het_df = pd.DataFrame({'sample_id':gt_samples, 'het':het_per_sample})
  return het_df.set_index("sample_id")

def calc_copynumber(self, transcript, sample_sets=None, sample_query=None):
  try:
    ds_cnv = ag3.gene_cnv(region=transcript, sample_sets=sample_sets, sample_query=sample_query, max_coverage_variance=1)
    cnv_df = pd.DataFrame({'sample_id':ds_cnv['sample_id'].values, 'cn_mode':ds_cnv['CN_mode'].values[0]})
  except ValueError:
    cnv_df = pd.DataFrame(columns=['sample_id', 'cn_mode'])
  return cnv_df.set_index('sample_id')

def plotly_track(df, colour='greys', range_color=None):
  import plotly.express as px
  fig = px.imshow(df, range_color=range_color)
  fig.update_layout(showlegend=False)
  fig.update_traces(dict(showscale=False, coloraxis=None, colorscale=colour))
  fig.update_xaxes(visible=False)

  return fig

## scipy clust

The following functions are useful when wanting to cut the tree to obtain diplotype clusters. Unfinished.

In [None]:
def cut_dist_tree(dist, linkage_method, cut_height, min_cluster_size, count_sort, distance_sort):

    Z = scipy.cluster.hierarchy.linkage(dist, method=linkage_method)
    dend = scipy.cluster.hierarchy.dendrogram(Z, color_threshold=cut_height, no_plot=True, count_sort=count_sort, distance_sort=distance_sort)
    leaves = dend['leaves']

    tree = scipy.cluster.hierarchy.to_tree(Z)
    s = np.arange(len(leaves))
    t = np.array([
        1 if l < squareform(dist).shape[0]
        else tree.get_descendant(l).get_count()
        for l in leaves
    ])

    cut = scipy.cluster.hierarchy.cut_tree(Z, height=cut_height)[:, 0]
    cluster_sizes = np.bincount(cut)
    clusters = [np.nonzero(cut == i)[0] for i in range(cut.max() + 1)]

    # N.B., the dendrogram was truncated, so each leaf in the dendrogram
    # may correspond to more than one original observation (i.e., haplotype).
    # Let's build a list storing the observations for each leaf:
    leaf_obs = [tree.get_descendant(ix).pre_order() for ix in leaves]

    # Now let's figure out for each leaf in the dendrogram, which of the clusters
    # obtained by cutting the tree earlier does it fall into?
    leaf_clusters = np.array([cut[l[0]] for l in leaf_obs])

    # Now let's build a data structure that reorders the clusters so they
    # occur in the same order as in the dendrogram, and also record the indices
    # of the start and stop leaf for each cluster:
    cluster_spans = list()
    c_prv = leaf_clusters[0]
    i_start = 0
    for i, c in enumerate(leaf_clusters[1:], 1):
        if c != c_prv:
            cluster_spans.append((i_start, i, clusters[c_prv]))
            i_start = i
        c_prv = c
    # don't forget the last one
    cluster_spans.append((i_start, i+1, clusters[c]))

    labels = []
    index = []
    cluster_n = 0
    for span in cluster_spans:
      index.extend(span[2])

      cluster_size = len(span[2])
      if cluster_size >= min_cluster_size:
        labels.extend(np.repeat(cluster_n, cluster_size))
        cluster_n += 1
      else:
        labels.extend(np.repeat(0, cluster_size))

    return pd.DataFrame({'index':index, 'cluster_labels':labels}).set_index('index')


def get_descendant(node, desc_id):
    """Search the descendants of the given node in a scipy tree.

    Parameters
    ----------
    node : scipy.cluster.hierarchy.ClusterNode
        The ancestor node to search from.
    desc_id : int
        The ID of the node to search for.

    Returns
    -------
    desc : scipy.cluster.hierarchy.ClusterNode
        If a node with the given ID is not found, returns None.

    """
    if node.id == desc_id:
        return node
    if node.is_leaf():
        return None
    if node.left.id == desc_id:
        return node.left
    if node.right.id == desc_id:
        return node.right
    # search left
    l = get_descendant(node.left, desc_id)
    if l is not None:
        return l
    # search right
    r = get_descendant(node.right, desc_id)
    return r

# monkey-patch as a method
scipy.cluster.hierarchy.ClusterNode.get_descendant = get_descendant


In [None]:
@numba.njit(parallel=True)
def diplotype_pdist(X, metric):
    """Optimised implementation of pairwise distance between diplotypes.

    N.B., here we assume the array X provides diplotypes as genotype allele
    counts, with axes in the order (n_samples, n_sites, n_alleles).

    Computation will be faster if X is a contiguous (C order) array.

    The metric argument is the function to compute distance for a pair of
    diplotypes. This can be a numba jitted function.

    """
    n_samples = X.shape[0]
    n_pairs = (n_samples * (n_samples - 1)) // 2
    out = np.zeros(n_pairs, dtype=np.float32)

    # Loop over samples, first in pair.
    for i in range(n_samples):
        x = X[i, :, :]

        # Loop over observations again, second in pair.
        for j in numba.prange(i + 1, n_samples):
            y = X[j, :, :]

            # Compute distance for the current pair.
            d = metric(x, y)

            # Store result for the current pair.
            k = square_to_condensed(i, j, n_samples)
            out[k] = d

    return out

@numba.njit
def square_to_condensed(i, j, n):
    assert i != j, "no diagonal elements in condensed matrix"
    if i < j:
        i, j = j, i
    return n*j - j*(j+1)//2 + i - 1 - j

@numba.njit
def diplotype_mean_cityblock(x, y):
    """Compute the mean cityblock distance between two diplotypes x and y. The
    diplotype vectors are expected as genotype allele counts, i.e., x and y
    should have the same shape (n_sites, n_alleles).

    N.B., here we compute the mean value of the distance over sites where
    both individuals have a called genotype. This avoids computing distance
    at missing sites.

    """
    n_sites = x.shape[0]
    n_alleles = x.shape[1]
    distance = np.float32(0)
    n_sites_called = np.float32(0)

    # Loop over sites.
    for i in range(n_sites):
        x_is_called = False
        y_is_called = False
        d = np.float32(0)

        # Loop over alleles.
        for j in range(n_alleles):

            # Access allele counts.
            xc = np.float32(x[i, j])
            yc = np.float32(y[i, j])

            # Check if any alleles observed.
            x_is_called = x_is_called or (xc > 0)
            y_is_called = y_is_called or (yc > 0)

            # Compute cityblock distance (absolute difference).
            d += np.fabs(xc - yc)

        # Accumulate distance for the current pair, but only if both samples
        # have a called genotype.
        if x_is_called and y_is_called:
            distance += d
            n_sites_called += np.float32(1)

    # Compute the mean distance over sites with called genotypes.
    if n_sites_called > 0:
        mean_distance = distance / n_sites_called
    else:
        mean_distance = np.nan

    return mean_distance


@numba.njit
def diplotype_mean_sqeuclidean(x, y):
    """Compute the mean squared euclidean distance between two diplotypes x and
    y. The diplotype vectors are expected as genotype allele counts, i.e., x and
    y should have the same shape (n_sites, n_alleles).

    N.B., here we compute the mean value of the distance over sites where
    both individuals have a called genotype. This avoids computing distance
    at missing sites.

    """
    n_sites = x.shape[0]
    n_alleles = x.shape[1]
    distance = np.float32(0)
    n_sites_called = np.float32(0)

    # Loop over sites.
    for i in range(n_sites):
        x_is_called = False
        y_is_called = False
        d = np.float32(0)

        # Loop over alleles.
        for j in range(n_alleles):

            # Access allele counts.
            xc = np.float32(x[i, j])
            yc = np.float32(y[i, j])

            # Check if any alleles observed.
            x_is_called = x_is_called or (xc > 0)
            y_is_called = y_is_called or (yc > 0)

            # Compute squared euclidean distance.
            d += (xc - yc)**2

        # Accumulate distance for the current pair, but only if both samples
        # have a called genotype.
        if x_is_called and y_is_called:
            distance += d
            n_sites_called += np.float32(1)

    # Compute the mean distance over sites with called genotypes.
    if n_sites_called > 0:
        mean_distance = distance / n_sites_called
    else:
        mean_distance = np.nan

    return mean_distance


def compute_genetic_distances(geno_ac):

    n_inds = geno_ac.shape[1]
    D = diplotype_pdist(geno_ac, metric=ma_dist_opt)
   # print("shape of D: {0}".format(D.shape))

    N = diplotype_pdist(geno_ac, metric=ma_countable_opt)
   # print("shape of N: {0}".format(N.shape))

    return D, N



# testing functions

In [None]:
ag3 = malariagen_data.Ag3(pre=True)
#af1 = malariagen_data.Af1()

In [None]:
resolved_region = '2L:28,545,396-28,550,748'
sample_sets=['AG1000G-GH']
sample_query='taxon == "gambiae"'
site_mask = None
random_seed=77

In [None]:
sample_sets = [
    # Ag1000G phase 3 sample sets in Ag3.0
    "AG1000G-GH",
    'AG1000G-ML-A',
     'AG1000G-BF-A',
     'AG1000G-BF-B',
     'AG1000G-GN-A',
     'AG1000G-GN-B',
    'AG1000G-TZ',
    #'1246-VO-TZ-KABULA-VMF00185',
    # GAARDIAN sample set in Ag3.4
    '1244-VO-GH-YAWSON-VMF00149',
    # GAARD Ghana sample set in Ag3.2
     "1244-VO-GH-YAWSON-VMF00051",
     '1245-VO-CI-CONSTANT-VMF00054',
     '1253-VO-TG-DJOGBENOU-VMF00052',
     '1237-VO-BJ-DJOGBENOU-VMF00050'
]

In [None]:
df_samples, clusters = cluster_summary(
    ag3,
    region=resolved_region,
    cut_height=0.03,
    min_cluster_size=40,
    sample_sets=sample_sets,
    site_mask='gamb_colu_arab',
    linkage_method='complete',
    )

In [None]:
df_samples = ag3.sample_metadata(sample_sets=sample_sets)

In [None]:
df_samples = df_samples.iloc[clusters['index'].to_list(), :].assign(cluster=clusters.cluster_labels.to_list())

In [None]:
pd.crosstab(df_samples.cluster, df_samples.taxon)

taxon,arabiensis,coluzzii,gambiae,gcx3,unassigned
cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,157,1002,671,11,2
1,0,0,118,0,0
2,0,0,48,0,0
3,0,0,220,0,1
4,71,0,0,0,0
5,0,0,87,0,0
6,0,43,0,0,0


In [None]:
ds_cnv = ag3.gene_cnv(region="AGAP006227-RA", sample_sets=sample_sets, sample_query=None, max_coverage_variance=1)
cnv_df = pd.DataFrame({'sample_id':ds_cnv['sample_id'].values, 'cn_mode':ds_cnv['CN_mode'].values[0]})

Load CNV HMM data:   0%|          | 0/258 [00:00<?, ?it/s]

Compute modal gene copy number:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
df_samples = df_samples.merge(cnv_df, how='left')

In [None]:
pivot_country_year_taxon = (
    df_samples
    .pivot_table(
        index=["country", "year"],
        columns=["taxon"],
        values="sample_id",
        aggfunc="count",
        fill_value=0
    )
)
pivot_country_year_taxon.to_csv("sample_manifest.tsv", sep="\t")

In [None]:
pivot_cluster = (
    df_samples
    .pivot_table(
        index=["cluster"],
        columns=["taxon"],
        values="sample_id",
        aggfunc="count",
        fill_value=0
    )
)

pivot_cnv = (
    df_samples
    .pivot_table(
        index=["cluster"],
        columns=["cn_mode"],
        values="sample_id",
        aggfunc="count",
        fill_value=0
    )
)
proportion_cnv = pivot_cnv.iloc[:, 2:].sum(axis=1) /pivot_cnv.sum(axis=1)
pivot_cluster = pivot_cluster.assign(proportion_with_cnv=np.round(proportion_cnv, 2)).drop(columns=['gcx3', 'unassigned'])
pivot_cluster.to_csv("cluster_manifest.tsv", sep="\t")
pivot_cluster

taxon,arabiensis,coluzzii,gambiae,proportion_with_cnv
cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,157,1002,671,0.02
1,0,0,118,0.01
2,0,0,48,0.96
3,0,0,220,0.0
4,71,0,0,0.13
5,0,0,87,0.01
6,0,43,0,0.0


In [None]:

proportion_cnv = pivot_cnv.iloc[:, 2:].sum(axis=1) /pivot_cnv.sum(axis=1)

In [None]:
df_summary = df_samples[['cluster', 'cn_mode']].value_counts().to_frame().sort_values('cluster').rename(columns={0:'count'}).reset_index()

In [None]:
df_summary2 = df_samples[['cluster', 'taxon']].value_counts().to_frame().sort_values('cluster').rename(columns={0:'count'}).reset_index()
px.bar(df_summary2.query("cluster != 0 and taxon != 'unassigned'"), x='cluster', y='count', color='taxon', template='simple_white', width=700, height=500)

In [None]:
px.bar(df_summary.query("cluster != 0"), x='cluster', y='count', color='cn_mode', template='simple_white', width=700, height=500)

In [None]:
fig = plot_diplotype_clustering_advanced(
    ag3,
    region=resolved_region,
    heterozygosity=True,
    karyo=False,
    cut_height=0.03,
    min_cluster_size=10,
    transcript='AGAP006227-RA',
    sample_sets=sample_sets,
   # sample_query="taxon not in ['unassigned', 'gcx3']",
    site_mask='gamb_colu_arab',
    linkage_method='complete',
    color='taxon'
    )
fig

  0%|          | 0/2431 [00:00<?, ?it/s]

Load CNV HMM data:   0%|          | 0/258 [00:00<?, ?it/s]

Compute modal gene copy number:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
fig = plot_diplotype_clustering_advanced(
    ag3,
    region="2L:28_548_433-28_550_748",
    heterozygosity=True,
    karyo=False,
    cut_height=0.03,
    min_cluster_size=40,
    transcript='AGAP006228-RA',
    sample_sets=sample_sets,
   # sample_query="taxon not in ['unassigned', 'gcx3']",
    site_mask='gamb_colu_arab',
    linkage_method='complete',
    color='taxon'
    )
fig
fig.write_image("dipclust_coeae2f.png", scale=2)

  0%|          | 0/2431 [00:00<?, ?it/s]

Load CNV HMM data:   0%|          | 0/258 [00:00<?, ?it/s]

Compute modal gene copy number:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
fig = plot_diplotype_clustering_advanced(
    ag3,
    region="2L:28_545_396-28_547_938",
    heterozygosity=True,
    karyo=False,
    cut_height=0.03,
    min_cluster_size=40,
    transcript='AGAP006227-RA',
    sample_sets=sample_sets,
   # sample_query="taxon not in ['unassigned', 'gcx3']",
    site_mask='gamb_colu_arab',
    linkage_method='complete',
    color='taxon'
    )
fig
fig.write_image("dipclust_coeae1f.png", scale=2)

  0%|          | 0/2431 [00:00<?, ?it/s]

Load CNV HMM data:   0%|          | 0/258 [00:00<?, ?it/s]

Compute modal gene copy number:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
fig.write_image("dipclust.png", scale=2)

In [None]:
df_samples = ag3.sample_metadata()

In [None]:
fig = plot_diplotype_clustering_advanced(
    ag3,
    region="2L:2,358,000-2,431,000",
    heterozygosity=True,
    cut_height=0.05,
    min_cluster_size=5,
    transcript='AGAP004707-RD',
    sample_sets="AG1000G-GH",
    #sample_query=sample_query,
    site_mask='gamb_colu_arab',
    linkage_method='complete',
    color='taxon'
    )
fig

  0%|          | 0/100 [00:00<?, ?it/s]

Load CNV HMM data:   0%|          | 0/12 [00:00<?, ?it/s]

Compute modal gene copy number:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
fig = plot_diplotype_clustering_advanced(
    ag3,
    region="2L:2,327,000-2,410,000",
    heterozygosity=True,
    cut_height=0.05,
    min_cluster_size=30,
    transcript='AGAP004707-RD',
    sample_sets=['1288-VO-UG-DONNELLY-VMF00168', '1288-VO-UG-DONNELLY-VMF00219'],
    #sample_query=sample_query,
    site_mask='gamb_colu_arab',
    linkage_method='complete',
    color='taxon'
    )
fig

  0%|          | 0/1222 [00:00<?, ?it/s]

Load CNV HMM data:   0%|          | 0/126 [00:00<?, ?it/s]

Compute modal gene copy number:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
fig = plot_diplotype_clustering_advanced(
    ag3,
    region="2R:28,480,000-28,490,000",
    heterozygosity=True,
    cut_height=0.05,
    min_cluster_size=20,
    transcript='AGAP002867-RA',
    sample_sets=['1288-VO-UG-DONNELLY-VMF00168', '1288-VO-UG-DONNELLY-VMF00219'],
    #sample_query=sample_query,
    site_mask='gamb_colu_arab',
    linkage_method='complete',
    color='taxon'
    )
fig

  0%|          | 0/1222 [00:00<?, ?it/s]

Load CNV HMM data:   0%|          | 0/168 [00:00<?, ?it/s]

Compute modal gene copy number:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
fig = plot_diplotype_clustering_advanced(
    ag3,
    region=resolved_region,
    heterozygosity=True,
    cut_height=0.05,
    min_cluster_size=20,
    transcript='AGAP006227-RA',
    sample_sets='AG1000G-GH',
    #sample_query=sample_query,
    site_mask=None,
    linkage_method='complete',
    color='taxon'
    )
fig

  0%|          | 0/100 [00:00<?, ?it/s]

Load CNV HMM data:   0%|          | 0/12 [00:00<?, ?it/s]

Compute modal gene copy number:   0%|          | 0/1 [00:00<?, ?it/s]