In [None]:
import itertools as it
import operator
import re
import subprocess
import tempfile
import time
from collections import Counter
from pathlib import Path

import gfapy
import holoviews as hv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pyfastx
from pyarrow import csv
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.cloning.design as design
import paulssonlab.cloning.sequence as sequence
import paulssonlab.sequencing.gaf as gaf_

In [None]:
hv.extension("bokeh")

# Functions

In [None]:
def mutagenize_seq(seq, q=0, error=0, letters="ATCG", rng=None):
    # mark errors as upper-case to make debugging easier (GraphAligner doesn't care)
    letters = list(letters)
    if rng is None:
        rng = np.random.default_rng()
    if q and error:
        raise ValueError("at most one of q and error can be specified")
    if q:
        error = 10 ** (-q / 10)
    num_errors = rng.binomial(len(seq), error)
    error_indices = rng.choice(len(seq), size=num_errors)
    for idx in error_indices:
        seq = seq[:idx] + rng.choice(letters) + seq[idx + 1 :]
    return seq


def generate_reads(segments, num_reads=100, q=0, error=0, rng=None):
    if rng is None:
        rng = np.random.default_rng()
    num_choices = np.array([len(s) for s in segments])
    true_path = rng.integers(num_choices[np.newaxis, :], size=(num_reads, 3))
    reversed = rng.integers(2, size=num_reads)
    reads = []
    for read_idx in range(num_reads):
        read = "".join(
            [
                variants[variant_idx]
                for variants, variant_idx in zip(segments, true_path[read_idx])
            ]
        )
        # read = mutagenize_seq(read, q=q, error=error, rng=rng)
        read = (
            read[0] + mutagenize_seq(read[1:-1], q=q, error=error, rng=rng) + read[-1]
        )
        if reversed[read_idx]:
            read = str(sequence.reverse_complement(read))
        read = read[:-1] + "G"
        # read = "G" + read[1:]
        reads.append(read)
    # add trailing newline
    formatted_reads = (
        "\n".join([f">r{idx}\n{read}" for idx, read in enumerate(reads)]) + "\n"
    )
    ground_truth = dict(true_path=true_path, reversed=reversed)
    return formatted_reads, ground_truth

In [None]:
def generate_gfa(segments):
    lines = ["H\tVN:Z:1.0"]
    lines.extend(
        [
            f"S\ts{s}={v}\t{seq}"
            for s, variants in enumerate(segments)
            for v, seq in enumerate(variants)
        ]
    )
    lines.extend(
        [
            f"L\ts{s}={v1}\t+\ts{s+1}={v2}\t+\t0M"
            for s in range(len(segments) - 1)
            for v1, v2 in it.product(
                range(len(segments[s])), range(len(segments[s + 1]))
            )
        ]
    )
    return "\n".join(lines) + "\n"  # add trailing newline

In [None]:
def run_aligner(gfa_filename, reads_filename, args=[]):
    # cmd_base = ["/home/jqs1/micromamba/envs/seqtest/bin/GraphAligner"]
    cmd_base = ["/home/jqs1/paulsson-home/bin/GraphAligner"]
    with tempfile.NamedTemporaryFile(mode="w+", suffix=".gaf") as gaf_file:
        cmd = [
            *cmd_base,
            "-g",
            gfa_filename,
            "-f",
            reads_filename,
            "-a",
            gaf_file.name,
            *args,
        ]
        start = time.time()
        out = subprocess.run(cmd, capture_output=True)
        stop = time.time()
        if out.returncode != 0:
            print("STDOUT:")
            print(out.stdout.decode())
            print()
            print("STDERR:")
            print(out.stderr.decode())
            print()
            raise RuntimeError("GraphAligner returned non-zero exit status")
        runtime = stop - start
        print("STDOUT")
        print(out.stdout.decode())
        print("STDERR")
        print(out.stderr.decode())
        gaf = gaf_.read_gaf(gaf_file.name)
        return gaf, runtime


def run_aligner_synthetic(segments, args=[["-x", "vg"]], num_reads=4, q=0, rng=None):
    if rng is None:
        rng = np.random.default_rng()
    with (
        tempfile.NamedTemporaryFile(mode="w+", suffix=".gfa") as gfa_file,
        tempfile.NamedTemporaryFile(mode="w+", suffix=".fasta") as reads_file,
    ):
        gfa = generate_gfa(segments)
        print(gfa)
        gfa_file.write(gfa)
        gfa_file.flush()
        reads, ground_truth = generate_reads(
            segments, num_reads=num_reads, q=q, rng=rng
        )
        print(reads)
        reads_file.write(reads)
        reads_file.flush()
        res = []
        for cmd_args in args:
            res.append(run_aligner(gfa_file.name, reads_file.name, args=cmd_args))
    return res, ground_truth

In [None]:
def check_path_equality(path, true_path):
    if path[0][0] == "<":
        path = path[::-1]
    if len(path) != len(true_path):
        return False
    for segment_idx, p in enumerate(path):
        match = re.match(r"(?:<|>)s(\d+)=(\d+)", p)
        if int(match.group(1)) != segment_idx:
            return False
        if int(match.group(2)) != true_path[segment_idx]:
            return False
    return True


