In [None]:
import itertools as it
import operator
import re
from collections import Counter
from pathlib import Path

import duckdb
import holoviews as hv
import ibis
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars
import pyabpoa
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pyfastx
import pysam
import spoa
from pyarrow import csv
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.sequencing.io as sio

In [None]:
hv.extension("bokeh")

In [None]:
%load_ext pyinstrument
import line_profiler
import pyinstrument

%load_ext line_profiler

# Config

In [None]:
data_dir = Path(
    "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/230707_repressilators/"
)

In [None]:
arrow_ds = ds.dataset(list(data_dir.glob("*.arrow")), format="arrow")
parquet_ds = ds.dataset(list(data_dir.glob("*.parquet")), format="parquet")

# Arrow

In [None]:
%%time
qs_duplex = duckdb.sql("SELECT qs FROM arrow_ds WHERE contains(name, ';')").arrow()
qs_simplex = duckdb.sql("SELECT qs FROM arrow_ds WHERE NOT contains(name, ';')").arrow()

In [None]:
plt.hist(qs_duplex, bins=100, density=True)
plt.hist(qs_simplex, bins=100, density=True);

In [None]:
%%time
qs_duplex2 = duckdb.sql(
    "SELECT list_aggregate(read_phred, 'mean') FROM arrow_ds WHERE contains(name, ';')"
).arrow()
qs_simplex2 = duckdb.sql(
    "SELECT list_aggregate(read_phred, 'mean') FROM arrow_ds WHERE NOT contains(name, ';')"
).arrow()

In [None]:
%%time
qs_duplex3 = duckdb.sql(
    "SELECT qs, list_aggregate(read_phred, 'mean') FROM arrow_ds WHERE contains(name, ';')"
).arrow()
qs_simplex3 = duckdb.sql(
    "SELECT qs, list_aggregate(read_phred, 'mean') FROM arrow_ds WHERE NOT contains(name, ';')"
).arrow()

In [None]:
b = next(arrow_ds.to_batches())

In [None]:
b.column("qs")[0]

In [None]:
x = np.asarray(b.column("read_phred")[0].values)

In [None]:
x.mean()

In [None]:
x

In [None]:
plt.hist(x, bins=10);

In [None]:
plt.plot(x);

In [None]:
-10 * np.log10((10 ** -(x / 10)).mean())

In [None]:
(10 ** (x / 10)).mean()

In [None]:
-10 * np.log10((10 ** (-x / 10)).mean())

In [None]:
qs_duplex3[0]

In [None]:
len(qs_duplex2)

In [None]:
len(qs_simplex2)

In [None]:
plt.hist(qs_duplex2, bins=100, density=True);
# plt.hist(qs_simplex2, bins=100, density=True);

In [None]:
%%time
phred_simplex = duckdb.sql(
    "SELECT read_phred FROM arrow_ds WHERE NOT contains(name, ';')"
).arrow()

In [None]:
%%time
phred_simplex2 = duckdb.sql(
    "SELECT read_phred FROM parquet_ds WHERE NOT contains(name, ';')"
).arrow()

In [None]:
arrow_ds2 = ds.dataset(list(data_dir.glob("*.arrow"))[:2], format="arrow")

In [None]:
%%time
phred_simplex2 = duckdb.sql("SELECT read_phred FROM arrow_ds2").arrow()

In [None]:
len(phred_simplex2[0][0])

In [None]:
len(phred_simplex)

In [None]:
%%time
duckdb.sql("SELECT count(distinct(name)) FROM arrow_ds WHERE path IS NOT NULL;").arrow()

In [None]:
table = duckdb.read_parquet(str(data_dir / "*.parquet"))

In [None]:
duckdb.execute(
    f"CREATE VIEW reads AS SELECT * FROM read_parquet('{data_dir / '*.parquet'}');"
)

In [None]:
%%time
duckdb.sql("SELECT count(distinct(name)) FROM reads WHERE path IS NOT NULL;").arrow()

In [None]:
duckdb.sql("SELECT COUNT(name) FROM READS;")

In [None]:
lengths = duckdb.sql(
    "SELECT LENGTH(read_seq) FROM READS WHERE path NOT NULL;"
).fetchnumpy()

In [None]:
plt.hist(lengths["length(read_seq)"], bins=100, log=True);

In [None]:
plt.hist(
    lengths["length(read_seq)"][lengths["length(read_seq)"] < 20000], bins=100, log=True
);

In [None]:
read_groups = duckdb.sql(
    "SELECT path, LIST(name), COUNT(name) FROM READS GROUP BY path;"
).arrow()

In [None]:
read_groups  # [0][1000]

In [None]:
plt.hist(read_groups.column('count("name")'), bins=100, log=True, range=(0, 250));

In [None]:
read_groups.nbytes / 1e6

# False pairs vs. partial barcodes

In [None]:
%%time
duplex_ids = set(k for k in reads.keys() if ";" in k)
simplex_ids = set(k for k in reads.keys() if ";" not in k)

In [None]:
%%time
total_reads = 0
complete_barcodes = 0
name_to_barcode = {}
name_to_path = {}
for table in tqdm(sio.iter_gaf(gaf_filename)):
    name_col = table.column("name")
    path_col = table.column("path")
    for idx in range(len(table)):
        name = name_col[idx].as_py()
        path = set([s[1:] for s in path_col[idx].as_py()])
        name_to_path[name] = path
        total_reads += 1
        if ("BC:BIT0=0" in path or "BC:BIT0=1" in path) and (
            "BC:BIT29=0" in path or "BC:BIT29=1" in path
        ):
            complete_barcodes += 1
            barcode = tuple(f"BC:BIT{bit}=1" in path for bit in range(30))
            name_to_barcode[name] = barcode

