# AB1 Chromatogram Viewer (Jupyter)

Interactive chromatogram viewer for **Sanger `.ab1`** files:

- Reads an `.ab1` with **Biopython**
- Automatically maps **ABI channels (DATA9–DATA12)** to **A/C/G/T**
- Displays an interactive chromatogram (scroll by bases), base calls, and per-base quality bars
- Hover:
  - Over a peak trace → shows which base trace (A/C/G/T) and intensity
  - Over a base letter → shows base and its Phred quality
- Click a base letter → toggles it to **`N`** (and back to the original base)
- Buttons:
  - **TRIM** using a Chromas-like sliding-window mean (`trim_by_quality`)
  - **MASK** low-quality bases to `N` (`mask_low_quality`)
  - Export **trimmed+edited** as **FASTA** or **FASTQ** (with qualities)

**Minimal dependencies:** `biopython`, `numpy`, `plotly`, `ipywidgets`.


In [None]:
import sys
if 'google.colab' in sys.modules:
    !pip install biopython plotly ipywidgets

In [2]:
# Cell 1 — Load an AB1 file and extract traces, basecalls, peak positions, and qualities

import numpy as np
from Bio import SeqIO

# -----------------------------
# INPUT: set your AB1 path here
# -----------------------------
if 'google.colab' in sys.modules:
    from google.colab import files
    uploaded = files.upload()
    # Si puges 'sample1.ab1', la ruta serà simplement el nom del fitxer
    AB1_PATH = list(uploaded.keys())[0]
else:
    AB1_PATH = "sample1.ab1"

rec = SeqIO.read(AB1_PATH, "abi")
raw = rec.annotations["abif_raw"]

# Basecalls (PBAS2) and peak locations (PLOC1)
pbas = raw.get("PBAS2")
if isinstance(pbas, bytes):
    pbas = pbas.decode(errors="ignore")
seq = str(pbas)

ploc = raw.get("PLOC1")  # trace indices for each basecall (len == len(seq))

# Qualities: prefer Biopython's phred_quality; fallback to PCON2 if needed
quals = rec.letter_annotations.get("phred_quality", None)
if quals is None or len(quals) != len(seq):
    pcon = raw.get("PCON2")
    quals = list(pcon) if pcon is not None else [None] * len(seq)

# Raw channels: typically DATA9..DATA12 exist
data_keys = ["DATA9", "DATA10", "DATA11", "DATA12"]
channels = {k: np.array(raw[k], dtype=float) for k in data_keys if k in raw}

if len(channels) != 4:
    raise ValueError(f"Expected 4 trace channels (DATA9..DATA12). Found: {list(channels.keys())}")

trace_len = len(next(iter(channels.values())))

print("AB1 loaded:", AB1_PATH)
print("Basecalls length:", len(seq))
print("Trace length:", trace_len)
print("Available channels:", list(channels.keys()))


Saving sample1.ab1 to sample1.ab1
AB1 loaded: sample1.ab1
Basecalls length: 570
Trace length: 12648
Available channels: ['DATA9', 'DATA10', 'DATA11', 'DATA12']


In [3]:
# Cell 2 — Automatically assign DATA channels to A/C/G/T and build trace_by_base

import numpy as np

def auto_assign_channels_to_bases(seq, ploc, channels, bases=("A","C","G","T"), sample_cap=5000):
    """Infer which ABI trace channel corresponds to A/C/G/T.

    Strategy:
      - At each basecall peak position (PLOC1), we measure each channel's intensity.
      - For positions called as 'A', we compute the mean intensity per channel; same for C/G/T.
      - We then greedily assign a one-to-one mapping base->channel using the highest mean intensities.

    Args:
        seq: Basecalled sequence (string).
        ploc: Peak locations (list/array of ints), one per basecall.
        channels: Dict like {"DATA9": np.array([...]), ...} for 4 channels.
        bases: Tuple of bases to map (default A,C,G,T).
        sample_cap: Maximum number of basecalls to use for mapping (speed/stability).

    Returns:
        base_to_chan: dict base->channel_key (e.g., {"A":"DATA10",...})
        chan_to_base: dict channel_key->base
        scores: numpy array (4x4) of mean intensities [base,channel]
        chan_keys: list of channel keys in column order for scores
    """
    idx = [(i, b) for i, b in enumerate(seq) if b in bases]
    if len(idx) == 0:
        raise ValueError("No A/C/G/T basecalls found to infer channels.")
    if len(idx) > sample_cap:
        idx = idx[:sample_cap]

    chan_keys = list(channels.keys())
    scores = np.zeros((len(bases), len(chan_keys)), dtype=float)

    for bi, base in enumerate(bases):
        pos_list = [ploc[i] for i, b in idx
                    if b == base and 0 <= ploc[i] < len(next(iter(channels.values())))]
        if len(pos_list) == 0:
            scores[bi, :] = -1e9
            continue
        pos_arr = np.array(pos_list, dtype=int)
        for cj, ck in enumerate(chan_keys):
            scores[bi, cj] = float(np.mean(channels[ck][pos_arr]))

    # Greedy one-to-one assignment
    remaining_bases = set(range(len(bases)))
    remaining_chans = set(range(len(chan_keys)))
    base_to_chan = {}
    chan_to_base = {}

    pairs = [(scores[bi, cj], bi, cj) for bi in range(len(bases)) for cj in range(len(chan_keys))]
    pairs.sort(reverse=True, key=lambda x: x[0])

    for score, bi, cj in pairs:
        if bi in remaining_bases and cj in remaining_chans:
            base_to_chan[bases[bi]] = chan_keys[cj]
            chan_to_base[chan_keys[cj]] = bases[bi]
            remaining_bases.remove(bi)
            remaining_chans.remove(cj)
        if not remaining_bases or not remaining_chans:
            break

    if len(base_to_chan) < 4:
        raise ValueError("Could not assign all 4 bases to channels confidently.")

    return base_to_chan, chan_to_base, scores, chan_keys