def check_alignment(gaf, ground_truth):
    errors = set()
    for idx in range(len(gaf)):
        path = gaf.column("path")[idx].as_py()
        if not check_path_equality(path, ground_truth["true_path"][idx]):
            # TODO
            # print(">>>",path,ground_truth["true_path"][idx])
            errors.add(idx)
        if (path[0][0] == "<") != ground_truth["reversed"][idx]:
            errors.add(idx)
    return errors

# Config

In [None]:
data_dir = Path("/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test")
carlos_variants = pyfastx.Fasta(data_dir / "230726_carlos/Savinov_Fragment_key.fasta")

# Test

In [None]:
# SEE: https://github.com/maickrau/GraphAligner/issues/18#issuecomment-635793833

In [None]:
# TODO: REPLACE RANDOM_BASES, use rng?

In [None]:
# segments = [["aa", "cc"], ["tttat", "cccgc"], ["tccccccc"]]
# segments = [["aa"*20, "cc"*20], ["tttat"*5, "cccgc"*5], ["tccccccc"*5]]
# segments = [
#     [design.random_bases(40) for i in range(2)],
#     [design.random_bases(40) for i in range(2)],
#     [design.random_bases(40) for i in range(2)],
# ]
segments = [
    [design.random_bases(40, "at") for i in range(2)],
    [design.random_bases(40, "at") for i in range(2)],
    [design.random_bases(40, "at") for i in range(2)],
]

In [None]:
%%time
base_args = ["-b", "100000", "-C", "500000", "--verbose"]
graphaligner_args = [
    # ["-x", "vg"],
    # [*base_args, "--seeds-mum-count", "-1"],
    # [*base_args, "--seeds-mem-count", "-1"],
    [
        *base_args,
        "--seedless-DP",
        # "--DP-restart-stride",
        # "1000",
        # "--precise-clipping",
        # "0.502",
        # "--X-drop",
        # "1000000",
    ],
]
res, ground_truth = run_aligner_synthetic(
    segments, num_reads=10, args=graphaligner_args, q=10, rng=np.random.default_rng(709)
)
errors = [check_alignment(run[0], ground_truth) for run in res]

In [None]:
for run_errors, run_args in zip(errors, graphaligner_args):
    print(f"{' '.join(run_args).rjust(80)}: {len(run_errors)}")

In [None]:
[len(e) for e in errors]

In [None]:
res

In [None]:
ground_truth

In [None]:
res[0][0].column("path")

# Completeness

In [None]:
%%time
segments = Counter()
ends = Counter()
total_reads = 0
for table in tqdm(gaf.iter_gaf(gaf_filename)):
    path_col = table.column("path")
    for idx in range(len(table)):
        path = [s[1:] for s in path_col[idx].as_py()]
        segments.update(path)
        ends[path[0]] += 1
        ends[path[-1]] += 1
        total_reads += 1

In [None]:
d = gaf.read_gaf(gaf_filename).to_pandas()

In [None]:
len(carlos_variants)

In [None]:
str(carlos_variants[0])

In [None]:
for k, v in sorted(
    {k: f"{v/total_reads*100:.0f}" for k, v in segments.items()}.items()
):
    print(f"{k}: {v}%")

In [None]:
for k, v in sorted({k: f"{v/total_reads*100:.0f}" for k, v in ends.items()}.items()):
    print(f"{k}: {v}%")

# Duplex barcode mismatches

In [None]:
%%time
total_reads = 0
complete_barcodes = 0
name_to_barcode = {}
for table in tqdm(parse_gaf(gaf_filename)):
    name_col = table.column("name")
    path_col = table.column("path")
    for idx in range(len(table)):
        name = name_col[idx].as_py()
        path = set([s[1:] for s in path_col[idx].as_py()])
        total_reads += 1
        if ("BIT0:0" in path or "BIT0:1" in path) and (
            "BIT29:0" in path or "BIT29:1" in path
        ):
            complete_barcodes += 1
            barcode = tuple(f"BIT{bit}:1" in path for bit in range(30))
            name_to_barcode[name] = barcode

In [None]:
(complete_barcodes, total_reads, complete_barcodes / total_reads)

In [None]:
duplex_matches = []
duplex_mismatches = []
duplex_missing = []
for name, barcode in tqdm(name_to_barcode.items()):
    reads = name.split(";")
    if len(reads) == 2:
        if reads[0] in name_to_barcode and reads[1] in name_to_barcode:
            if name_to_barcode[reads[0]] != name_to_barcode[reads[1]]:
                duplex_mismatches.append(name)
            else:
                duplex_matches.append(name)
        else:
            duplex_missing.append(name)

In [None]:
name_to_barcode["e7a0f1dc-d947-4265-9dd4-d4cda25a0928"] == name_to_barcode[
    "50815360-6914-41f9-8da8-1882c8db69e6"
]

In [None]:
len(duplex_missing)

In [None]:
len(duplex_matches)

In [None]:
len(duplex_mismatches)

In [None]:
duplex_mismatches[10]