In [None]:
import itertools as it
import operator
import re
import subprocess
import tempfile
import time
from collections import Counter
from pathlib import Path

import gfapy
import holoviews as hv
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.compute as pc
import pyfastx
from pyarrow import csv
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.sequencing.gfa as sgfa
import paulssonlab.sequencing.io as sio
import paulssonlab.sequencing.processing as processing

In [None]:
hv.extension("bokeh")

In [None]:
pl.enable_string_cache()

# prepare_reads.py for non-dorado duplex

In [None]:
gfa_filename = "/home/jqs1/scratch/sequencing/230707_repressilators/20230707_2040_MN35044_FAS94231_25542e0d/references/barcode.gfa"
gfa = gfapy.Gfa.from_file(gfa_filename)
gfa = sgfa.filter_gfa(
    gfa, exclude=["UNS9", "BC:UPSTREAM", "BC:JUNCTION", "BC:T7_TERM", "BC:SPACER2"]
)
graph = sgfa.gfa_to_dag(gfa)
# weakly_connected_components is a generator, so only compute once
wccs = list(nx.weakly_connected_components(graph))
forward_segments = sgfa.dag_forward_segments(graph, wccs=wccs)
endpoints = sgfa.dag_endpoints(graph, wccs=wccs)

In [None]:
p1 = [
    ">BC:BIT0=1",
    ">BC:BIT1=1",
    ">BC:BIT2=1",
    ">BC:BIT3=0",
    ">BC:BIT4=1",
    ">BC:BIT5=0",
    ">BC:BIT6=0",
    ">BC:BIT7=0",
    ">BC:BIT8=0",
    ">BC:BIT9=1",
    ">BC:BIT10=1",
    ">BC:BIT11=1",
    ">BC:BIT12=0",
    ">BC:BIT13=0",
    ">BC:BIT14=1",
    ">BC:SPACER1",
    ">BC:BIT15=0",
    ">BC:BIT16=1",
    ">BC:BIT17=1",
    ">BC:BIT18=1",
    ">BC:BIT19=0",
    ">BC:BIT20=1",
    ">BC:BIT21=0",
    ">BC:BIT22=1",
    ">BC:BIT23=1",
    ">BC:BIT24=0",
    ">BC:BIT25=1",
    ">BC:BIT26=1",
    ">BC:BIT27=1",
    ">BC:BIT28=0",
    ">BC:BIT29=0",
]
p2 = [
    ">BC:BIT0=1",
    ">BC:BIT1=0",
    ">BC:BIT2=0",
    ">BC:BIT3=1",
    ">BC:BIT4=1",
    ">BC:BIT5=1",
    ">BC:BIT6=1",
    ">BC:BIT7=0",
    ">BC:BIT8=1",
    ">BC:BIT9=1",
    ">BC:BIT10=0",
    ">BC:BIT11=0",
    ">BC:BIT12=0",
    ">BC:BIT13=1",
    ">BC:BIT14=1",
    ">BC:SPACER1",
    ">BC:BIT15=1",
    ">BC:BIT16=0",
    ">BC:BIT17=0",
    ">BC:BIT18=0",
    ">BC:BIT19=1",
    ">BC:BIT20=0",
    ">BC:BIT21=0",
    ">BC:BIT22=1",
    ">BC:BIT23=1",
    ">BC:BIT24=1",
    ">BC:BIT25=0",
    ">BC:BIT26=1",
    ">BC:BIT27=0",
    ">BC:BIT28=1",
    ">BC:BIT29=1",
]
p3 = [
    ">BC:BIT0=1",
    ">BC:BIT1=0",
    ">BC:BIT2=1",
    ">BC:BIT3=0",
    ">BC:BIT4=0",
    ">BC:BIT5=0",
    ">BC:BIT6=1",
    ">BC:BIT7=1",
    ">BC:BIT8=0",
    ">BC:BIT9=0",
    ">BC:BIT10=1",
    ">BC:BIT11=1",
    ">BC:BIT12=1",
    ">BC:BIT13=0",
    ">BC:BIT14=1",
    ">BC:SPACER1",
    ">BC:BIT15=0",
    ">BC:BIT16=0",
    ">BC:BIT17=1",
    ">BC:BIT18=0",
    ">BC:BIT19=0",
    ">BC:BIT20=0",
    ">BC:BIT21=0",
    ">BC:BIT22=1",
    ">BC:BIT23=0",
    ">BC:BIT24=0",
    ">BC:BIT25=0",
    ">BC:BIT26=1",
    ">BC:BIT27=1",
    ">BC:BIT28=0",
    ">BC:BIT29=0",
]
p1_r = [reverse_path_mapping[x] for x in reversed(p1)]
p2_r = [reverse_path_mapping[x] for x in reversed(p2)]
p3_r = [reverse_path_mapping[x] for x in reversed(p3)]

