In [None]:
import itertools as it
import operator
import re
from collections import Counter
from pathlib import Path

import gfapy
import holoviews as hv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
import pyfastx
from pyarrow import csv
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.sequencing.gaf as gaf

In [None]:
hv.extension("bokeh")

# Config

In [None]:
data_dir = Path("/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/")
# gaf_filename = data_dir / "barcode.gfa"
# gfa = gfapy.Gfa.from_file(data_dir / "references/bcd_rbses.gfa")

# Completeness

In [None]:
# gaf_filename = data_dir / "230707_repressilators/channel-135_merged.gaf"
# gaf_filename = data_dir / "230726_carlos/channel-100_merged.gaf"
# gaf_filename = data_dir / "230818_bcd_rbses/channel-100_merged.gaf"
gaf_filename = data_dir / "230818_repressilators/channel-1032_merged.gaf"
# gaf_filename = data_dir / "230922_bcd_rbses_constitutive/channel-100_merged.gaf"

In [None]:
%%time
reads = pyfastx.Fastq(str(gaf_filename).replace(".gaf", ".fastq"))

In [None]:
%%time
duplex_ids = set(k for k in reads.keys() if ";" in k)
simplex_ids = set(k for k in reads.keys() if ";" not in k)

In [None]:
%%time
total_reads = 0
complete_barcodes = 0
name_to_barcode = {}
name_to_path = {}
for table in tqdm(gaf.iter_gaf(gaf_filename)):
    name_col = table.column("name")
    path_col = table.column("path")
    for idx in range(len(table)):
        name = name_col[idx].as_py()
        path = set([s[1:] for s in path_col[idx].as_py()])
        name_to_path[name] = path
        total_reads += 1
        if ("BC:BIT0=0" in path or "BC:BIT0=1" in path) and (
            "BC:BIT29=0" in path or "BC:BIT29=1" in path
        ):
            complete_barcodes += 1
            barcode = tuple(f"BC:BIT{bit}=1" in path for bit in range(30))
            name_to_barcode[name] = barcode

In [None]:
def partial_barcode_mismatches(a, b):
    return set(k for k, v in Counter(s.split("=")[0] for s in (a ^ b)).items() if v > 1)

In [None]:
def mapping_status(id_, name_to_barcode, name_to_path):
    if id_ in name_to_barcode:
        return 0
    elif id_ in name_to_path:
        return 1
    else:
        return 2


STATUS_TO_NAME = {0: "Complete", 1: "Partial", 2: "Unmapped"}

census = pd.DataFrame(
    np.zeros((3, 6), dtype=np.uint32),
    columns=[
        "Complete/Complete",
        "Complete/Partial",
        "Complete/Unmapped",
        "Partial/Partial",
        "Partial/Unmapped",
        "Unmapped/Unmapped",
    ],
    index=["Duplex complete", "Duplex partial", "Duplex unmapped"],
)
total_reads = census.copy()
matches = census.copy()
mismatches = census.copy()

for duplex_id in tqdm(duplex_ids):
    parent_ids = duplex_id.split(";")
    parent_statuses = sorted(
        [
            mapping_status(parent_ids[0], name_to_barcode, name_to_path),
            mapping_status(parent_ids[1], name_to_barcode, name_to_path),
        ]
    )
    parent_col = "/".join(STATUS_TO_NAME[i] for i in parent_statuses)
    if duplex_id in name_to_barcode:
        row = "Duplex complete"
    elif duplex_id in name_to_path:
        row = "Duplex partial"
    else:
        row = "Duplex unmapped"
    total_reads.loc[row, parent_col] += 1
    if parent_statuses[0] != 2 and parent_statuses[1] != 2:
        if partial_barcode_mismatches(
            name_to_path[parent_ids[0]], name_to_path[parent_ids[1]]
        ):
            # if name_to_barcode[parent_ids[0]] == name_to_barcode[parent_ids[1]]:
            #    print(">>>",duplex_id);0/0
            # print(row, parent_col, partial_barcode_mismatches(name_to_path[parent_ids[0]], name_to_path[parent_ids[1]]));0/0
            mismatches.loc[row, parent_col] += 1
        else:
            matches.loc[row, parent_col] += 1

In [None]:
total_reads

In [None]:
matches

In [None]:
mismatches

In [None]:
a = data_dir / "230707_repressilators/channel-135_merged.gaf"

In [None]:
def summarize_filename(p):
    a = p.parts[-2]
    return "/".join([a[: a.index("_") + 2], p.parts[-1]])

In [None]:
summarize_filename(a)

In [None]:
gaf_filenames = [
    data_dir / "230707_repressilators/channel-135_merged.gaf",
    data_dir / "230726_carlos/channel-100_merged.gaf",
    data_dir / "230818_bcd_rbses/channel-100_merged.gaf",
    data_dir / "230818_repressilators/channel-1032_merged.gaf",
    data_dir / "230922_bcd_rbses_constitutive/channel-100_merged.gaf",
]

rows = []