base_to_chan, chan_to_base, score_matrix, chan_keys = auto_assign_channels_to_bases(seq, ploc, channels)

print("Auto channel assignment (base -> DATA key):")
for b in "ACGT":
    print(f"  {b} -> {base_to_chan[b]}")

# Build base-labeled traces for plotting (A/C/G/T)
trace_by_base = {b: channels[base_to_chan[b]] for b in "ACGT"}


Auto channel assignment (base -> DATA key):
  A -> DATA10
  C -> DATA12
  G -> DATA9
  T -> DATA11


In [11]:
# Cell 3 — Interactive chromatogram viewer (Colab, full width, click-to-edit, trim shading, low-Q in red)

from google.colab import output
output.enable_custom_widget_manager()

import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display
from pathlib import Path
from typing import List
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO
import numpy as np

# =============================
# Quality utilities
# =============================
def mask_low_quality(seq: str, quals: List[int], min_q: int) -> str:
    """Replace bases with Q < min_q by 'N'."""
    return "".join(b if (q is not None and q >= min_q) else "N" for b, q in zip(seq, quals))

def trim_by_quality(seq: str, quals: List[int], min_q: int, window: int, min_len: int):
    """Trim low-quality ends using a sliding-window mean of Phred scores."""
    if not seq or not quals or len(seq) != len(quals):
        return "", [], 0, 0

    n = len(seq)
    if n < window:
        good = [i for i, q in enumerate(quals) if q is not None and q >= min_q]
        if not good:
            return "", [], 0, 0
        left, right = good[0], good[-1] + 1
        if right - left < min_len:
            return "", [], 0, 0
        return seq[left:right], quals[left:right], left, right

    ps = [0]
    for q in quals:
        ps.append(ps[-1] + (q if q is not None else 0))

    def win_mean(i: int) -> float:
        return (ps[i + window] - ps[i]) / window

    good_windows = [i for i in range(n - window + 1) if win_mean(i) >= min_q]
    if not good_windows:
        return "", [], 0, 0

    left = good_windows[0]
    right = good_windows[-1] + window

    while left < right and (quals[left] is None or quals[left] < min_q):
        left += 1
    while right > left and (quals[right - 1] is None or quals[right - 1] < min_q):
        right -= 1

    if right - left < min_len:
        return "", [], 0, 0

    return seq[left:right], quals[left:right], left, right

def export_trimmed_edited(out_path: str, record_id: str, seq_str: str, qual_list: List[int], fmt: str):
    """Export trimmed+edited sequence to FASTA or FASTQ."""
    rec_out = SeqRecord(Seq(seq_str), id=record_id, description="trimmed+edited from AB1 viewer")
    if fmt.lower() == "fastq":
        if len(seq_str) != len(qual_list):
            raise ValueError("FASTQ export requires qualities with the same length as the sequence.")
        rec_out.letter_annotations["phred_quality"] = [int(q) if q is not None else 0 for q in qual_list]
    SeqIO.write([rec_out], out_path, fmt.lower())

# =============================
# Viewer state
# =============================
orig_seq = seq
orig_quals = list(quals)
edited_seq = list(orig_seq)

view_left = 0
view_right = len(orig_seq)

trim_left = 0
trim_right = len(orig_seq)

current_abs_base_indices = []