test_data = [
    # valid: a;b
    # invalid: a, b
    dict(name="a", path=p1),
    dict(name="b", path=p1_r),
    dict(name="a;b", path=p1),
    # valid: c, d
    # invalid: c;d
    dict(name="c", path=p2),
    dict(name="d", path=p1_r),
    dict(name="c;d", path=p1),
    # valid: e, f
    # invalid: e;f
    dict(name="e", path=p2),
    dict(name="f", path=p1_r),
    dict(name="e;f", path=p3),
    # valid: g
    # invalid: g;h
    dict(name="g", path=p1),
    dict(name="g;h", path=p1),
    # valid: i
    # invalid: i;j
    dict(name="i", path=p1),
    dict(name="i;j", path=p2),
    # valid: k
    # invalid: k;l
    # null: l
    dict(name="k", path=p1),
    dict(name="l", path=p1_r),
    dict(name="l", path=p2_r),
    dict(name="k;l", path=p1),
    # valid: m, n
    # null: m;n
    dict(name="m", path=p1),
    dict(name="n", path=p1_r),
    dict(name="m;n", path=p1),
    dict(name="m;n", path=p2),
    # valid: o, p
    dict(name="o", path=p1),
    dict(name="p", path=p2),
    # valid: q, r
    # invalid: q;r
    dict(name="q", path=p1),
    dict(name="r", path=p1),
    dict(name="q;r", path=p1),
    # valid: t;u
    # invalid: t, u
    dict(name="t", path=p1),
    dict(name="u", path=p1_r),
    dict(name="t;u", path=p1_r),
]
test_df = pl.from_dicts(
    [{"dx": 1 if ";" in x["name"] else 0, **x} for x in test_data],
    schema=dict(name=pl.String, path=pl.List(pl.Categorical), dx=pl.Int8),
)

In [None]:
test_df

In [None]:
test_df2 = processing.prepare_reads(test_df, forward_segments, endpoints)

In [None]:
test_df2

In [None]:
test_df2.select("name", "is_valid").to_numpy()

In [None]:
df2 = processing.prepare_reads(df, forward_segments, endpoints)

In [None]:
s = pl.Series(["a", "b"], dtype=pl.Categorical)
s.replace("a", "c", default=None)

In [None]:
%%time
df = pl.scan_ipc(
    "/n/scratch/users/j/jqs1/sequencing/230707_repressilators/20230707_2040_MN35044_FAS94231_25542e0d/test_output/join_gaf_test/channel-1_merged.arrow"
)
df2 = processing.prepare_reads(df, forward_segments, endpoints)
df2 = df2.collect()

In [None]:
df2

In [None]:
df2["path"][:10].to_numpy()

In [None]:
df2.filter(pl.col("is_valid") == False)

In [None]:
%%time
df = pl.scan_ipc(
    "/n/scratch/users/j/jqs1/sequencing/230707_repressilators/20230707_2040_MN35044_FAS94231_25542e0d/test_output/join_gaf_test/channel-1_merged.arrow"
)
df = processing.normalize_paths(df, forward_segments)
df = df.with_columns(pl.col("path").hash().alias("path_hash"))
df = processing.flag_end_to_end(df, endpoints)
df = df.with_columns(
    is_duplicate_alignment=pl.col("name").is_duplicated(),
).with_columns(_candidate=(~pl.col("is_duplicate_alignment") & pl.col("end_to_end")))
df = df.collect()

