In [129]:
import os
import sys

import numpy as np
import pandas as pd

from bokeh.plotting import figure, show
from bokeh.io import output_notebook
from bokeh.models import ColumnDataSource, Range1d, NumeralTickFormatter, Legend
from bokeh.models.widgets import RangeSlider
from bokeh.palettes import brewer
from bokeh.transform import linear_cmap
from bokeh.layouts import gridplot, column
from bokeh.application.handlers import FunctionHandler
from bokeh.application import Application

import msprime

sys.path.append("../src")
import util

output_notebook()


In [130]:
os.environ["BOKEH_ALLOW_WS_ORIGIN"] = '0aaf0agotd3etfja916liv2etcl4ul9j3fk8kav1m1a16m18da6b'


In [131]:
ts = msprime.sim_mutations(
    msprime.sim_ancestry(
        100,
        sequence_length=1e7,
        ploidy=2,
        recombination_rate=1e-8,
        random_seed=1234
    ),
    rate=1e-7,
    random_seed=1234
)
ts


Tree Sequence,Unnamed: 1
Trees,5
Sequence Length,10000000.0
Time Units,generations
Sample Nodes,200
Total Size,33.8 KiB
Metadata,No Metadata

Table,Rows,Size,Has Metadata
Edges,414,12.9 KiB,
Individuals,100,2.8 KiB,
Migrations,0,8 Bytes,
Mutations,31,1.1 KiB,
Nodes,403,11.0 KiB,
Populations,1,224 Bytes,✅
Provenances,2,1.7 KiB,
Sites,31,791 Bytes,


In [132]:
path = util.SamplePath(
    individual="test",
    nodes=np.concatenate((
        np.repeat( 80, 5),
        np.repeat(100, 6),
        np.repeat(240, 6),
        np.repeat(350, 5),
        np.repeat( 70, 4),
        np.repeat(260, 5),
    )),
    site_positions=ts.sites_position,
)
path


SamplePath(individual='test', nodes=array([ 80,  80,  80,  80,  80, 100, 100, 100, 100, 100, 100, 240, 240,
       240, 240, 240, 240, 350, 350, 350, 350, 350,  70,  70,  70,  70,
       260, 260, 260, 260, 260]), site_positions=array([ 378546.,  762087., 1121394., 1252798., 1794978., 2144190.,
       3137811., 3220575., 3517468., 3575760., 3578269., 3811182.,
       3943932., 4080679., 4477739., 5174622., 5390107., 5823286.,
       6892356., 6905343., 7247899., 7462474., 7913459., 7974395.,
       8187035., 8338579., 8988312., 9046731., 9588017., 9763050.,
       9790388.]), metadata=None, is_valid=True)

In [134]:
np.random.seed(1234)
fwd_prob_mat_ar = np.random.uniform(size=(ts.num_nodes, ts.num_sites))
fwd_prob_mat_df = pd.DataFrame({
    'node_id' : np.array([np.repeat(i, ts.num_sites) for i in np.arange(ts.num_nodes)]).flatten(),
    'site_id' : np.array([np.arange(ts.num_sites) for _ in np.arange(ts.num_nodes)]).flatten(),
    'site_pos': np.tile(ts.sites_position, ts.num_nodes).flatten(),
    'prob'    : fwd_prob_mat_ar.flatten(),
})
fwd_prob_mat_df


Unnamed: 0,node_id,site_id,site_pos,prob
0,0,0,378546.0,0.191519
1,0,1,762087.0,0.622109
2,0,2,1121394.0,0.437728
3,0,3,1252798.0,0.785359
4,0,4,1794978.0,0.779976
...,...,...,...,...
12488,402,26,8988312.0,0.659749
12489,402,27,9046731.0,0.648086
12490,402,28,9588017.0,0.148121
12491,402,29,9763050.0,0.119325


In [143]:
def get_data(interval):
    return (path.nodes[interval[0]:interval[1]],
            path.site_positions[interval[0]:interval[1]])


