In [None]:
import sys
import os
from pathlib import Path
import json
import numpy as np
from typing import Dict, Tuple, List
from joblib import Parallel, delayed
import logging
import matplotlib.pyplot as plt
import seaborn as sns
logging.getLogger('matplotlib').setLevel(logging.ERROR)
sys.path.append("..")
from src.assay_calibration.fit_utils.fit import (calculate_score_ranges,thresholds_from_prior)  # noqa: E402
from src.assay_calibration.fit_utils.two_sample import density_utils  # noqa: E402
from src.assay_calibration.data_utils.dataset import Scoreset  # noqa: E402
from src.assay_calibration.fit_utils.utils import serialize_dict  # noqa: E402



def summarize_scoreset(fits,scoreset,save_filepath,use_median_prior,use_2c_equation,n_c,benign_method, **kwargs):
    """
    Summarizes a scoreset based on the provided arguments.
    Args:
        args: An object containing the following attributes:
            - fits (List[Dict]) : List of fit results
            - scoreset_name (str) : Name of the scoreset to summarize
            - df (pandas.DataFrame, optional): A pandas DataFrame to be summarized. 
              If provided, this will be used directly.
            - pillar_df_filepath (str, optional): A file path to a CSV file. If `df` 
              is not provided, the CSV file at this path will be read into a pandas 
              DataFrame and used for summarization.
            - use_median_prior (bool). Whether to use median prior or 5-th percentile most conservative bootstrap for each threshold.
            - use_2c_equation (bool). Whether to use prior equation rather than EM for 2c.
    Optional Keyword Args:
        - point_values (List[int], optional): List of point values to assign. Defaults to 
          [1,2,3,4,5,6,7,8].
        - pathogenic_idx (int, optional): Index of the pathogenic component. Defaults to 0.
        - benign_idx (int, optional): Index of the benign component. Defaults to 1.
        - tolerance (float, optional): Tolerance for convergence in prior estimation. Defaults to 1e-4.
        - max_em_steps (int, optional): Maximum number of EM steps for prior estimation. Defaults to 10000.

    Note:
        Either `df` or `pillar_df_filepath` must be provided in `args`. If both are 
        provided, `df` takes precedence.
    """
    priors, prior, point_ranges, score_range, log_fp, log_fb, all_path_ranges, all_ben_ranges, C = process_fits(fits,scoreset,use_median_prior,use_2c_equation,benign_method,Path(save_filepath))
    results = dict(prior=prior,
                   point_ranges=point_ranges,
                   priors=priors,
                   score_range=score_range,
                   log_lr_plus=log_fp - log_fb,
                   C=C,
                   all_path_ranges=all_path_ranges,
                   all_ben_ranges=all_ben_ranges)
    results = serialize_dict(results)
    save_filepath = Path(save_filepath)
    save_filepath.parent.mkdir(exist_ok=True,parents=True)
    # with open(save_filepath,'w') as f:
    #     json.dump(results,f,indent=2)
    # save_filepath_compact = save_filepath.parent / f"{save_filepath.stem}_compact.json"
    # with open(save_filepath_compact,'w') as f:
    #     json.dump({k: results[k] for k in ['prior','point_ranges']},
    #               f,indent=2)
    scoreset_fit_figure = plot_scoreset(scoreset, results, fits,score_range, use_median_prior, use_2c_equation, n_c, benign_method, C)
    figure_filepath = save_filepath.parent / f"{save_filepath.stem}_figure_fits.png"
    scoreset_fit_figure.savefig(figure_filepath,bbox_inches='tight',dpi=300)
    plt.close(scoreset_fit_figure) 
    summary_fig = plot_summary(scoreset, fits, results, score_range, log_fp, log_fb, use_median_prior, use_2c_equation, n_c, benign_method, C)
    
    summary_figure_filepath = save_filepath.parent / f"{save_filepath.stem}_figure_summary.png"
    summary_fig.savefig(summary_figure_filepath,bbox_inches='tight',dpi=300)
    plt.close(summary_fig)

    return scoreset, results, fits, score_range, f"({'equation' if use_2c_equation else 'em'}, {'median' if use_median_prior else '5-percentile'}, {benign_method})", n_c

def enforce_monotonicity_point_ranges(point_ranges, point_values):
    max_path_points = None
    max_ben_points = None
    for i in point_values:
        point = i # pathogenic

        if max_path_points is not None:
            point_ranges[point] = []
        elif len(point_ranges[point]) > 1: # e.g. --_-
            # flatten
            point_ranges[point] = [[point_ranges[point][0][0], point_ranges[point][-1][-1]]]
            if max_path_points is None:
                max_path_points = point
                
        point = -i # benign

        if max_ben_points is not None:
            point_ranges[point] = []
        elif len(point_ranges[point]) > 1: # e.g. --_-
            # flatten
            point_ranges[point] = [[point_ranges[point][0][0], point_ranges[point][-1][-1]]]
            if max_ben_points is None:
                max_ben_points = point
        

def prior_equation_2c(w_p, w_b, w_g):
    return (w_g[1] - w_b[1]) / (w_p[1] - w_b[1])

def prior_invalid(prior):
    return prior <= 0 or prior >= 1

def get_bootstrap_score_ranges(fitIdx, fit, fp, fb, score_range, fit_priors, point_values):
    fit_xmin, fit_xmax = fit['fit']['xlims']
    mask = (score_range >= fit_xmin) & (score_range <= fit_xmax)

    # log_fp_local = np.zeros_like(fp)
    # log_fb_local = np.zeros_like(fb)

    # CRITICAL: IGNORE BOOTSTRAPS THAT DON'T SPAN DATA POINT. MARKING 0 WILL CAUSE STRANGE LR+ CURVES AT EXTREMES
    log_fp_local = np.full_like(fp, np.nan, dtype=float)
    log_fb_local = np.full_like(fb, np.nan, dtype=float)

    log_fp_local[mask] = fp[mask]
    log_fb_local[mask] = fb[mask]

    lrP = log_fp_local[mask] - log_fb_local[mask]
    s = score_range[mask]

    
    ranges_p, ranges_b, C = calculate_score_ranges(
        lrP, lrP, fit_priors[fitIdx], s, point_values
    )
    
    if prior_invalid(fit_priors[fitIdx]):
        for key in ranges_p:
            ranges_p[key] = [np.nan]
        for key in ranges_b:
            ranges_b[key] = [np.nan]

    return fitIdx, log_fp_local, log_fb_local, ranges_p, ranges_b, int(C)


