# Polars bigidx test

In [None]:
!micromamba install -y polars

In [None]:
!pip install --force-reinstall -v /home/jqs1/projects/polars/py-polars/target/wheels/polars-0.19.9-cp38-abi3-linux_x86_64.whl

In [None]:
!pip uninstall -y polars

In [None]:
!pip install --force-reinstall polars-u64-idx

In [None]:
!micromamba remove -y polars

In [None]:
%%time
# check for polars-u64-idx (bigidx)
import polars as pl

pl.select(
    pl.repeat(False, n=(2**32) + 100, eager=True).alias("col1")
).with_row_count()

# Imports

In [None]:
import os

# os.environ["POLARS_MAX_THREADS"] = "8"
# os.environ["POLARS_VERBOSE"] = "1"

In [None]:
import itertools as it
import operator
import re
from collections import Counter
from pathlib import Path

import duckdb
import gfapy
import holoviews as hv
import ibis
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import polars as pl
import pyabpoa
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pyfastx
import pysam
import spoa
from pyarrow import csv
from tqdm.auto import tqdm, trange

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import paulssonlab.sequencing.gfa as sgfa
import paulssonlab.sequencing.io as sio
import paulssonlab.sequencing.processing as processing

In [None]:
hv.extension("bokeh")

In [None]:
%load_ext pyinstrument
import line_profiler
import pyinstrument

%load_ext line_profiler

In [None]:
pl.enable_string_cache()

# Config

In [None]:
data_dir = Path(
    "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/230707_repressilators/dorado_0.4.0/uncompressed/"
)

# Group by path

## Setup

In [None]:
gfa = gfapy.Gfa.from_file(
    "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/barcode.gfa"
)

In [None]:
gfa_filtered = sgfa.filter_gfa(
    gfa, exclude=["UNS9", "BC:UPSTREAM", "BC:JUNCTION", "BC:T7_TERM", "BC:SPACER2"]
)
forward_segments = sgfa.gfa_forward_segments(gfa_filtered)
endpoints = sgfa.gfa_endpoints(gfa_filtered)

In [None]:
endpoints

## Polars

In [None]:
# df = pl.scan_pyarrow_dataset(arrow_ds)
# df_input = pl.scan_ipc(str(data_dir / "*.arrow"))
# df_input = pl.scan_ipc(list(data_dir.glob("*.arrow"))[0])
# df_input = pl.scan_parquet(list(data_dir.glob("*.parquet"))[0])

In [None]:
%%time
df_input = pl.scan_ipc(list(data_dir.glob("*.arrow"))[4])
df = processing.normalize_path(df_input, forward_segments, endpoints=endpoints).collect(
    streaming=True
)
df.head()

In [None]:
df

In [None]:
%%time
df_usable = processing.identify_usable_reads(df)

In [None]:
%%time
df_usable = processing.identify_usable_reads(df.lazy()).collect()

In [None]:
df_usable

In [None]:
pl.read_ipc(
    "/home/jqs1/scratch/jqs1/sequencing/230930_alignment_test/230707_repressilators/dorado_0.4.0/uncompressed/prepared/channel-1_merged.arrow"
)

In [None]:
%%time
df_read_groups = (
    df_usable.select(
        pl.col(
            "name",
            "is_duplex",
            "read_seq",
            "read_phred",
            "reverse_complement",
            "path",
            "full_path",
        )
    )
    .group_by("path")
    .agg(
        pl.col("name", "read_seq", "read_phred", "reverse_complement"),
        pl.col("is_duplex").sum().alias("duplex_depth"),
        pl.count().alias("depth"),
    )
    .with_columns((pl.col("depth") - pl.col("duplex_depth")).alias("simplex_depth"))
    .filter(pl.col("depth") > 5)
    # .map_groups(
    #     consensus_func,
    #     schema=dict(
    #         path_subset=pl.List(pl.Categorical),
    #         read_seq=pl.Utf8,
    #         read_phred=pl.Utf8,
    #         depth_simplex=pl.UInt64,
    #         depth_duplex=pl.UInt64,
    #     ),
    # )
)
res = df_read_groups.fetch(10_000)
# res.select(pl.all().exclude("name","read_seq", "read_phred", "reverse_complement"))

In [None]:
def consensus_func(df):
    return (
        df.with_columns(
            pl.col("name").str.contains(";").alias("is_duplex"),
            pl.col("name").str.contains(";").not_().alias("is_simplex"),
        )
        .with_columns(
            pl.sum("is_simplex").alias("depth_simplex"),
            pl.sum("is_duplex").alias("depth_duplex"),
        )
        .select(
            pl.col(
                "path_subset",
                "read_seq",
                "read_phred",
                "depth_simplex",
                "depth_duplex",
            )
        )
        .head(1)
    )