In [None]:
import itertools as it
import operator
import re
from pathlib import Path

import gfapy
import holoviews as hv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.compute as pc
from pyarrow import csv
from tqdm.auto import tqdm, trange

In [None]:
hv.extension("bokeh")

# Config

In [None]:
data_dir = Path(
    "/home/jqs1/scratch/jqs1/sequencing/230818_bcd_rbses/20230818_1343_1A_PAQ97606_f49ab41c"
)
gaf_filename = data_dir / "temp/mapped_t4.gaf"
gfa = gfapy.Gfa.from_file(data_dir / "references/bcd_rbses.gfa")

# GAF

In [None]:
# SEE: http://samtools.github.io/hts-specs/SAMv1.pdf
# and https://samtools.github.io/hts-specs/SAMtags.pdf
# pyarrow CSV parser only supports pa.dictionary with int32 indices
SAM_TAG_TYPES = {
    "A": pa.dictionary(pa.int32(), pa.string()),
    "f": pa.float32(),
    "i": pa.int32(),
    "Z": pa.string(),
}
GAF_COLUMN_TYPES = {
    "query_length": pa.uint64(),
    "query_start": pa.uint64(),
    "query_end": pa.uint64(),
    "strand": pa.dictionary(pa.int32(), pa.string()),
    "path": pa.string(),
    "path_length": pa.uint64(),
    "path_start": pa.uint64(),
    "path_end": pa.uint64(),
    "residue_matches": pa.uint64(),
    "block_length": pa.uint64(),
    "mapping_quality": pa.uint8(),
}
SAM_TAG_REGEX = re.compile(
    r"^(?P<tag>[a-zA-Z0-9]+):(?P<tag_value>A:.|f:\d+(\.\d+)?|i:\d+|Z:.*)$"
)


def parse_gaf_types(gaf_filename):
    with open(gaf_filename, "r") as f:
        first_row = f.readline().split("\t")
    columns_to_parse = {}
    column_types = []
    for idx in reversed(range(len(first_row))):
        if match := SAM_TAG_REGEX.match(first_row[idx]):
            tag = match.group("tag")
            column_types.append((tag, pa.string()))
            tag_value = match.group("tag_value")
            columns_to_parse[tag] = tag_value[: tag_value.index(":")]
        else:
            break
    column_types.extend(reversed(GAF_COLUMN_TYPES.items()))
    for idx in reversed(range(idx + 1 - len(GAF_COLUMN_TYPES))):
        if match := SAM_TAG_REGEX.match(first_row[idx]):
            tag = match.group("tag")
            column_types.append((tag, pa.string()))
            tag_value = match.group("tag_value")
            type_ = tag_value[: tag_value.index(":")]
            columns_to_parse[tag] = type_
        else:
            if idx != 0:
                raise ValueError("expecting SAM tags following FASTQ read name")
            else:
                column_types.append(("name", pa.string()))
    column_types = dict(reversed(column_types))
    return column_types


def parse_gaf_table(table, columns_to_parse):
    for tag, type_ in columns_to_parse.items():
        col_idx = table.column_names.index(tag)
        new_column = pc.replace_substring_regex(table[tag], f"{tag}:{type_}:", "").cast(
            SAM_TAG_TYPES[type_]
        )
        table = table.set_column(col_idx, tag, new_column)
    path = pa.array(
        [re.split(r"(?=<|>)", s.as_py())[1:] for s in table.column("path")],
        type=pa.list_(pa.dictionary(pa.int16(), pa.string())),
    )
    table = table.set_column(table.column_names.index("path"), "path", path)
    return table


def parse_gaf(gaf_filename):
    column_types = parse_gaf_types(gaf_filename)
    read_options = csv.ReadOptions(column_names=column_types.keys())
    parse_options = csv.ParseOptions(delimiter="\t")
    convert_options = csv.ConvertOptions(column_types=column_types)
    with csv.open_csv(
        gaf_filename,
        read_options=read_options,
        parse_options=parse_options,
        convert_options=convert_options,
    ) as f:
        while True:
            try:
                table = parse_gaf_table(
                    pa.Table.from_batches([f.read_next_batch()]), columns_to_parse
                )
            except StopIteration:
                break
            yield table

In [None]:
%%time
qs = 0
for table in tqdm(parse_gaf(gaf_filename)):
    qs += 1

In [None]:
table

In [None]:
count_dist

In [None]:
msgs = barcode_msgs_cluster[5][0]