def process_fits(fits, scoreset, use_median_prior, use_2c_equation, benign_method, save_filepath, **kwargs)->Tuple[np.ndarray,Dict[int,List[Tuple[float,float]]],np.ndarray,np.ndarray,np.ndarray,List[Dict[int,List[Tuple[float,float]]]],List[Dict[int,List[Tuple[float,float]]]]]:
    n_cores = os.cpu_count() or 1

    priors_filepath = save_filepath.parent / f"{save_filepath.stem.replace('_median','').replace('_5-percentile','')}_priors.npy"
    if not priors_filepath.exists():
    
        if not use_2c_equation:
            print('estimating priors...')
            fit_priors = np.array(Parallel(n_jobs=min(len(fits), n_cores), verbose=10)(delayed(get_fit_prior)(fit, scoreset, benign_method, **kwargs)
                                       for fit in fits))
        else:
            print('computing priors with equation...')
            fit_priors = []
            for fit in fits:
                if len(fit['fit']['weights']) == 3:
                    w_p, w_b, w_g = fit['fit']['weights']
                elif len(fit['fit']['weights']) == 4:
                    w_p, w_b, w_g, w_s = fit['fit']['weights']
                else:
                    raise ValueError(f"Number of samples != 3 or 4: {len(fit['fit']['weights'])}")
                if benign_method == 'synonymous':
                    fit_priors.append(prior_equation_2c(w_p, w_s, w_g))
                elif benign_method == 'avg':
                    w_bs = (np.array(w_b)+np.array(w_s))/2
                    fit_priors.append(prior_equation_2c(w_p, w_bs, w_g))
                else:
                    fit_priors.append(prior_equation_2c(w_p, w_b, w_g))
            fit_priors = np.array(fit_priors)
            
        np.save(priors_filepath, fit_priors)
        
    else:
        print(f"loading priors from cached {'equation' if use_2c_equation else 'em'}")
        fit_priors = np.load(priors_filepath)

    
    point_values = kwargs.get('point_values',[1,2,3,4,5,6,7,8])
    # if use_median_prior:
    prior = np.nanmedian(fit_priors)
    # else:
    #     prior = np.array([np.nanmin(fit_priors), np.nanmax(fit_priors)]) # set threshold on per-bootstrap basis, use fifth percentile for the most conservative thresholds
    observed_scores = scoreset.scores[scoreset._sample_assignments.any(1)]
    score_range = np.linspace(*np.percentile(observed_scores,[0,100]),10000) # type: ignore

    _log_fp = np.stack([density_utils.mixture_pdf(score_range, _fit['fit']['component_params'],_fit['fit']['weights'][0])
                        for _fit in fits])
    benign_idx = 3 if benign_method == 'synonymous' else 1
    if benign_method != 'avg':
        _log_fb = np.stack([density_utils.mixture_pdf(score_range, _fit['fit']['component_params'],_fit['fit']['weights'][benign_idx])
                           for _fit in fits])
    else:
        # print('avg b/s density:','b',fits[0]['fit']['weights'][1],'s',fits[0]['fit']['weights'][3],'bs',np.mean([fits[0]['fit']['weights'][1],fits[0]['fit']['weights'][3]]),)
        _log_fb = np.stack([density_utils.mixture_pdf(score_range, _fit['fit']['component_params'],(np.array(_fit['fit']['weights'][1])+np.array(_fit['fit']['weights'][3]))/2)
                           for _fit in fits])
    log_fp = np.full((len(fits),len(score_range)),np.nan)
    log_fb = np.full((len(fits),len(score_range)),np.nan)
    # ranges_pathogenic = []
    # ranges_benign = []

    # print('getting thresholds for each bootstrap...')
    # for fitIdx,(fit, fp,fb) in enumerate(zip(fits, _log_fp,_log_fb)):
    #     fit_xmin,fit_xmax = fit['fit']['xlims']
    #     mask = (score_range >= fit_xmin) & (score_range <= fit_xmax)
    #     log_fp[fitIdx,mask] = fp[mask]
    #     log_fb[fitIdx,mask] = fb[mask]
    #     lrP = log_fp[fitIdx,mask] - log_fb[fitIdx,mask]
    #     s = score_range[mask]
    #     ranges_p, ranges_b = calculate_score_ranges(lrP,lrP, fit_priors[fitIdx],s,
    #                                                 point_values) # point_ranges = {point_value : [score1, score2, ...]}
    #     ranges_pathogenic.append(ranges_p)
    #     ranges_benign.append(ranges_b)


    boot_points_filepath = save_filepath.parent / f"{save_filepath.stem.replace('_median','').replace('_5-percentile','')}_boot_points.npz"
    
    # if True:# not boot_points_filepath.exists():
    print('getting point ranges for each bootstrap...')
    results = Parallel(
        n_jobs=min(len(fits), n_cores),
        verbose=10
    )(
        delayed(get_bootstrap_score_ranges)(fitIdx, fit, fp, fb, score_range, fit_priors, point_values)
        for fitIdx, (fit, fp, fb) in enumerate(zip(fits, _log_fp, _log_fb))
    )
        # with open(boot_points_filepath, 'wb') as f:
        #     pickle.dump(results, f)
    # else:
    #     print('loading cached point ranges for each bootstrap...')
    #     with open(boot_points_filepath, 'rb') as f:
    #         results = pickle.load(f)
    
    # Update parent arrays in main process
    ranges_pathogenic, ranges_benign = [], []
    Cs = []
    
    for fitIdx, log_fp_local, log_fb_local, ranges_p, ranges_b, C in results:
        log_fp[fitIdx] = log_fp_local
        log_fb[fitIdx] = log_fb_local
        ranges_pathogenic.append({key: np.array(value).reshape(-1) for key, value in ranges_p.items()})
        ranges_benign.append({key: np.array(value).reshape(-1) for key, value in ranges_b.items()})
        Cs.append(C)
    
    
        # np.savez_compressed(boot_points_filepath,
        #             log_fp=log_fp, log_fb=log_fb, ranges_pathogenic=ranges_pathogenic, ranges_benign=ranges_benign, Cs=Cs)

    # else:
    #     print('loading cached point ranges for each bootstrap...')
    #     boot_points_results = np.load(boot_points_filepath, allow_pickle=True)
    #     log_fp=boot_points_results['log_fp']
    #     log_fb=boot_points_results['log_fb']
    #     ranges_pathogenic=boot_points_results['ranges_pathogenic']
    #     ranges_benign=boot_points_results['ranges_benign']
    #     Cs=boot_points_results['Cs']

    log_lr_plus = log_fp - log_fb
    nan_counts = np.isnan(log_lr_plus).sum(0)
    range_subset = nan_counts < log_lr_plus.shape[1] # changed from 0. 1000/10000 is arbitrary
    point_ranges = {}
    
    C = np.array([np.nanpercentile(Cs, 5), np.nanpercentile(Cs, 95)])
    print('ranges_p bootstrap 0, score +1:',ranges_pathogenic[0][1])

    # print('nan_counts',nan_counts, nan_counts.shape)
    # print('range_subset',range_subset, range_subset.shape)
    # print('log_lr_plus.shape',log_lr_plus.shape, log_lr_plus[:,0])
    
    # if points 1 and -1 are out of order, consider scoreset flipped
    scoreset_flipped = len(ranges_pathogenic[0][1]) != 0 and len(ranges_benign[0][-1]) != 0 and ranges_pathogenic[0][1][-1] >= ranges_benign[0][-1][0]
    
    if prior > 0 and prior < 1:
        
        if use_median_prior:
            print('using median prior to get unified thresholds...')
            # point_ranges = {point_value : [score1, score2, ...]}
            point_ranges_pathogenic, point_ranges_benign, C = calculate_score_ranges(np.nanpercentile(log_lr_plus[:,range_subset],
                                                                                                5,axis=0),
                                                                                    np.nanpercentile(log_lr_plus[:,range_subset],
                                                                                                    95,axis=0),
                                                                                    prior,
                                                                                    score_range[range_subset],
                                                                                    point_values,)
            point_ranges = {**point_ranges_pathogenic,**point_ranges_benign}
        else:
            # use 5-percentile of most conservative bootstrap thresholds for each point assignment
            print('using 5-percentile to get conservative thresholds...')
            
            # print('ranges_benign[0]',ranges_benign[0])
            # print('scoreset_flipped:',scoreset_flipped,len(ranges_pathogenic[0][1]) != 0 , len(ranges_benign[0][-1]) != 0 , ranges_pathogenic[0][1][-1] >= ranges_benign[0][-1][0])
            p_5percentile_conservative = 5 if not scoreset_flipped else 95
            b_5percentile_conservative = 95 if not scoreset_flipped else 5
            p_max = max if not scoreset_flipped else min
            b_min = min if not scoreset_flipped else max
            p_inf = -np.inf if not scoreset_flipped else np.inf
            b_inf = np.inf if not scoreset_flipped else -np.inf

            conservative_thresholds = {}
            # print('ranges_pathogenic',ranges_pathogenic[0])
            # print('ranges_benign',ranges_benign[0])
            print('boot prior:',np.nanmin(fit_priors), '-', np.nanmax(fit_priors))
            for point_value in point_values: # 1,2,...,8

                conservative_thresholds[point_value] = np.nanpercentile([p_max(ranges_p[point_value]) if len(ranges_p[point_value]) > 0 else p_inf for ranges_p in ranges_pathogenic], p_5percentile_conservative)
                print(point_value,'nan bootstrap points:',np.isnan([p_max(ranges_p[point_value]) if len(ranges_p[point_value]) > 0 else p_inf for ranges_p in ranges_pathogenic]).sum())
                conservative_thresholds[-1*point_value] = np.nanpercentile([b_min(ranges_b[-1*point_value]) if len(ranges_b[-1*point_value]) > 0 else b_inf for ranges_b in ranges_benign], b_5percentile_conservative)

            print('conservative_thresholds',conservative_thresholds)
            for point_value, threshold in conservative_thresholds.items():
                assert point_value != 0
                if np.isnan(threshold) or np.isinf(threshold):
                    point_ranges[point_value] = []
                    continue
                valid_scores = score_range[range_subset]
                if (point_value > 0 and not scoreset_flipped) or (point_value < 0 and scoreset_flipped): # pathogenic or flipped benign
                    if abs(point_value) == max(point_values):
                        point_ranges[point_value] = [valid_scores[0], threshold]
                    else:
                        lower_lim = conservative_thresholds[point_value+1] if point_value > 0 else conservative_thresholds[point_value-1]
                        if np.isnan(lower_lim):
                            point_ranges[point_value] = [valid_scores[0], threshold]
                        else:
                            # point_ranges[point_value] = valid_scores[(valid_scores > lower_lim) & (valid_scores < threshold)]
                            point_ranges[point_value] = [lower_lim, threshold]
                else: # benign
                    if abs(point_value) == max(point_values):
                        point_ranges[point_value] = [threshold, valid_scores[-1]]
                    else:
                        upper_lim = conservative_thresholds[point_value-1] if point_value < 0 else conservative_thresholds[point_value+1]
                        if np.isnan(upper_lim):
                            point_ranges[point_value] = [threshold, valid_scores[-1]]
                        else:
                            # point_ranges[point_value] = valid_scores[(valid_scores < upper_lim) & (valid_scores > threshold)]   
                            point_ranges[point_value] = [threshold, upper_lim]
                
            for point_value in point_ranges:
                if len(point_ranges[point_value]) != 0:
                    point_ranges[point_value] = [point_ranges[point_value]]
            print('point_ranges',point_ranges)
    
    # enforce point range monotonicty before returning
    enforce_monotonicity_point_ranges(point_ranges, point_values)
    
    return fit_priors, prior, point_ranges,score_range[range_subset],log_fp[:,range_subset], log_fb[:,range_subset], ranges_pathogenic, ranges_benign, C
    


