In [None]:
%run "../head.py"

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from hic_basic.binnify import GenomeIdeograph
from hic_basic.coolstuff import cli_balance, cli_compartment
from hic_basic.sequence import count_CpG
from hic_basic.hicio import read_meta, load_json, dump_json
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from lib.plot import plot_2_scAB

import pandas as pd
import plotly.express as px
from hic_basic.hicio import load_json
from lib.metrics import stack_contour
from lib.plot import plot_heatmap_with_bars, plot_figure_canvas

In [None]:
def sort_scAB(file, region, kregion, keyfunc):
    """
    Note:
      Will drop rows with all NaNs in kregion.
    """
    data = pd.read_parquet(file)
    data.sort_index(inplace=True)
    if isinstance(region, str):
        # whole chrom
        chrom = region
        data = data.loc[
            (chrom,) : (chrom,)
            ].droplevel(0).T
    else: 
        chrom, start, end = region
        data = data.loc[
            (chrom,start) : (chrom,end)
            ].droplevel(0).T
    target_data = pd.read_parquet(file)
    target_data.sort_index(inplace=True)
    kchrom, kstart, kend = kregion
    target_data = target_data.loc[
        (kchrom,kstart) : (kchrom,kend)
    ].droplevel(0).T
    sorted_samples = target_data.dropna(how="all").apply(
        keyfunc,
        axis=1
    ).sort_values().index.tolist()
    return data.loc[sorted_samples]
def chunk_sum(chunk, cols, sample_dict):
    """
    Select col and rows according to sample_dict repeatedly and sum.
    Input:
        chunk: pd.DataFrame, the chunk of the bfs file.
        cols: list, columns to be aggregated.
        sample_dict: dict, {agg_feature_name: [samples]}.
    Output:
        df: pd.DataFrame, aggregated features.
    """
    df_list = []
    for col in cols:
        samples = sample_dict[col]
        df = chunk.loc[ # only part of the samples will appear in the chunk
            chunk.index.get_level_values(0).isin(samples), col
            ].groupby(level=[1,2]).sum() # groupby ht and dv
        df_list.append(df)
    df = pd.concat(df_list, axis=1)
    return df.sum(axis=1)
#chunk_sum(chunk, list(sample_dict.keys()), sample_dict).sum()
def scAB_feature_agg(bfs_file, sample_dict, chunksize=50000, report=False):
    """
    Aggregating scAB features.
    Input:
        bfs_file: str, path to the bfs h5 file.
        sample_dict: dict, {agg_feature_name: [samples]}.
        chunksize: int, chunk size for reading the bfs file.
        report: bool, print the progress.
    Output:
        dist: pd.Series, aggregated features.
    """
    col_list = list(sample_dict.keys())
    with pd.HDFStore(bfs_file,"r") as store:
        chunk_sum_list = []
        for chunk in store.select(
                "main",
                where = 'lr == 0',
                columns = ["ht","dv","sample_name"] + col_list,
                iterator=True,
                chunksize=chunksize
                ):
            chunk = chunk.set_index(["sample_name","ht","dv"])
            # calculate complex feature like "Astrong" in the chunk
            chunk_sum_res = chunk_sum(
                chunk, col_list, sample_dict
                ).rename("count").reset_index(["ht","dv"])
            chunk_sum_list.append(
                chunk_sum_res
            )
            if report:
                print("add",chunk_sum_res["count"].sum().sum())
    dist = pd.concat(
        chunk_sum_list,axis=0 # sum all samples
        ).groupby(["ht","dv"])["count"].sum()
    return dist
# data = scAB_feature_agg(
#     arg_bfs_vx_files["Sperm_scAB"],
#     Bstrong
#     )
def get_density(bfs_file):
    # just use density of lr==0 to normalize the data
    # won't give a meaningful frequency by dividing the density
    with pd.HDFStore(bfs_file, "r") as store:
        density = store.select(
            "main",
            where="lr == 0",
            columns=["ht","dv","density"]
            )
    density.loc[density["density"]==0] = pd.NA
    density = density.groupby(["ht","dv"])["density"].mean()
    return density
# density = get_density(arg_bfs_vx_files["Sperm"])
# density[density.rank(method="dense").astype(int) == 3]