In [None]:
def partial_barcode_mismatches(a, b):
    return set(k for k, v in Counter(s.split("=")[0] for s in (a ^ b)).items() if v > 1)

In [None]:
def mapping_status(id_, name_to_barcode, name_to_path):
    if id_ in name_to_barcode:
        return 0
    elif id_ in name_to_path:
        return 1
    else:
        return 2


STATUS_TO_NAME = {0: "Complete", 1: "Partial", 2: "Unmapped"}

census = pd.DataFrame(
    np.zeros((3, 6), dtype=np.uint32),
    columns=[
        "Complete/Complete",
        "Complete/Partial",
        "Complete/Unmapped",
        "Partial/Partial",
        "Partial/Unmapped",
        "Unmapped/Unmapped",
    ],
    index=["Duplex complete", "Duplex partial", "Duplex unmapped"],
)
total_reads = census.copy()
matches = census.copy()
mismatches = census.copy()

for duplex_id in tqdm(duplex_ids):
    parent_ids = duplex_id.split(";")
    parent_statuses = sorted(
        [
            mapping_status(parent_ids[0], name_to_barcode, name_to_path),
            mapping_status(parent_ids[1], name_to_barcode, name_to_path),
        ]
    )
    parent_col = "/".join(STATUS_TO_NAME[i] for i in parent_statuses)
    if duplex_id in name_to_barcode:
        row = "Duplex complete"
    elif duplex_id in name_to_path:
        row = "Duplex partial"
    else:
        row = "Duplex unmapped"
    total_reads.loc[row, parent_col] += 1
    if parent_statuses[0] != 2 and parent_statuses[1] != 2:
        if partial_barcode_mismatches(
            name_to_path[parent_ids[0]], name_to_path[parent_ids[1]]
        ):
            # if name_to_barcode[parent_ids[0]] == name_to_barcode[parent_ids[1]]:
            #    print(">>>",duplex_id);0/0
            # print(row, parent_col, partial_barcode_mismatches(name_to_path[parent_ids[0]], name_to_path[parent_ids[1]]));0/0
            mismatches.loc[row, parent_col] += 1
        else:
            matches.loc[row, parent_col] += 1

In [None]:
total_reads

In [None]:
matches

In [None]:
mismatches

# False pairs summary

In [None]:
def summarize_filename(p):
    a = p.parts[-2]
    return "/".join([a[: a.index("_") + 2], p.parts[-1]])

In [None]:
data_dir = Path("/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/")

gaf_filenames = [
    # data_dir / "230707_repressilators/channel-135_merged.gaf",
    data_dir / "230707_repressilators/dorado_0.4.0/channel-135_merged_barcodeonly.gaf",
    # data_dir / "230726_carlos/channel-100_merged.gaf",
    data_dir / "230726_carlos/dorado_0.4.0/channel-100_merged_barcodeonly.gaf",
    # data_dir / "230818_bcd_rbses/channel-100_merged.gaf",
    # data_dir / "230818_repressilators/channel-1032_merged.gaf",
    # data_dir / "230922_bcd_rbses_constitutive/channel-100_merged.gaf",
]

rows = []

for gaf_filename in tqdm(gaf_filenames):
    name_to_barcode = {}
    name_to_path = {}
    for table in tqdm(sio.iter_gaf(gaf_filename)):
        name_col = table.column("name")
        path_col = table.column("path")
        for idx in range(len(table)):
            name = name_col[idx].as_py()
            path = set([s[1:] for s in path_col[idx].as_py()])
            name_to_path[name] = path
            if ("BC:BIT0=0" in path or "BC:BIT0=1" in path) and (
                "BC:BIT29=0" in path or "BC:BIT29=1" in path
            ):
                barcode = tuple(f"BC:BIT{bit}=1" in path for bit in range(30))
                name_to_barcode[name] = barcode
    reads = pyfastx.Fastq(str(gaf_filename).replace("_barcodeonly.gaf", ".fastq.gz"))
    duplex_ids = set(k for k in reads.keys() if ";" in k)
    num_mapped_reads = 0
    num_barcodes_match = 0
    for duplex_id in tqdm(duplex_ids):
        parent_ids = duplex_id.split(";")
        if (
            duplex_id in name_to_barcode
            and parent_ids[0] in name_to_barcode
            and parent_ids[1] in name_to_barcode
        ):
            num_mapped_reads += 1
            if (
                name_to_barcode[duplex_id]
                == name_to_barcode[parent_ids[0]]
                == name_to_barcode[parent_ids[1]]
            ):
                num_barcodes_match += 1
    experiment_name = gaf_filename.parts[-2][: gaf_filename.parts[-2].index("_") + 2]
    num_total_reads = len(reads)
    num_duplex_reads = len(duplex_ids)
    rows.append(
        {
            "Experiment": experiment_name,
            "Total Reads": num_total_reads,
            "Duplex Reads": num_duplex_reads,
            "Mapped Reads": num_mapped_reads,
            "Barcodes Match": num_barcodes_match,
            "Duplex Rate": num_duplex_reads / num_total_reads,
            "Pair Mismatch Rate": 1 - num_barcodes_match / num_mapped_reads,
        }
    )

In [None]:
pd.DataFrame(rows).style.format({"Pair Mismatch Rate": "{:.0%}"})

In [None]:
print(pd.DataFrame(rows).to_markdown(tablefmt="github", index=False))