# PRECISE-QC — Basecalling Confusion Matrix and Base-specific Error Rates by Type Plot

### The code was adapted from: KleistLab/nanopore_dRNAseq (https://github.com/KleistLab/nanopore_dRNAseq)

This notebook builds and visualizes a **confusion matrix** of basecalling outcomes for sgRNA reads:

- **Columns**: reference bases (**A, C, G, U**)
- **Rows**: basecalled outcomes (**A, C, G, U, deleted**)
- Matrix is **column-normalized** (per reference base), **log-transformed** for contrast, and annotated with **percentages**.

Inputs:
- `SAM` files (aligned with MD tags),
- corresponding reference `FASTA` files.

## 0) Setup

Required packages:

- `numpy`, `matplotlib`, `seaborn` — plotting
- `pysam` — SAM/FASTA parsing
- `tqdm` — progress bar
- `collections` — `defaultdict` for motif tallies

If running on **Colab**, you may need to install them first:
```bash
!pip install numpy matplotlib seaborn pysam tqdm
````


In [13]:
!pip install numpy matplotlib seaborn pysam tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pysam
from tqdm import tqdm
from collections import defaultdict



## 1) Inputs & Parameters

- **`refs`**: maps each sample key to its **reference FASTA** (opened with `pysam.Fastafile`)
- **`samfiles`**: maps each sample key to its **SAM** file path (aligned reads; include `--MD` during alignment)
- **`samples`**: list of sample keys available for analysis

Define each parameter below


In [9]:
#@title EDIT and Run the paths to the references
ref_1 = "path/to/first_fasta_file"
ref_2 = "path/to/second_fasta_file"
ref_3 = "path/to/third_fasta_file"

In [10]:
#@title EDIT and Run the paths to SAM files
sam_1 = "path/to/first_sam_file"
sam_2 = "path/to/second_sam_file"
sam_3 = "path/to/third_sam_file"

In [None]:
#@title Run this cell to define the dictionary with references and sam files
refs = {
    "first": pysam.Fastafile(ref_1),
    "second": pysam.Fastafile(ref_2),
    "third": pysam.Fastafile(ref_3),

}
samfiles = {
    "first": sam_1,
    "second": sam_2,
    "third": sam_3
}


## 2) Helper Functions

- **`rev_compl`**: reverse-complement (DNA→RNA letters; `T`→`U`)
- **`parse_samfile`**: load aligned reads from a SAM (skip unmapped)
- **`get_homo_positions`**: find homopolymer positions
- **`process_sample`**: filter reads by mapping quality & duplicates
- **`analyze_reads`**: expand alignments to ref/query strings, build:
  - confusion matrix (rows: A/C/G/U/deleted; cols: A/C/G/U ref)

In [16]:
#@title RUN this cell to define helper functios
motifs = ['GA', 'AG', 'GU', 'UG', 'GC', 'CG', 'AC', 'CA', 'AU','UA', 'UC', 'CU']


motif_acc = {motif: defaultdict(int) for motif in motifs}
base2index = {"A": 0, "C": 1, "G": 2, "U": 3, "_": 4}

def rev_compl(seq):
    trans = str.maketrans("ACGTUacgtu", "UGCAAugcaa")
    return seq.translate(trans)[::-1]

def parse_samfile(samfile):
    """Parse a SAM file into a list of aligned reads"""
    with pysam.AlignmentFile(samfile, "r") as sam:
        return [aln for aln in sam if not aln.is_unmapped]

def get_homo_positions(seq, min_len=3):
    """Return homopolymer positions in reference"""
    homo_pos = set()
    last_base = ''
    count = 0
    for i, base in enumerate(seq):
        if base == last_base:
            count += 1
        else:
            if count >= min_len:
                for j in range(i - count, i):
                    homo_pos.add(j)
            count = 1
        last_base = base
    return homo_pos


def process_sample(sample, sam_path, ref_fa, motif_acc, kmer_size=2, max_reads=100000):
    reads = parse_samfile(sam_path)
    confusion = np.zeros((5, 4), dtype=int)
    err_dist = {"ho_mis": 0, "ho_del": 0, "ho_ins": 0, "he_mis": 0, "he_del": 0, "he_ins": 0}
    filtered_reads = []
    included = 0
    skipped = 0

    for aln in reads[:max_reads]:
        if aln.mapping_quality < 50 or aln.is_duplicate:
            skipped += 1
            continue
        filtered_reads.append(aln)
        included += 1

    return filtered_reads
def analyze_reads(sample, reads, ref_fa, motif_acc):
    confusion = np.zeros((5, 4), dtype=int)
    err_dist = {"ho_mis": 0, "ho_del": 0, "ho_ins": 0, "he_mis": 0, "he_del": 0, "he_ins": 0}
    skip_sites = {23, 24, 52, 53, 76, 77}

    for aln in tqdm(reads, desc=f"Processing {sample}"):
        query_seq = aln.query_sequence
        if aln.is_reverse:
            query_seq = rev_compl(query_seq)

        cigar = aln.cigartuples
        soft_clip_start = cigar[0][1] if cigar[0][0] == 4 else 0
        soft_clip_end = cigar[-1][1] if cigar[-1][0] == 4 else 0
        query_seq = query_seq[soft_clip_start: len(query_seq) - soft_clip_end]

        ref_seq = ref_fa.fetch(aln.reference_name, aln.reference_start, aln.reference_end).upper().replace("T", "U")
        query_seq = query_seq.upper().replace("T", "U")

        ref_pos = aln.reference_start
        query_pos = 0
        ref_idx = 0
        last_error = None
        expanded_ref = []
        expanded_query = []
        insert_pos = {}

        homo_pos = get_homo_positions(ref_seq)

        for op, length in cigar:
            if op == 0:  # Match or mismatch
                for _ in range(length):
                    if ref_pos in skip_sites:
                        query_pos += 1
                        ref_idx += 1
                        ref_pos += 1
                        continue

                    r_base = ref_seq[ref_idx]
                    q_base = query_seq[query_pos]

                    if r_base != q_base:
                        if not (r_base == "U" and q_base == "C" and ref_pos == 86):
                            last_error = "mismatch"
                    else:
                        last_error = None

                    expanded_ref.append(r_base)
                    expanded_query.append(q_base)

                    query_pos += 1
                    ref_idx += 1
                    ref_pos += 1

            elif op == 1:  # Insertion relative to reference
                insert_pos.setdefault(ref_pos - aln.reference_start, [])
                insert_pos[ref_pos - aln.reference_start].append(query_seq[query_pos:query_pos + length])
                query_pos += length
                last_error = "insertion"

            elif op == 2:  # Deletion
                for _ in range(length):
                    expanded_ref.append(ref_seq[ref_idx])
                    expanded_query.append("_")
                    ref_idx += 1
                    ref_pos += 1
                last_error = "deletion"

        expanded_ref = ''.join(expanded_ref)
        expanded_query = ''.join(expanded_query)

        for i in range(len(expanded_ref)):
            r = expanded_ref[i]
            q = expanded_query[i]
            if r in base2index:
                col = base2index[r]
                row = base2index[q] if q in base2index else 4
                confusion[row, col] += 1

        for i in range(len(expanded_ref) - 1):
            kmer = expanded_ref[i:i+2]
            if kmer not in motif_acc:
                continue
            sample_kmer = expanded_query[i:i+2]
            if "_" in sample_kmer:
                err_dist["he_del" if i not in homo_pos else "ho_del"] += 1
            elif len(sample_kmer) == 2 and kmer != sample_kmer:
                err_dist["he_mis" if i not in homo_pos else "ho_mis"] += 1
            motif_acc[kmer][sample_kmer] += 1

        for ipos, ins_bases in insert_pos.items():
            if 0 <= ipos - 1 < len(ref_seq):
                motif = ref_seq[ipos - 1:ipos + 1]
                if motif in motif_acc:
                    err_dist["he_ins" if ipos not in homo_pos else "ho_ins"] += 1
                    for ins_seq in ins_bases:
                        mutated = motif[0] + ins_seq + motif[1]
                        motif_acc[motif][mutated] += 1

    return confusion, err_dist

## 3) Run Analysis & Plot Heatmap

1. Set the `samples`.
2. Filter reads with `process_sample(...)`.
3. Analyze with `analyze_reads(...)`.
4. Normalize matrix by column, flip vertically to place **deleted** at bottom.
5. Log-transform for visualization; annotate with percentages.
6. Save PNG (transparent background) and/or display inline.

In [11]:
#@title EDIT and Run the names of the samples
samples = ["first","second", "third"]

In [15]:
#@title Pick the sample from your list for plotting: Indicate the number in the list of the samples starting from zero


s = 0 # change number here

In [None]:
#@title EDIT and Run the cell to define the names:

# Title of the plot
title_name = "EDIT HERE"

# Plot name for saving
plot_name = "EDITHERE.png"



In [None]:
#@title RUN this cell to plot the matrix
sample = samples[s]
filtered_reads = process_sample(sample, samfiles[sample], refs[sample], motif_acc)

confusion, err_dist = analyze_reads(sample, filtered_reads, refs[sample], motif_acc)

conf = confusion
conf_norm = conf / np.sum(conf, axis=0, keepdims=True)
conf_norm = np.flip(conf_norm, axis=0)

# Log transform
log_conf = -np.log(conf_norm + 1e-10)
annot = np.round(conf_norm * 100, 1).astype(str) + "%"

sns.set(rc={'figure.figsize': (10, 9)})
sns.set_context("paper", font_scale=3)

sns.heatmap(log_conf, annot=annot, fmt='',
            xticklabels=["A", "C", "G", "U"],
            yticklabels=np.flip(["A", "C", "G", "U", "deleted"]),
            cmap="OrRd", cbar=False, linewidths=0.5, linecolor='gray')

plt.xlabel("Reference base")
plt.ylabel("Basecalled base")
plt.yticks(rotation=0)
plt.title(title_name)
plt.tight_layout()
# plt.show()
plt.savefig(plot_name, bbox_inches="tight", transparent=True)

## Base-specific Error Rates by Type Plot

In [None]:
#@title EDIT and Run the cell to define the names:

# Title of the plot
title_name = "EDIT HERE"

# Plot name for saving
plot_name = "EDITHERE.png"

In [None]:
#@title RUN this cell to plot Base-specific Error Rates by Type


def analyze_reads(sample, reads, ref_fa, motif_acc):
    confusion = np.zeros((5, 4), dtype=int)

    error_tuple = {"A": {"insertion": 0, "deletion": 0, "mismatch": 0},
                   "C": {"insertion": 0, "deletion": 0, "mismatch": 0},
                   "U": {"insertion": 0, "deletion": 0, "mismatch": 0},
                   "G": {"insertion": 0, "deletion": 0, "mismatch": 0}}

    err_dist = {"ho_mis": 0, "ho_del": 0, "ho_ins": 0, "he_mis": 0, "he_del": 0, "he_ins": 0}

    for aln in tqdm(reads, desc=f"Processing {sample}"):
        query_seq = aln.query_sequence
        if aln.is_reverse:
            query_seq = rev_compl(query_seq)

        cigar = aln.cigartuples
        soft_clip_start = cigar[0][1] if cigar[0][0] == 4 else 0
        soft_clip_end = cigar[-1][1] if cigar[-1][0] == 4 else 0

        query_seq = query_seq[soft_clip_start: len(query_seq) - soft_clip_end]
        ref_seq = ref_fa.fetch(aln.reference_name, aln.reference_start, aln.reference_end).upper().replace("T", "U")
        query_seq = query_seq.upper().replace("T", "U")
        # print(ref_seq)

        ref_pos = aln.reference_start
        query_pos = 0
        ref_idx = 0  # local index for ref_seq slicing
        last_error = None  # to skip consecutive errors

        for op, length in cigar:

            if op == 0:  # Match or mismatch
                for _ in range(length):
                    r_base = ref_seq[ref_idx]
                    q_base = query_seq[query_pos]
                    if r_base != q_base:
                        # print("Ref", r_base, "Query", q_base)
                        # print(query_seq[query_pos-1])
                        # Skip mismatch: U -> C at absolute ref position 88
                        if r_base == "U" and q_base == "C" and ref_pos == 86:
                            pass

                        elif last_error != "mismatch":

                            prev_base = query_seq[query_pos - 1] if query_pos > 0 else None
                            if prev_base in error_tuple:
                                error_tuple[prev_base]["mismatch"] += 1
                            last_error = "mismatch"
                    else:
                        last_error = None


                    query_pos += 1
                    ref_idx += 1
                    ref_pos += 1


            elif op == 1:  # Insertion
                if last_error != "insertion":
                    prev_base = query_seq[query_pos - 1] if query_pos > 0 else None
                    if prev_base in error_tuple:
                        error_tuple[prev_base]["insertion"] += 1
                    last_error = "insertion"
                query_pos += length  # skip all inserted bases

            elif op == 2:  # Deletion
                if last_error != "deletion":
                    prev_base = query_seq[query_pos - 1] if query_pos > 0 else None
                    if prev_base in error_tuple:
                        error_tuple[prev_base]["deletion"] += 1
                    last_error = "deletion"
                ref_idx += length
                ref_pos += length

            else:
                continue

    return confusion, error_tuple
reads = process_sample(sample, samfiles[sample], motif_acc, refs[sample])
confusion, err_tuple= analyze_reads(sample, reads, refs[sample], motif_acc)

reads = process_sample(sample, samfiles[sample], motif_acc, refs[sample])
confusion, error_tuple= analyze_reads(sample, reads, refs[sample], motif_acc)

bases = ['A', 'C', 'U', 'G']
error_types = ['insertion', 'deletion', 'mismatch']

total_errors = sum(
    error_tuple[base][etype] for base in bases for etype in error_types
)
data = []
for base in bases:
    for etype in error_types:
        rate = 100 * error_tuple[base][etype] / total_errors
        data.append({
            "base": base,
            "error_rate": rate,
            "error_type": etype
        })

df = pd.DataFrame(data)

sns.set(rc={'figure.figsize': (10, 6)})
sns.set_context("talk", font_scale=1.1)
sns.set_style("ticks")

sns.barplot(data=df, y="base", x="error_rate", hue="error_type")
# sns.despine()
# plt.legend().remove()
plt.xlabel("Error Rate (%)")
plt.ylabel("Base (before error)")
plt.title(title)
plt.tight_layout()
# plt.show()
plt.savefig(plot_name, bbox_inches="tight", transparent=True)