In [None]:
import itertools as it
import operator
import re
import subprocess
import tempfile
import time
from collections import Counter
from functools import partial
from pathlib import Path

import awkward as ak
import bottleneck as bn
import duckdb
import gfapy
import holoviews as hv
import ibis
import matplotlib.pyplot as plt
import numba
import numpy as np
import pandas as pd
import polars as pl
import pyabpoa
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pyfastx
import pysam
import spoa
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from pyarrow import csv
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.sequencing.consensus as con
import paulssonlab.sequencing.io as sio
import paulssonlab.sequencing.processing as processing
from paulssonlab.util.sequence import reverse_complement

In [None]:
hv.extension("bokeh")

In [None]:
%load_ext pyinstrument
import line_profiler
import pyinstrument

%load_ext line_profiler

In [None]:
pl.enable_string_cache()

# Config

In [None]:
gfa_filename = "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/230707_repressilators/pLIB419.gfa"

# Consensus

## Filtering by hash

In [None]:
%%time
df = pl.scan_ipc(
    "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/230707_repressilators/dorado_0.4.0/uncompressed/prepared/*.arrow"
)
res = df.filter(pl.col("path").hash() % 10 == 4).collect()

In [None]:
%%time
hashes = df.select(pl.col("path").hash()).collect().get_column("path")
hashesn = hashes.to_numpy()

In [None]:
hashesn % 1000

In [None]:
(hashes % 100).value_counts()

In [None]:
num = 10
counts = (hashes % num).value_counts()["counts"]
plt.hist(counts, bins=50, range=(0, counts.max()));

## POA

In [None]:
cpu_f = set(
    "fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm ida arat pln pts dtherm tpr_shadow vnmi flexpriority ept vpid fsgsbase smep erms xsaveopt".split()
)
cpu_e = set(
    "fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm ida arat epb pln pts dtherm tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm xsaveopt cqm_llc cqm_occup_llc".split()
)

In [None]:
cpu_e - cpu_f

In [None]:
con.poa(["aaaat", "aaagt", "gaaat", "gaaag"], return_phreds=True)

In [None]:
print(
    list(sio.format_fastx(["AAAAT"], [np.array([0, 10, 20, 30, 40], dtype=np.uint8)]))[
        0
    ]
)

## Group by path

In [None]:
%%time
df = pl.scan_ipc(
    "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/230707_repressilators/dorado_0.4.0/uncompressed/prepared/*.arrow"
)
res = (
    (
        processing.compute_depth(df.filter(pl.col("path").hash() % 1000 == 4))
        .filter(pl.col("duplex_depth") > 10)
        .select(
            pl.col(
                "name",
                "is_duplex",
                "read_seq",
                "read_phred",
                "reverse_complement",
                "path",
                "depth",
                "simplex_depth",
                "duplex_depth",
            )
        )
        .group_by("path")
        .agg(
            pl.map_groups(
                pl.struct(
                    pl.col("name", "read_seq", "read_phred", "reverse_complement")
                ),
                # lambda df: con.get_consensus(df[0], return_phreds=True),
                partial(con.get_consensus_group_by, return_phreds=True),
                # return_dtype=pl.Struct(dict(consensus_seq=pl.Utf8, consensus_phred=pl.UInt8)),
                returns_scalar=True,
            ).alias("consensus"),
            pl.col("path", "depth", "simplex_depth", "duplex_depth").first(),
        )
    )
    .unnest("consensus")
    .collect()
)

In [None]:
res

# Old consensus

In [None]:
group_depths = (
    test_groups.select(pl.col("name"))
    .with_columns(pl.col("name").list.len())
    .filter(pl.col("name") > 1)
    .sort("name", descending=True)
    .to_numpy()
)

In [None]:
plt.plot(group_depths[:, 0])

In [None]:
%%time
test_groups_subset = test_groups.filter(pl.col("name").list.len().is_between(100, 120))

In [None]:
test_groups_subset.select("depth", "duplex_depth")

In [None]:
test_reads = (
    test_groups_subset[0]
    .select(
        pl.col("name", "read_seq", "read_phred", "reverse_complement").list.explode(),
        pl.col("duplex_depth", "simplex_depth"),
    )
    .with_columns(
        pl.col("read_seq").str.len_bytes().alias("read_len"),
        pl.col("name").str.contains(";").alias("is_duplex"),
        pl.col("name").str.contains(";").not_().alias("is_simplex"),
    )
    .sort("is_duplex", descending=True)
)
test_reads

In [None]:
plt.hist(test_reads.get_column("read_len"), bins=100);

In [None]:
test_reads = test_reads.filter(pl.col("read_len").is_between(3300, 3700))

In [None]:
%%time
seqs, phreds = con.prepare_reads(
    test_reads.get_column("read_seq").to_list(),
    test_reads.get_column("reverse_complement").to_arrow(),
    test_reads.get_column("read_phred").to_arrow(),
)

In [None]:
%%time
msa_seqs = con.msa(seqs, method="abpoa", aln_mode="l")

In [None]:
seq_global = con.chars_to_str(consensus_global[0])
seq_local = con.chars_to_str(consensus_local[0])
(len(seq_global), len(seq_local))

In [None]:
print(seq_global)

In [None]:
print(seq_local)

In [None]:
con.print_msa(msa_seqs, phreds)