def get_fit_prior(fit, scoreset, benign_method, **kwargs):
    pathogenic_idx = kwargs.get('pathogenic_idx',0)
    benign_idx = kwargs.get('benign_idx',1)
    if benign_method == 'synonymous':
        benign_idx = 3
    params = fit['fit']['component_params']
    weights = fit['fit']['weights']
    population = scoreset.scores[scoreset._sample_assignments[:,2]]
    pathogenic_density = density_utils.joint_densities(population,
                                                       params,
                                                       weights[pathogenic_idx]).sum(axis=0)
    if benign_method != 'avg':
        benign_density = density_utils.joint_densities(population,
                                                       params,
                                                       weights[benign_idx]).sum(axis=0)
    else:
        bs_weights = (np.array(weights[1])+np.array(weights[3]))/2
        benign_density = density_utils.joint_densities(population,
                                                       params,
                                                       bs_weights).sum(axis=0)
        
    assert len(pathogenic_density) == len(population)
    assert len(benign_density) == len(population)
    prior_estimate = 0.5
    converged = False
    em_steps = 0
    max_em_steps = kwargs.get("max_em_steps",10000)
    while not converged:
        em_steps += 1

        with np.errstate(divide='ignore', invalid='ignore', over='ignore',under='ignore'):
            posteriors = 1 / (
                1
                + (1 - prior_estimate)
                / prior_estimate
                * benign_density # type: ignore
                / pathogenic_density
            )
        new_prior = np.nanmean(posteriors)
        prior_estimate = new_prior
        if prior_estimate < 0 or prior_estimate > 1:
            raise ValueError(f"Invalid prior estimate obtained, {prior_estimate}")
        if em_steps >= max_em_steps:
            break
    return prior_estimate
    
