<a href="https://colab.research.google.com/github/sanjaynagi/diplotype-clustering/blob/main/Diplotype_clustering.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install malariagen_data -q

### Diplotype clustering and amino acid variation

When we are interested in a locus that is under selection, we often would like to know how many sweeps there are and how large each sweep is, but also we would like to know which amino acid variants are present on each swept haplotype.

Usually we would use haplotype clustering or haplotype networks for this purpose, but the use of haplotypes means that multi-allelic sites are omitted from the data. Performing additional phasing of multi-allelic sites is a pain, so we want to avoid this where possible.

In this notebook, I present code to perform clustering on diplotypes from unphased genomic data. Along with the diplotype dendrogram, we plot the frequency of amino acid variants for a given transcript. It should also be straightforward to include CNVs in this as well as amino acid mutations, though this is not implemented yet.

**Work in progress! (WIP)**  

<br></br>
  
TODO   

- Figure out/ fix plots not in perfect alignment. aa plot is wider.
- Do I have to use number of alternate calls or is there any way to use the genotypes directly in clustering? the problem being i dont know how to cluster a 3D array unless i flatten it somehow
- Check distance metrics etc are appropriate/correct.
- Site filters still an issue for some genes
- Make hoverdata useful for aa plot.


 <br></br>

We will first define all the functions we are going to need!

In [None]:
import numpy as np
import pandas as pd
import malariagen_data
import allel

In [None]:
def distinct(gt):
    """Return sets of indices for each distinct diplotype."""
    from collections import defaultdict

    # setup collection
    d = defaultdict(set)

    # iterate over diplotypes
    for i in range(gt.shape[1]):

        # hash the diplotype
        k = hash(gt.values[:, i].tobytes())

        # collect
        d[k].add(i)

    # extract sets, sorted by most common
    return sorted(d.values(), key=len, reverse=True)

def _get_max_distance(h, metric="euclidean", linkage_method="single"):
    """
    Find the maximum distance between diplotypes
    """
    from scipy.cluster.hierarchy import linkage

    z = linkage(h, metric=metric, method=linkage_method)

    # Get the distances column
    dists = z[:, 2]
    # Convert to the number of SNP differences
    # dists *= h.shape[1]
    # Return the maximum
    return dists.max()

