In [None]:
import itertools as it
import operator
import re
import subprocess
import tempfile
import time
from collections import Counter
from pathlib import Path

import gfapy
import holoviews as hv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pyfastx
from pyarrow import csv
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.cloning.design as design
import paulssonlab.cloning.sequence as sequence
import paulssonlab.sequencing.io as sio

In [None]:
hv.extension("bokeh")

# Functions

In [None]:
def mutagenize_seq(seq, q=0, error=0, letters="ATCG", rng=None):
    # mark errors as upper-case to make debugging easier (GraphAligner doesn't care)
    letters = list(letters)
    if rng is None:
        rng = np.random.default_rng()
    if q and error:
        raise ValueError("at most one of q and error can be specified")
    if q:
        error = 10 ** (-q / 10)
    num_errors = rng.binomial(len(seq), error)
    error_indices = rng.choice(len(seq), size=num_errors)
    for idx in error_indices:
        seq = seq[:idx] + rng.choice(letters) + seq[idx + 1 :]
    return seq


def generate_reads(segments, num_reads=100, q=0, error=0, rng=None):
    if rng is None:
        rng = np.random.default_rng()
    num_choices = np.array([len(s) for s in segments])
    num_segments = len(segments)
    true_path = rng.integers(num_choices[np.newaxis, :], size=(num_reads, num_segments))
    reversed = rng.integers(2, size=num_reads)
    reads = []
    for read_idx in range(num_reads):
        read = "".join(
            [
                variants[variant_idx]
                for variants, variant_idx in zip(segments, true_path[read_idx])
            ]
        )
        read = mutagenize_seq(read, q=q, error=error, rng=rng)
        ###### TODO
        # read = (
        #     read[0] + mutagenize_seq(read[1:-1], q=q, error=error, rng=rng) + read[-1]
        # )
        ######
        if reversed[read_idx]:
            read = str(sequence.reverse_complement(read))
        ####### TODO
        # read = "N" + read
        # read = read[:-1] + "G"
        # read = "G" + read[1:]
        #######
        reads.append(read)
    # add trailing newline
    formatted_reads = (
        "\n".join([f">r{idx}\n{read}" for idx, read in enumerate(reads)]) + "\n"
    )
    ground_truth = dict(true_path=true_path, reversed=reversed)
    return formatted_reads, ground_truth

In [None]:
def generate_gfa(segments):
    lines = ["H\tVN:Z:1.0"]
    lines.extend(
        [
            f"S\ts{s}={v}\t{seq}"
            for s, variants in enumerate(segments)
            for v, seq in enumerate(variants)
        ]
    )
    lines.extend(
        [
            f"L\ts{s}={v1}\t+\ts{s+1}={v2}\t+\t0M"
            for s in range(len(segments) - 1)
            for v1, v2 in it.product(
                range(len(segments[s])), range(len(segments[s + 1]))
            )
        ]
    )
    return "\n".join(lines) + "\n"  # add trailing newline

In [None]:
def run_aligner(gfa_filename, reads_filename, args=[]):
    cmd_base = ["/home/jqs1/micromamba/envs/graphaligner/bin/GraphAligner"]
    # cmd_base = ["/home/jqs1/paulsson-home/bin/GraphAligner"]
    with tempfile.NamedTemporaryFile(mode="w+", suffix=".gaf") as gaf_file:
        cmd = [
            *cmd_base,
            "-g",
            gfa_filename,
            "-f",
            reads_filename,
            "-a",
            gaf_file.name,
            *args,
        ]
        start = time.time()
        out = subprocess.run(cmd, capture_output=True)
        stop = time.time()
        if out.returncode != 0:
            print("STDOUT:")
            print(out.stdout.decode())
            print()
            print("STDERR:")
            print(out.stderr.decode())
            print()
            raise RuntimeError("GraphAligner returned non-zero exit status")
        runtime = stop - start
        print("STDOUT")
        print(out.stdout.decode())
        print("STDERR")
        print(out.stderr.decode())
        gaf = sio.read_gaf(gaf_file.name)
        return gaf, runtime


