In [None]:
import itertools as it
import operator
import re
from collections import Counter
from pathlib import Path

import awkward as ak
import bottleneck as bn
import duckdb
import gfapy
import holoviews as hv
import ibis
import matplotlib.pyplot as plt
import numba
import numpy as np
import pandas as pd
import polars as pl
import pyabpoa
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pyfastx
import pysam
import spoa
from pyarrow import csv
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.sequencing.io as sio

In [None]:
hv.extension("bokeh")

In [None]:
%load_ext pyinstrument
import line_profiler
import pyinstrument

%load_ext line_profiler

# Consensus

In [None]:
test_groups = pl.read_ipc(
    "/home/jqs1/scratch/jqs1/sequencing/scratch/test_read_groups_100.arrow"
)

In [None]:
group_depths = (
    test_groups.select(pl.col("name"))
    .with_columns(pl.col("name").list.len())
    .filter(pl.col("name") > 1)
    .sort("name", descending=True)
    .to_numpy()
)

In [None]:
plt.plot(group_depths[:, 0])

In [None]:
%%time
test_groups_subset = test_groups.filter(pl.col("name").list.len().is_between(20, 25))

In [None]:
test_groups_subset.select("depth", "duplex_depth")

In [None]:
test_reads = (
    test_groups_subset[20]
    .select(
        pl.col("name", "read_seq", "read_phred", "reverse_complement").list.explode(),
        pl.col("duplex_depth", "simplex_depth"),
    )
    .with_columns(
        pl.col("read_seq").str.len_bytes().alias("read_len"),
        pl.col("name").str.contains(";").alias("is_duplex"),
        pl.col("name").str.contains(";").not_().alias("is_simplex"),
    )
    .sort("is_duplex", descending=True)
)
test_reads

In [None]:
_RC_BASES = ("ACRMBH", "TGYKVD")
RC_MAP = {
    k: v
    for lower in (False, True)
    for a, b in zip(*[bases.lower() if lower else bases for bases in _RC_BASES])
    for k, v in [(a, b), (b, a)]
}


# @numba.jit(nopython=True)
def reverse_complement(seq):
    return "".join(RC_MAP.get(base, base) for base in reversed(seq))

