In [None]:
from Bio import SeqIO
from bokeh.io import output_notebook, push_notebook
from bokeh.layouts import column, gridplot
from bokeh.plotting import figure, show
from bokeh.models.tools import BoxZoomTool
from bokeh.models.tickers import FixedTicker
from bokeh.models import Range1d, LinearAxis, BoxAnnotation
from bokeh.models.callbacks import CustomJS
from bokeh.resources import INLINE, CDN
import numpy as np
from util import grouper

In [None]:
output_notebook(resources=INLINE)

In [None]:
def read_ab1(ab1_file):
    seq = SeqIO.read(ab1_file, "abi")
    quality = seq.letter_annotations["phred_quality"]
    peak_calls = seq.annotations["abif_raw"]["PLOC2"]
    order = seq.annotations["abif_raw"]["FWO_1"]
    sequence = str(seq.seq)
    # Extract color data
    peaks = {}
    for i, (base, r) in enumerate(zip(order, [9, 10, 11, 12])):
        peaks[base] = seq.annotations["abif_raw"]["DATA{}".format(r)]
    return {
        "name": seq.id,
        "sequence": sequence,
        "peaks": peaks,
        "peak_calls": peak_calls,
        "quality": quality,
    }

In [None]:
def show_chromatogram(
    ab1_files, margin=10, max_interval=1000, plot_width=900, base_width=10
):
    ab1s = [read_ab1(ab1_file) for ab1_file in ab1_files]
    # calculate dimensions
    max_length = max(len(ab1["sequence"]) for ab1 in ab1s)
    bases_per_row = plot_width // base_width
    sequences = [
        ab1["sequence"] + "-" * max(max_length - len(ab1["sequence"]), 0)
        for ab1 in ab1s
    ]
    grouped_sequences = [
        list(grouper(sequence, bases_per_row)) for sequence in sequences
    ]
    num_rows = max(len(row) for row in grouped_sequences)
    # num_rows = int(np.ceil(max_length / bases_per_row))
    num_seqs = len(ab1s)
    # make figures
    overview_plots = []
    overview_boxes = []
    overview_x_range = Range1d(
        -base_width, bases_per_row + base_width
    )  # is margin right?
    overview_y_range = Range1d(0, 17 * (num_seqs + 1))
    for row in range(num_rows):
        p = figure(
            plot_width=plot_width,
            plot_height=80,
            tools="tap,save",
            x_range=overview_x_range,
            y_range=overview_y_range,
        )
        p.axis.visible = False
        p.min_border = 0
        p.min_border_bottom = 0
        p.grid.visible = False
        p.outline_line_color = None
        overview_plots.append(p)
        for seq in range(num_seqs):
            sequence = grouped_sequences[seq][row]
            p.text(
                x=range(len(sequence)),
                y=17 * seq,
                text=[s for s in sequence],
                text_align="center",
                text_baseline="bottom",
                text_font="Courier",
                text_font_size="12px",
            )
        box = BoxAnnotation(left=10, right=40, fill_color="blue", fill_alpha=0.5)
        p.add_layout(box)
        overview_boxes.append(box)
        # overview_plots[-1].min_border_bottom = 5
    # overview_plots[-1].min_border_bottom = 30
    # overview_plots[-1].height += 30
    detail_plots = []
    detail_x_range = None
    for sequence, ab1 in zip(sequences, ab1s):
        name, peaks, peak_calls, quality = map(
            ab1.__getitem__, ("name", "peaks", "peak_calls", "quality")
        )
        if detail_x_range is None:
            x_min = np.min(peak_calls)
            x_max = np.max(peak_calls)
            detail_x_range = Range1d(
                x_min,
                x_min + max_interval,
                bounds=(x_min - margin, x_max + margin),
                max_interval=max_interval,
            )
        p = figure(
            plot_width=plot_width,
            plot_height=50,
            # output_backend="webgl",
            # tools=['xpan', 'reset','xwheel_zoom', BoxZoomTool(dimensions='width')],
            tools=["xpan"],
            active_drag="xpan",
            # active_scroll='xwheel_zoom',
            # title=name,
            x_range=detail_x_range,
            # y_range=y_range
        )
        p.extra_y_ranges = {"q": Range1d(start=0, end=100)}
        p.add_layout(LinearAxis(y_range_name="q"), "right")
        p.rect(
            x=peak_calls,
            y=np.array(quality) / 2,
            width=14,
            height=quality,
            y_range_name="q",
            color="#e7ecf6",
        )
        for color, base in zip(["red", "green", "blue", "black"], ["T", "A", "C", "G"]):
            peak = peaks[base]
            p.line(range(len(peak)), peak, line_color=color)
        ###
        # box = BoxAnnotation(left=1.5, right=2.5, fill_color='green', fill_alpha=0.1)
        # p.add_layout(box)
        ####
        # Adjust tickers/ranges
        # p.xaxis.ticker = None
        # p.xaxis.ticker = FixedTicker(ticks=peak_calls)
        # p.xaxis.major_label_overrides = dict(zip(peak_calls, sequence))
        # p.y_range = Range1d(0, max(max(x) for x in peaks.values()))
        detail_plots.append(p)
        p2 = figure(plot_width=plot_width, plot_height=30, x_range=detail_x_range)
        # p2.rect(peak_calls, 0, 10, 0.95, fill_alpha=0.6)
        r = p2.text(
            x=peak_calls,
            y=0.1,
            text=[s for s in sequence],
            text_align="center",
            text_baseline="middle",
        )
        # r.glyph.text_font_style="bold"
        detail_plots.append(p2)
        p.min_border = 0
        p2.min_border = 0
        p.grid.visible = False
        p2.grid.visible = False
        p.outline_line_color = None
        p2.outline_line_color = None
    for p in detail_plots:
        p.xaxis.visible = False
    for p in detail_plots:
        p.yaxis.visible = False
    detail_plots[-1].height = 50
    detail_plots[-1].xaxis.visible = True
    # callback
    for row in range(num_rows):
        # callback
        range_callback = CustomJS(
            args=dict(boxes=overview_boxes, plot=detail_plots[0]),
            code="""
                                           var location = cb_obj.x;
                                           //plot.trigger('change');
                                           for (i=0; i<boxes.length; i++) {
                                               boxes[i].left = location-10;
                                               boxes[i].right = location+10;
                                           }
                                           plot.x_range.start = location-10;
                                           plot.x_range.end = location+10;
                                       """,
        )
        overview_plots[row].js_on_event("tap", range_callback)
    # show(gridplot([[p] for p in plots]))
    # show(gridplot([[p] for p in overview_plots]))
    show(
        column(
            gridplot([[p] for p in overview_plots]),
            gridplot([[p] for p in detail_plots]),
        )
    )


show_chromatogram(
    [
        "testseq/242a7-ITS4.ab1",
        "testseq/243a7-ITS4.ab1",
        "testseq/252a7-ITS4.ab1",
        "testseq/253a7-ITS4.ab1",
    ]
)