In [None]:
import itertools as it
from collections import Counter, defaultdict
from functools import partial
from glob import glob
from pathlib import Path

import gfapy
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
from cytoolz import dissoc
from dask.distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.sequencing.align as align
import paulssonlab.sequencing.cigar as scigar
import paulssonlab.sequencing.consensus as con
import paulssonlab.sequencing.gfa as sgfa
import paulssonlab.sequencing.io as sio
import paulssonlab.sequencing.processing as processing
from paulssonlab.util.sequence import reverse_complement

In [None]:
hv.extension("bokeh")

In [None]:
pl.enable_string_cache()

# Functions

In [None]:
def concat_glob(filename):
    return pl.concat([pl.scan_ipc(f) for f in glob(filename)], how="diagonal")

# Config

In [None]:
cluster = SLURMCluster(
    queue="short",
    walltime="00:30:00",
    memory="1GB",
    local_directory="/tmp",
    log_directory="/home/jqs1/log",
    cores=1,
    processes=1,
)
client = Client(cluster)

In [None]:
cluster

In [None]:
cluster.scale(0)

# 240111_bcd_rbses_revio

In [None]:
arrow_filename = (
    "/home/jqs1/scratch/sequencing/240111_pLIB442-447_revio/consensus/*.arrow"
)
df_revio = concat_glob(arrow_filename).collect()

In [None]:
df_revio["grouping_depth"].sort(descending=True).to_pandas().hvplot.step()

In [None]:
df_revio.with_columns(len=pl.col("consensus_seq").str.len_bytes())

In [None]:
df_revio.select(
    pl.col("grouping_depth").sum().alias("foo")
).to_series()  # .to_dict(as_series=False)

# 231207_pLIB442-447

In [None]:
%%time
arrow_filename = "/home/jqs1/scratch/sequencing/231207_pLIB442-447/20231207_1151_3C_PAU07761_c6097b3e/consensus/*.arrow"
df_ont = concat_glob(arrow_filename).collect()

# Barcode distribution

In [None]:
arrow_filename = "/home/jqs1/scratch/sequencing/231201_bcd_rbses_run3/20231201_1101_1F_PAU05823_773c75ee/consensus/*.arrow"
df_231201 = concat_glob(arrow_filename).collect()

In [None]:
arrow_filename = "/home/jqs1/scratch/sequencing/230818_repressilators/20230905_1132_1H_PAQ85679_c9d74ddb/consensus/*.arrow"
df_230818 = concat_glob(arrow_filename).collect()

In [None]:
%%time
df_joined = (
    df_231201.with_columns(path_hash=processing.categorical_list_hash(pl.col("path")))
    .select(pl.all().name.suffix("_231201"))
    .join(
        df_230818.with_columns(
            path_hash=processing.categorical_list_hash(pl.col("path"))
        ).select(pl.all().name.suffix("_230818")),
        left_on="path_hash_231201",
        right_on="path_hash_230818",
        how="outer_coalesce",
        validate="1:1",
    )
)

In [None]:
df_joined.group_by(
    pl.col("path_231201").is_not_null(), pl.col("path_230818").is_not_null()
).agg(pl.len()).with_columns(frac=pl.col("len") / pl.col("len").sum())

In [None]:
df_joined.with_columns(len=pl.col("consensus_seq_231201").str.len_bytes())[
    ["grouping_depth_231201", "grouping_depth_230818", "len"]
].to_pandas().hvplot.scatter(
    "grouping_depth_231201", "grouping_depth_230818", color="len", size=0.5
)

# Pairwise alignment

In [None]:
%%time
df_joined = (
    df_revio.with_columns(path_hash=processing.categorical_list_hash(pl.col("path")))
    .select(pl.all().name.suffix("_revio"))
    .join(
        df_ont.with_columns(
            path_hash=processing.categorical_list_hash(pl.col("path"))
        ).select(pl.all().name.suffix("_ont")),
        left_on="path_hash_revio",
        right_on="path_hash_ont",
        how="outer_coalesce",
        validate="1:1",
    )
)

In [None]:
df_joined.group_by(
    pl.col("path_revio").is_not_null(), pl.col("path_ont").is_not_null()
).agg(pl.len()).with_columns(frac=pl.col("len") / pl.col("len").sum()).rename(
    {"path_revio": "in_revio", "path_ont": "in_ont"}
)