In [None]:
%%time
df_candidates = df.filter(pl.col("_candidate"))
df_input = df_candidates.with_columns(
    pl.col("name").str.split(";").alias("_parent_names"),
)
df_with_parents = (
    df_input.filter(pl.col("dx") == 1)
    .select(
        pl.col("name"),
        pl.col("path").alias("_duplex_path"),
        pl.col("_parent_names").alias("_parent_name"),
    )
    .explode("_parent_name")
    .join(
        df_candidates.select(pl.col("name"), pl.col("path").alias("_parent_path")),
        how="left",
        left_on=pl.col("_parent_name"),
        right_on=pl.col("name"),
    )
    .with_columns(
        (pl.col("_duplex_path") == pl.col("_parent_path")).alias("_paths_match_duplex")
    )
)
df_duplex_paths_match = df_with_parents.group_by("name").agg(
    pl.col("_paths_match_duplex").all(),
    pl.col("_parent_name"),
)

In [None]:
df_duplex_paths_match

In [None]:
df.filter(pl.col("name") == "23d73de1-49d0-478f-a95b-3d9a8287fea8")[
    "read_seq"
].to_pandas().values

In [None]:
df_with_parents.filter(~pl.col("_paths_match_duplex"))[["_parent_name"]].to_numpy()

In [None]:
print(
    "\n\n".join(
        [
            f"{''.join(x[0])}"
            for x in df_with_parents.filter(~pl.col("_paths_match_duplex"))[
                ["_duplex_path", "_parent_path"]
            ]
        ]
    )
)

In [None]:
df_duplex_paths_match.filter(~pl.col("_paths_match_duplex"))

In [None]:
df_with_parents.with_columns(
    (pl.col("_duplex_path") == pl.col("_parent_path")).alias("_path_matches")
).filter(~pl.col("_path_matches"))

In [None]:
df_paths_match.filter(~pl.col("_path_matches"))

In [None]:
df_with_parents

In [None]:
df_duplex_paths_match

In [None]:
%%time
df_with_parents.head(10).collect()

In [None]:
(
    df_with_parents.with_columns(
        pl.col("path").first().over("name").alias("_path_first")
    ).with_columns(
        (pl.col("path") == pl.col("_path_first")).alias("_path_matches_first")
    )
)

In [None]:
%%time
df_duplex_paths_match = (
    df_with_parents.with_columns(
        pl.col("path").first().over("name").alias("_path_first")
    )
    .with_columns(
        (pl.col("path") == pl.col("_path_first")).alias("_path_matches_first")
    )
    .group_by("name")
    .agg(
        pl.col("_path_matches_first").all().alias("_duplex_paths_match"),
        pl.col("_parent_name"),
    )
)

In [None]:
df_duplex_paths_match

In [None]:
%%time
df_simplex_paths_match = df_duplex_paths_match.select(
    pl.col("_duplex_paths_match").alias("_duplex_child_path_matches"),
    pl.col("_parent_name"),
).explode("_parent_name")

In [None]:
df_simplex_paths_match

In [None]:
%%time
df_usable = (
    df_input.join(
        df_duplex_paths_match.select(
            pl.col("name"),
            pl.col("_duplex_paths_match").alias("_duplex_parent_paths_match"),
        ),
        how="left",
        on="name",
    )
    .join(df_simplex_paths_match, how="left", left_on="name", right_on="_parent_name")
    .with_columns(
        pl.when(pl.col("is_duplex"))
        .then(pl.col("_duplex_parent_paths_match").fill_null(False))
        .otherwise(pl.col("_duplex_child_path_matches").not_().fill_null(True))
        .alias("is_usable_read")
    )
    .select(pl.all().exclude("_parent_names"))
)

In [None]:
df_usable

In [None]:
%%time
df = processing.flag_usable_reads(df)
df = processing.flag_end_to_end(df, endpoints)
df = df.with_columns(
    is_duplicate_alignment=pl.col("name").is_duplicated(),
)

In [None]:
%%time
df = df.collect()

In [None]:
df_usable.filter(pl.col("_duplex_sibling_paths_match") == False, ~pl.col("is_duplex"))