In [None]:
# gam_filename2 = "duplex_sup1_vg2.gam"
gam_filename2 = "duplex_sup1_subsample_vg2.gam"
msgs3 = []
lens = []
for msg in tqdm(it.islice(stream.parse(gam_filename2, vg_pb2.Alignment), 10_000_000)):
    lens.append(len(msg.sequence))
    if 800 <= len(msg.sequence) <= 900:
        msgs3.append(msg)
    # path = set([m.position.name for m in msg.path.mapping])
    # if not (("BIT0ON" in path or "BIT0OFF" in path) and "mScarletI" in path):
    #     continue
    # msgs2.append(msg)

In [None]:
msgs3 = list(tqdm(stream.parse("duplex_sup1_subsample_vg2.gam", vg_pb2.Alignment)))

In [None]:
for i in range(100):
    msg = msgs3[i]
    print(len(msg.sequence))
    print(" ".join(m.position.name for m in msg.path.mapping))
    print()

In [None]:
len(
    "TCCTCAATCGCACTGGAAACATCAAGGTCGACGAAAGACCGCTGAGGAGCCAGATACATAGATTACCACAACTCCGAGCCCTTCCACCAAAAAAAACAGATAGCCGCGCGAACGCGGCTAACTGTTGAAAAAAAACAGATAACAGATACCGAAGTATCTGTTATCTTTCCCAAAAAACCCCTCAAGACCCGTTTAGAGGCCCCAAGGGGTTATTACTGATGGCAATGTGATGTCCTCATCTTACTCCCTCTAGTCTATCATTACCCTCCTCCTGCTCTTAACTACCCTCATTCCGACCCTTACTACTACATCATCGACCTTTCTCCATACCCAACTGTCCTAACAACCAACTACTCCGCCTCTTCATCCTCTTTCAACGTTCTCCCTCTATCAACTCAGCAACCACACTCAACTACCATGACATTACACCTCATTCTCCCGACTTTCCACATACTTCCCAGTTTACTCCCTACACCTCCAAGATTCCATACCCACTCTCTTCGCTCTCTACACCCACCAATAAGTTCCTAACAAATCACATCCCGTATCTGTTATGTAATTGCTAGTTAAACAACCCATCCCACCAGATAAATCATTCCCACTACCCGTCAATCCACCATTCCTCAACGAAACTTCATCACTCTCCTCCGCACCCTAACATACAACTCTCGAATACTCTCCCACCTCAACTGCTTCTTCTCTTACACCCTCTGTCTATCATCTCCAAACCACAGACATCTTCTCTCCAACCTTCGCCCTCTTACTTATCTACCCAGACTCCACTACTACTCACTCTGTCACCATAATTCCTCCTCCTGATCCTCCTTCAATACATCCCGAAACACACACTAAACCACCCGTCACCTTTCTCCTTTCCTCTGAGGCTAGCTAACGTTACTGTACGGTATTGTAGAAAAAGGCATAGTGCTGCTAACGTTCGTCCCTATAGTGAGTCGTATTATGTAGTTCCTTATCATCTGC"
)

In [None]:
len(
    "CTGAGGAGCCAGATACATAGATTACCACAACTCCGAGCCCTTCCACCAAAAAAAACAGATAGCCGCGCGAACGCGGCTAACTGTTGAAAAAAAACAGATAACAGATACCGAAGTATCTGTTATCTTTCCCAAAAAACCCCTCAAGACCCGTTTAGAGGCCCCAAGGGGTTATTACTGATGGCAATGTGATGTCCTCATCTTACTCCCTCTAGTCTATCATTACCCTCCTCCTGCTCTTAACTACCCTCATTCCGACCCTTACTACTACATCATCGACCTTTCTCCATACCCAACTGTCCTAACAACCAACTACTCCGCCTCTTCATCCTCTTTCAACGTTCTCCCTCTATCAACTCAGCAACCACACTCAACTACCATGACATTACACCTCATTCTCCCGACTTTCCACATACTTCCCAGTTTACTCCCTACACCTCCAAGATTCCATACCCACTCTCTTCGCTCTCTACACCCACCAATAAGTTCCTAACAAATCACATCCCGTATCTGTTATGTAATTGCTAGTTAAACAACCCATCCCACCAGATAAATCATTCCCACTACCCGTCAATCCACCATTCCTCAACGAAACTTCATCACTCTCCTCCGCACCCTAACATACAACTCTCGAATACTCTCCCACCTCAACTGCTTCTTCTCTTACACCCTCTGTCTATCATCTCCAAACCACAGACATCTTCTCTCCAACCTTCGCCCTCTTACTTATCTACCCAGACTCCACTACTACTCACTCTGTCACCATAATTCCTCCTCCTGATCCTCCTTCAATACATCCCGAAACACACACTAAACCACCCGTCACCTTTCTCCTTTCCTCT"
)

In [None]:
import pyfastx

In [None]:
fq = pyfastx.Fastq("duplex_sup1_subsample.fastq")

