In [None]:
import itertools as it
import operator
import re
import subprocess
import tempfile
import time
from collections import Counter
from functools import partial
from pathlib import Path

import awkward as ak
import bottleneck as bn
import duckdb
import gfapy
import holoviews as hv
import ibis
import matplotlib.pyplot as plt
import numba
import numpy as np
import pandas as pd
import parasail
import polars as pl
import pyabpoa
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pyfastx
import pysam
import spoa
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from pyarrow import csv
from pywfa import WavefrontAligner
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.sequencing.align as align
import paulssonlab.sequencing.consensus as con
import paulssonlab.sequencing.gfa as sgfa
import paulssonlab.sequencing.io as sio
import paulssonlab.sequencing.processing as processing
from paulssonlab.util.sequence import reverse_complement

In [None]:
hv.extension("bokeh")

In [None]:
%load_ext pyinstrument
import line_profiler
import pyinstrument

%load_ext line_profiler

In [None]:
pl.enable_string_cache()

# Config

In [None]:
gfa_filename = "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/230707_repressilators/pLIB419.gfa"

In [None]:
gfa = gfapy.Gfa.from_file(gfa_filename)

# Realign

In [None]:
df = pl.read_ipc(
    "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/230707_repressilators/dorado_0.4.0/uncompressed/prepared/consensus_spoa2/align2/combined.arrow"
)

In [None]:
df.estimated_size(unit="mb")

In [None]:
{col: df.get_column(col).estimated_size(unit="mb") for col in df.columns}

In [None]:
path = df[0, "path_consensus"].to_list()
consensus_seq = df[0, "consensus_seq"]
cigar = df[0, "cg"]

In [None]:
ref_seq = sgfa.assemble_seq_from_path(gfa, path)

In [None]:
score, cigar = align.pairwise_align(
    consensus_seq,
    ref_seq,
    degenerate=True,
    gap_opening=12,
    gap_extension=3,
    match=2,
    mismatch=-1,
    parasail_algorithm="nw",
)
(score, cigar)

In [None]:
score, cigar = align.pairwise_align(
    consensus_seq,
    ref_seq,
    degenerate=True,
    gap_opening=12,
    gap_extension=3,
    match=2,
    mismatch=-1,
)
(score, cigar)

In [None]:
score, cigar = align.pairwise_align(consensus_seq, ref_seq, degenerate=True)
(score, cigar)

In [None]:
def func(row, name_to_seq=None, **kwargs):
    path, seq = row
    ref_seq = sgfa.assemble_seq_from_path(name_to_seq, path)
    score, cigar = align.pairwise_align(seq, ref_seq, **kwargs)
    return score, cigar
    # return dict(a=score, b=cigar)
    # return dict(a=1, b=2)


name_to_seq = sgfa.gfa_name_mapping(gfa)
df.head(2).select(pl.col("path_consensus", "consensus_seq")).map_rows(
    partial(func, name_to_seq=name_to_seq),
    # return_dtype=pl.Struct(dict(a=pl.Int32, b=pl.Int32)),
    # return_dtype=pl.Struct(dict(score_realign=pl.Int32, cg_realign=pl.Utf8)),
).rename({"column_0": "score_realign", "column_1": "cg_realign"}).with_columns(
    pl.col("score_realign").cast(pl.Int32)
)

# Identify variants

In [None]:
df2 = pl.read_ipc(
    "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/230707_repressilators/dorado_0.4.0/uncompressed/prepared/consensus_spoa2/align2/combined_realigned.arrow"
)

In [None]:
len(df2)

In [None]:
df2.estimated_size(unit="mb")

In [None]:
{col: df2.get_column(col).estimated_size(unit="mb") for col in df2.columns}

In [None]:
df2[:10, "cg_realign"].to_list()

In [None]:
idx = 3
seq = df2[idx, "consensus_seq"]
cigar = df2[idx, "cg_realign"]
path = df2[idx, "path_consensus"].to_list()

In [None]:
cigar

In [None]:
"".join(path)

In [None]:
cigar_d = align.decode_cigar(cigar)

In [None]:
cigar_d

In [None]:
# if isinstance(name_to_seq, Gfa):
#     name_to_seq = gfa_name_mapping(name_to_seq)
name_to_seq = sgfa.gfa_name_mapping(gfa)
segments = [name_to_seq[name] for name in path]
segment_names = [name[1:] for name in path]
segment_rc = [name[0] == "<" for name in path]
# segment_lengths = [len(s) for s in segments]
cigar_d = align.decode_cigar(cigar)
ops = [c[0] for c in cigar_d]
op_lengths = [c[1] for c in cigar_d]

In [None]:
# TODO: trim insertions from each flank? or associate those with first/last segment?
# option to do either?