In [None]:
df_joined["grouping_depth_ont"].to_pandas().hvplot.kde() * df_joined.filter(
    pl.col("path_revio").is_null()
)["grouping_depth_ont"].to_pandas().hvplot.kde()

In [None]:
df_joined_both = df_joined.filter(
    pl.col("path_revio").is_not_null(), pl.col("path_ont").is_not_null()
)

In [None]:
df_joined.with_columns(len=pl.col("consensus_seq_revio").str.len_bytes())[
    ["grouping_depth_revio", "grouping_depth_ont", "len"]
].to_pandas().hvplot.scatter(
    "grouping_depth_revio", "grouping_depth_ont", color="len", size=0.5
)

In [None]:
query

In [None]:
idx = 106
query = df_joined_both[idx, "consensus_seq_revio"]
ref = df_joined_both[idx, "consensus_seq_ont"]
align_kwargs = {
    "gap_opening": 10,
    "gap_extension": 1,
    "match": 1,
    "mismatch": -1,
    "parasail_algorithm": "sw",
}
res = align.pairwise_align(query, ref, **align_kwargs)
res

In [None]:
def count_ops(cigar):
    counts = Counter()
    for idx, (op, length) in enumerate(cigar):
        if idx == 0 or idx == len(cigar) - 1:
            continue
        counts[op] += length
    return counts


def count_op_classes(cigar):
    counts = count_ops(cigar)
    mismatches = counts.get(align.CigarOp.X, 0)
    indels = counts.get(align.CigarOp.I, 0) + counts.get(align.CigarOp.D, 0)
    return (mismatches, indels)

In [None]:
%%time
align_kwargs = {
    "gap_opening": 10,
    "gap_extension": 1,
    "match": 1,
    "mismatch": -1,
    "parasail_algorithm": "sw",
}
for idx in range(200):
    query = df_joined_both[idx, "consensus_seq_revio"]
    ref = df_joined_both[idx, "consensus_seq_ont"]
    res = align.pairwise_align(query, ref, **align_kwargs)
    print(idx, count_op_classes(res[1]))

In [None]:
df_joined_both.columns

In [None]:
df_joined_both.columns

In [None]:
%%time
align_kwargs = {
    "gap_opening": 10,
    "gap_extension": 1,
    "match": 1,
    "mismatch": -1,
    "parasail_algorithm": "sw",
}
columns = [
    "name_ont",
    "consensus_depth_revio",
    "consensus_depth_ont",
    "consensus_duplex_depth_ont",
    "grouping_depth_ont",
]
rows = []
num = len(df_joined_both)
# num = 1000
for idx in trange(num):
    query = df_joined_both[idx, "consensus_seq_revio"]
    ref = df_joined_both[idx, "consensus_seq_ont"]
    row = df_joined_both[idx, columns].to_dicts()[0]
    row["alignment"] = client.submit(align.pairwise_align, query, ref, **align_kwargs)
    rows.append(row)

In [None]:
%%time
res = client.gather(rows)

In [None]:
res[0]

In [None]:
def process_alignments(input):
    output = []
    for row in input:
        new_row = dissoc(row, "alignment")
        new_row["score"] = row["alignment"][0]
        mismatches, indels = count_op_classes(row["alignment"][1])
        new_row["mismatches"] = mismatches
        new_row["indels"] = indels
        output.append(new_row)
    return output

In [None]:
alignments = process_alignments(res)

In [None]:
alignments[0]

In [None]:
for idx, row in enumerate(alignments):
    row["name_ont"] = df_joined_both[idx, "name_ont"]

In [None]:
df = pd.DataFrame(alignments)

In [None]:
df

In [None]:
df.to_parquet("240304errors.parquet")

In [None]:
df.hvplot.scatter("consensus_depth_ont", "indels")

In [None]:
df.hvplot.bivariate("consensus_depth_ont", "indels")

In [None]:
df["mismatches"].value_counts() / len(df)

In [None]:
df["indels"].value_counts() / len(df)

In [None]:
df[(df["mismatches"] > 0) | (df["indels"] > 5)]