def plot_scoreset(scoreset:Scoreset, summary: Dict, scoreset_fits: List[Dict], score_range, use_median_prior,use_2c_equation, n_c, benign_method, C):
    fig, ax = plt.subplots(2,scoreset.n_samples, figsize=(5*scoreset.n_samples,10),sharex=True,sharey=False)
    for sample_num in range(scoreset.n_samples):
        sns.histplot(scoreset.scores[scoreset.sample_assignments[:,sample_num]],stat='density',ax=ax[1,sample_num],alpha=.5,color='pink',)
        density = sample_density(score_range, scoreset_fits, sample_num)
        for compNum in range(density.shape[1]):

            compDensity = density[:,compNum,:]
            d = np.nanpercentile(compDensity,[5,50,95],axis=0)
            ax[1,sample_num].plot(score_range,d[1],color=f"C{compNum}",linestyle='--',label=f"Comp {compNum+1}")
            ax[1,sample_num].legend()
        d = np.nansum(density,axis=1)
        d_perc = np.percentile(d,[5,50,95],axis=0)
        ax[1,sample_num].plot(score_range,d_perc[1],color='black',alpha=.5)
        ax[1,sample_num].fill_between(score_range,d_perc[0],d_perc[2],color='gray',alpha=0.3)
        ax[1,sample_num].set_xlabel("Score")
        ax[0,sample_num].set_title(f"{scoreset.sample_names[sample_num]} (n={scoreset.sample_assignments[:,sample_num].sum():,d})")
    point_ranges = sorted([(int(k), v) for k,v in summary['point_ranges'].items()])
    point_values = [pr[0] for pr in point_ranges]
    print(point_ranges)
    for axi in ax[0]:
        for pointIdx,(pointVal, scoreRanges) in enumerate(point_ranges):
            for sr in scoreRanges:
                axi.plot([sr[0], sr[1]], [pointIdx,pointIdx], color='red' if pointVal > 0 else 'blue', linestyle='-', alpha=0.7)
        axi.set_ylim(-1,len(point_values))
        axi.set_ylabel("Points")

        axi.set_yticks(range(len(point_values)),labels=list(map(lambda i: f"{i:+d}" if i!=0 else "0",point_values)))
    ax[0,2].set_title(f"{scoreset.scoreset_name} ({n_c}, median:{use_median_prior},em:{not use_2c_equation}): (gnomAD pop, n={scoreset.sample_assignments[:,2].sum():,d})\nprior {summary['prior']:.3f}, C: {summary['C']}")
    return fig

