In [None]:
import itertools as it
from collections import Counter, defaultdict
from functools import partial
from glob import glob
from pathlib import Path

import gfapy
import holoviews as hv
import hvplot.pandas
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.sequencing.align as align
import paulssonlab.sequencing.cigar as scigar
import paulssonlab.sequencing.consensus as con
import paulssonlab.sequencing.gfa as sgfa
import paulssonlab.sequencing.io as sio
import paulssonlab.sequencing.processing as processing
from paulssonlab.util.sequence import reverse_complement

In [None]:
hv.extension("bokeh")

In [None]:
pl.enable_string_cache()

# Functions

In [None]:
def concat_glob(filename):
    return pl.concat([pl.scan_ipc(f) for f in glob(filename)], how="diagonal")

In [None]:
def label_columns(cols, func=None):
    expr = None
    for col in cols:
        if expr is None:
            expr = pl.when(pl.col(col).is_not_null())
        else:
            expr = expr.when(pl.col(col).is_not_null())
        if func is not None:
            lit = func(col)
        else:
            lit = col
        expr = expr.then(pl.lit(lit))
    return expr

# 240612_pLIB476_isolates

In [None]:
%%time
arrow_filename = "/home/jqs1/scratch/sequencing/240612_pLIB476_isolates/output/vg/prepare_reads/*.arrow"
df = concat_glob(arrow_filename).collect()

In [None]:
%%time
df2 = df.with_columns(
    dup=pl.col("name").is_duplicated(),
    primary_alignment=pl.col("name").is_first_distinct(),
    e2e=pl.col("full_path")
    .list.set_intersection(["<UNS9", ">UNS9", "<UNS3", ">UNS3"])
    .list.len()
    == 2,
    bc_e2e=pl.col("full_path")
    .list.set_intersection(["<BC:T7_prom", ">BC:T7_prom", "<BC:spacer2", ">BC:spacer2"])
    .list.len()
    == 2,
)

In [None]:
df3 = (
    df2.filter(pl.col("bc_e2e"), pl.col("primary_alignment"))
    .unnest("extract_segments")
    .with_columns(
        sample=pl.col("name").str.split_exact("_", 2).struct[2].cast(pl.Int32)
    )
)

In [None]:
df4 = (
    df3.group_by("sample")
    .agg(
        *[
            pl.col(f"BC:bit{bit}|{type_}").mean()
            for bit in range(30)
            for type_ in ("mismatches", "insertions", "deletions")
        ]
    )
    .sort("sample")
)

In [None]:
df4

In [None]:
df4.filter(pl.col("sample") == 1)

In [None]:
len(mismatches)

In [None]:
bits = np.arange(30)
for row in df4.to_dicts():
    mismatches = [row[f"BC:bit{bit}|mismatches"] for bit in bits]
    insertions = [row[f"BC:bit{bit}|insertions"] for bit in bits]
    deletions = [row[f"BC:bit{bit}|deletions"] for bit in bits]
    plt.figure(figsize=(10, 3))
    plt.stackplot(
        [-0.5, *(bits + 0.5)],
        [*mismatches, 0],
        [*insertions, 0],
        [*deletions, 0],
        labels=["mismatches", "insertions", "deletions"],
        step="post",
    )
    plt.ylim([0, 6])
    plt.title(f"sample {row['sample']}")
    plt.xticks(bits)
    plt.legend();

In [None]:
sample = 4
bit = 17
bins = np.arange(10)
bin_centers = (bins[:-1] + bins[1:]) / 2
hists = {
    type_: np.histogram(
        df3.filter(pl.col("sample") == sample)[f"BC:bit{bit}|{type_}"], bins=bins
    )[0]
    for type_ in ("mismatches", "insertions", "deletions")
}
plt.figure(figsize=(6, 3))
plt.stackplot(
    bin_centers - 1,
    hists["mismatches"],
    hists["insertions"],
    hists["deletions"],
    labels=["mismatches", "insertions", "deletions"],
    step="post",
)
plt.legend();

In [None]:
%%time
arrow_filename = "/home/jqs1/scratch/sequencing/240612_pLIB476_isolates/output/primary_max_divergence=1/extract_segments/*.arrow"
df = concat_glob(arrow_filename).collect()

In [None]:
df2 = (
    df.filter(pl.col("name").is_first_distinct())
    .with_columns(
        barcode_str=pl.concat_str(
            [
                pl.when(
                    pl.col("variants_path")
                    .list.contains(f">BC:bit{idx}=1")
                    .or_(pl.col("variants_path").list.contains(f"<BC:bit{idx}=1"))
                )
                .then(pl.lit("1"))
                .otherwise(pl.lit("0"))
                for idx in range(30)
            ]
        )
    )
    .sort("barcode_str")
    .to_pandas()
)

In [None]:
df2[["barcode_str", "name", "grouping_depth"]]

In [None]:
df2[["barcode_str", "name", "grouping_depth"]]

In [None]:
df2.iloc[-1].loc["variants_path"]

