In [None]:
import itertools as it
import re

import gfapy
import holoviews as hv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
from pyarrow import csv
from tqdm.auto import tqdm

In [None]:
hv.extension("bokeh")

In [None]:
!micromamba list|grep protobuf

In [None]:
!micromamba install -y protobuf=4.21.7 async_generator pyarrow

In [None]:
!pip install --no-deps pystream-protobuf

# GAF

In [None]:
import re
from collections import Counter

In [None]:
def segment_frequences(table, segment_names):
    rows = []
    for path in table.column("path"):
        read_segments = re.split(r">|<", str(path))
        segment_counts = Counter(read_segments)
        rows.append([segment_counts[s] for s in segment_names])
    return pd.DataFrame(rows, columns=segment_names, index=table.column("name"))

In [None]:
gfa = gfapy.Gfa.from_file("nao745bc.gfa")

In [None]:
!du -hs *.gaf

In [None]:
# gaf_filename = "duplex_hac1_subsample_dbg.gaf"
# gaf_filename = "duplex_hac1_subsample_dbg2.gaf"
# gaf_filename = "duplex_hac1_subsample_vg2.gaf"
gaf_filename = "duplex_sup1_vg2.gaf"

In [None]:
%%time
# SEE: http://samtools.github.io/hts-specs/SAMv1.pdf
# and https://samtools.github.io/hts-specs/SAMtags.pdf
SAM_TAG_TYPES = {
    "A": pa.dictionary(pa.int32(), pa.string()),
    "f": pa.float32(),
    "i": pa.int32(),
    "Z": pa.string(),
}
with open(gaf_filename, "r") as f:
    first = f.readline().split("\t")
tags = first[12:]
tag_column_types = {(t := tag.split(":"))[0]: SAM_TAG_TYPES[t[1]] for tag in tags}
column_types = {
    "name": pa.string(),
    "query_length": pa.uint64(),
    "query_start": pa.uint64(),
    "query_end": pa.uint64(),
    "strand": pa.dictionary(pa.int32(), pa.string()),
    "path": pa.string(),
    "path_length": pa.uint64(),
    "path_start": pa.uint64(),
    "path_end": pa.uint64(),
    "residue_matches": pa.uint64(),
    "block_length": pa.uint64(),
    "mapping_quality": pa.uint8(),
    **{tag: pa.string() for tag in tag_column_types.keys()},
}
read_options = csv.ReadOptions(column_names=column_types.keys())
parse_options = csv.ParseOptions(delimiter="\t")
convert_options = csv.ConvertOptions(column_types=column_types)
with csv.open_csv(
    gaf_filename,
    read_options=read_options,
    parse_options=parse_options,
    convert_options=convert_options,
) as f:
    # tt = f.read_next_batch()
    segment_counts = Counter()
    barcode_counts = Counter()
    total = 0
    duplex = 0
    singleton = 0
    filtered = 0
    # while True:
    # for _ in tqdm(it.islice(it.count(), 10)):
    for _ in tqdm(it.count()):
        try:
            table = f.read_next_batch()
        except StopIteration:
            break
        # break
        freqs = segment_frequences(table, gfa.segment_names)
        duplex += freqs.index.str.contains(";").sum()
        # filtered_freqs = freqs[(freqs.max(axis=1) == 1) & ((freqs["BIT0OFF"] == 1) | (freqs["BIT0ON"] == 1)) & (freqs["pPhlF"] == 1)]
        filtered_freqs = freqs[
            freqs.index.str.contains(";")
            & (freqs.max(axis=1) == 1)
            & ((freqs["BIT0OFF"] == 1) | (freqs["BIT0ON"] == 1))
            & (freqs["RBS1"] == 1)
        ]
        # filtered_freqs = freqs
        filtered += len(filtered_freqs)
        total += len(freqs)
        segment_counts.update(list(filtered_freqs.itertuples(index=False)))
        barcode_counts.update(
            list(
                filtered_freqs.loc[
                    :, filtered_freqs.columns.str.startswith("BIT")
                ].itertuples(index=False)
            )
        )

