# Create an interactive scatterplot using plotly

In [1]:
import scipy.sparse
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests

In [2]:
# reduce dataset for less resource-intensive development
development = True

In [3]:
# Retrieve symbols for an MSigDB Gene Set
# https://www.gsea-msigdb.org/gsea/msigdb/search.jsp
url = "https://www.gsea-msigdb.org/gsea/msigdb/download_geneset.jsp"
params = dict(geneSetName="GO_CARDIOVASCULAR_SYSTEM_DEVELOPMENT", fileType="gmx")
response = requests.get(url, params)
query_genes = set(response.text.splitlines()[2:])
len(query_genes)

810

In [4]:
def wrap_text(text, width=55):
    """
    Plotly does not wrap long text by default
    https://github.com/plotly/plotly.js/issues/1964
    """
    import textwrap
    lines = textwrap.wrap(text, width=width)
    return "<br>".join(lines)

In [5]:
embedding_df = pd.read_csv("data/embeddings/node2vec-128d-to-umap-2d.tsv.xz", sep="\t")
embedding_df.cluster = embedding_df.cluster.map("{:02d}".format)
embedding_df.annotation = embedding_df.annotation.map(wrap_text)
embedding_df["query_gene"] = embedding_df.preferred_name.isin(query_genes).astype(int)
embedding_df["size"] = embedding_df.query_gene.map({0: 1, 1: 15})
if development:
    embedding_df = embedding_df.sample(n=2000, random_state=0)
embedding_df.head(2)

Unnamed: 0,index,protein_external_id,preferred_name,protein_size,annotation,umap_0,umap_1,cluster,query_gene,size
13767,13767,9606.ENSP00000378207,AKAP5,427,A-kinase anchor protein 5; May anchor the PKA ...,6.982815,3.326482,27,0,1
17899,17899,9606.ENSP00000454014,SHISA9,424,Protein shisa-9; Regulator of short-term neuro...,5.655531,0.999013,54,0,1


## Interactive scatterplot

In [6]:
fig = px.scatter(
    data_frame=embedding_df,
    x="umap_0",
    y="umap_1",
    color="cluster",
    hover_name="preferred_name",
    hover_data=["annotation"],
    category_orders=dict(cluster=sorted(embedding_df.cluster.unique())),
    width=900,
    height=900,
    size="size",
    size_max=10,
    render_mode='webgl',  # https://plotly.com/python/webgl-vs-svg/
)
# convert from plotly.graph_objs._figure.Figure to plotly.graph_objs._figurewidget.FigureWidget for cross-cell interactivity
fig = go.FigureWidget(fig)
fig.update_traces(opacity=0.85)

# https://medium.com/plotly/introducing-plotly-py-theming-b644109ac9c7
# fig.layout.template = "plotly_dark"
if not development:
    fig.write_html("data/viz/interactive-scatterplot.html", auto_open=False)
# else:
    # https://nbviewer.jupyter.org/github/jonmmease/plotly_ipywidget_notebooks/tree/master/notebooks/
    # FigureWidget does not use the renderers framework discussed above,
    # so you should not use the show() figure method or the plotly.io.show function on FigureWidget objects
    # fig.show(renderrer="notebook_connected")
fig

FigureWidget({
    'data': [{'customdata': array([['Retinol dehydrogenase 11; Exhibits an oxidoreductive<br>ca…

## Table of selected points

This table requires a running Python kernel to update (it won't work in an HTML export).


#### Implementation notes

Plotly has a demo to [Populate a Table Using a Plotly Mouse Selection Event](https://plotly.com/python/v3/selection-events/).
Updating a table based on selected points is simple when all points are part of the same trace.
But when there are multiple traces (like for each cluster/color in our scatterplot), the implementation becomes convoluted and fragile. See https://stackoverflow.com/a/53863831/4651668.

In [7]:
table_columns = ["preferred_name", "cluster"]

def get_table_values(gene_symbols):
    """Return embedding_df values formatted for a plotly Table, filtered by gene_symbols."""
    return list(
        embedding_df
        .query("preferred_name in @gene_symbols")
        .reindex(columns=table_columns)
        .to_dict(orient="list")
        .values()
    )

def update_table(gene_symbols):
    """Update the table widget filtered for gene_symbols."""
    table.data[0].cells.values = get_table_values(gene_symbols)

# retains state across successive selection_callback calls
selection = dict(up_to_trace=-1, genes=[])

def selection_callback(trace, points, selector):
    """This function is called on every trace to determine which genes are selected."""
    if points.trace_index <= selection["up_to_trace"]:
        selection["genes"] = []
    selection["up_to_trace"] = points.trace_index
    genes = trace.hovertext[points.point_inds]
    selection["genes"].extend(genes)
    if points.trace_index + 1 == len(fig.data):
        update_table(selection["genes"])

# register selection callback
for scatter in fig.data:
    # on_selection callback runs on all traces, even those that have no markers selected
    scatter.on_selection(selection_callback)

# create table
table = go.Table(
    header=dict(values=table_columns),
    cells=dict(values=get_table_values([]))
)
table = go.FigureWidget(table)
table

FigureWidget({
    'data': [{'cells': {'values': [[], []]},
              'header': {'values': ['preferred_nam…

In [8]:
# # snippets for scaling points whose radius stays invariant upon zoom
# # https://stackoverflow.com/questions/47086547/set-marker-size-based-on-coordinate-values-not-pixels-in-plotly-r
# # https://plotly.com/python/click-events/
# scatter = fig.data[0]
# def callback(xaxis, yaxis):
# scatter.on_change(callback, ('xaxis', 'range'), ('yaxis', 'range'))