for gaf_filename in tqdm(gaf_filenames):
    name_to_barcode = {}
    name_to_path = {}
    for table in tqdm(gaf.iter_gaf(gaf_filename)):
        name_col = table.column("name")
        path_col = table.column("path")
        for idx in range(len(table)):
            name = name_col[idx].as_py()
            path = set([s[1:] for s in path_col[idx].as_py()])
            name_to_path[name] = path
            if ("BC:BIT0=0" in path or "BC:BIT0=1" in path) and (
                "BC:BIT29=0" in path or "BC:BIT29=1" in path
            ):
                barcode = tuple(f"BC:BIT{bit}=1" in path for bit in range(30))
                name_to_barcode[name] = barcode
    reads = pyfastx.Fastq(str(gaf_filename).replace(".gaf", ".fastq"))
    duplex_ids = set(k for k in reads.keys() if ";" in k)
    num_mapped_reads = 0
    num_barcodes_match = 0
    for duplex_id in tqdm(duplex_ids):
        parent_ids = duplex_id.split(";")
        if (
            duplex_id in name_to_barcode
            and parent_ids[0] in name_to_barcode
            and parent_ids[1] in name_to_barcode
        ):
            num_mapped_reads += 1
            if (
                name_to_barcode[duplex_id]
                == name_to_barcode[parent_ids[0]]
                == name_to_barcode[parent_ids[1]]
            ):
                num_barcodes_match += 1
    experiment_name = gaf_filename.parts[-2][: gaf_filename.parts[-2].index("_") + 2]
    num_total_reads = len(reads)
    num_duplex_reads = len(duplex_ids)
    rows.append(
        {
            "Experiment": experiment_name,
            "Total Reads": num_total_reads,
            "Duplex Reads": num_duplex_reads,
            "Mapped Reads": num_mapped_reads,
            "Barcodes Match": num_barcodes_match,
            "Duplex Rate": num_duplex_reads / num_total_reads,
            "Pair Mismatch Rate": 1 - num_barcodes_match / num_mapped_reads,
        }
    )

In [None]:
pd.DataFrame(rows).style.format({"Pair Mismatch Rate": "{:.0%}"})

In [None]:
print(pd.DataFrame(rows).to_markdown(tablefmt="github", index=False))

# Old

In [None]:
%%time
segments = Counter()
ends = Counter()
total_reads = 0
for table in tqdm(gaf.iter_gaf(gaf_filename)):
    path_col = table.column("path")
    for idx in range(len(table)):
        path = [s[1:] for s in path_col[idx].as_py()]
        segments.update(path)
        ends[path[0]] += 1
        ends[path[-1]] += 1
        total_reads += 1

In [None]:
for k, v in sorted(
    {k: f"{v/total_reads*100:.0f}" for k, v in segments.items()}.items()
):
    print(f"{k}: {v}%")

In [None]:
for k, v in sorted({k: f"{v/total_reads*100:.0f}" for k, v in ends.items()}.items()):
    print(f"{k}: {v}%")

In [None]:
(
    complete_barcodes,
    total_reads,
    complete_barcodes / total_reads,
    len(duplex_ids),
    len(duplex_ids) / total_reads,
)

In [None]:
duplex_mismatches = []
duplex_matches = []
duplex_missingone = []
duplex_missingboth = []
duplex_nobarcode = []
for duplex_id in tqdm(duplex_ids):
    if duplex_id not in name_to_barcode:
        duplex_nobarcode.append(duplex_id)
    else:
        reads = duplex_id.split(";")
        num_alignments = sum(read in name_to_barcode for read in reads)
        if num_alignments == 0:
            duplex_missingboth.append(duplex_id)
        elif num_alignments == 1:
            duplex_missingone.append(duplex_id)
        else:
            if name_to_barcode[reads[0]] == name_to_barcode[reads[1]]:
                duplex_matches.append(duplex_id)
            else:
                duplex_mismatches.append(duplex_id)

In [None]:
(
    len(duplex_nobarcode),
    len(duplex_mismatches),
    len(duplex_matches),
    len(duplex_missingone),
    len(duplex_missingboth),
)

In [None]:
name_to_path[duplex_nobarcode[8]]

In [None]:
duplex_read = duplex_nobarcode[8]
parents = duplex_read.split(";")
print("1>", name_to_path[parents[0]])
print("2>", name_to_path[parents[1]])
m = partial_barcode_mismatches(name_to_path[parents[0]], name_to_path[parents[1]])
print()
print(m)

In [None]:
for duplex_read in duplex_nobarcode[:1]:
    parents = duplex_read.split(";")
    print("1>", name_to_path[parents[0]])
    print("2>", name_to_path[parents[1]])
    m = partial_barcode_mismatches(name_to_path[parents[0]], name_to_path[parents[1]])
    print(m)

In [None]:
a = name_to_path[parents[0]]
b = name_to_path[parents[1]]

In [None]:
b

In [None]:
a & b

In [None]:
a ^ b

In [None]:
duplex_nobarcode_distances = []
no_parent_alignment = []
for duplex_read in duplex_nobarcode:
    parents = duplex_read.split(";")
    if parents[0] not in name_to_path or parents[1] not in name_to_path:
        no_parent_alignment.append(duplex_read)
    else:
        m = partial_barcode_mismatches(
            name_to_path[parents[0]], name_to_path[parents[1]]
        )
        duplex_nobarcode_distances.append(len(m))

In [None]:
(len(no_parent_alignment), len(duplex_nobarcode_distances))

In [None]:
sorted(Counter(duplex_nobarcode_distances).items())

In [None]:
print(pd.DataFrame([{"foo": 100, "bar": 200}]).to_markdown())