In [None]:
(filtered, total, filtered / total, duplex, duplex / total)

In [None]:
def bit_sums(freqs):
    return pd.DataFrame(
        {f"BIT{i}": freqs[f"BIT{i}ON"] + freqs[f"BIT{i}OFF"] for i in range(30)}
    )

In [None]:
filtered_freqs.mean(axis=0)

In [None]:
filtered_freqs.loc[:, filtered_freqs.columns.str.endswith("ON")].mean(axis=0).plot.bar()

In [None]:
filtered_freqs.mean(axis=0).plot.bar()

In [None]:
segment_counts.most_common(3)

In [None]:
barcode_counts.most_common(3)

In [None]:
plt.hist(barcode_counts.values(), bins=100, log=True);

In [None]:
n, bins, patches = plt.hist(
    barcode_counts.values(),
    100,
    histtype="step",
    density=False,
    cumulative=-1,
    log=True,
)

In [None]:
n, bins, patches = plt.hist(
    barcode_counts.values(),
    100,
    range=(0, 10),
    histtype="step",
    density=False,
    cumulative=-1,
    log=True,
)

In [None]:
len(barcode_counts)

In [None]:
sum(1 for v in barcode_counts.values() if v == 1)

In [None]:
sum(1 for v in barcode_counts.values() if v == 2)

In [None]:
sum(1 for v in barcode_counts.values() if v == 3)

In [None]:
sum(1 for v in barcode_counts.values() if v == 4)

In [None]:
sum(1 for v in barcode_counts.values() if v >= 5)

In [None]:
sum(1 for v in barcode_counts.values() if 5 <= v < 20)

In [None]:
sum(1 for v in barcode_counts.values() if 20 <= v < 100)

In [None]:
sum(1 for v in barcode_counts.values() if v >= 100)

In [None]:
sum(1 for v in barcode_counts.values() if v >= 100)

In [None]:
sum(1 for v in barcode_counts.values() if v >= 2)

In [None]:
max(barcode_counts.values())

In [None]:
sum(v for v in barcode_counts.values() if v == 1) / total

In [None]:
sum(1 for v in barcode_counts.values() if v >= 2)

In [None]:
sum(v for v in barcode_counts.values() if v >= 10)

In [None]:
sum(v for v in barcode_counts.values() if v >= 10) / filtered

In [None]:
filtered

In [None]:
sum(v for v in barcode_counts.values())

In [None]:
n, bins, patches = plt.hist(
    barcode_counts.values(),
    100,
    range=(0, 20),
    histtype="step",
    density=True,
    cumulative=-1,
    log=True,
)

In [None]:
gfa.try_get_segment("BIT1OFF")

In [None]:
import uuid

In [None]:
u = uuid.UUID("6e507a8a-c271-4561-8768-0f9bf9d4c301")

In [None]:
import sys

In [None]:
sys.getsizeof(u.int)

In [None]:
sys.getsizeof("6e507a8a-c271-4561-8768-0f9bf9d4c301")

In [None]:
u.int

In [None]:
t["cg"][0]

In [None]:
t["name"].str.split(";")

# GAM

In [None]:
# 1) segment_cigars df (segment coördinates, normalize orientation) [numba]
# 2) segment_mismatches df (cellwise apply, get insertions/deletions/mismatches/equal)
# 3) filter on barcode mismatches (?)
# 4) group segment_cigars by barcode, run cigar_aggregation on non-barcode (or all!) segments
# 5)

In [None]:
import google.protobuf

In [None]:
google.protobuf.__version__

In [None]:
import stream
import vg_pb2

In [None]:
gfa_filename = "nao745bc.gfa"
gfa = gfapy.Gfa.from_file(gfa_filename)

In [None]:
# gam_filename = "duplex_sup1_subsample_vg2.gam"
# gam_filename = "reverse_test_duplex.gam"
gam_filename = "reverse_test_simplex.gam"
# gam_filename = "reverse_test_duplex_t1.gam"