def run_aligner_synthetic(segments, args=[["-x", "vg"]], num_reads=4, q=0, rng=None):
    if rng is None:
        rng = np.random.default_rng()
    with (
        tempfile.NamedTemporaryFile(mode="w+", suffix=".gfa") as gfa_file,
        tempfile.NamedTemporaryFile(mode="w+", suffix=".fasta") as reads_file,
    ):
        gfa = generate_gfa(segments)
        # print(gfa)
        gfa_file.write(gfa)
        gfa_file.flush()
        reads, ground_truth = generate_reads(
            segments, num_reads=num_reads, q=q, rng=rng
        )
        # print(reads)
        reads_file.write(reads)
        reads_file.flush()
        res = []
        for cmd_args in args:
            res.append(run_aligner(gfa_file.name, reads_file.name, args=cmd_args))
    return res, ground_truth

In [None]:
def check_path_equality(path, true_path):
    if path[0][0] == "<":
        path = path[::-1]
    if len(path) != len(true_path):
        return False
    for segment_idx, p in enumerate(path):
        match = re.match(r"(?:<|>)s(\d+)=(\d+)", p)
        if int(match.group(1)) != segment_idx:
            return False
        if int(match.group(2)) != true_path[segment_idx]:
            return False
    return True


def check_alignment(gaf, ground_truth):
    errors = set()
    for idx in range(len(gaf)):
        path = gaf.column("path")[idx].as_py()
        if not check_path_equality(path, ground_truth["true_path"][idx]):
            # TODO
            # print(">>>",path,ground_truth["true_path"][idx])
            errors.add(idx)
        if (path[0][0] == "<") != ground_truth["reversed"][idx]:
            errors.add(idx)
    return errors

# Config

In [None]:
# data_dir = Path("/home/jqs1/scratch/sequencing/230707_repressilators/20230707_2040_MN35044_FAS94231_25542e0d/_temp/")
# carlos_variants = pyfastx.Fasta(data_dir / "230726_carlos/Savinov_Fragment_key.fasta")

# Test

In [None]:
# SEE: https://github.com/maickrau/GraphAligner/issues/18#issuecomment-635793833

In [None]:
# TODO: REPLACE RANDOM_BASES, use rng?

In [None]:
# segments = [["aa", "cc"], ["tttat", "cccgc"], ["tccccccc"]]
# segments = [["aa"*20, "cc"*20], ["tttat"*5, "cccgc"*5], ["tccccccc"*5]]
# segments = [
#     [design.random_bases(40) for i in range(2)],
#     [design.random_bases(40) for i in range(2)],
#     [design.random_bases(40) for i in range(2)],
# ]
segments = [
    [design.random_bases(40, "atcg") for i in range(2)],
    [design.random_bases(40, "atcg") for i in range(2)],
    [design.random_bases(10, "atcg") for i in range(10)],
    [design.random_bases(40, "atcg") for i in range(2)],
    [design.random_bases(40, "atcg") for i in range(2)],
]