In [None]:
fq["e8a89209-b020-4f67-ab33-6a97fa35366e;6ce12155-a33f-4657-bb71-eebc07dd1ff7"].seq

In [None]:
%%time
# offsets = []
# ids = set()
for msg in msgs3:
    if 800 <= len(msg.sequence) <= 900 and msg.path.mapping[0].position.name == "BetI":
        if 760 <= msg.path.mapping[0].position.offset <= 780 and ";" in msg.name:
            # ids.add(msg.name)
            print(f">{msg.name}")
            print(fq[msg.name].seq)
            # print(msg.sequence)
        # print(msg);0/0
        # offsets.append(msg.path.mapping[0].position.offset)

In [None]:
len(ids)

In [None]:
plt.hist(offsets, range=(720, 820), bins=100);

In [None]:
msg.name

In [None]:
plt.hist(lens, bins=100);

In [None]:
plt.hist(lens, bins=100);

In [None]:
parts = {
    "LacI": "ATGAAACCAGTAACGTTATACGATGTCGCAGAGTATGCCGGTGTCTCTTATATGACCGTTTCCCGCGTGGTGAACCAGGCCAGCCACGTTTCTGCGAAAACGCGGGAAAAAGTGGAAGCGGCGATGGTGGAGCTGAATTACATTCCCAACCGCGTGGCACAACAACTGGCGGGCAAACAGTCGTTGCTGATTGGCGTTGCCACCTCCAGTCTGGCCCTGCACGCGCCGTCGCAAATTGTCGCGGCGATTAAATCTCGCGCCGATCAACTGGGTGCCAGCGTGGTGGTGTCGATGGTAGAACGAAGCGGCGTCGAAGCCTGTAAAGCGGCGGTGCACAATCTTCTCGCGCAACGCGTCAGTGGGCTGATCATTAACTATCCGCTGGATGACCAGGATGCCATTGCTGTGGAAGCTGCCTGCACTAATGTTCCGGCGTTATTTCTTGATGTCTCTGACCAGACACCCATCAACAGTATTATTTACTCCCATGAGGACGGTACGCGACTGGGCGTGGAGCATCTGGTCGCATTGGGTCACCAGCAAATCGCGCTGTTAGCGGGCCCATTAAGTTCTGTCTCGGCGCGTCTGCGTCTGGCTGGCTGGCATAAATATCTCACTCGCAATCAAATTCAGCCGATAGCGGAACGGGAAGGCGACTGGAGTGCCATGTCCGGTTTTCAACAAACCATGCAAATGCTGAATGAGGGCATCGTTCCCACTGCGATGCTGGTTGCCAACGATCAGATGGCGCTGGGCGCAATGCGCGCCATTACCGAGTCCGGGCTGCGCGTTGGTGCGGATATCTCGGTAGTGGGATACGACGATACCGAAGATAGCTCATGTTATATCCCGCCGTTAACCACCATCAAACAGGATTTTCGCCTGCTGGGGCAAACCAGCGTGGACCGCTTGCTGCAACTCTCTCAGGGCCAGGCGGTGAAGGGCAATCAGCTGTTGCCAGTCTCACTGGTGAAAAGAAAAACCACCCTGGCGCCCAATACGCAAACCGCCTCTCCCCGCGCGTTGGCCGATTCATTAATGCAGCTGGCACGACAGGTTTCCCGACTGGAAAGCGGGCAGT",
    "PhlF": "ATGGCACGTACCCCGAGCCGTAGCAGCATTGGTAGCCTGCGTAGTCCGCATACCCATAAAGCAATTCTGACCAGCACCATTGAAATCCTGAAAGAATGTGGTTATAGCGGTCTGAGCATTGAAAGCGTGGCACGTCGCGCCGGTGCAGGCAAACCGACCATTTATCGTTGGTGGACCAACAAAGCAGCACTGATTGCCGAAGTGTATGAAAATGAAATCGAACAGGTACGTAAATTTCCGGATTTGGGTAGCTTTAAAGCCGATCTGGATTTTCTGCTGCATAATCTGTGGAAAGTTTGGCGTGAAACCATTTGTGGTGAAGCATTTCGTTGTGTTATTGCAGAAGCACAGTTGGACCCTGTAACCCTGACCCAACTGAAAGATCAGTTTATGGAACGTCGTCGTGAGATACCGAAAAAACTGGTTGAAGATGCCATTAGCAATGGTGAACTGCCGAAAGATATCAATCGTGAACTGCTGCTGGATATGATTTTTGGTTTTTGTTGGTATCGCCTGCTGACCGAACAGTTGACCGTTGAACAGGATATTGAAGAATTTACCTTCCTGCTGATTAATGGTGTTTGTCCGGGTACACAGTGTTAA",
    "BetI": "ATGCCGAAACTGGGTATGCAGAGCATTCGTCGTCGTCAGCTGATTGATGCAACCCTGGAAGCAATTAATGAAGTTGGTATGCATGATGCAACCATTGCACAGATTGCACGTCGTGCCGGTGTTAGCACCGGTATTATTAGCCATTATTTCCGCGATAAAAACGGTCTACTGGAAGCAACCATGCGTGATATTACCAGCCAGCTGCGTGATGCAGTTCTGAATCGTCTGCATGCACTGCCGCAGGGTAGCGCAGAACAGCGTCTGCAGGCAATTGTTGGTGGTAATTTTGATGAAACCCAGGTTAGCAGCGCAGCAATGAAAGCATGGCTGGCATTTTGGGCAATCAGCATGCATCAGCCGATGCTGTATCGTCTGCAGCAGGTTAGCAGTCGTCGTCTGCTGAGCAATCTGGTTAGCGAATTTCGTCGTGAACTGCCTCGTGAACAGGCACAAGAGGCAGGTTATGGTCTGGCAGCACTGATTGATGGTCTGTGGCTGCGTGCAGCACTGAGCGGTAAACCGCTGGATAAAACCCGTGCAAATAGCCTGACCCGTCATTTTATCACCCAGCATCTGCCGACCGATTAA",
}