In [None]:
df2.iloc[-2].loc["variants_path"]

In [None]:
%%time
arrow_filename = "/home/jqs1/scratch/sequencing/240612_pLIB476_isolates/output/primary_max_divergence=1/prepare_consensus/*.arrow"
df = concat_glob(arrow_filename).collect()

In [None]:
df2 = processing.compute_divergences(
    df, processing.unique_segments(df, "path"), struct_name="extract_segments"
)
df2 = df2.with_columns(
    sample=pl.col("name").str.split_exact("_", 2).struct[2].cast(pl.Int32)
)

In [None]:
df2.filter(pl.col("sample") == 11).sort("max_divergence")

In [None]:
plt.hist(df2.filter(pl.col("sample") == 11)["max_divergence"], bins=100);

In [None]:
df.head(100)

# 240610_pLIB476

In [None]:
%%time
arrow_filename = "/home/jqs1/scratch/sequencing/240610_pLIB476_bottleneck/pLIB476_bottleneck/pLIB476/20240607_1433_MN35044_FAX60316_7d690112/output/default/extract_segments/*.arrow"
df = concat_glob(arrow_filename).collect()

In [None]:
len(df.filter(pl.col("grouping_depth") >= 10))

In [None]:
df[0, "variants_path"].to_list()

In [None]:
%%time
df2 = df.with_columns(
    dup=pl.col("name").is_duplicated(),
    primary_alignment=pl.col("name").is_first_distinct(),
    e2e=pl.col("variants_path")
    .list.set_intersection(["<UNS9", ">UNS9", "<UNS3", ">UNS3"])
    .list.len()
    == 2,
    bc_e2e=pl.col("variants_path")
    .list.set_intersection(["<BC:T7_prom", ">BC:T7_prom", "<BC:spacer2", ">BC:spacer2"])
    .list.len()
    == 2,
)

In [None]:
df2.filter(pl.col("primary_alignment"))["SD2_variant|seq"].value_counts(
    sort=True
).to_pandas().hvplot.step(
    logy=True,
)

In [None]:
df2.filter(pl.col("primary_alignment"), pl.col("e2e"))["SD2_variant|seq"].value_counts(
    sort=True
).filter(pl.col("count") >= 3)["count"].len()

In [None]:
df2.filter(pl.col("primary_alignment"))["SD2_variant|seq"].value_counts(
    sort=True
).filter(pl.col("count") > 1).to_pandas().plot(drawstyle="steps", logy=True)

In [None]:
df2["promoter|variant"]

In [None]:
df2.filter(pl.col("e2e"), pl.col("primary_alignment")).select(
    pl.struct(["promoter|variant", "SD2_variant|seq"])
)["promoter|variant"].value_counts(sort=True).filter(pl.col("count") > 1)

In [None]:
df2.filter(pl.col("e2e"), pl.col("primary_alignment")).group_by("promoter|variant").agg(
    pl.len()
)

In [None]:
df2.filter(pl.col("primary_alignment"), pl.col("promoter|variant") == "J23100")[
    "SD2_variant|seq"
].value_counts(sort=True).filter(pl.col("count") > 1).to_pandas().plot(
    drawstyle="steps", logy=True
)

In [None]:
%%time
df_variants = df2.filter(pl.col("e2e"), pl.col("primary_alignment")).with_columns(
    pl.coalesce(
        label_columns(
            [
                "pLIB433:PhlF_pPhlF|seq",
                "pLIB434:LacI_pTac|seq",
                "pLIB435:BetI_pBetI|seq",
            ],
            lambda x: x.split("|")[0],
        ),
        pl.concat_str(pl.lit("pLIB431-432:RBS="), pl.col("pLIB431-432:RBS|variant")),
    ).alias("RBS")
)

In [None]:
len(
    df2.filter(
        pl.col("primary_alignment"), pl.col("e2e"), pl.col("consensus_depth") >= 10
    )
)

In [None]:
len(
    df2.filter(
        pl.col("primary_alignment"), pl.col("e2e"), pl.col("consensus_depth") >= 5
    )
)

In [None]:
len(
    df2.filter(
        pl.col("primary_alignment"), pl.col("e2e"), pl.col("consensus_depth") >= 3
    )
)

In [None]:
df3["grouping_depth"].sum()

In [None]:
df3 = df2.filter(pl.col("primary_alignment"), pl.col("e2e"))

In [None]:
df2.group_by("primary_alignment").agg(pl.len()).with_columns(
    frac=pl.col("len") / pl.col("len").sum()
)

In [None]:
df2.filter(pl.col("primary_alignment"), pl.col("e2e")).select(
    pl.col("consensus_seq").str.len_bytes()
).max()

In [None]:
df2.filter("primary_alignment").group_by("dup").agg(pl.len()).with_columns(
    frac=pl.col("len") / pl.col("len").sum()
)

In [None]:
df2.filter("primary_alignment").group_by("e2e").agg(pl.len()).with_columns(
    frac=pl.col("len") / pl.col("len").sum()
)