def plot_scoreset_compare_point_assignments(dataset, scoresets, summary, scoreset_fits, score_ranges, n_samples):
    
    # Get 2c and 3c scoresets
    scoreset_2c = scoresets[list(scoresets.keys())[0]]  # 2c with samples
    scoreset_3c = scoresets[list(scoresets.keys())[-1]]  # 3c with samples
    score_range_2c = score_ranges[list(score_ranges.keys())[0]]
    score_range_3c = score_ranges[list(score_ranges.keys())[-1]]
    
    # Get configs - SORT CONSISTENTLY
    configs_2c = sorted([k for k in summary.keys() if k[1] == '2c' and 'avg' not in k]) + \
                 sorted([k for k in summary.keys() if k[1] == '2c' and 'avg' in k])
    
    configs_3c = sorted([k for k in summary.keys() if k[1] == '3c' and 'avg' not in k]) + \
                 sorted([k for k in summary.keys() if k[1] == '3c' and 'avg' in k])
        
    # Determine actual number of sample columns needed
    n_samples_2c = scoreset_2c.n_samples
    n_samples_3c = scoreset_3c.n_samples
    max_samples = max(n_samples_2c, n_samples_3c)
    
    # Layout: 4 rows (3c lr+, 3c fits/points, 2c fits/points, 2c lr+)
    # Columns: max(n_samples) + max(n_configs)
    max_configs = max(len(configs_2c), len(configs_3c))
    n_cols_total = max_samples + max_configs
    
    fig, ax = plt.subplots(4, n_cols_total, figsize=(5*n_cols_total, 20), 
                           squeeze=False, gridspec_kw={'hspace': 0.3, 'wspace': 0.3})

    # ===== Row 0: 3c LR+ summaries =====
    # Hide sample columns in row 0
    for col_idx in range(max_samples):
        ax[0, col_idx].axis('off')
    
    # Will get xlim after plotting fits
    
    # ===== Row 1: 3c fits and point assignments =====
    # Plot 3c fits
    for sample_num in range(n_samples_3c):
        ax_fit = ax[1, sample_num]
        
        sns.histplot(scoreset_3c.scores[scoreset_3c.sample_assignments[:,sample_num]], 
                     stat='density', ax=ax_fit, alpha=.5, color='pink')
        
        density = sample_density(score_range_3c, scoreset_fits[list(scoreset_fits.keys())[-1]], sample_num)
        for compNum in range(density.shape[1]):
            compDensity = density[:,compNum,:]
            d = np.nanpercentile(compDensity,[5,50,95],axis=0)
            ax_fit.plot(score_range_3c, d[1], color=f"C{compNum}", linestyle='--', label=f"Comp {compNum+1}")
        ax_fit.legend(fontsize=8)
        
        d = np.nansum(density, axis=1)
        d_perc = np.percentile(d, [5,50,95], axis=0)
        ax_fit.plot(score_range_3c, d_perc[1], color='black', alpha=.5)
        ax_fit.fill_between(score_range_3c, d_perc[0], d_perc[2], color='gray', alpha=0.3)
        ax_fit.set_title(f"3c: {scoreset_3c.sample_names[sample_num]}\n(n={scoreset_3c.sample_assignments[:,sample_num].sum():,d})")
        ax_fit.set_xlabel("Score")
        ax_fit.set_ylabel("Density")
        ax_fit.grid(linewidth=0.5, alpha=0.3)
    
    # Hide unused sample columns for 3c in row 1
    for col_idx in range(n_samples_3c, max_samples):
        ax[1, col_idx].axis('off')
    
    # Get x-limits from 3c fits
    xlim_3c = ax[1, 0].get_xlim()
    
    # Now plot 3c LR+ summaries in row 0 (now that we have xlim)
    for config_idx, (config, n_c) in enumerate(configs_3c):
        col_idx = max_samples + config_idx
        ax_lr = ax[0, col_idx]
        
        log_lr_plus = summary[(config, n_c)]['log_lr_plus']
        llr_curves = np.nanpercentile(np.array(log_lr_plus),[5,50,95],axis=0)
        labels = ['5th percentile','Median','95th percentile']
        
        for i, c in enumerate(['red','black','blue']):
            ax_lr.plot(score_range_3c, llr_curves[i], color=c, label=labels[i])
        
        point_values = sorted(list(set([abs(int(k)) for k in summary[(config, n_c)]['point_ranges'].keys()])))
        tauP, tauB, _ = list(map(np.log, thresholds_from_prior(summary[(config, n_c)]['prior'], point_values)))
        priors = np.percentile(np.array(summary[(config, n_c)]['priors']),[5,50,95])
        
        ax_lr.set_title(f"3c LR+ {config}\nprior: {priors[1]:.3f} ({priors[0]:.3f}-{priors[2]:.3f}), C: {summary[(config, n_c)]['C']}", fontsize=10)
        add_thresholds(tauP, tauB, ax_lr)
        ax_lr.set_xlabel("Score")
        ax_lr.set_ylabel("Log LR+")
        ax_lr.legend(fontsize=6, loc='best')
        ax_lr.set_xlim(xlim_3c)
        ax_lr.grid(linewidth=0.5, alpha=0.3)
    
    # Hide unused config columns for 3c in row 0
    for col_idx in range(max_samples + len(configs_3c), n_cols_total):
        ax[0, col_idx].axis('off')
    
    # Plot 3c point assignments in row 1
    for config_idx, (config, n_c) in enumerate(configs_3c):
        col_idx = max_samples + config_idx
        ax_points = ax[1, col_idx]
        
        point_ranges = sorted([(int(k), v) for k,v in summary[(config, n_c)]['point_ranges'].items()])
        point_values = [pr[0] for pr in point_ranges]
        
        # Plot all samples on same axis
        for sample_num in range(n_samples_3c):
            for pointIdx, (pointVal, scoreRanges) in enumerate(point_ranges):
                for sr in scoreRanges:
                    ax_points.plot([sr[0], sr[1]], [pointIdx, pointIdx], 
                                 color='red' if pointVal > 0 else 'blue', 
                                 linestyle='-', alpha=0.7, linewidth=2)
        
        ax_points.set_ylim(-1, len(point_values))
        ax_points.set_yticks(range(len(point_values)), 
                           labels=list(map(lambda i: f"{i:+d}" if i!=0 else "0", point_values)))
        ax_points.set_xlabel("Score")
        ax_points.set_ylabel("Points")
        ax_points.set_title(f"3c Points {config}", fontsize=10)
        ax_points.set_xlim(xlim_3c)
        ax_points.grid(linewidth=0.5, alpha=0.3)
    
    # Hide unused config columns for 3c in row 1
    for col_idx in range(max_samples + len(configs_3c), n_cols_total):
        ax[1, col_idx].axis('off')

    
    # ===== Row 2: 2c fits and point assignments =====
    # Plot 2c fits
    for sample_num in range(n_samples_2c):
        ax_fit = ax[2, sample_num]
        
        sns.histplot(scoreset_2c.scores[scoreset_2c.sample_assignments[:,sample_num]], 
                     stat='density', ax=ax_fit, alpha=.5, color='pink')
        
        density = sample_density(score_range_2c, scoreset_fits[list(scoreset_fits.keys())[0]], sample_num)
        for compNum in range(density.shape[1]):
            compDensity = density[:,compNum,:]
            d = np.nanpercentile(compDensity,[5,50,95],axis=0)
            ax_fit.plot(score_range_2c, d[1], color=f"C{compNum}", linestyle='--', label=f"Comp {compNum+1}")
        ax_fit.legend(fontsize=8)
        
        d = np.nansum(density, axis=1)
        d_perc = np.percentile(d, [5,50,95], axis=0)
        ax_fit.plot(score_range_2c, d_perc[1], color='black', alpha=.5)
        ax_fit.fill_between(score_range_2c, d_perc[0], d_perc[2], color='gray', alpha=0.3)
        ax_fit.set_title(f"2c: {scoreset_2c.sample_names[sample_num]}\n(n={scoreset_2c.sample_assignments[:,sample_num].sum():,d})")
        ax_fit.set_xlabel("Score")
        ax_fit.set_ylabel("Density")
        ax_fit.grid(linewidth=0.5, alpha=0.3)
    
    # Hide unused sample columns for 2c in row 2
    for col_idx in range(n_samples_2c, max_samples):
        ax[2, col_idx].axis('off')
    
    # Get x-limits from 2c fits
    xlim_2c = ax[2, 0].get_xlim()
    
    # Plot 2c point assignments in row 2
    for config_idx, (config, n_c) in enumerate(configs_2c):
        col_idx = max_samples + config_idx
        ax_points = ax[2, col_idx]
        
        point_ranges = sorted([(int(k), v) for k,v in summary[(config, n_c)]['point_ranges'].items()])
        point_values = [pr[0] for pr in point_ranges]
        
        # Plot all samples on same axis
        for sample_num in range(n_samples_2c):
            for pointIdx, (pointVal, scoreRanges) in enumerate(point_ranges):
                for sr in scoreRanges:
                    ax_points.plot([sr[0], sr[1]], [pointIdx, pointIdx], 
                                 color='red' if pointVal > 0 else 'blue', 
                                 linestyle='-', alpha=0.7, linewidth=2)
        
        ax_points.set_ylim(-1, len(point_values))
        ax_points.set_yticks(range(len(point_values)), 
                           labels=list(map(lambda i: f"{i:+d}" if i!=0 else "0", point_values)))
        ax_points.set_xlabel("Score")
        ax_points.set_ylabel("Points")
        ax_points.set_title(f"2c Points {config}", fontsize=10)
        ax_points.set_xlim(xlim_2c)
        ax_points.grid(linewidth=0.5, alpha=0.3)
    
    # Hide unused config columns for 2c in row 2
    for col_idx in range(max_samples + len(configs_2c), n_cols_total):
        ax[2, col_idx].axis('off')
    
    # ===== Row 3: 2c LR+ summaries =====
    # Hide sample columns in row 3
    for col_idx in range(max_samples):
        ax[3, col_idx].axis('off')
    
    for config_idx, (config, n_c) in enumerate(configs_2c):
        col_idx = max_samples + config_idx
        ax_lr = ax[3, col_idx]
        
        log_lr_plus = summary[(config, n_c)]['log_lr_plus']
        llr_curves = np.nanpercentile(np.array(log_lr_plus),[5,50,95],axis=0)
        labels = ['5th percentile','Median','95th percentile']
        
        for i, c in enumerate(['red','black','blue']):
            ax_lr.plot(score_range_2c, llr_curves[i], color=c, label=labels[i])
        
        point_values = sorted(list(set([abs(int(k)) for k in summary[(config, n_c)]['point_ranges'].keys()])))
        tauP, tauB, _ = list(map(np.log, thresholds_from_prior(summary[(config, n_c)]['prior'], point_values)))
        priors = np.percentile(np.array(summary[(config, n_c)]['priors']),[5,50,95])
        
        ax_lr.set_title(f"2c LR+ {config}\nprior: {priors[1]:.3f} ({priors[0]:.3f}-{priors[2]:.3f}), C: {summary[(config, n_c)]['C']}", fontsize=10)
        add_thresholds(tauP, tauB, ax_lr)
        ax_lr.set_xlabel("Score")
        ax_lr.set_ylabel("Log LR+")
        ax_lr.legend(fontsize=6, loc='best')
        ax_lr.set_xlim(xlim_2c)
        ax_lr.grid(linewidth=0.5, alpha=0.3)
    
    # Hide unused config columns for 2c in row 3
    for col_idx in range(max_samples + len(configs_2c), n_cols_total):
        ax[3, col_idx].axis('off')
    
    fig.suptitle(f"{scoreset_2c.scoreset_name}", fontsize=16, y=0.995)
    
    return fig

def sample_density(x, fits, sampleNum):
    _density = np.stack([density_utils.joint_densities(x, _fit['fit']['component_params'],_fit['fit']['weights'][sampleNum])
                        for _fit in fits])
    density = np.full(_density.shape,np.nan)
    for fitIdx,fit in enumerate(fits):
        fit_xmin,fit_xmax = fit['fit']['xlims']
        mask = (x >= fit_xmin) & (x <= fit_xmax)
        density[fitIdx,:,mask] = _density[fitIdx,:,mask]
    return density