# GraphAligner

In [None]:
def run_aligner(gfa_filename, reads_filename, args=[]):
    cmd_base = ["/home/jqs1/paulsson-home/bin/GraphAligner"]
    with tempfile.NamedTemporaryFile(mode="w+", suffix=".gaf") as gaf_file:
        cmd = [
            *cmd_base,
            "-g",
            gfa_filename,
            "-f",
            reads_filename,
            "-a",
            gaf_file.name,
            *args,
        ]
        start = time.time()
        out = subprocess.run(cmd, capture_output=True)
        stop = time.time()
        if out.returncode != 0:
            print("STDOUT:")
            print(out.stdout.decode())
            print()
            print("STDERR:")
            print(out.stderr.decode())
            print()
            raise RuntimeError("GraphAligner returned non-zero exit status")
        runtime = stop - start
        # print("STDOUT")
        # print(out.stdout.decode())
        # print("STDERR")
        # print(out.stderr.decode())
        gaf = sio.read_gaf(gaf_file.name)
        return gaf, runtime


def align_reads(gfa_filename, reads, args=["-x", "vg"]):
    with tempfile.NamedTemporaryFile(mode="w+", suffix=".fasta") as reads_file:
        formatted_reads = (
            "\n".join([f">r{idx}\n{read}" for idx, read in enumerate(reads)]) + "\n"
        )
        reads_file.write(formatted_reads)
        reads_file.flush()
        return run_aligner(gfa_filename, reads_file.name, args=args)[0]

## Read group alignment

In [None]:
test_groups_subset[0]

In [None]:
test_reads = (
    test_groups_subset[idx]
    .select(
        pl.col("name", "read_seq", "read_phred", "reverse_complement").list.explode(),
        pl.col("duplex_depth", "simplex_depth"),
    )
    .with_columns(
        pl.col("read_seq").str.len_bytes().alias("read_len"),
        pl.col("name").str.contains(";").alias("is_duplex"),
        pl.col("name").str.contains(";").not_().alias("is_simplex"),
    )
    .sort("is_duplex", descending=True)
)

In [None]:
%%time
seqs, phreds = con.prepare_reads(
    test_reads.get_column("read_seq").to_list(),
    test_reads.get_column("reverse_complement").to_arrow(),
    test_reads.get_column("read_phred").to_arrow(),
)

In [None]:
%%time
msa_seqs = con.msa(seqs[:5], method="abpoa", aln_mode="g")
(
    consensus_seq,
    consensus_phred,
    nonconsensus_seq,
    nonconsensus_phred,
) = con.phred_weighted_consensus(msa_seqs, phreds)

In [None]:
%%time
gaf = align_reads(gfa_filename, seqs, args=["-x", "dbg"])

In [None]:
gaf["path_length"]

In [None]:
plt.hist(gaf["NM"], bins=100);

In [None]:
aligner = pyabpoa.msa_aligner(aln_mode="g")
res = aligner.msa(seqs[:10], out_cons=True, out_msa=True)
# msa_seqs = res.msa_seq

In [None]:
con_seq, msa_seqs = spoa.poa(seqs[:10])

In [None]:
print(con_seq)

In [None]:
print(res.cons_seq[0])

In [None]:
print(seqs[1])

In [None]:
res.msa_seq[0]

In [None]:
print(con.chars_to_str(consensus_seq))

In [None]:
print(con.chars_to_str(msa_seqs[3], True))

In [None]:
gaf["cg"]

In [None]:
gaf.column_names

In [None]:
gaf["name"]

In [None]:
print(seqs[0])

In [None]:
print("\n".join([f">r{idx}\n{seq}" for idx, seq in enumerate(seqs)]) + "\n")

## Consensus alignment

In [None]:
%%time
res = []
for idx in trange(len(test_groups_subset)):
    test_reads = (
        test_groups_subset[idx]
        .select(
            pl.col(
                "name", "read_seq", "read_phred", "reverse_complement"
            ).list.explode(),
            pl.col("duplex_depth", "simplex_depth"),
        )
        .with_columns(
            pl.col("read_seq").str.len_bytes().alias("read_len"),
            pl.col("name").str.contains(";").alias("is_duplex"),
            pl.col("name").str.contains(";").not_().alias("is_simplex"),
        )
        .sort("is_duplex", descending=True)
    )
    seqs, phreds = con.prepare_reads(
        test_reads.get_column("read_seq").to_list(),
        test_reads.get_column("reverse_complement").to_arrow(),
        test_reads.get_column("read_phred").to_arrow(),
    )
    msa_seqs = con.msa(seqs, method="abpoa", aln_mode="l")
    (
        consensus_seq,
        consensus_phred,
        nonconsensus_seq,
        nonconsensus_phred,
    ) = con.phred_weighted_consensus(msa_seqs, phreds)
    res.append(
        dict(
            consensus_seq=consensus_seq,
            consensus_phred=consensus_phred,
            nonconsensus_seq=nonconsensus_seq,
            nonconsensus_phred=nonconsensus_phred,
        )
    )

In [None]:
%%time
gaf = align_reads(
    gfa_filename, [con.chars_to_str(r["consensus_seq"]) for r in res], args=["-x", "vg"]
)

In [None]:
gaf["name"]

In [None]:
gaf