In [None]:
df2.filter("primary_alignment").group_by("bc_e2e").agg(pl.len()).with_columns(
    frac=pl.col("len") / pl.col("len").sum()
)

In [None]:
plt.hist(
    df.select(pl.col("consensus_seq").str.len_bytes())["consensus_seq"],
    bins=100,
    log=True,
);

In [None]:
plt.hist(df["grouping_depth"], bins=100, log=True);

In [None]:
plt.hist(df["consensus_depth"], bins=100, log=True);

## Export to Eaton format

In [None]:
df2.head(10).with_columns(
    barcode=pl.concat_str(
        [
            pl.when(
                pl.col("variants_path")
                .list.contains(f">BC:bit{idx}=1")
                .or_(pl.col("variants_path").list.contains(f"<BC:bit{idx}=1"))
            )
            .then(pl.lit("1"))
            .otherwise(pl.lit("0"))
            for idx in range(30)
        ]
    )
)

In [None]:
df2.columns

In [None]:
df2.

In [None]:
%%time
df_eaton = (
    df2.filter(
        pl.col("primary_alignment"), pl.col("e2e"), pl.col("consensus_depth") >= 5
    )
    .with_columns(
        barcode=pl.concat_str(
            [
                pl.when(
                    pl.col("variants_path")
                    .list.contains(f">BC:bit{idx}=1")
                    .or_(pl.col("variants_path").list.contains(f"<BC:bit{idx}=1"))
                )
                .then(pl.lit("1"))
                .otherwise(pl.lit("0"))
                for idx in range(30)
            ]
        ),
        reference=pl.lit(""),
        alignmentstart=1,
        cigar=pl.lit(""),
        subsample=pl.lit(""),
    )
    .rename({"consensus_seq": "consensus"})
    .select(
        "barcode",
        "consensus",
        "reference",
        "alignmentstart",
        "cigar",
        "subsample",
    )
    .sort("barcode")
    .with_row_index(name="barcodeid")
    .with_row_index(name="")
)

In [None]:
df_eaton

In [None]:
df_eaton.write_csv("240610_pLIB476_eaton_export.tsv", separator="\t")

In [None]:
!pwd

In [None]:
df2[0, "variants_path"].to_list()

In [None]:
df2["BC:bit28|variant"].is_not_null().sum()

# 240610_pLIB476 vs. 240510_pLIB473-476

In [None]:
# arrow_filename = "/home/jqs1/scratch/sequencing/240610_pLIB476_bottleneck/pLIB476_bottleneck/pLIB476/20240607_1433_MN35044_FAX60316_7d690112/"
arrow_filename = "/home/jqs1/scratch/sequencing/240513_pLIB473_476/20240513_1645_2C_PAW46239_b49d575f/prepare_reads.all_segments/*.arrow"
df = concat_glob(arrow_filename)  # .collect()

In [None]:
gfa_filename = "/home/jqs1/scratch/sequencing/240610_pLIB476_bottleneck/pLIB476_bottleneck/pLIB476/20240607_1433_MN35044_FAX60316_7d690112/references/pLIB476jqs.gfa"
gfa = gfapy.Gfa.from_file(gfa_filename)

In [None]:
df.schema

In [None]:
gfa.segment_names

In [None]:
df2 = df.filter(pl.col("end_to_end")).head(100).collect()

In [None]:
df2.with_columns(
    barcode=pl.concat_list(
        [
            pl.col("extract_segments").struct.field(f"BC:bit{idx}|variant")
            for idx in range(30)
        ]
    )
)["barcode"]

In [None]:
df3[22, "full_path"].to_list()

In [None]:
df.schema["extract_segments"].fields

In [None]:
# TODO: use exclude after release including https://github.com/pola-rs/polars/issues/16661
seg_col = pl.col("extract_columns").struct.field
df2 = (
    df.head(10)
    .with_columns(
        divergence=pl.sum_horizontal(
            seg_col(r"\|(mismatches|insertions|deletions)").exclude(
                r"upstream\|(mismatches|insertions|deletions)",
                r"downstream\|(mismatches|insertions|deletions)",
            )
        )
    )
    .collect()
)

In [None]:
seg_col = pl.col("extract_columns").struct.field
df2 = (
    df.head(10)
    .with_columns(
        divergence=pl.sum_horizontal(
            [
                seg_col(f"{s[1:]}|{type_}").fill_null(strategy="zero")
                for type_ in ("matches", "mismatches", "insertions", "deletions")
                for s in forward_segments
            ]
        )
    )
    .collect()
)

In [None]:
df2["divergence"]

In [None]:
df2 = df.filter(pl.col("end_to_end")).with_columns(
    barcode=pl.concat_list(
        [
            pl.col("extract_segments").struct.field(f"BC:bit{idx}|variant")
            for idx in range(30)
        ]
    )
)

In [None]:
df3 = df2.head(100).collect()

In [None]:
%%time
df4 = df2.filter(pl.col("barcode") == df3[22, "barcode"].to_list()).collect()

In [None]:
df4