In [None]:
msgs = list(stream.parse(gam_filename, vg_pb2.Alignment))

In [None]:
!head -n 4 duplex_sup1_subsample.fastq

In [None]:
for msg in stream.parse(gam_filename, vg_pb2.Alignment):
    if ";" in msg.name:
        continue
    # print(msg)
    # print()
    # print("************")
    # print()
    break

In [None]:
for msg in it.islice(stream.parse(gam_filename, vg_pb2.Alignment), 1):
    pass
    # print(msg)
    # print()
    # print("************")
    # print()

In [None]:
msgs[0].name

In [None]:
# TODO: use paulssonlab.cloning.workflow.DEGENERATE_BASES_REGEX
DEGENERATE_BASES = "RYMKSWHBVDN".lower()
DEGENERATE_BASES_REGEX = re.compile(f"[{DEGENERATE_BASES}]", re.IGNORECASE)

In [None]:
segments = {s.name: s.sequence for s in gfa.segments}
segment_degenerate_bases = {
    name: [m.start(0) for m in re.finditer(DEGENERATE_BASES_REGEX, segments["RBS1"])]
    for name, seq in segments.items()
}

In [None]:
from enum import Enum
from typing import NamedTuple


class Op(Enum):
    INSERTION = 1
    DELETION = 2
    SUBSTITUTION = 3


class Edit(NamedTuple):
    op: Op
    index: int
    seq: str
    length: int

    def __repr__(self):
        match self.op:
            case Op.INSERTION:
                return f"I:{self.index}:{self.seq or ''}"
            case Op.DELETION:
                return f"D:{self.index}:{self.length}"
            case Op.SUBSTITUTION:
                return f"S:{self.index}:{self.seq or ''}"
            case _:
                return f"{self.op}:{self.index}:{self.seq or self.length or ''}"

    __str__ = __repr__

In [None]:
segment_degenerate_bases["RBS1"]

In [None]:
from Bio.Seq import Seq


def local_index(index, length, is_reverse):
    if is_reverse:
        return length - index
    else:
        return index


def reverse_complement(seq):
    return str(Seq(seq).reverse_complement())


def reversed_seq(seq, is_reverse):
    if is_reverse:
        return reverse_complement(seq)
    else:
        return seq


def normalize_alignment(msg):
    segment_edits = {segment_name: [] for segment_name in segments.keys()}
    segment_read_indices = {}
    # for mapping in it.islice(msg.path.mapping, 7):
    read_index = 0
    for mapping in msg.path.mapping:
        segment_index = 0
        segment_name = mapping.position.name
        offset = mapping.position.offset
        if offset:
            segment_index += offset
        segment_read_start = read_index
        is_reverse = mapping.position.is_reverse
        edits = segment_edits[segment_name]
        segment_length = len(segments[segment_name])
        # TODO: separate edit handling for
        for edit in mapping.edit:
            if edit.from_length == edit.to_length:
                if edit.sequence:
                    # snp
                    # TODO: eat matching bases
                    edits.append(
                        Edit(
                            Op.SUBSTITUTION,
                            local_index(
                                segment_index,
                                segment_length - len(edit.sequence),
                                is_reverse,
                            ),
                            reversed_seq(edit.sequence, is_reverse),
                            None,
                        )
                    )
                    segment_index += edit.from_length
                    read_index += edit.to_length
                else:
                    # match
                    degenerate_base_indices = segment_degenerate_bases[segment_name]
                    segment_index += edit.from_length
                    read_index += edit.to_length
                    # if degenerate_base_indices:
                    #     for base_index in degenerate_base_indices:
                    #         # TODO: handle is_reverse
                    #         local_index(
                    #             segment_index, segment_length, is_reverse, 1
                    #         )
                    #         if segment_index <= base_index < segment_index + edit.from_length:
                    #             edits.append(
                    #                 Edit(
                    #                     Op.SUBSTITUTION,
                    #                     local_index(
                    #                         segment_index, segment_length - len(edit.sequence), is_reverse,
                    #                     ),
                    #                     reversed_seq(edit.sequence, is_reverse),
                    #                     None,
                    #                 )
                    #             )
                    # else:
                    #     segment_index += edit.from_length
                    #     read_index += edit.to_length
            elif not edit.to_length:
                # deletion
                edits.append(
                    Edit(
                        Op.DELETION,
                        local_index(
                            segment_index, segment_length - edit.from_length, is_reverse
                        ),
                        None,
                        edit.from_length,
                    )
                )
                segment_index += edit.from_length
            elif edit.from_length < edit.to_length:
                # insertion
                # if from_length > 0: need to remove matching bases from both sides
                assert edit.from_length == 0
                edits.append(
                    Edit(
                        Op.INSERTION,
                        local_index(
                            segment_index, segment_length - edit.to_length, is_reverse
                        ),
                        reversed_seq(edit.sequence, is_reverse),
                        None,
                    )
                )
                # TODO: shouldn't increment segment_index, right?
                # segment_index += edit.to_length
                read_index += edit.to_length
        segment_read_end = read_index
        segment_read_indices[segment_name] = (segment_read_start, segment_read_end)
        if is_reverse:
            edits.reverse()
        # TODO: merge like edits (i.e., degen base insertions)
        # TODO: normalize consecutive unlike edits:
        # insertion/sub with like bases
        # CNNNNANNNNCC
        # caaaataaaatt
    return segment_edits, segment_read_indices

