In [None]:
import itertools as it
import operator
import re
from collections import Counter
from pathlib import Path

import duckdb
import gfapy
import holoviews as hv
import ibis
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import pyabpoa
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pyfastx
import pysam
import spoa
from pyarrow import csv
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.sequencing.io as sio

In [None]:
hv.extension("bokeh")

In [None]:
%load_ext pyinstrument
import line_profiler
import pyinstrument

%load_ext line_profiler

In [None]:
pl.enable_string_cache()

# Config

In [None]:
data_dir = Path(
    "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/230707_repressilators/"
)

In [None]:
arrow_ds = ds.dataset(list(data_dir.glob("*.arrow")), format="arrow")
arrow_ds2 = ds.dataset(list(data_dir.glob("*.arrow"))[:1], format="arrow")
parquet_ds = ds.dataset(list(data_dir.glob("*.parquet")), format="parquet")

# Group by path

## Setup

In [None]:
gfa = gfapy.Gfa.from_file(
    "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/barcode.gfa"
)

In [None]:
bc_segments_oriented = [
    [f"{o}{s}" for s in gfa.segment_names if s.startswith("BC:")] for o in "><"
]

In [None]:
bc_segments = [
    f"{o}{s}" for s in gfa.segment_names if s.startswith("BC:BIT") for o in "><"
]

In [None]:
reverse_path_mapping = {
    f"{o[0]}{s}": f"{o[1]}{s}" for s in gfa.segment_names for o in ["<>", "><"]
}

## Polars

In [None]:
def consensus_func(df):
    return (
        df.with_columns(
            pl.col("name").str.contains(";").alias("is_duplex"),
            pl.col("name").str.contains(";").not_().alias("is_simplex"),
        )
        .with_columns(
            pl.sum("is_simplex").alias("depth_simplex"),
            pl.sum("is_duplex").alias("depth_duplex"),
        )
        .select(
            pl.col(
                "path_subset",
                "read_seq",
                "read_phred",
                "depth_simplex",
                "depth_duplex",
            )
        )
        .head(1)
    )

In [None]:
%%time
# with pl.StringCache():
# df = pl.scan_pyarrow_dataset(arrow_ds)
df_input = pl.scan_ipc(str(data_dir / "*.arrow"))
# df_input = pl.scan_ipc(list(data_dir.glob("*.arrow"))[0])
df_path = (
    df_input.rename({"path": "full_path"})  # .limit(100_000)
    .filter(pl.col("name").is_duplicated().not_())
    .with_columns(
        pl.col("full_path").list.set_intersection(bc_segments)
        # TODO: waiting on https://github.com/pola-rs/polars/issues/11735
        # to keep path columns as categorical
        # .cast(pl.List(pl.Categorical))
        .alias("_path"),
    )
    .with_columns(
        pl.col("_path")
        .list.reverse()
        .list.eval(pl.element().map_dict(reverse_path_mapping))
        .alias("_path_reversed"),
        (
            pl.col("full_path")
            .list.set_intersection(bc_segments_oriented[0])
            .list.len()
            > 0
        ).alias("_is_forward"),
        (
            pl.col("full_path")
            .list.set_intersection(bc_segments_oriented[1])
            .list.len()
            > 0
        ).alias("_is_reverse"),
    )
    .with_columns(
        pl.when(pl.col("_is_forward") & pl.col("_is_reverse").not_())
        .then(False)
        .when(pl.col("_is_forward").not_() & pl.col("_is_reverse"))
        .then(True)
        .otherwise(None)
        .alias("reverse_complement"),
    )
    .with_columns(
        pl.when(pl.col("reverse_complement") == False)
        .then(pl.col("_path"))
        .when(pl.col("reverse_complement") == True)
        .then(pl.col("_path_reversed"))
        .cast(pl.List(pl.Categorical))
        .alias("path")
    )
    .with_columns(pl.col("name").str.split(";").alias("parent_names"))
    .with_columns(
        pl.when(pl.col("parent_names").list.len() == 2)
        .then(True)
        .otherwise(False)
        .alias("is_duplex")
    )
    .filter(pl.col("path").is_not_null())
    .filter(pl.col("path").list.len() == 30)  # TODO
)
df_with_parents = (
    df_path.filter(pl.col("is_duplex"))
    .select(pl.col("name"), pl.col("parent_names").alias("_parent_name"))
    .explode("_parent_name")
    .join(
        df_path.select(pl.col("name", "path")),
        how="left",
        left_on=pl.col("_parent_name"),
        right_on=pl.col("name"),
    )
)
df_duplex_paths_match = (
    df_with_parents.with_columns(
        pl.col("path")
        .first()
        .over("name", mapping_strategy="join")
        .alias("_path_first")
    )
    .with_columns(
        (pl.col("path") == pl.col("_path_first")).alias("_path_matches_first")
    )
    .group_by("name")
    .agg(
        pl.col("_path_matches_first").all().alias("_duplex_paths_match"),
        pl.col("_parent_name"),
    )
)
df_simplex_paths_match = df_duplex_paths_match.select(
    pl.col("_duplex_paths_match").alias("_child_duplex_paths_match"),
    pl.col("_parent_name"),
).explode("_parent_name")
df_usable = (
    df_path.join(
        df_duplex_paths_match.select(pl.col("name", "_duplex_paths_match")),
        how="left",
        on="name",
    )
    .join(df_simplex_paths_match, how="left", left_on="name", right_on="_parent_name")
    .with_columns(
        pl.when(pl.col("is_duplex"))
        .then(pl.col("_duplex_paths_match").fill_null(False))
        .otherwise(pl.col("_child_duplex_paths_match").not_().fill_null(True))
        .alias("usable")
    )
)
# res = df_usable.select(pl.all().exclude("^_.*$")).collect()