def add_thresholds(tauP, tauB, ax):
    for tp,tb in zip(tauP,tauB):
        ax.axhline(tp,color='red',linestyle='--',alpha=0.5)
        ax.axhline(tb,color='blue',linestyle='--',alpha=0.5)

def plot_summary(scoreset: Scoreset, fits: List[Dict], summary:Dict, score_range, log_fp, log_fb, use_median_prior,use_2c_equation, n_c, benign_method, C):
    fig, ax = plt.subplots(1,1, figsize=(5,5))
    log_lr_plus = log_fp - log_fb
    llr_curves = np.nanpercentile(np.array(log_lr_plus),[5,50,95],axis=0)
    labels = ['5th percentile','Median','95th percentile']
    for i,c in enumerate(['red','black','blue']):
        ax.plot(score_range,llr_curves[i],color=c,label=labels[i])
    point_values = sorted(list(set([abs(int(k)) for k in summary['point_ranges'].keys()])))
    tauP,tauB,_ = list(map(np.log, thresholds_from_prior(summary['prior'],point_values)) )
    priors = np.percentile(np.array(summary['priors']),[5,50,95])
    ax.set_title(f"{dataset} ({n_c}, median:{use_median_prior},em:{not use_2c_equation}): prior: {priors[1]:.3f} ({priors[0]:.3f}-{priors[2]:.3f}), C: {C}")
    add_thresholds(tauP, tauB, ax)
    ax.set_xlabel("Score")
    ax.set_ylabel("Log LR+")
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    return fig

In [2]:
# save_dir = f'/data/ross/assay_calibration/point_assignments/{"un" if not constrained else ""}constrained_pngs'

In [3]:
import gzip

constrained = True

if constrained:
    results_name = 'initial_datasets_results_1000bootstraps_100fits'
    with gzip.open(f'/data/ross/assay_calibration/{results_name}.json.gz', 'rt', encoding='utf-8') as f:
        results = json.load(f)
    
    results_name = 'clinvar_circ_datasets_results_1000bootstraps_100fits'
    with gzip.open(f'/data/ross/assay_calibration/{results_name}.json.gz', 'rt', encoding='utf-8') as f:
        results = {**results, **json.load(f)}

else:
    results_name = 'unconstrained_rerun_initial_datasets_results_1000bootstraps_100fits'
    with gzip.open(f'/data/ross/assay_calibration/{results_name}.json.gz', 'rt', encoding='utf-8') as f:
        results = json.load(f)

urgent_datasets=(
  'VHL_Buckley_2024',
  'XRCC2_unpublished',
  'BARD1_unpublished'
)    

In [27]:
import pickle

start = False
for dataset in results.keys(): # PRIORITIZE URGENT DATASETS TO RUN FIRST, REST ALPHANUMERIC

    if dataset != 'ASPA_Grønbæk-Thygesen_2024_abundance':
        continue

    # if dataset == 'BARD1_unpublished':
    #     start = True
    #     continue
    # if not start:
        continue

    dataset_f = f"/data/ross/assay_calibration/scoresets/{dataset}.json"
    scoreset = Scoreset.from_json(dataset_f)
    n_samples = len([s for s in scoreset.samples])

    save_dir = f'/data/ross/assay_calibration/point_assignment_comparison/{dataset}'
    os.makedirs(save_dir, exist_ok=True)

    scoresets_dict, summary_dict, scoreset_fits_dict, score_ranges_dict = {},{},{},{}
    
    for n_c in ('2c','3c'):
        
        boot_results = [results[dataset][key][n_c] for key,val in results[dataset].items()]
        # boot_results = boot_results[:20] # for speed of dev

        for benign_method in ['benign', 'avg']:
            for use_median_prior, use_2c_equation in zip([False, True, False], [False, False, True]):
                    
                # only do synonymous analysis if synonymous exists and only for one config
                if benign_method != 'benign' and (not use_median_prior or use_2c_equation or n_samples != 4):
                    continue
                    
                if n_c == '3c' and use_2c_equation: # N/A
                    continue
    
                experiment_code = f'{dataset}_{n_c}_{"median" if use_median_prior else "5-percentile"}_{"equation" if use_2c_equation else "em"}{"_"+benign_method if benign_method != "benign" else ""}'
                
                save_filepath = f'{save_dir}/{experiment_code}.json'
                pkl_filepath = f'{save_dir}/{experiment_code}.pkl'
        
                if os.path.exists(pkl_filepath):
                    print(f'{pkl_filepath} exists')
                    with open(pkl_filepath,'rb') as f:
                        scoreset, indv_summary, fits, score_range, config, n_c = pickle.load(f)
                        
                    if not config.startswith('(e'): # old config, swap
                        config = f"({'equation' if use_2c_equation else 'em'}, {'median' if use_median_prior else '5-percentile'}, {benign_method})"
                    
                    if 'C' in indv_summary:
                        scoresets_dict[(config,n_c)] = scoreset
                        summary_dict[(config,n_c)] = indv_summary
                        scoreset_fits_dict[(config,n_c)] = fits
                        score_ranges_dict[(config,n_c)] = score_range
                        # continue ### rerun everthing
                    
                print(f'Starting {experiment_code}...')
                try:
                
                    scoreset, indv_summary, fits, score_range, config, _ = summarize_scoreset(boot_results,scoreset,save_filepath,use_median_prior,use_2c_equation,n_c,benign_method)
                    del indv_summary['all_path_ranges']
                    del indv_summary['all_ben_ranges']
                    
                    scoresets_dict[(config,n_c)] = scoreset
                    summary_dict[(config,n_c)] = indv_summary
                    scoreset_fits_dict[(config,n_c)] = fits
                    score_ranges_dict[(config,n_c)] = score_range
            
                    with open(pkl_filepath, 'wb') as f:
                        pickle.dump((scoreset, indv_summary, fits, score_range, config, n_c), f)

                except Exception as e:
                    print(e)


    point_comparison_figure = plot_scoreset_compare_point_assignments(dataset, scoresets_dict, summary_dict, scoreset_fits_dict, score_ranges_dict, n_samples)
    
    figure_filepath = f"{save_dir}/{dataset}_point_comparison.png"
    point_comparison_figure.savefig(figure_filepath,bbox_inches='tight',dpi=300)
    plt.close(point_comparison_figure)

/data/ross/assay_calibration/point_assignment_comparison/ASPA_Grønbæk-Thygesen_2024_abundance/ASPA_Grønbæk-Thygesen_2024_abundance_2c_5-percentile_em.pkl exists
Starting ASPA_Grønbæk-Thygesen_2024_abundance_2c_5-percentile_em...
loading priors from cached em
getting point ranges for each bootstrap...