# =============================
# Widgets
# =============================
start_slider = widgets.IntSlider(value=0, min=0, max=len(orig_seq)-1, step=1, description="Start base")
win_size     = widgets.IntSlider(value=120, min=30, max=400, step=10, description="Bases/window")
pad_slider   = widgets.IntSlider(value=40, min=0, max=200, step=10, description="Pad")

qmask_slider       = widgets.IntSlider(value=20, min=0, max=40, step=1, description="Low-Q <")
btn_mask           = widgets.Button(description="Apply MASK (N)", button_style="warning")

trim_minq_slider   = widgets.IntSlider(value=20, min=0, max=40, step=1, description="Trim minQ")
trim_window_slider = widgets.IntSlider(value=10, min=3, max=60, step=1, description="Trim window",style={'description_width': '100px'})
trim_minlen_slider = widgets.IntSlider(value=80, min=20, max=2000, step=10, description="Trim minLen")
btn_trim           = widgets.Button(description="Apply TRIM", button_style="info")

btn_reset          = widgets.Button(description="Reset edits", button_style="")

out_dir = Path("viewer_out")
out_dir.mkdir(exist_ok=True)
export_name        = widgets.Text(value="trimmed_edited", description="Basename")
btn_export_fasta   = widgets.Button(description="Export FASTA", button_style="success")
btn_export_fastq   = widgets.Button(description="Export FASTQ", button_style="success")

status   = widgets.HTML()
out_plot = widgets.Output(layout=widgets.Layout(width="100%"))

# =============================
# Plot (go.Figure)
# =============================
fig = go.Figure()
fig.update_layout(
    template="plotly_white",
    autosize=False,
    width=1300,
    height=560,
    margin=dict(l=40, r=20, t=50, b=40),
    hovermode="closest",
    legend=dict(orientation="h"),
    yaxis=dict(title="Intensity"),
    yaxis2=dict(title="Phred Q", overlaying="y", side="right", range=[0, 45], showgrid=False),
)

def get_window_by_bases_in_view(base_start_view: int, base_end_view: int, pad: int):
    """Map base window (relative to view) -> trace window using peak locations (ploc)."""
    view_len = view_right - view_left
    base_start_view = max(0, min(int(base_start_view), max(0, view_len - 1)))
    base_end_view = max(base_start_view + 1, min(int(base_end_view), view_len))

    abs_start = view_left + base_start_view
    abs_end = view_left + base_end_view

    x0 = max(0, int(ploc[abs_start]) - pad)
    x1 = min(trace_len, int(ploc[abs_end - 1]) + pad)
    return abs_start, abs_end, x0, x1