In [None]:
normalize_alignment(msgs[0])[0]  # ["pPhlF"]

In [None]:
normalize_alignment(msgs[1])[0]  # ["pPhlF"]

In [None]:
reverse_complement(msgs[0].sequence[3951:4136])

In [None]:
segments["RBS1"]

In [None]:
msgs[1].sequence[185:206]

In [None]:
msgs[1].sequence[slice(*normalize_alignment(msgs[1])[1]["RBS1"])]

In [None]:
msgs[1].sequence[slice(*normalize_alignment(msgs[1])[1]["RBS2"])]

In [None]:
msgs[1].sequence[slice(*normalize_alignment(msgs[1])[1]["RBS3"])]

In [None]:
(2506, 2527)

In [None]:
normalize_alignment(msgs[1])[0]["RBS1"]

In [None]:
msgs[1].path.mapping[1]

In [None]:
msgs[0].path.mapping[-2]

In [None]:
normalize_alignment(msgs[1])[0]

In [None]:
normalize_alignment(msgs[0])[1]

In [None]:
normalize_alignment(msgs[0])

In [None]:
position {
    node_id: 66
    is_reverse: true
    name: "BIT29OFF"
  }

# Group by barcode

In [None]:
# gam_filename = "duplex_sup1_subsample_vg2.gam"
gam_filename = "duplex_sup1_vg2.gam"

In [None]:
from collections import Counter, defaultdict

In [None]:
%%time
barcode_msgs = defaultdict(list)
# for msg in it.islice(stream.parse(gam_filename, vg_pb2.Alignment), 100_000):
for msg in tqdm(stream.parse(gam_filename, vg_pb2.Alignment)):
    path = set([m.position.name for m in msg.path.mapping])
    if not (("BIT0ON" in path or "BIT0OFF" in path) and "pBetI" in path):
        continue
    barcode = tuple(1 if f"BIT{i}ON" in path else 0 for i in range(30))
    barcode_msgs[barcode].append(msg)

In [None]:
len(barcode_msgs)

In [None]:
!du -hs *.fastq

In [None]:
list(sorted(Counter(len(v) for k, v in barcode_msgs.items()).items()))

In [None]:
barcode_msgs_cluster = defaultdict(list)
for k, v in barcode_msgs.items():
    barcode_msgs_cluster[len(v)].append(v)

In [None]:
msgs = barcode_msgs_cluster[30]

In [None]:
del barcode_msgs_cluster, barcode_msgs

In [None]:
msgs[29]