In [None]:
import itertools as it
import operator
import re
import subprocess
import tempfile
import time
from collections import Counter, defaultdict
from functools import partial
from pathlib import Path

import awkward as ak
import duckdb
import gfapy
import holoviews as hv
import ibis
import matplotlib.pyplot as plt
import numba
import numpy as np
import pandas as pd
import parasail
import pod5
import polars as pl
import pyabpoa
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pyfastx
import pysam
import spoa
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from pyarrow import csv
from pywfa import WavefrontAligner
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.sequencing.align as align
import paulssonlab.sequencing.cigar as scigar
import paulssonlab.sequencing.consensus as con
import paulssonlab.sequencing.gfa as sgfa
import paulssonlab.sequencing.io as sio
import paulssonlab.sequencing.processing as processing
import paulssonlab.sequencing.uuid as uuid
from paulssonlab.util.sequence import reverse_complement

In [None]:
hv.extension("bokeh")

In [None]:
%load_ext pyinstrument
import line_profiler
import pyinstrument

%load_ext line_profiler

In [None]:
pl.enable_string_cache()

# Config

In [None]:
gfa_filename = "/home/jqs1/scratch3/jqs1/sequencing/230930_alignment_test/barcode.gfa"

In [None]:
gfa = gfapy.Gfa.from_file(gfa_filename)

# Duplex pairing

## UUID parsing

In [None]:
pod5_filename = "/home/jqs1/scratch3/jqs1/sequencing/230930_alignment_test/230707_repressilators/channel-135_merged.pod5"
gaf_filename = "/home/jqs1/scratch3/jqs1/sequencing/230930_alignment_test/230707_repressilators/channel-135_merged.gaf"

In [None]:
%%time
# gaf = pa.Table.from_batches([next(sio.iter_gaf(gaf_filename))])
# gaf = pa.Table.from_batches(list(it.islice(sio.iter_gaf(gaf_filename), 2)))
gaf = pa.Table.from_batches(sio.iter_gaf(gaf_filename))

In [None]:
%%time
gaf_simplex = gaf.filter(pc.invert(pc.match_substring(gaf["name"], ";")))

In [None]:
name_col = gaf_simplex["name"]  # [:10]

In [None]:
%%time
uuids = uuid.parse_uuids(uuid.remove_hyphens(ak.from_arrow(name_col)))

In [None]:
x = np.concatenate([uuids] * 1000)
%timeit ak.to_arrow(ak.enforce_type(ak.from_numpy(x), "bytes"))

In [None]:
x.nbytes / 1e6

In [None]:
len(x)

In [None]:
%timeit ak.to_arrow(ak.enforce_type(ak.from_numpy(uuids), "bytes"))

In [None]:
ak.to_arrow(ak.enforce_type(ak.from_numpy(uuids), "bytes"))

In [None]:
gaf_simplex["name"]

## Join

In [None]:
uuids = uuid.parse_uuids(uuid.remove_hyphens(ak.from_arrow(gaf_simplex["name"])))
uuids_bytes = ak.to_arrow(
    ak.enforce_type(ak.from_numpy(uuids), "bytes"), extensionarray=False
)
# waiting on https://github.com/apache/arrow/issues/39232
# not sure if it's necessary/helps performance
# uuids_bytes = uuids_bytes.cast(pa.binary(16))
gaf_simplex_parsed = gaf_simplex.set_column(
    gaf_simplex.column_names.index("name"), "name", uuids_bytes
)

In [None]:
reader = pod5.Reader(pod5_filename)

In [None]:
run_info_table = reader.run_info_table.read_all()

In [None]:
read_table = reader.read_table.read_all()

In [None]:
read_table2 = read_table.select(["read_id", "start", "well", "channel", "run_info"])

In [None]:
read_table_df = pl.from_arrow(read_table2).lazy()

In [None]:
gaf_df = pl.from_arrow(gaf_simplex_parsed).lazy()

In [None]:
gaf_unique_df = gaf_df.filter(pl.col("name").is_first_distinct())

In [None]:
joined_df = read_table_df.join(
    gaf_unique_df, left_on="read_id", right_on="name", how="left"
)

In [None]:
joined_df = joined_df.collect()