def render():
    """Rebuild traces and shapes for current window."""
    global current_abs_base_indices

    fig.data = []
    fig.layout.shapes = []

    view_len = view_right - view_left
    if view_len <= 0:
        status.value = "<b>No data to display.</b>"
        return

    base_start_view = int(start_slider.value)
    base_end_view   = min(view_len, base_start_view + int(win_size.value))

    abs_start, abs_end, x0, x1 = get_window_by_bases_in_view(
        base_start_view, base_end_view, pad=int(pad_slider.value)
    )

    x = np.arange(x0, x1)

    abs_inds = [i for i in range(abs_start, abs_end) if x0 <= ploc[i] < x1]
    current_abs_base_indices = abs_inds

    if not abs_inds:
        status.value = "<b>No peaks in this window.</b> Try increasing Pad or moving Start base."
        return

    # Chromatogram lines
    for b in "ACGT":
        fig.add_trace(go.Scatter(
            x=x,
            y=trace_by_base[b][x0:x1],
            mode="lines",
            name=b,
            hovertemplate="Trace: %{fullData.name}<br>Intensity: %{y}<extra></extra>",
        ))

    # TRIM shading
    if trim_left > 0:
        left0 = max(0, abs_start)
        left1 = min(trim_left - 1, abs_end - 1)
        if left0 <= left1:
            fig.add_shape(
                type="rect", xref="x", yref="paper",
                x0=max(x0, int(ploc[left0])),
                x1=min(x1, int(ploc[left1])),
                y0=0, y1=1,
                fillcolor="rgba(255,255,0,0.25)",
                line=dict(width=0),
                layer="below",
            )

    if trim_right < len(orig_seq):
        right0 = max(trim_right, abs_start)
        right1 = min(len(orig_seq) - 1, abs_end - 1)
        if right0 <= right1:
            fig.add_shape(
                type="rect", xref="x", yref="paper",
                x0=max(x0, int(ploc[right0])),
                x1=min(x1, int(ploc[right1])),
                y0=0, y1=1,
                fillcolor="rgba(255,255,0,0.25)",
                line=dict(width=0),
                layer="below",
            )

    # Base calls
    bases_txt = []
    colors = []
    q_txt = []
    locs = []
    for i in abs_inds:
        base = edited_seq[i]
        q = orig_quals[i]
        bases_txt.append(base)
        q_txt.append(q if q is not None else -1)
        locs.append(ploc[i])
        colors.append("red" if (q is not None and q < int(qmask_slider.value)) else "black")

    y_top = max(max(trace_by_base[b][x0:x1]) for b in "ACGT")
    y_text = y_top * 1.03 if y_top > 0 else 1.0

    custom = np.column_stack([q_txt, abs_inds])

    fig.add_trace(go.Scatter(
        x=locs,
        y=[y_text]*len(locs),
        mode="text+markers",
        marker=dict(size=10, opacity=0),
        text=bases_txt,
        textfont=dict(color=colors, size=10),
        name="Base calls",
        customdata=custom,
        hovertemplate="Base: %{text}<br>Q: %{customdata[0]}<br>Pos: %{customdata[1]:d}<extra></extra>",
    ))

    # Quality bars
    fig.add_trace(go.Bar(
        x=locs,
        y=[q if q is not None else 0 for q in q_txt],
        name="Quality",
        opacity=0.5,
        yaxis="y2",
        hovertemplate="Q: %{y}<extra></extra>",
    ))

    # X-axis ticks
    n_ticks = 10
    step = max(1, (abs_end - abs_start) // n_ticks)
    tick_base_inds = list(range(abs_start, abs_end, step))
    tickvals = [ploc[i] for i in tick_base_inds]
    ticktext = [str(i + 1) for i in tick_base_inds]

    fig.update_layout(
        title=f"{AB1_PATH} | View abs {view_left}-{view_right-1} (len={view_right-view_left})",
        xaxis=dict(
            title="Sequence position (1-based)",
            tickmode="array",
            tickvals=tickvals,
            ticktext=ticktext,
        ),
        autosize=False,
        width=1300,
    )

    status.value = (
        f"<b>TRIM export:</b> abs {trim_left}-{trim_right-1} (len={trim_right-trim_left}) | "
        f"<b>Low-Q threshold (red):</b> {int(qmask_slider.value)}"
    )

    # Show the updated figure after render
    with out_plot:
        out_plot.clear_output(wait=True)
        fig.show(config={"responsive": True})


# =============================
# Callbacks
# =============================
def on_any_change(change):
    render()

def on_mask(_):
    edited_seq[:] = list(mask_low_quality("".join(edited_seq), orig_quals, int(qmask_slider.value)))
    render()

def on_trim(_):
    global trim_left, trim_right
    s_full = "".join(edited_seq)
    t_seq, _, left, right = trim_by_quality(
        s_full, orig_quals,
        min_q=int(trim_minq_slider.value),
        window=int(trim_window_slider.value),
        min_len=int(trim_minlen_slider.value),
    )
    if not t_seq:
        status.value = "<b>TRIM failed:</b> no segment passed the trimming criteria."
        return
    trim_left, trim_right = left, right
    render()

def on_reset(_):
    global trim_left, trim_right
    edited_seq[:] = list(orig_seq)
    trim_left, trim_right = 0, len(orig_seq)
    render()

def on_export(fmt: str):
    s_trim = "".join(edited_seq[trim_left:trim_right])
    q_trim = orig_quals[trim_left:trim_right]
    out = out_dir / f"{export_name.value}.{fmt}"
    export_trimmed_edited(str(out), rec.id, s_trim, q_trim, fmt=fmt)
    status.value = f"<b>Exported:</b> {out}"
    try:
        from google.colab import files
        files.download(str(out))
    except Exception:
        pass

# Wire observers
start_slider.observe(on_any_change, names="value")
win_size.observe(on_any_change, names="value")
pad_slider.observe(on_any_change, names="value")
qmask_slider.observe(on_any_change, names="value")

btn_mask.on_click(on_mask)
btn_trim.on_click(on_trim)
btn_reset.on_click(on_reset)
btn_export_fasta.on_click(lambda _: on_export("fasta"))
btn_export_fastq.on_click(lambda _: on_export("fastq"))

# =============================
# UI + INITIALIZATION
# =============================
ui = widgets.VBox([
    widgets.HBox([start_slider, win_size, pad_slider]),
    widgets.HBox([qmask_slider, btn_mask, trim_minq_slider, trim_window_slider, trim_minlen_slider, btn_trim]),
    widgets.HBox([btn_reset, export_name, btn_export_fasta, btn_export_fastq]),
    status,
    out_plot
], layout=widgets.Layout(width="100%"))

display(ui)

# Initial render
render()

VBox(children=(HBox(children=(IntSlider(value=0, description='Start base', max=569), IntSlider(value=120, desc…