In [None]:
parts.keys()

In [None]:
?edlib.align

In [None]:
%%time
# offsets = []
snp_threshold = 50
ids = set()
misassembled_simplex = defaultdict(list)
misassembled_duplex = defaultdict(list)
for msg in tqdm(msgs3):
    if 800 <= len(msg.sequence) <= 900:
        read = fq[msg.name].seq
        # print(msg.name)
        d1 = tuple(
            edlib.align(part, read, mode="HW")["editDistance"]
            for part in parts.values()
        )
        d2 = tuple(
            edlib.align(reverse_complement(part), read, mode="HW")["editDistance"]
            for part in parts.values()
        )
        if min(d1) <= min(d2):
            d = d1
        else:
            d = d2
        # print(k)
        # print()
        # continue
        key = tuple(dd <= snp_threshold for dd in d)
        if ";" in msg.name:
            misassembled_duplex[key].append(read)
        else:
            misassembled_simplex[key].append(read)
        # ids.add(msg.name)
        # print(f">{msg.name}")
        # print(fq[msg.name].seq)
        # print(msg.sequence)
        # print(msg);0/0
        # offsets.append(msg.path.mapping[0].position.offset)

In [None]:
{k: len(v) for k, v in misassembled_duplex.items()}

In [None]:
{k: len(v) for k, v in misassembled_simplex.items()}

In [None]:
idx = 9
read = misassembled_duplex[(False, False, False)][idx]
end = "ATCACATTGCCATCAGTAATAACCCCTTGGGGCCTCTAAACGGGTCTTGAGGGGTTTTTTGGGAAAGATAACAGATACTTCGGTATCTGTTATCTGTTTTTTTTCAACAGATAGCCGCGTTCGCGCGGCTATCTGTTTTTTTTGGTGGAAGGGCTCGGAGTTGTGGTAATCTATGTATCCTGG"
print(edlib.align(end, read, mode="HW"))
print(edlib.align(reverse_complement(end), read, mode="HW"))

In [None]:
BARCODE_END = "ATCACATTGCCATCAGTAATAACCCCTTGGGGCCTCTAAACGGGTCTTGAGGGGTTTTTTG"


def trim_barcode(read, end=BARCODE_END, threshold=10):
    if (d1 := edlib.align(end, read, mode="HW"))["editDistance"] <= threshold:
        # print("1",d1)
        return (
            read[d1["locations"][0][1] - len(end) + 1 :],
            read[: d1["locations"][0][1] - len(end) + 1],
        )
    elif (d2 := edlib.align(reverse_complement(end), read, mode="HW"))[
        "editDistance"
    ] <= threshold:
        # print("2",d2)
        return reverse_complement(
            read[: d2["locations"][0][1] + 1]
        ), reverse_complement(read[d2["locations"][0][1] + 1 :])
    else:
        return read

In [None]:
read

In [None]:
trim_barcode(read)

In [None]:
trim_barcode(reverse_complement(read))

In [None]:
# TTT: full circuit, no barcode (?)

In [None]:
{k: len(v) for k, v in misassembled_duplex.items()}

In [None]:
for read in misassembled_duplex[(False, False, False)]:
    print(">FFF")
    print(trim_barcode(read)[0])

In [None]:
{k: len(v) for k, v in misassembled_simplex.items()}

In [None]:
Counter([extract_read_segments(msg)["RBS2"] for msg in msgs if ";" in msg.name])

In [None]:
Counter([extract_read_segments(msg)["RBS2"] for msg in msgs])