def plot_diplotype_clustering(
    self,
    region,
    sites,
    sample_sets=None,
    sample_query=None,
    site_mask=None,
    color="location",
    symbol=None,
    distance_metric='euclidean',
    linkage_method='single',
    count_sort=True,
    distance_sort=False,
    random_seed=42,
    width=800,
    height=600,
    **kwargs,
):
    import plotly.express as px
    from scipy.cluster.hierarchy import linkage
    from scipy.spatial.distance import pdist
    from malariagen_data.plotly_dendrogram import create_dendrogram
    debug = self._log.debug

    resolved_region: Region = malariagen_data.util.parse_single_region(self, region)
    del region

    ds_snps = self.snp_calls(
        region=resolved_region,
        sample_query=sample_query,
        sample_sets=sample_sets,
        site_mask=site_mask,
        random_seed=random_seed,
    )
    gt = allel.GenotypeDaskArray(ds_snps["call_genotype"].data)
    # subset to biallelic sites (haplotype sites)
    pos = ds_snps["variant_position"].values

    debug("subsetting to haplotype positions")
    haplotype_pos = self._haplotype_sites_for_contig(
        contig=resolved_region.contig,
        analysis=sites,
        field="POS",
        inline_array=True,
        chunks="native",
    ).compute()
    hap_site_mask = np.in1d(pos, haplotype_pos, assume_unique=True)
    pos = pos[hap_site_mask]
    gt = gt.compress(hap_site_mask, axis=0)

    gt = gt.to_n_alt().compute()

    debug("load sample metadata")
    df_samples = self.sample_metadata(
        sample_sets=sample_sets, sample_query=sample_query
    )
    debug("align sample metadata with diplotypes")
    gt_samples = ds_snps["sample_id"].values.tolist()
    df_samples = (
        df_samples.set_index("sample_id").loc[gt_samples].reset_index()
    )

    debug("set up plotting options")
    hover_data = [
        "sample_id",
        "partner_sample_id",
        "sample_set",
        "taxon",
        "country",
        "admin1_iso",
        "admin1_name",
        "admin2_name",
        "location",
        "year",
        "month",
    ]

    if color and color not in hover_data:
        hover_data.append(color)
    if symbol and symbol not in hover_data:
        hover_data.append(symbol)

    plot_kwargs = dict(
        template="simple_white",
        hover_name="sample_id",
        hover_data=hover_data,
        render_mode="svg",
    )

    # set labels as the index which we extract to reorder metadata
    leaf_labels = np.arange(gt.shape[1])
    # get the max distance, required to set xmin, xmax, which we need xmin to be slightly below 0
    max_dist = _get_max_distance(
        gt.T, metric="euclidean", linkage_method=linkage_method
    )

    # noinspection PyTypeChecker
    fig = create_dendrogram(
        gt.T,
        distfun=lambda x: pdist(x, 'euclidean'),
        linkagefun=lambda x: linkage(x, method=linkage_method),
        labels=leaf_labels,
        color_threshold=0,
        count_sort=count_sort,
        distance_sort=distance_sort,
    )

    fig.update_traces(
        hoverinfo="y",
        line=dict(width=0.5, color="black"),
    )


    title_lines = []
    if sample_sets is not None:
        title_lines.append(f"sample sets: {sample_sets}")
    if sample_query is not None:
        title_lines.append(f"sample query: {sample_query}")
    title_lines.append(f"genomic region: {resolved_region} ({gt.shape[0]} SNPs)")
    title = "<br>".join(title_lines)

    fig.update_layout(
        width=width,
        height=height,
        title=title,
        autosize=True,
        hovermode="closest",
        plot_bgcolor="white",
        yaxis_title="Distance (no. SNPs)",
        xaxis_title="Diplotypes",
        showlegend=True,
    )
    # select only columns in hover_data
    df_samples = df_samples[hover_data]
    diplotype_order = fig.layout.xaxis["ticktext"]
    debug("Reorder haplotype metadata to align with haplotype clustering")
    df_samples = df_samples.loc[
        diplotype_order
    ]

    fig.update_xaxes(mirror=False, showgrid=True, showticklabels=False, ticks="")
    fig.update_yaxes(
        mirror=False, showgrid=True, showline=True, range=[-2, max_dist + 1]
    )

    fig.add_traces(
        list(
            px.scatter(
                df_samples,
                x=fig.layout.xaxis["tickvals"],
                y=np.repeat(-0.2, len(gt.T)),
                color=color,
                symbol=symbol,
                **plot_kwargs,
            ).select_traces()
        )
    )
    return fig, diplotype_order


def aa_genotypes(self, transcript, site_mask=None, sample_set=None, sample_query=None, drop_invariant=True):
    gt = self.snp_genotypes(
        region=transcript,
        sample_sets=sample_set,
        sample_query=sample_query,
        site_mask=site_mask,
        field="GT",
    ).compute()

    region, df_snps = self._snp_df(transcript=transcript)

    freq_cols = {}
    for idx in range(gt.shape[1]):
        gt_coh = gt[:, [idx], :]
        ac_coh = allel.GenotypeArray(gt_coh).count_alleles(max_allele=3)
        af_coh = ac_coh.to_frequencies()
        freq_cols["frq_" + str(idx)] = af_coh[:, 1:].flatten()

    df_freqs = pd.DataFrame(freq_cols)
    df_max_af = pd.DataFrame({"max_af": df_freqs.max(axis=1)})

    df_snps.reset_index(drop=True, inplace=True)
    df_snps = pd.concat([df_snps, df_freqs, df_max_af], axis=1)

    if site_mask is not None:
        loc_sites = df_snps[f"pass_{site_mask}"]
        df_snps = df_snps.loc[loc_sites]

    df_snps.reset_index(inplace=True, drop=True)

    ann = self._annotator()
    ann.get_effects(
        transcript=transcript, variants=df_snps
    )
    df_snps["label"] = self._pandas_apply(
        self._make_snp_label_effect,
        df_snps,
        columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"],
    )
    df_snps.set_index(
        ["contig", "position", "ref_allele", "alt_allele", "aa_change"],
        inplace=True)

    df_snps = df_snps.query("~aa_change.isnull()", engine='python')
    if drop_invariant:
        df_snps = df_snps.query("max_af > 0")
    return df_snps