In [None]:
joined_df.columns

In [None]:
# first unique, full-length alignment
# same well, channel
# start within (1s * 5000 Hz)

In [None]:
run_info = {}
for col in ["sample_rate"]:
    run_info[col] = dict(
        zip(
            run_info_table["acquisition_id"].to_pylist(),
            run_info_table[col].to_pylist(),
        )
    )

In [None]:
run_info

In [None]:
sample_rate = run_info["sample_rate"]["255dd505a1ce8175032f42d363ea427604712e08"]

In [None]:
joined_df.cast({"start": pl.Int64}).sort("start").rolling(
    "start", period=f"{sample_rate}i"
).agg(pl.all().exclude("start"))

In [None]:
joined_df.filter(pl.col("path").is_not_null()).cast({"start": pl.Int64}).sort(
    "start"
).rolling(
    "start", by=["channel", "well"], period=f"{sample_rate}i", check_sorted=False
).agg(
    pl.col("start").alias("_start"),
    pl.col("read_id"),
    pl.col("path"),
    pl.col("path").first().alias("_path_first"),
    # pl.col("path").list.eval(pl.element() == pl.col("path_first")),
).rename(
    {"_start": "start", "start": "window_start"}
).filter(
    pl.col("read_id").list.len() >= 2
).explode(
    "start", "read_id", "path"
).with_columns(
    (pl.col("path") == pl.col("_path_first")).alias("_duplex_match")
)

In [None]:
joined_df.filter(pl.col("path").is_not_null()).cast({"start": pl.Int64}).sort(
    "start"
).rolling("start", by=["channel", "well"], period=f"{sample_rate}i").agg(
    pl.col("start").alias("_start"),
    pl.col("read_id"),
    pl.col("path"),
    pl.col("path").first().alias("_path_first"),
    # pl.col("path").list.eval(pl.element() == pl.col("path_first")),
).filter(
    pl.col("read_id").list.len() >= 2
)  # .explode("_start", "read_id", "path").with_columns((pl.col("path") == pl.col("_path_first")).alias("_duplex_match"))

In [None]:
joined_df["start"].max()

In [None]:
1295997446 / 5000 / 60 / 60

In [None]:
pod5_filename = "/home/jqs1/scratch/sequencing/230707_repressilators/pod5_pass_split/channel-251_merged.pod5"
reader = pod5.Reader(pod5_filename)
read_table = reader.read_table.read_all()
read_table2 = read_table.select(["read_id", "start", "well", "channel", "run_info"])
read_table_df = pl.from_arrow(read_table2).lazy()

In [None]:
read_table_pdf = read_table2.to_pandas()

In [None]:
read_table_pdf["time"] = pd.to_datetime(read_table_pdf["start"] / 5000, unit="s")

In [None]:
read_table_pdf = read_table_pdf.sort_values("time")

In [None]:
read_table_pdf[read_table_pdf["channel"] == 268][["time", "read_id"]].rolling(
    "1s", on="time"
).first()

In [None]:
read_table_subset = read_table_pdf[read_table_pdf["channel"] == 306].sort_values(
    "start"
)

In [None]:
x = read_table_subset["start"].values

In [None]:
((x[1:] - x[:-1]) < 5000).sum()

In [None]:
total = 0
for channel in set(read_table_pdf["channel"]):
    read_table_subset = read_table_pdf[
        read_table_pdf["channel"] == channel
    ].sort_values("start")
    x = read_table_subset["start"].values
    num_duplex = ((x[1:] - x[:-1]) < 5000).sum()
    total += num_duplex
    if num_duplex > 0:
        print(f"{channel}: {num_duplex}")
print(f"TOTAL: {total}")

In [None]:
dt = 5000
r = read_table_subset["start"].map(
    lambda x: read_table_subset["start"].between(x - dt, x).sum()
)
r[r >= 2]

In [None]:
dt = 5000
read_table_subset["start"][:10].map(
    lambda x: read_table_subset["read_id"][
        read_table_subset["start"].between(x - dt, x)
    ]
)