In [None]:
[ zip(test_reads.get_column("read_phred").to_arrow(), read_rc)

In [None]:
read_seq = test_reads.get_column("read_seq").to_list()
read_phred = ak.from_arrow(test_reads.get_column("read_phred").to_arrow())
read_rc = ak.from_arrow(test_reads.get_column("reverse_complement").to_arrow())

In [None]:
read_seq_oriented = [
    reverse_complement(seq) if rc else seq for seq, rc in zip(read_seq, read_rc)
]
read_phred_oriented = [
    phred[::-1] if rc else phred for phred, rc in zip(read_phred, read_rc)
]

In [None]:
%%time
aligner = pyabpoa.msa_aligner(aln_mode="l")
res = aligner.msa(read_seq_oriented, out_cons=False, out_msa=True)

In [None]:
msa_seq = np.array([np.frombuffer(seq.encode(), dtype=np.uint8) for seq in res.msa_seq])

In [None]:
%%time
base_votes = []
for base_idx in range(msa_length):
    votes = {}
    for seq_idx in range(num_seqs):
        base = msa_seq[seq_idx, base_idx]
        votes[base] = votes.get(base, 0) + 1
    base_votes.append(votes)

In [None]:
print("   A   T   C   G   -")
for base_idx in range(msa_length):
    votes = {chr(k): v for k, v in base_votes[base_idx].items()}
    print(
        " {A: >3d} {T: >3d} {C: >3d} {G: >3d} {-: >3d}".format(
            **{"A": 0, "T": 0, "C": 0, "G": 0, "-": 0, **votes}
        )
    )

In [None]:
%%time
phred = np.empty(msa_length, dtype=np.uint16)
for idx in range(num_seqs):
    seq = res.msa_seq[idx]
    # gaps = []
    for b in range(msa_length):
        base = seq[b]
        if base == "-":
            pass
            # gaps.append(b)
        else:

In [None]:
%%time
phred_matrix = np.empty((num_seqs, msa_length), dtype=np.float32)
for seq_idx in range(num_seqs):
    phred = read_phred[seq_idx]
    offset = 0
    for base_idx in range(msa_length):
        base = msa_seq[seq_idx, base_idx]
        if base == 45:
            phred_matrix[seq_idx, base_idx] = np.nan
            offset += 1
        else:
            phred_matrix[seq_idx, base_idx] = phred[
                base_idx - offset
            ]  # 10**-(phred[base_idx-offset]/10)

In [None]:
forwards = bn.push(phred_matrix)
backwards = bn.push(phred_matrix[:, ::-1])[:, ::-1]
forwards_nonnull = bn.replace(forwards.copy(), np.nan, 0)
backwards_nonnull = bn.replace(backwards.copy(), np.nan, 0)
imputed = (
    (forwards_nonnull + backwards_nonnull)
    / ((~np.isnan(forwards)).astype(np.uint8) + (~np.isnan(backwards)).astype(np.uint8))
).astype(np.int32)

In [None]:
phred_matrix_imputed = phred_matrix

In [None]:
output_seq = np.empty(msa_length, dtype=np.uint8)
output_phred = np.empty(msa_length, dtype=np.int32)
output_seq_nogap = np.full(msa_length, 32, dtype=np.uint8)  # 32 is " "
output_phred_nogap = np.full(msa_length, -1, dtype=np.int32)
offset = 0
for base_idx in range(msa_length):
    votes = {}
    for seq_idx in range(num_seqs):
        base = msa_seq[seq_idx, base_idx]
        votes[base] = votes.get(base, 0) + 1
    # numba doesn't support this in nopython mode:
    # max(votes, key=votes.get)
    # SEE: https://stackoverflow.com/questions/75139835/numba-dict-find-key-of-the-minimum-value-in-dict?noredirect=1#comment132598298_75139835
    sorted_votes = sorted([(v, k) for k, v in votes.items()])  # [0][1]
    base = sorted_votes[0][1]
    phred = sorted_votes[0][0]  # - sum(v[0] for v in sorted_votes[1:])
    output_seq[base_idx] = base
    output_phred[base_idx] = phred
    if base == 45:
        offset += 1
    else:
        output_seq_nogap[base_idx - offset] = base
        output_phred_nogap[base_idx - offset] = phred

In [None]:
output_seq

In [None]:
output_seq_nogap.tobytes().decode().rstrip()

## Numba

In [None]:
GAP_CHAR = ord("-")
SPACE_CHAR = ord(" ")


def phred_weighted_consensus(seqs, phreds, gap_quality_method="mean"):
    num_seqs = len(seqs)
    if not num_seqs:
        return
    msa_length = len(seqs[0])
    aligned_phred = np.empty(msa_length, dtype=np.int32)
    votes = [{} for _ in range(msa_length)]
    for seq_idx in range(num_seqs):
        last_nongap_phred = -1
        aligned_seq = seqs[seq_idx]
        aligned_phred[:] = -1
        unaligned_phred = phreds[seq_idx]
        offset = 0
        for base_idx in range(msa_length):
            base = aligned_seq[base_idx]
            if base == GAP_CHAR:
                offset += 1
                if last_nongap_phred != -1:
                    aligned_phred[base_idx] = last_nongap_phred
            else:
                phred = unaligned_phred[base_idx - offset]
                aligned_phred[base_idx] = phred
                last_nongap_phred = phred
        last_nongap_phred = -1
        for base_idx in reversed(range(msa_length)):
            base = aligned_seq[base_idx]
            if base == GAP_CHAR:
                offset -= 1
                if last_nongap_phred != -1:
                    existing_aligned_phred = aligned_phred[base_idx]
                    if existing_aligned_phred == -1:
                        # aligned_phred[base_idx] = last_nongap_phred
                        base_phred = last_nongap_phred
                    else:
                        if gap_quality_method == "min":
                            base_phred = min(last_nongap_phred, existing_aligned_phred)
                        elif gap_quality_method == "mean":
                            base_phred = (
                                last_nongap_phred + existing_aligned_phred
                            ) // 2
                        # aligned_phred[base_idx] = base_phred
            else:
                base_phred = unaligned_phred[base_idx - offset]
                last_nongap_phred = base_phred
            votes[base_idx][base] = votes[base_idx].get(base, 0) + 1
    consensus = np.empty(msa_length, dtype=np.uint8)
    nonconsensus = np.full(msa_length, SPACE_CHAR, dtype=np.uint8)
    consensus_phred = np.empty(msa_length, dtype=np.int32)
    nonconsensus_phred = np.zeros(msa_length, dtype=np.int32)
    for base_idx in range(msa_length):
        sorted_votes = sorted([(v, k) for k, v in votes[base_idx].items()])
        consensus[base_idx] = sorted_votes[0][1]
        consensus_phred[base_idx] = sorted_votes[0][0]
        if len(sorted_votes) >= 2:
            nonconsensus[base_idx] = sorted_votes[1][1]
            nonconsensus_phred[base_idx] = sum(v[0] for v in sorted_votes[1:])
    return consensus, consensus_phred, nonconsensus, nonconsensus_phred

In [None]:
GAP_CHAR = ord("-")
SPACE_CHAR = ord(" ")


@numba.njit
def phred_weighted_consensus(seqs, phreds, gap_quality_method="mean"):
    num_seqs = len(seqs)
    if not num_seqs:
        return
    msa_length = len(seqs[0])
    aligned_phred = np.empty(msa_length, dtype=np.int32)
    votes = [{} for _ in range(msa_length)]
    for seq_idx in range(num_seqs):
        last_nongap_phred = -1
        aligned_seq = seqs[seq_idx]
        aligned_phred[:] = -1
        unaligned_phred = phreds[seq_idx]
        offset = 0
        for base_idx in range(msa_length):
            base = aligned_seq[base_idx]
            if base == GAP_CHAR:
                offset += 1
                if last_nongap_phred != -1:
                    aligned_phred[base_idx] = last_nongap_phred
            else:
                phred = unaligned_phred[base_idx - offset]
                aligned_phred[base_idx] = phred
                last_nongap_phred = phred
        last_nongap_phred = -1
        # numba doesn't support reversed(range(msa_length))
        for base_idx in range(msa_length - 1, -1, -1):
            base = aligned_seq[base_idx]
            if base == GAP_CHAR:
                offset -= 1
                if last_nongap_phred != -1:
                    existing_aligned_phred = aligned_phred[base_idx]
                    if existing_aligned_phred == -1:
                        # aligned_phred[base_idx] = last_nongap_phred
                        base_phred = last_nongap_phred
                    else:
                        if gap_quality_method == "min":
                            base_phred = min(last_nongap_phred, existing_aligned_phred)
                        elif gap_quality_method == "mean":
                            base_phred = (
                                last_nongap_phred + existing_aligned_phred
                            ) // 2
                        # aligned_phred[base_idx] = base_phred
            else:
                base_phred = unaligned_phred[base_idx - offset]
                last_nongap_phred = base_phred
            votes[base_idx][base] = votes[base_idx].get(base, 0) + 1
    consensus = np.empty(msa_length, dtype=np.uint8)
    nonconsensus = np.full(msa_length, SPACE_CHAR, dtype=np.uint8)
    consensus_phred = np.empty(msa_length, dtype=np.int32)
    nonconsensus_phred = np.zeros(msa_length, dtype=np.int32)
    for base_idx in range(msa_length):
        sorted_votes = sorted([(v, k) for k, v in votes[base_idx].items()])
        consensus[base_idx] = sorted_votes[0][1]
        consensus_phred[base_idx] = sorted_votes[0][0]
        if len(sorted_votes) >= 2:
            nonconsensus[base_idx] = sorted_votes[1][1]
            nonconsensus_phred[base_idx] = np.sum([v[0] for v in sorted_votes[1:]])
    return consensus, consensus_phred, nonconsensus, nonconsensus_phred

In [None]:
%%time
consensus, consensus_phred, conconsensus, nonconsensus_phred = phred_weighted_consensus(
    msa_seq, read_phred_oriented
)

In [None]:
consensus.tobytes().decode()

In [None]:
len(read_phred_oriented)

In [None]:
len(msa_seq)

## SPOA

In [None]:
%%time
consensus, msa = spoa.poa(test_reads.get_column("read_seq").to_list())