def create_step_chart_app(path, ts, markers, tracks, legend_labels, colors, matrix, controls):
    def modify_doc(doc):
        #ctrl_args = {name: ctrl.value for name, ctrl in controls.items()}
        source = ColumnDataSource(data=dict(
            node_id=path.nodes,
            site_id=np.arange(len(path)),
            site_pos=path.site_positions,
        ))
        is_sample = np.array(ts.nodes_flags[path.nodes], dtype=bool)
        source_sample = ColumnDataSource(data=dict(
            node_id=path.nodes[is_sample],
            site_id=np.arange(len(path))[is_sample],
            site_pos=path.site_positions[is_sample],
        ))
        source_nonsample = ColumnDataSource(data=dict(
            node_id=path.nodes[~is_sample],
            site_id=np.arange(len(path))[~is_sample],
            site_pos=path.site_positions[~is_sample],
        ))
        source_matrix = ColumnDataSource(data=dict(
            node_id=matrix['node_id'].values,
            site_id=matrix['site_id'].values,
            site_pos=matrix['site_pos'].values,
            prob=matrix['prob'].values,
        ))

        # TODO: Add more info about the parent nodes.
        TOOLTIPS = [
            ("Parent node id", "@node_id"),
            ("Site id", "@site_id"),
            ("Site position", "@site_pos"),
        ]

        p1 = figure(
            height=400, width=800,
            x_axis_label='Genomic position',
            y_axis_label='Parent node id',
            tooltips=TOOLTIPS,
        )
        p1.y_range = Range1d(
            0, ts.num_nodes,
            bounds=(0, ts.num_nodes)
        )
        #site_pos_offset = 10**4
        #p1.x_range = Range1d(
        #    0 - site_pos_offset, ts.sequence_length + site_pos_offset,
        #    bounds=(0 - site_pos_offset, ts.sequence_length + site_pos_offset),
        #)
        site_id_offset = 1
        p1.x_range = Range1d(
            0 - site_id_offset, len(path) + site_id_offset,
            bounds=(0 - site_id_offset, len(path) + site_id_offset),
        )
        p1.xaxis.axis_label_text_font_style = 'normal'
        p1.yaxis.axis_label_text_font_style = 'normal'
        p1.xaxis.axis_label_text_font_size = '14pt'
        p1.yaxis.axis_label_text_font_size = '14pt'
        p1.xaxis.formatter = NumeralTickFormatter(format="0.00a")
        p1.grid.visible = False

        # Show probability matrix
        p1.rect(
            x='site_id',
            y='node_id',
            source=source_matrix,
            width=1, height=5,
            fill_color=linear_cmap(
                'prob',
                palette=brewer['Purples'][9],
                low=0,
                high=1,
            ),
            line_color=None,
            fill_alpha=0.25,
        )

        r1 = p1.step(
            #x='site_pos',
            x='site_id',
            y='node_id',
            source=source,
            line_width=2, line_color='red', mode='after',
        )
        r2 = p1.square(
            #x='site_pos',
            x='site_id',
            y='node_id',
            source=source_sample,
            fill_color='black', size=8, line_width=0,
        )
        r3 = p1.circle(
            #x='site_pos',
            x='site_id',
            y='node_id',
            source=source_nonsample,
            fill_color='orange', size=8, line_width=0,
        )
        # Add legend
        legend1 = Legend(items=[
            ('Copying path', [r1]),
            ('Sample nodes', [r2]),
            ('Non-sample nodes', [r3]),
        ], location='center')
        p1.add_layout(legend1, 'right')

        # Show annotation tracks
        # TODO: Change to vbar.
        #p1.ray(
        #    x=markers,
        #    y=np.repeat(0, len(markers)),
        #    angle=np.repeat(90, len(markers)), angle_units="deg",
        #    color='red', line_dash='dashed',
        #    line_width=1, line_alpha=0.5,
        #)

        # Show additional data
        p2 = figure(
            height=100, width=800,
            x_axis_label='', y_axis_label='',
            x_range=p1.x_range, y_range=p1.y_range,
        )
        i = 0
        for track, color in zip(tracks, colors):
            p2.vbar(
                x=track,
                top=np.repeat(i + 1, len(track)),
                bottom=np.repeat(i, len(track)),
                color=color,
                width=0.5,
            )
            i += 1
        p2.y_range = Range1d(0, len(tracks), bounds=(0, len(tracks)))
        p2.xaxis.visible = False
        p2.yaxis.major_tick_line_color = None
        p2.yaxis.minor_tick_line_color = None
        p2.yaxis.major_label_text_color = None
        p2.grid.visible = False
        # Add legend
        renderers = [(label, [r]) for label, r in zip(legend_labels, p2.renderers)]
        renderers.reverse()
        legend2 = Legend(items=renderers, location='center')
        legend2.click_policy='mute'
        p2.add_layout(legend2, 'right')

        # Define on-change behavior
        for ctrl_name, ctrl in controls.items():
            def update(attr, old, new):
                control_args = {name: ctrl.value for name, ctrl in controls.items()}
                nodes, positions = get_data(**control_args)
                source.data = dict(nodes=nodes, positions=positions)
            ctrl.on_change('value', update)

        doc.add_root(gridplot(
            [
                #[column(*controls.values())],
                [p2],
                [p1],
            ]
        ))

    handler = FunctionHandler(modify_doc)
    app = Application(handler)

    return app


In [145]:
controls = {
    "interval": RangeSlider(
        start=0, end=len(path), step=1,
        value=(0, len(path)),
        title="Genomic interval"
    )
}

app = create_step_chart_app(
    path,
    ts,
    markers=path.site_positions[::2],
    tracks=[
        np.arange(len(path)), #path.site_positions,
        np.arange(len(path)), #path.site_positions,
        np.arange(len(path)), #path.site_positions,
        np.arange(len(path)), #path.site_positions,
    ],
    legend_labels=[
        'BEAGLE',
        'tskit',
        'truth',
        'chip',
    ],
    colors=[
        'blue',
        'orange',
        'green',
        'grey',
    ],
    matrix=fwd_prob_mat_df,
    controls=controls,
)
show(app)