In [None]:
df.filter(pl.col("_duplex_sibling_paths_match").is_null(), ~pl.col("is_duplex"))

In [None]:
df.filter(~pl.col("_duplex_children_paths_match"), pl.col("is_duplex")).sort(
    "query_length"
)

In [None]:
df4.filter(pl.col("usable_read"))

In [None]:
df4.filter(~pl.col("_child_duplex_paths_match").is_null())

In [None]:
df.filter(pl.col("name").is_duplicated())

In [None]:
df.filter(pl.col("dx") == 0)

In [None]:
df_joined = df.join(df["name"].value_counts(sort=True), on="name")

In [None]:
plt.hist(
    df_joined.filter(pl.col("counts") == 1)["query_length"].to_numpy(),
    bins=100,
    log=True,
);

In [None]:
plt.hist(
    df_joined.filter(pl.col("counts") == 2)["query_length"].to_numpy(),
    bins=100,
    log=True,
);

In [None]:
plt.hist(
    df_joined.filter(pl.col("counts") > 2)["query_length"].to_numpy(),
    bins=100,
    log=True,
);

In [None]:
df_joined.filter(pl.col("counts") > 3)

In [None]:
df["name"].value_counts(sort=True).filter(pl.col("counts") > 2)

In [None]:
x = df.filter(pl.col("name").is_duplicated())["path"]

In [None]:
print("\n\n".join([" ".join(y) for y in x.to_pandas()[:20]]))

In [None]:
df.filter(pl.col("name").is_first_distinct())

In [None]:
len(df)

In [None]:
df.columns

In [None]:
df.tail(100)

In [None]:
df.filter(pl.col("name").is_duplicated())

In [None]:
df["name"].str.contains(";").sum()

In [None]:
df.filter(pl.col("name").str.contains(";"))

In [None]:
df2.filter(pl.col("dx") != 0)

# GFA name mapping

In [None]:
g = gfapy.Gfa.from_file(
    "/home/jqs1/scratch/sequencing/230707_repressilators/20230707_2040_MN35044_FAS94231_25542e0d/references/pLIB419.gfa"
)

In [None]:
g.segments[0].name = "foo"

In [None]:
print(g.to_gfa1_s())

# RecGraph mtx

In [None]:
DEGENERATE_BASES = {
    "R": "AG",
    "Y": "CT",
    "M": "AC",
    "K": "GT",
    "S": "CG",
    "W": "AT",
    "B": "CGT",
    "D": "AGT",
    "H": "ACT",
    "V": "ACG",
    "N": "ACGT",
}


def degenerate_recgraph_matrix(
    match,
    mismatch,
    degenerate_match=None,
    degenerate_mismatch=None,
    degenerate_bases=DEGENERATE_BASES,
):
    if degenerate_match is None:
        degenerate_match = match
    if degenerate_mismatch is None:
        degenerate_mismatch = mismatch
    bases = "ATCG" + "".join(degenerate_bases.keys())
    num_bases = len(bases)
    base_to_idx = {base: idx for idx, base in enumerate(bases)}
    if degenerate_match is None:
        degenerate_match = match
    if degenerate_mismatch is None:
        degenerate_mismatch = mismatch
    # matrix = parasail.matrix_create(bases, match, mismatch)
    matrix = np.full((num_bases, num_bases), mismatch)
    # matrix[
    for deg_base, matching_bases in degenerate_bases.items():
        idx = base_to_idx[deg_base]
        degenerate_match_idxs = [base_to_idx[base] for base in matching_bases]
        degenerate_mismatch_idxs = [
            base_to_idx[base] for base in set("ATCG") - set(matching_bases)
        ]
        for idx2 in degenerate_match_idxs:
            matrix[idx, idx2] = matrix[idx2, idx] = degenerate_match
        for idx2 in degenerate_mismatch_idxs:
            matrix[idx, idx2] = matrix[idx2, idx] = degenerate_mismatch
    alphabet_aliases = "".join(
        f"{base}{deg_base}{deg_base}{base}"
        for deg_base, matching_bases in degenerate_bases.items()
        for base in matching_bases
    )
    return matrix, alphabet_aliases