In [None]:
segment_idx = 0
cigar_idx = 0
ref_idx = 0
query_idx = 0
# seg_length = len(segments[0])
# op_counts = defaultdict(lambda: ))
op_counts = {}
segment_length = len(segments[segment_idx])
segment_name = segment_names[segment_idx]
op = ops[cigar_idx]
op_length = op_lengths[cigar_idx]
while True:
    advance = min(x for x in (op_length, segment_length) if x is not None)
    print(f"op {op} {op_length} seg {segment_name} {segment_length} advance {advance}")
    # if op in [align.CigarOp.I, align.CigarOp["="], align.CigarOp.X]:
    op_length -= advance
    if op in [align.CigarOp.D, align.CigarOp["="], align.CigarOp.X]:
        segment_length -= advance
    op_counts.setdefault(segment_name, Counter())
    op_counts[segment_name][str(op)] += advance
    # optionally append to segment_cigar
    if segment_length == 0:
        segment_idx += 1
        if segment_idx == len(segments):
            segment_length = None
        else:
            segment_length = len(segments[segment_idx])
            segment_name = segment_names[segment_idx]
    if op_length == 0:
        cigar_idx += 1
        if cigar_idx == len(ops):
            pass  # can we ever get here without immediately breaking below?
        else:
            op = ops[cigar_idx]
            op_length = op_lengths[cigar_idx]
    # TODO: need to wrap up
    if cigar_idx == len(ops) and segment_idx == len(segments):
        break

In [None]:
op_counts

In [None]:
(
    str(align.CigarOp.I),
    int(align.CigarOp.I),
    repr(align.CigarOp.I),
    "{}".format(align.CigarOp.I),
)

In [None]:
(
    str(align.CigarOp.I),
    int(align.CigarOp.I),
    repr(align.CigarOp.I),
    "{}".format(align.CigarOp.I),
)

In [None]:
return "".join(name_to_seq[segment] for segment in path)

In [None]:
from Bio.Seq import Seq


def local_index(index, length, is_reverse, extra=0):
    if is_reverse:
        return length - index #length - index - 1 - extra
    else:
        return index

def reverse_complement(seq):
    return str(Seq(seq).reverse_complement())


def reversed_seq(seq, is_reverse):
    if is_reverse:
        return reverse_complement(seq)
    else:
        return seq


def normalize_alignment(msg):
    segment_edits = {segment_name: [] for segment_name in segments.keys()}
    segment_read_indices = {}
    # for mapping in it.islice(msg.path.mapping, 7):
    read_index = 0
    for mapping in msg.path.mapping:
        segment_index = 0
        segment_name = mapping.position.name
        offset = mapping.position.offset
        if offset:
            segment_index += offset
        segment_read_start = read_index
        is_reverse = mapping.position.is_reverse
        edits = segment_edits[segment_name]
        segment_length = len(segments[segment_name])
        for edit in mapping.edit:
            if edit.from_length == edit.to_length:
                if edit.sequence:
                    # snp
                    # TODO: eat matching bases
                    edits.append(
                        Edit(
                            Op.SUBSTITUTION,
                            local_index(
                                segment_index, segment_length, is_reverse, len(edit.sequence)
                            ),
                            reversed_seq(edit.sequence, is_reverse),
                            None,
                        )
                    )
                    segment_index += edit.from_length
                    read_index += edit.to_length
                else:
                    # match
                    degenerate_base_indices = segment_degenerate_bases[segment_name]
                    if degenerate_base_indices:
                        for base_index in degenerate_base_indices:
                            # TODO: handle is_reverse
                            local_index(
                                segment_index, segment_length, is_reverse, 1
                            )
                            if segment_index <= base_index < segment_index + edit.from_length:
                                edits.append(
                                    Edit(
                                        Op.SUBSTITUTION,
                                        ,
                                        reversed_seq(edit.sequence, is_reverse),
                                        None,
                                    )
                                )
                    else:
                        segment_index += edit.from_length
                        read_index += edit.to_length
            elif not edit.to_length:
                # deletion
                edits.append(
                    Edit(
                        Op.DELETION,
                        local_index(segment_index, segment_length, is_reverse, edit.from_length),
                        None,
                        edit.from_length,
                    )
                )
                segment_index += edit.from_length
            elif edit.from_length < edit.to_length:
                # insertion
                # if from_length > 0: need to remove matching bases from both sides
                assert edit.from_length == 0
                edits.append(
                    Edit(
                        Op.INSERTION,
                        local_index(segment_index, segment_length, is_reverse, edit.to_length),
                        reversed_seq(edit.sequence, is_reverse),
                        None,
                    )
                )
                # TODO: shouldn't increment segment_index, right?
                #segment_index += edit.to_length
                read_index += edit.to_length
        segment_read_end = read_index
        segment_read_indices[segment_name] = (segment_read_start, segment_read_end)
        if is_reverse:
            edits.reverse()
        # TODO: merge like edits (i.e., degen base insertions)
    return segment_edits, segment_read_indices