[Parallel(n_jobs=72)]: Using backend LokyBackend with 72 concurrent workers.
[Parallel(n_jobs=72)]: Done   1 tasks      | elapsed:    2.6s
[Parallel(n_jobs=72)]: Done  18 tasks      | elapsed:    2.7s
[Parallel(n_jobs=72)]: Done  37 tasks      | elapsed:    2.8s
[Parallel(n_jobs=72)]: Done  56 tasks      | elapsed:    3.0s
[Parallel(n_jobs=72)]: Done  77 tasks      | elapsed:    4.2s
[Parallel(n_jobs=72)]: Done  98 tasks      | elapsed:    4.3s
[Parallel(n_jobs=72)]: Done 121 tasks      | elapsed:    4.5s
[Parallel(n_jobs=72)]: Done 144 tasks      | elapsed:    5.5s
[Parallel(n_jobs=72)]: Done 169 tasks      | elapsed:    5.8s
[Parallel(n_jobs=72)]: Done 194 tasks      | elapsed:    6.1s
[Parallel(n_jobs=72)]: Done 221 tasks      | elapsed:    7.0s
[Parallel(n_jobs=72)]: Done 248 tasks      | elapsed:    7.3s
[Parallel(n_jobs=72)]: Done 277 tasks      | elapsed:    7.9s
[Parallel(n_jobs=72)]: Done 306 tasks      | elapsed:    8.6s
[Parallel(n_jobs=72)]: Done 337 tasks      | elapsed:  

ranges_p bootstrap 0, score +1: [nan]
using 5-percentile to get conservative thresholds...
boot prior: 0.0 - 1.0
1 nan bootstrap points: 128
2 nan bootstrap points: 128
3 nan bootstrap points: 128
4 nan bootstrap points: 128
5 nan bootstrap points: 128
6 nan bootstrap points: 128
7 nan bootstrap points: 128
8 nan bootstrap points: 128
conservative_thresholds {1: np.float64(nan), -1: np.float64(nan), 2: np.float64(nan), -2: np.float64(nan), 3: np.float64(nan), -3: np.float64(nan), 4: np.float64(nan), -4: np.float64(nan), 5: np.float64(nan), -5: np.float64(nan), 6: np.float64(nan), -6: np.float64(nan), 7: np.float64(nan), -7: np.float64(nan), 8: np.float64(nan), -8: np.float64(nan)}
point_ranges {1: [], -1: [], 2: [], -2: [], 3: [], -3: [], 4: [], -4: [], 5: [], -5: [], 6: [], -6: [], 7: [], -7: [], 8: [], -8: []}


  diff_b_a = subtract(b, a)


[(-8, []), (-7, []), (-6, []), (-5, []), (-4, []), (-3, []), (-2, []), (-1, []), (1, []), (2, []), (3, []), (4, []), (5, []), (6, []), (7, []), (8, [])]
/data/ross/assay_calibration/point_assignment_comparison/ASPA_Grønbæk-Thygesen_2024_abundance/ASPA_Grønbæk-Thygesen_2024_abundance_2c_median_em.pkl exists
Starting ASPA_Grønbæk-Thygesen_2024_abundance_2c_median_em...
loading priors from cached em
getting point ranges for each bootstrap...


[Parallel(n_jobs=72)]: Using backend LokyBackend with 72 concurrent workers.
[Parallel(n_jobs=72)]: Done   1 tasks      | elapsed:    1.7s
[Parallel(n_jobs=72)]: Done  18 tasks      | elapsed:    1.8s
[Parallel(n_jobs=72)]: Done  37 tasks      | elapsed:    1.8s
[Parallel(n_jobs=72)]: Done  56 tasks      | elapsed:    1.9s
[Parallel(n_jobs=72)]: Done  77 tasks      | elapsed:    3.1s
[Parallel(n_jobs=72)]: Done  98 tasks      | elapsed:    3.2s
[Parallel(n_jobs=72)]: Done 121 tasks      | elapsed:    3.4s
[Parallel(n_jobs=72)]: Done 144 tasks      | elapsed:    4.3s
[Parallel(n_jobs=72)]: Done 169 tasks      | elapsed:    4.6s
[Parallel(n_jobs=72)]: Done 194 tasks      | elapsed:    4.9s
[Parallel(n_jobs=72)]: Done 221 tasks      | elapsed:    5.9s
[Parallel(n_jobs=72)]: Done 248 tasks      | elapsed:    6.1s
[Parallel(n_jobs=72)]: Done 277 tasks      | elapsed:    6.6s
[Parallel(n_jobs=72)]: Done 306 tasks      | elapsed:    7.4s
[Parallel(n_jobs=72)]: Done 337 tasks      | elapsed:  

ranges_p bootstrap 0, score +1: [nan]
using median prior to get unified thresholds...
[(-8, []), (-7, []), (-6, []), (-5, []), (-4, []), (-3, []), (-2, []), (-1, []), (1, []), (2, []), (3, []), (4, []), (5, []), (6, []), (7, []), (8, [])]
/data/ross/assay_calibration/point_assignment_comparison/ASPA_Grønbæk-Thygesen_2024_abundance/ASPA_Grønbæk-Thygesen_2024_abundance_2c_5-percentile_equation.pkl exists
Starting ASPA_Grønbæk-Thygesen_2024_abundance_2c_5-percentile_equation...
loading priors from cached equation
getting point ranges for each bootstrap...


[Parallel(n_jobs=72)]: Using backend LokyBackend with 72 concurrent workers.
[Parallel(n_jobs=72)]: Done   1 tasks      | elapsed:    1.4s
[Parallel(n_jobs=72)]: Done  18 tasks      | elapsed:    1.5s
[Parallel(n_jobs=72)]: Done  37 tasks      | elapsed:    1.6s
[Parallel(n_jobs=72)]: Done  56 tasks      | elapsed:    1.7s
[Parallel(n_jobs=72)]: Done  77 tasks      | elapsed:    2.8s
[Parallel(n_jobs=72)]: Done  98 tasks      | elapsed:    3.0s
[Parallel(n_jobs=72)]: Done 121 tasks      | elapsed:    3.2s
[Parallel(n_jobs=72)]: Done 144 tasks      | elapsed:    3.7s
[Parallel(n_jobs=72)]: Done 169 tasks      | elapsed:    4.4s
[Parallel(n_jobs=72)]: Done 194 tasks      | elapsed:    4.7s
[Parallel(n_jobs=72)]: Done 221 tasks      | elapsed:    5.6s
[Parallel(n_jobs=72)]: Done 248 tasks      | elapsed:    5.9s
[Parallel(n_jobs=72)]: Done 277 tasks      | elapsed:    6.5s
[Parallel(n_jobs=72)]: Done 306 tasks      | elapsed:    7.2s
[Parallel(n_jobs=72)]: Done 337 tasks      | elapsed:  