In [None]:
read_table_df.cast({"start": pl.Int64}).sort("start").rolling(
    "start", by=["channel", "well"], period="5000i"
).agg(pl.col("start").alias("_start"), pl.all().exclude("start")).filter(
    pl.col("read_id").list.len() >= 2
).collect()

In [None]:
joined_df.filter(pl.col("path").is_not_null()).cast({"start": pl.Int64}).sort(
    "start"
).rolling("start", by=["channel", "well"], period=f"5000i").agg(
    pl.col("start").alias("_start"),
    pl.col("read_id"),
    pl.col("path"),
    pl.col("path").first().alias("_path_first"),
    # pl.col("path").list.eval(pl.element() == pl.col("path_first")),
).rename(
    {"_start": "start", "start": "window_end"}
).filter(
    pl.col("read_id").list.len() >= 2
).explode(
    "start", "read_id", "path"
).with_columns(
    (pl.col("path") == pl.col("_path_first")).alias("_duplex_match")
)

In [None]:
read_table  # ["run_info"]

# BAM

In [None]:
# bam_filename = "/home/jqs1/scratch3/jqs1/sequencing/230930_alignment_test/230707_repressilators/channel-135_merged.bam"
# bam_filename = "/home/jqs1/scratch3/jqs1/sequencing/230922_bcd_rbses_constitutive/20230922_1104_1A_PAQ83451_8d610a8c/bam_pass/channel-100_merged.bam"
bam_filename = "/home/jqs1/scratch/sequencing/230707_repressilators/20230707_2040_MN35044_FAS94231_25542e0d/_temp/channel-1_merged.bam"

In [None]:
bam = pysam.AlignmentFile(bam_filename, check_sq=False)

In [None]:
bam.reset()
reads = bam.fetch(until_eof=True)

In [None]:
x = next(reads)

In [None]:
bam.count(until_eof=True)

In [None]:
dict(x.tags)["st"]

In [None]:
x.tags

In [None]:
x.tags

In [None]:
read

In [None]:
bam.reset()
reads = [
    {"name": read.query_name, **dict(read.tags)}
    for read in tqdm(bam.fetch(until_eof=True))
]

In [None]:
%%time
df = pd.DataFrame.from_records(reads)
df["st"] = pd.to_datetime(df["st"], format="ISO8601")
df = df.sort_values("st").reset_index(drop=True)

In [None]:
df["_endtime"] = df["st"] + pd.to_timedelta(df["du"], unit="s")

In [None]:
df2 = pl.from_pandas(df)

In [None]:
df3 = df2.select(["name", "st", "_endtime", "du", "ch", "mx"])
# df3 = df3.filter(pl.col("ch") == 134, pl.col("mx") == 1)
df3_st = df3.sort("st")
df3_endtime = df3.sort("_endtime")

In [None]:
%%time
df4 = df3_endtime.join_asof(
    df3_st, left_on="_endtime", right_on="st", strategy="forward", tolerance="10s"
)

In [None]:
%%time
df4 = df3_endtime.join_asof(
    df3_st,
    left_on="_endtime",
    right_on="st",
    by=["ch", "mx"],
    strategy="forward",
    tolerance="10s",
)

In [None]:
df4.with_columns((pl.col("st_right") - pl.col("_endtime")).alias("dt"))

In [None]:
df4.filter(pl.col("name_right").is_not_null())

In [None]:
df["st"][0]

In [None]:
df_subset = df[df["ch"] == 282]
x = (df_subset["st"] - df_subset["st"].iloc[0]).dt.total_seconds().values

In [None]:
import oxbow as ox

In [None]:
f = ox.read_bam(bam_filename)

In [None]:
df_subset = df[(df["ch"] == 135) & (df["dx"] == -1)]

In [None]:
df_subset[:20]

In [None]:
total = 0
for channel in set(df["ch"]):
    df_subset = df[(df["ch"] == channel) & df["dx"] == -1]
    x = (df_subset["st"] - df_subset["st"].iloc[0]).dt.total_seconds().values
    num_duplex = ((x[1:] - x[:-1]) == 0).sum()
    total += num_duplex
    if num_duplex > 0:
        print(f"{channel}: {num_duplex}")
print(f"TOTAL: {total}")

In [None]:
((x[1:] - x[:-1]) < 1).sum()