In [None]:
%%time
df_read_groups = (
    df_usable.select(
        pl.col(
            "name",
            "is_duplex",
            "read_seq",
            "read_phred",
            "reverse_complement",
            "path",
            "full_path",
        )
    )
    .group_by("path")
    .agg(
        pl.col("name", "read_seq", "read_phred", "reverse_complement"),
        pl.col("is_duplex").sum().alias("duplex_depth"),
        pl.count().alias("depth"),
    )
    .with_columns((pl.col("depth") - pl.col("duplex_depth")).alias("simplex_depth"))
    .filter(pl.col("depth") > 5)
    # .map_groups(
    #     consensus_func,
    #     schema=dict(
    #         path_subset=pl.List(pl.Categorical),
    #         read_seq=pl.Utf8,
    #         read_phred=pl.Utf8,
    #         depth_simplex=pl.UInt64,
    #         depth_duplex=pl.UInt64,
    #     ),
    # )
)
res = df_read_groups.collect()
# res.select(pl.all().exclude("name","read_seq", "read_phred", "reverse_complement"))

In [None]:
%%time
res.write_ipc("test_read_groups_full.arrow", compression=None)

In [None]:
len(res)

In [None]:
res.estimated_size() / 1e9

In [None]:
res.select(pl.col("depth", "duplex_depth"))

In [None]:
test_groups = res

In [None]:
%%time
test_groups = pl.read_ipc("test_read_groups.arrow")

In [None]:
test_groups.select(pl.col("name")).with_columns(pl.col("name").list.len()).filter(
    pl.col("name") != 1
).sort("name")

In [None]:
test_reads = test_groups[10].select(
    pl.col("name", "read_seq", "read_phred").list.explode()
)

In [None]:
test_reads

In [None]:
aligner = pyabpoa.msa_aligner()
seqs = test_reads.get_column("read_seq").to_list()
res = aligner.msa(seqs, out_cons=True, out_msa=True)

for seq in res.cons_seq:
    print(seq)  # print consensus sequence

res.print_msa()  # print row-column multiple sequence alignment in PIR format

In [None]:
aligner = pyabpoa.msa_aligner()

In [None]:
consensus, msa = spoa.poa(["AACTTATA", "AACTTATG", "AACTATA"])

In [None]:
%%time
consensus, msa = spoa.poa(test_reads.get_column("read_seq").to_list())

In [None]:
consensus

In [None]:
print("\n".join(msa))

## Secondary alignments

In [None]:
%%time
dups = (
    df_input.limit(10_000)
    .filter(pl.col("name").is_duplicated())
    .select(pl.col("name", "residue_matches", "path"))
    .collect()
    .rows()
)

In [None]:
for idx in range(0, 10, 2):
    print(dups[idx][1], " ".join(dups[idx][2]))
    print()
    print(dups[idx + 1][1], " ".join(dups[idx + 1][2]))
    print()
    print()

## DuckDB

In [None]:
%%time
res = duckdb.execute(
    "SELECT path, list_intersect(path, $bc_segments) FROM arrow_ds2 LIMIT 2",
    dict(bc_segments=bc_segments),
).arrow()

In [None]:
res