ranges_p bootstrap 0, score +1: [nan]
using 5-percentile to get conservative thresholds...
boot prior: -579.1938730023005 - 48.12336582343781
1 nan bootstrap points: 187
2 nan bootstrap points: 187
3 nan bootstrap points: 187
4 nan bootstrap points: 187
5 nan bootstrap points: 187
6 nan bootstrap points: 187
7 nan bootstrap points: 187
8 nan bootstrap points: 187
conservative_thresholds {1: np.float64(0.1285883628362836), -1: np.float64(nan), 2: np.float64(0.11961321332133215), -2: np.float64(nan), 3: np.float64(0.10673923792379239), -3: np.float64(nan), 4: np.float64(nan), -4: np.float64(nan), 5: np.float64(nan), -5: np.float64(nan), 6: np.float64(nan), -6: np.float64(nan), 7: np.float64(nan), -7: np.float64(nan), 8: np.float64(nan), -8: np.float64(nan)}
point_ranges {1: [[np.float64(0.11961321332133215), np.float64(0.1285883628362836)]], -1: [], 2: [[np.float64(0.10673923792379239), np.float64(0.11961321332133215)]], -2: [], 3: [[np.float64(-0.0304), np.float64(0.10673923792379239)]]

  diff_b_a = subtract(b, a)


[(-8, []), (-7, []), (-6, []), (-5, []), (-4, []), (-3, []), (-2, []), (-1, []), (1, [[0.11961321332133215, 0.1285883628362836]]), (2, [[0.10673923792379239, 0.11961321332133215]]), (3, [[-0.0304, 0.10673923792379239]]), (4, []), (5, []), (6, []), (7, []), (8, [])]
/data/ross/assay_calibration/point_assignment_comparison/ASPA_Grønbæk-Thygesen_2024_abundance/ASPA_Grønbæk-Thygesen_2024_abundance_3c_5-percentile_em.pkl exists
Starting ASPA_Grønbæk-Thygesen_2024_abundance_3c_5-percentile_em...
loading priors from cached em
getting point ranges for each bootstrap...


[Parallel(n_jobs=72)]: Using backend LokyBackend with 72 concurrent workers.
[Parallel(n_jobs=72)]: Done   1 tasks      | elapsed:    1.3s
[Parallel(n_jobs=72)]: Done  18 tasks      | elapsed:    1.5s
[Parallel(n_jobs=72)]: Done  37 tasks      | elapsed:    1.5s
[Parallel(n_jobs=72)]: Done  56 tasks      | elapsed:    1.7s
[Parallel(n_jobs=72)]: Done  77 tasks      | elapsed:    2.8s
[Parallel(n_jobs=72)]: Done  98 tasks      | elapsed:    2.9s
[Parallel(n_jobs=72)]: Done 121 tasks      | elapsed:    3.1s
[Parallel(n_jobs=72)]: Done 144 tasks      | elapsed:    3.8s
[Parallel(n_jobs=72)]: Done 169 tasks      | elapsed:    4.3s
[Parallel(n_jobs=72)]: Done 194 tasks      | elapsed:    4.6s
[Parallel(n_jobs=72)]: Done 221 tasks      | elapsed:    5.5s
[Parallel(n_jobs=72)]: Done 248 tasks      | elapsed:    5.9s
[Parallel(n_jobs=72)]: Done 277 tasks      | elapsed:    6.3s
[Parallel(n_jobs=72)]: Done 306 tasks      | elapsed:    7.1s
[Parallel(n_jobs=72)]: Done 337 tasks      | elapsed:  

ranges_p bootstrap 0, score +1: [0.07779896 0.07884562]
using 5-percentile to get conservative thresholds...
boot prior: 0.0 - 1.0
1 nan bootstrap points: 16
2 nan bootstrap points: 16
3 nan bootstrap points: 16
4 nan bootstrap points: 16
5 nan bootstrap points: 16
6 nan bootstrap points: 16
7 nan bootstrap points: 16
8 nan bootstrap points: 16
conservative_thresholds {1: np.float64(0.07851200020002001), -1: np.float64(0.7877188948894889), 2: np.float64(0.07539817281728174), -2: np.float64(nan), 3: np.float64(0.07253946994699471), -3: np.float64(nan), 4: np.float64(0.06920322632263227), -4: np.float64(nan), 5: np.float64(0.0635577787778778), -5: np.float64(nan), 6: np.float64(0.05385650565056506), -6: np.float64(nan), 7: np.float64(0.033623169316931704), -7: np.float64(nan), 8: np.float64(nan), -8: np.float64(nan)}
point_ranges {1: [[np.float64(0.07539817281728174), np.float64(0.07851200020002001)]], -1: [[np.float64(0.7877188948894889), np.float64(1.2778)]], 2: [[np.float64(0.07253946

[Parallel(n_jobs=72)]: Using backend LokyBackend with 72 concurrent workers.
[Parallel(n_jobs=72)]: Done   1 tasks      | elapsed:    1.4s
[Parallel(n_jobs=72)]: Done  18 tasks      | elapsed:    1.5s
[Parallel(n_jobs=72)]: Done  37 tasks      | elapsed:    1.5s
[Parallel(n_jobs=72)]: Done  56 tasks      | elapsed:    1.7s
[Parallel(n_jobs=72)]: Done  77 tasks      | elapsed:    2.8s
[Parallel(n_jobs=72)]: Done  98 tasks      | elapsed:    2.9s
[Parallel(n_jobs=72)]: Done 121 tasks      | elapsed:    3.1s
[Parallel(n_jobs=72)]: Done 144 tasks      | elapsed:    3.9s
[Parallel(n_jobs=72)]: Done 169 tasks      | elapsed:    4.4s
[Parallel(n_jobs=72)]: Done 194 tasks      | elapsed:    4.6s
[Parallel(n_jobs=72)]: Done 221 tasks      | elapsed:    5.6s
[Parallel(n_jobs=72)]: Done 248 tasks      | elapsed:    5.9s
[Parallel(n_jobs=72)]: Done 277 tasks      | elapsed:    6.3s
[Parallel(n_jobs=72)]: Done 306 tasks      | elapsed:    7.2s
[Parallel(n_jobs=72)]: Done 337 tasks      | elapsed:  

ranges_p bootstrap 0, score +1: [0.07779896 0.07884562]
using median prior to get unified thresholds...
[(-8, []), (-7, []), (-6, []), (-5, []), (-4, []), (-3, []), (-2, []), (-1, []), (1, []), (2, []), (3, []), (4, []), (5, []), (6, []), (7, []), (8, [])]


In [None]:
'done'

In [None]:

# [(-8, []), (-7, []), (-6, []), (-5, []), (-4, []), (-3, []), (-2, [[-0.18066820274493, 1.1121633402781734]]), (-1, [[-0.2190387471559836, -0.18066820274493]]), (1, [[-0.9107317610725785, -0.8994764013786694]]), (2, [[-0.92352194254293, -0.9107317610725785]]), (3, [[-0.9409165893426077, -0.92352194254293]]), (4, [[-0.9961701732945247, -0.9409165893426077]]), (5, [[-3.3065885540987723, -0.9961701732945247]]), (6, []), (7, []), (8, [])]