In [None]:
%%time
# base_args = ["-b", "100000", "-C", "500000", "--verbose"]
# base_args = ["--verbose", "--seeds-clustersize", "0"]
base_args = ["--verbose"]
graphaligner_args = [
    # [*base_args, "-x", "vg"],
    [
        *base_args,
        "-b",
        "100000",
        "-C",
        "500000",
        "--seedless-DP",
        "--DP-restart-stride",
        "1000",
    ],
    # [*base_args, "-x", "dbg"],
    # [*base_args, "-b", "15", "-C", "-1", "--seeds-minimizer-windowsize", "20", "--seeds-minimizer-density", "0", "--seeds-minimizer-length", "10"],
    # [
    #     *base_args,
    #     "--seeds-mxm-windowsize",
    #     "500",
    #     "--seeds-mxm-length",
    #     "30",
    #     "--seeds-mem-count",
    #     "10000",
    #     "--bandwidth",
    #     "15",
    #     # "--multimap-score-fraction",
    #     # "0.99",
    #     # "--precise-clipping",
    #     # "0.85",
    #     "--min-alignment-score",
    #     "5000",
    #     # "--clip-ambiguous-ends",
    #     # "100",
    #     # "--overlap-incompatible-cutoff",
    #     # "0.15",
    #     "--max-trace-count",
    #     "5",
    # ],
    # [*base_args, "-b", "15", "-C", "-1", "--seeds-mum-count", "-1"],
    # [*base_args, "-b", "15", "-C", "-1", "--seeds-mem-count", "-1"],
    # [
    #     *base_args,
    #     "-b",
    #     "15",
    #     "-C",
    #     "-1",
    #     "--seeds-mum-count",
    #     "-1",
    #     "--max-trace-count",
    #     "-1",
    #     "--max-cluster-extend",
    #     "-1",
    # ],
    # [
    #     *base_args,
    #     "-b",
    #     "15",
    #     "-C",
    #     "-1",
    #     "--seeds-mem-count",
    #     "-1",
    #     "--max-trace-count",
    #     "-1",
    #     "--max-cluster-extend",
    #     "-1",
    # ],
    # [
    #     *base_args,
    #     # "--seedless-DP",
    #     # "--DP-restart-stride",
    #     # "1000",
    #     # "--precise-clipping",
    #     # "0.502",
    #     # "--X-drop",
    #     # "1000000",
    # ],
]
res, ground_truth = run_aligner_synthetic(
    segments,
    num_reads=100,
    args=graphaligner_args,
    q=10,
    rng=np.random.default_rng(714),
)
errors = [check_alignment(run[0], ground_truth) for run in res]
times = [run[1] for run in res]

In [None]:
res[0][0]["path"]

In [None]:
for t, run_errors, run_args in zip(times, errors, graphaligner_args):
    print(f"{' '.join(run_args).rjust(80)}: {len(run_errors)} ({t:.2f}s)")

In [None]:
res

In [None]:
ground_truth

In [None]:
res[0][0].column("path")

# GFA name mapping

In [None]:
g = gfapy.Gfa.from_file(
    "/home/jqs1/scratch/sequencing/230707_repressilators/20230707_2040_MN35044_FAS94231_25542e0d/references/pLIB419.gfa"
)

In [None]:
g.segments[0].name = "foo"

In [None]:
print(g.to_gfa1_s())

# RecGraph mtx

In [None]:
DEGENERATE_BASES = {
    "R": "AG",
    "Y": "CT",
    "M": "AC",
    "K": "GT",
    "S": "CG",
    "W": "AT",
    "B": "CGT",
    "D": "AGT",
    "H": "ACT",
    "V": "ACG",
    "N": "ACGT",
}


def degenerate_recgraph_matrix(
    match,
    mismatch,
    degenerate_match=None,
    degenerate_mismatch=None,
    degenerate_bases=DEGENERATE_BASES,
):
    if degenerate_match is None:
        degenerate_match = match
    if degenerate_mismatch is None:
        degenerate_mismatch = mismatch
    bases = "ATCG" + "".join(degenerate_bases.keys())
    num_bases = len(bases)
    base_to_idx = {base: idx for idx, base in enumerate(bases)}
    if degenerate_match is None:
        degenerate_match = match
    if degenerate_mismatch is None:
        degenerate_mismatch = mismatch
    # matrix = parasail.matrix_create(bases, match, mismatch)
    matrix = np.full((num_bases, num_bases), mismatch)
    # matrix[
    for deg_base, matching_bases in degenerate_bases.items():
        idx = base_to_idx[deg_base]
        degenerate_match_idxs = [base_to_idx[base] for base in matching_bases]
        degenerate_mismatch_idxs = [
            base_to_idx[base] for base in set("ATCG") - set(matching_bases)
        ]
        for idx2 in degenerate_match_idxs:
            matrix[idx, idx2] = matrix[idx2, idx] = degenerate_match
        for idx2 in degenerate_mismatch_idxs:
            matrix[idx, idx2] = matrix[idx2, idx] = degenerate_mismatch
    alphabet_aliases = "".join(
        f"{base}{deg_base}{deg_base}{base}"
        for deg_base, matching_bases in degenerate_bases.items()
        for base in matching_bases
    )
    return matrix, alphabet_aliases