def plot_aas(df_snps, transcript):
    import plotly.express as px
    fig = px.imshow(df_snps, color_continuous_scale='greys', height=600, aspect=0.5)
    fig.update_layout(showlegend=False, yaxis_title='amino acid change', xaxis_title="Diplotypes", title=transcript)
    fig.update_traces(dict(showscale=False,
                        coloraxis=None,
                        colorscale='greys'))
    fig.update_xaxes(visible=False)
    return fig

def plot_aa_diplotype_clustering(
    self,
    transcript,
    region,
    sample_sets=None,
    sample_query=None,
    site_mask=None,
    sites='gamb_colu',
    color='country',
    distance_metric = 'euclidean',
    linkage_method='single',
    count_sort=True,
    distance_sort=False,
    reverse_strand=False,
    width=1000,
    height=1000
  ):
    from plotly.subplots import make_subplots

    fig, order = plot_diplotype_clustering(
        self=self,
        region=region,
        sample_sets=sample_sets,
        sample_query=sample_query,
        site_mask=site_mask,
        sites=sites,
        count_sort=count_sort,
        distance_sort=distance_sort,
        distance_metric=distance_metric,
        linkage_method=linkage_method,
        color=color)

    # load genotypes at amino acid variants for each sample
    df_snps = aa_genotypes(self, transcript, sample_query=sample_query, sample_set=sample_sets, site_mask=site_mask)
    df_snps = df_snps.query("effect == 'NON_SYNONYMOUS_CODING'").reset_index().set_index('aa_change')
    if reverse_strand == True:
      df_snps = df_snps[::-1] #flip df cos reverse strand so aa are backwards

    # set to diplotype cluster order
    df_snps = df_snps.filter(like='frq').fillna(0).iloc[:, order.tolist()]

    fig2 = plot_aas(df_snps, transcript)
    fig2.update_layout(
        margin=dict(l=20, r=20, t=0, b=0),
        paper_bgcolor="white",
        )

    # make subplots
    figures = [
                fig,
                fig2
        ]

    fig = make_subplots(rows=len(figures),
                        cols=1,
                        shared_xaxes=True,
                        vertical_spacing=0.05,
                        row_heights=[0.2, 0.4],
                        x_title='Diplotypes',
                        subplot_titles=(f"Diplotype clustering - {region}",
                                        ''))

    for i, figure in enumerate(figures):
        for trace in range(len(figure["data"])):
            fig.append_trace(figure["data"][trace], row=i+1, col=1)

    fig.update_xaxes(visible=False)
    fig.update_layout(
        width=width,
        height=height,
        autosize=True,
        hovermode="closest",
        plot_bgcolor="white",
    )

    fig['layout']['yaxis']['title']='Distance (No. SNPs)'
    fig['layout']['yaxis2']['title']=f'{transcript} amino acids'

    return fig


In [None]:
ag3 = malariagen_data.Ag3(pre=True)
af1 = malariagen_data.Af1()

In [None]:
plot_aa_diplotype_clustering(
    ag3,
    transcript='AGAP004707-RD',
    sample_sets='1244-VO-GH-YAWSON-VMF00149',
    sites='gamb_colu',
    region='2L:2248521-2252500',
    color='taxon'
    )

In [None]:
transcript = 'LOC125764713_t1'

plot_aa_diplotype_clustering(
    af1,
    transcript=transcript,
    region='2RL:8,685,464-8,690,407',
    site_mask='funestus',
    sites='funestus',
    color='country',
    )