In [None]:
import itertools as it
import operator
import re
import subprocess
import tempfile
import time
from collections import Counter
from functools import partial
from pathlib import Path

import awkward as ak
import bottleneck as bn
import duckdb
import gfapy
import holoviews as hv
import ibis
import matplotlib.pyplot as plt
import numba
import numpy as np
import pandas as pd
import parasail
import polars as pl
import pyabpoa
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pyfastx
import pysam
import spoa
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from pyarrow import csv
from pywfa import WavefrontAligner
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.sequencing.consensus as con
import paulssonlab.sequencing.io as sio
import paulssonlab.sequencing.processing as processing
from paulssonlab.util.sequence import reverse_complement

In [None]:
hv.extension("bokeh")

In [None]:
%load_ext pyinstrument
import line_profiler
import pyinstrument

%load_ext line_profiler

In [None]:
pl.enable_string_cache()

# Config

In [None]:
gfa_filename = "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/230707_repressilators/pLIB419.gfa"

In [None]:
gfa = gfapy.Gfa.from_file(gfa_filename)

# Realign

In [None]:
df = pl.read_ipc(
    "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/230707_repressilators/dorado_0.4.0/uncompressed/prepared/consensus_spoa2/align2/combined.arrow"
)

In [None]:
df.estimated_size(unit="mb")

In [None]:
{col: df.get_column(col).estimated_size(unit="mb") for col in df.columns}

In [None]:
path = df[0, "path_consensus"].to_list()
consensus_seq = df[0, "consensus_seq"]
cigar = df[0, "cg"]

In [None]:
def gfa_name_mapping(gfa):
    return {
        f"<{name}"
        if rc
        else f">{name}": reverse_complement(seg.sequence)
        if rc
        else seg.sequence
        for name, seg in gfa._records["S"].items()
        for rc in (False, True)
    }

In [None]:
def assemble_seq_from_path(name_to_seq, path):
    if isinstance(name_to_seq, gfapy.Gfa):
        name_to_seq = gfa_name_mapping(name_to_seq)
    return "".join(name_to_seq[segment] for segment in path)

In [None]:
ref_seq = assemble_seq_from_path(gfa, path)

In [None]:
matrix = parasail.matrix_create("ATCG", 1, 0)
result = parasail.sw_trace_striped_sat(ref_seq, consensus_seq, 1, 1, matrix)

In [None]:
DEGENERATE_BASES = {
    "R": "AG",
    "Y": "CT",
    "M": "AC",
    "K": "GT",
    "S": "CG",
    "W": "AT",
    "B": "CGT",
    "D": "AGT",
    "H": "ACT",
    "V": "ACG",
    "N": "ACGT",
}


def degenerate_parasail_matrix(
    match,
    mismatch,
    deg_match=None,
    deg_mismatch=None,
    degenerate_bases=DEGENERATE_BASES,
):
    if deg_match is None:
        deg_match = match
    if deg_mismatch is None:
        deg_mismatch = mismatch
    bases = "ATCG" + "".join(degenerate_bases.keys())
    bases_set = set(bases)
    base_to_idx = {base: idx for idx, base in enumerate(bases)}
    match = 1
    mismatch = 0
    deg_match = 1  # match
    deg_mismatch = mismatch
    matrix = parasail.matrix_create(bases, match, mismatch)
    for deg_base, matching_bases in degenerate_bases.items():
        idx = base_to_idx[deg_base]
        deg_match_idxs = [base_to_idx[base] for base in matching_bases]
        deg_mismatch_idxs = [
            base_to_idx[base] for base in set("ATCG") - set(matching_bases)
        ]
        for idx2 in deg_match_idxs:
            matrix[idx, idx2] = matrix[idx2, idx] = deg_match
        for idx2 in deg_mismatch_idxs:
            matrix[idx, idx2] = matrix[idx2, idx] = deg_mismatch
    alphabet_aliases = "".join(
        f"{base}{deg_base}{deg_base}{base}"
        for deg_base, matching_bases in degenerate_bases.items()
        for base in matching_bases
    )
    return matrix, alphabet_aliases

In [None]:
name_to_seq = gfa_name_mapping(gfa)
parasail_matrix, alphabet_aliases = degenerate_parasail_matrix(1, 0)

In [None]:
%%time


def pairwise_align_row(row):
    path, seq = row
    ref_seq = assemble_seq_from_path(name_to_seq, path)
    result = parasail.sg_trace_striped_sat(seq, ref_seq, 12, 3, parasail_matrix)
    score = result.score
    cigar = result.get_cigar(alphabet_aliases=alphabet_aliases).decode.decode()
    return (score, cigar)
    # return dict(score_realign=score, cg_realign=cigar)


df.head(20).select(pl.col("path_consensus", "consensus_seq")).map_rows(
    pairwise_align_row, return_dtype=pl.Struct(dict(a=pl.UInt32, b=pl.Utf8))
)

In [None]:
?pl.Struct

In [None]:
?df.map_rows

In [None]:
len(df)

In [None]:
matrix.matrix

In [None]:
%timeit parasail.sg_trace_striped_sat(consensus_seq, ref_seq, 12, 3, matrix)

In [None]:
# result = parasail.sw_trace_striped_sat(consensus_seq, ref_seq, 12, 3, matrix)
result = parasail.sg_trace_striped_sat(consensus_seq, ref_seq, 12, 3, matrix)
result = parasail.sg_trace_striped_sat(consensus_seq, ref_seq, 1, 1, matrix)

In [None]:
result.get_cigar(alphabet_aliases=alphabet_aliases).decode.decode()

In [None]:
result.get_cigar(alphabet_aliases="ANTNCNGN").decode

In [None]:
cigar

In [None]:
ref_seq[:38]

In [None]:
consensus_seq[:38]

In [None]:
wfa = WavefrontAligner(ref_seq)
score = wfa.wavefront_align(consensus_seq)

In [None]:
wfa.cigarstring.replace("M", "=")

In [None]:
# wfa.cigartuples

In [None]:
wfa.cigar_print_pretty()

In [None]:
cigar