# For each constraint score, assess its performance on a stringent truth set

In [1]:
import matplotlib.pyplot as plt 
plt.rcParams.update({
    'font.size': 20,
    'figure.figsize': (12, 6)
})
plt.rc('font', family='arial')

In [9]:
import importlib
import matplotlib.cm as cm

import sys
import os
sys.path.append(os.path.abspath(os.path.join('..')))

import util
importlib.reload(util)
from util import (
  length_to_string, 
  compute_limits as _compute_limits, 
  slice_feature_space,
)

In [2]:
CONSTRAINT_TOOLS = '/scratch/ucgd/lustre-labs/quinlan/u6018199/constraint-tools'
CONSTRAINT_TOOLS_DATA = '/scratch/ucgd/lustre-labs/quinlan/data-shared/constraint-tools'

import sys
sys.path.append(f'{CONSTRAINT_TOOLS}/utilities')

In [10]:
THRESHOLD_BIN_SIZE = 500 # 4000, previously 

## Gnocchi

In [4]:
import polars as pl

def get_gnocchi_windows(): 
    df = pl.read_csv(
        f'{CONSTRAINT_TOOLS_DATA}/stringent_truth_sets/gnocchi-truth-set.bed',
        separator='\t',
    )
    df = df.to_pandas()
    return df 
    
GNOCCHI_WINDOWS = get_gnocchi_windows()
GNOCCHI_WINDOWS

Unnamed: 0,chromosome,start,end,gnocchi,truly constrained,B,B_M1star.EUR,GC_content_1000bp
0,chr1,1554620,1555020,4.059724,True,0.652,0.108103,0.606394
1,chr1,2128961,2129161,6.530123,True,0.841,0.347981,0.585415
2,chr1,2268561,2268761,5.007183,True,0.847,0.347981,0.602398
3,chr1,2545161,2545361,2.775673,True,0.840,0.347981,0.640360
4,chr1,3208836,3209036,6.070480,True,0.966,0.788536,0.525475
...,...,...,...,...,...,...,...,...
7085,chr10,90560000,90561000,1.536373,False,0.803,0.136005,0.375624
7086,chr11,9337000,9338000,-1.182434,False,0.793,0.394411,0.432567
7087,chr9,103067000,103068000,0.887998,False,0.936,0.116710,0.395604
7088,chr7,34770000,34771000,2.594133,False,0.866,0.524225,0.390609


## lambda_s (Dukler et al)

In [5]:
def get_lambda_s_windows(): 
    df = pl.read_csv(
        f'{CONSTRAINT_TOOLS_DATA}/stringent_truth_sets/lambda_s-truth-set.bed',
        separator='\t',
    )
    df = df.to_pandas()
    return df 
  
LAMBDA_S_WINDOWS = get_lambda_s_windows()
LAMBDA_S_WINDOWS

Unnamed: 0,chromosome,start,end,truly constrained,B,B_M1star.EUR,GC_content_1000bp,lambda_s
0,chr1,2128961,2129161,True,0.841,0.347981,0.585415,0.117883
1,chr1,2268561,2268761,True,0.847,0.347981,0.602398,0.115906
2,chr1,6240740,6241540,True,0.872,0.014875,0.548452,0.090827
3,chr1,6483340,6483540,True,0.837,0.014875,0.572428,0.088351
4,chr1,6697340,6697540,True,0.708,0.014875,0.461538,-0.008011
...,...,...,...,...,...,...,...,...
3345,chr2,133387000,133388000,False,0.835,0.281558,0.397602,-0.302237
3346,chr2,182977000,182978000,False,0.664,0.273451,0.383616,0.016217
3347,chr14,28374000,28375000,False,0.838,0.208852,0.323676,0.080208
3348,chr12,67961000,67962000,False,0.925,0.399876,0.369630,-0.026623


## Depletion Rank

In [6]:
def get_halldorsson_windows(): 
    df = pl.read_csv(
        f'{CONSTRAINT_TOOLS_DATA}/stringent_truth_sets/depletion_rank-truth-set.bed',
        separator='\t',
    )
    df = df.to_pandas()
    return df 
  
HALLDORSSON_WINDOWS = get_halldorsson_windows()
HALLDORSSON_WINDOWS

Unnamed: 0,chromosome,start,end,truly constrained,B,B_M1star.EUR,GC_content_1000bp,depletion_rank_constraint_score_complement
0,chr1,1554620,1555020,True,0.652,0.108103,0.606394,0.368898
1,chr1,2128961,2129161,True,0.841,0.347981,0.585415,0.899933
2,chr1,2268561,2268761,True,0.847,0.347981,0.602398,0.600846
3,chr1,2545161,2545361,True,0.840,0.347981,0.640360,0.650301
4,chr1,3208836,3209036,True,0.966,0.788536,0.525475,0.982930
...,...,...,...,...,...,...,...,...
7079,chr1,187935800,187936300,False,0.944,0.167226,0.381618,0.382341
7080,chr4,72266800,72267300,False,0.904,0.121811,0.328671,0.106622
7081,chr7,158780450,158780950,False,0.617,0.109041,0.436563,0.036848
7082,chr7,153212750,153213250,False,0.947,0.173963,0.359640,0.128205


## CDTS

In [7]:
def get_CDTS_windows(): 
    df = pl.read_csv(
        f'{CONSTRAINT_TOOLS_DATA}/stringent_truth_sets/CDTS-truth-set.bed',
        separator='\t',
    )
    df = df.to_pandas()
    return df 
  
CDTS_WINDOWS = get_CDTS_windows()
CDTS_WINDOWS

Unnamed: 0,chromosome,start,end,truly constrained,B,B_M1star.EUR,GC_content_1000bp,percentile_rank_of_observed_minus_expected_complement
0,chr1,1554620,1555020,True,0.652,0.108103,0.606394,72.416125
1,chr1,2128961,2129161,True,0.841,0.347981,0.585415,53.544438
2,chr1,2268561,2268761,True,0.847,0.347981,0.602398,47.909374
3,chr1,2545161,2545361,True,0.840,0.347981,0.640360,98.359012
4,chr1,3208836,3209036,True,0.966,0.788536,0.525475,97.164713
...,...,...,...,...,...,...,...,...
6251,chr11,98922329,98922880,False,0.944,0.246060,0.325674,34.542219
6252,chr16,60926505,60927056,False,0.783,0.517240,0.463536,8.388564
6253,chr14,101636322,101636873,False,0.857,0.200026,0.505495,70.191058
6254,chr2,71485009,71485560,False,0.956,0.522355,0.476523,39.155480


## Features that negatively impact a constraint-score-based classifier

In [13]:
import pandas as pd

def downsample(df, group_columns, target):
  positive_class_sizes = df.groupby(group_columns)[target].apply(lambda ser: ser.value_counts().get(True, 0))
  negative_class_sizes = df.groupby(group_columns)[target].apply(lambda ser: ser.value_counts().get(False, 0))  
  positive_to_negative_ratios = positive_class_sizes/negative_class_sizes
  min_positive_to_negative_ratio = positive_to_negative_ratios.min()

  def downsample_positive_class(group):
    negative_class = group[group[target] == False]
    negative_class_size = len(negative_class)
    positive_class = group[group[target] == True]
    new_positive_class_size = int(min_positive_to_negative_ratio*negative_class_size)
    positive_class_downsampled = positive_class.sample(new_positive_class_size)
    return pd.concat([positive_class_downsampled, negative_class])
  
  df_downsampled = df.groupby(group_columns).apply(downsample_positive_class).reset_index(drop=True)
  return df_downsampled

def preprocess(df, feature, target, number_bins=None, bins=None): 
    df = df.copy() 

    if number_bins is None and bins is None: 
        raise ValueError('must provide only one of number_bins and bins')    
    if number_bins is not None: 
        df[f'{feature}_bin'] = pd.cut(df[f'{feature}'], bins=number_bins)
    elif bins is not None: 
        # https://pandas.pydata.org/docs/reference/api/pandas.cut.html
        df[f'{feature}_bin'] = pd.cut(df[f'{feature}'], bins) 
    else: 
        raise ValueError('must provide either bins or number_bins')

    df = downsample(
        df, 
        group_columns=[f'{feature}_bin'], 
        target=target
    )
    return df

def get_GC_mean_factor(gc_window_size): 
  if gc_window_size == 1000: 
    return 1
  elif gc_window_size == 1000000:
    return 0.975
  else:
    raise ValueError(f'invalid GC window size: {gc_window_size}')
                     
def get_GC_std_factor(gc_window_size): 
  if gc_window_size == 1000: 
    return 0.3
  elif gc_window_size == 1000000:
    return 0.3
  else:
    raise ValueError(f'invalid GC window size: {gc_window_size}')
  
def compute_center_limits(df, feature, mean_factor, std_factor):
  return _compute_limits(df, feature, mean_factor, std_factor)

def compute_GC_tail_limits(gc_window_size):
  if gc_window_size == 1000: 
    return 0.5, 1.0
  elif gc_window_size == 1000000:
    return 0.440, 1.0
  else:
    raise ValueError(f'invalid GC window size: {gc_window_size}')

def get_GC_feature_lims_label(df, gc_window_size): 
  return ( 
    f'GC_content_{gc_window_size}bp', 
    compute_center_limits(
      df, 
      f'GC_content_{gc_window_size}bp', 
      mean_factor=get_GC_mean_factor(gc_window_size), 
      std_factor=get_GC_std_factor(gc_window_size)
    ), 
    compute_GC_tail_limits(gc_window_size),
    f'GC_content ({length_to_string(gc_window_size)})'
  )

def get_features_and_lims_and_labels(df, gc_window_size, log=False): 
  features_and_lims_and_labels = [
    get_GC_feature_lims_label(df, gc_window_size),
    (
      'B_M1star.EUR', 
      compute_center_limits(
        df, 
        'B_M1star.EUR', 
        mean_factor=1, 
        std_factor=0.3
      ), 
      (0.75, 1.5),
      'gBGC'
    ),
    (
      'B', 
      compute_center_limits(
        df, 
        'B', 
        mean_factor=1, 
        std_factor=0.3
      ), 
      (0.5, 0.55),
      'BGS'
    ),
  ]

  if log: 
    for f, cl, _, _ in features_and_lims_and_labels: 
      print(f'{f}: {cl}')

  return features_and_lims_and_labels 

In [15]:
# this is "r" in the baseline-classifier theory: 
def compute_positive_fraction(df, target):
  value_counts = df[target].value_counts()
  number_negative_examples = value_counts.get(False, 0)
  number_positive_examples = value_counts.get(True, 0)
  return number_positive_examples / (number_negative_examples + number_positive_examples)

In [11]:
from sklearn.metrics import precision_recall_curve, roc_curve, auc

def plot_random_classifier(df, target, type_, color, ax):
    r = compute_positive_fraction(df, target)
    if type_ == 'precision': 
        ys = [r, r] 
    elif type_ == 'FDR':
        ys = [1-r, 1-r]
    else: 
        raise ValueError(f'invalid type: {type_}')
    ax.plot([0, 1], ys, linestyle='--', linewidth=3, color=color, label='random classifier') 
  
def plot_pr_curve_single_bin(ax, type_, precision, recall, feature, feature_bin, color):
    FDR = 1 - precision

    if type_ == 'precision': 
        ys = precision
    elif type_ == 'FDR':
        ys = FDR
    else:
        raise ValueError(f'invalid type: {type_}')
    
    ax.plot(recall, ys, color=color, linestyle='-', label=f'{feature} in {feature_bin}') 

def plot_roc_curve_single_bin(ax, fpr, tpr, color, feature, feature_bin):
    ax.plot(fpr, tpr, color=color, linestyle='-', label=f'{feature} in {feature_bin}')

def finish_pr_curve(ax, type_, constraint_score_alias, r, ylim_scale): 
    ax.grid(False)
    ax.legend()
    ax.set_xlabel('Recall')
    if type_ == 'precision':
        ylabel = 'Precision'
    elif type_ == 'FDR':
        ylabel = 'FDR'
    else:
        raise ValueError(f'invalid type: {type_}')
    ax.set_ylabel(ylabel) 
    ax.set_xlim(0, 1)
    if type_ == 'precision':
        ylim = (0, ylim_scale*r)
    elif type_ == 'FDR':
        ylim = (0, ylim_scale*(1-r))
    else: 
        raise ValueError(f'invalid type: {type_}')
    ax.set_ylim(ylim) # type: ignore
    ax.set_title(
        f'{constraint_score_alias}\n'
    )

def finish_roc_curve(ax, constraint_score_alias):
    ax.grid(False)
    ax.legend()
    ax.set_xlabel('FPR')
    ax.set_ylabel('TPR') 
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_title(
        f'{constraint_score_alias}\n'
    )

def plot_curves_all_bins(df, type_, feature, number_bins, bins, constraint_score, constraint_score_alias, target, ylim_scale, sliced): 
    # plot_feature_distribution(df, feature, xlabel=feature, density=False)
    
    df = preprocess(df, feature, target, number_bins, bins)

    _, ax_pr = plt.subplots(figsize=(10, 10))
    facecolor = 0.75
    ax_pr.set_facecolor((facecolor, facecolor, facecolor))  # Set the background color of the axis to a lighter shade of grey
    # _, ax_roc = plt.subplots(figsize=(10, 10))

    cmap = cm.get_cmap('Blues') # 'Greys', 'inferno' 

    list_of_dicts = []

    if number_bins is None and bins is None: 
        raise ValueError('must provide only one of number_bins and bins')    
    if number_bins is not None: 
        bins = df[f'{feature}_bin'].unique()
    elif bins is not None: 
        number_bins = len(bins)
    else: 
        raise ValueError('must provide either bins or number_bins')

    for i, feature_bin in enumerate(sorted(bins)): 
        df_bin = df[df[f'{feature}_bin'] == feature_bin]
        print(f'{feature_bin}: {len(df_bin)}')
        if len(df_bin) < THRESHOLD_BIN_SIZE: continue
        color = cmap(i / number_bins)
        targets, scores = df_bin[target], df_bin[constraint_score]        
        precision, recall, _ = precision_recall_curve(targets, scores)
        fpr, tpr, _ = roc_curve(targets, scores)
        list_of_dicts.append({
            'constraint_score': constraint_score_alias,
            'conditioned_on_complementary_features': sliced,
            'feature_to_stratify_by': feature,
            'feature_bin': feature_bin,
            'area': auc(recall, precision),
            'area_type': 'PRC',
            'area_color': color,
        })
        list_of_dicts.append({
            'constraint_score': constraint_score_alias,
            'conditioned_on_complementary_features': sliced,
            'feature_to_stratify_by': feature,
            'feature_bin': feature_bin,
            'area': auc(fpr, tpr),
            'area_type': 'ROC',
            'area_color': color,
        })
        list_of_dicts.append({
            'constraint_score': constraint_score_alias,
            'conditioned_on_complementary_features': sliced,
            'feature_to_stratify_by': feature,
            'feature_bin': feature_bin,
            'area': auc(recall, precision)/compute_positive_fraction(df_bin, target),
            'area_type': 'PRCnorm',
            'area_color': color,
        })
        plot_pr_curve_single_bin(ax_pr, type_, precision, recall, feature, feature_bin, color)
        # plot_roc_curve_single_bin(ax_roc, fpr, tpr, color, feature, feature_bin)

    plot_random_classifier(
        df, 
        target, 
        type_=type_,
        color='black', 
        ax=ax_pr,
    )
    # ax_roc.plot([0, 1], [0, 1], linewidth=3, linestyle='--', color='black')

    # without breaking down by feature
    targets, scores = df[target], df[constraint_score]
    precision, recall, _ = precision_recall_curve(targets, scores)
    FDR = 1 - precision
    fpr, tpr, _ = roc_curve(targets, scores)

    list_of_dicts.append({
        'constraint_score': constraint_score_alias,
        'conditioned_on_complementary_features': sliced,
        'feature_to_stratify_by': feature,
        'feature_bin': 'all',
        'area': auc(recall, precision),
        'area_type': 'PRC',
        'area_color': 'green',
    })
    list_of_dicts.append({
        'constraint_score': constraint_score_alias,
        'conditioned_on_complementary_features': sliced,
        'feature_to_stratify_by': feature,
        'feature_bin': 'all',
        'area': auc(fpr, tpr),
        'area_type': 'ROC',
        'area_color': 'green',
    })
    list_of_dicts.append({
        'constraint_score': constraint_score_alias,
        'conditioned_on_complementary_features': sliced,
        'feature_to_stratify_by': feature,
        'feature_bin': 'all',
        'area': auc(recall, precision)/compute_positive_fraction(df, target),
        'area_type': 'PRCnorm',
        'area_color': 'green',
    })

    if type_ == 'precision': 
        ys = precision
    elif type_ == 'FDR':
        ys = FDR
    else:
        raise ValueError(f'invalid type: {type_}')
    ax_pr.plot(recall, ys, color='black', linewidth=3, linestyle='-', label=f'without breaking down by {feature}') 

    finish_pr_curve(ax_pr, type_, constraint_score_alias, compute_positive_fraction(df, target), ylim_scale)
    # finish_roc_curve(ax_roc, constraint_score_alias)

    plt.show()

    return list_of_dicts

def plot_curves_with_and_without_slicing(
    df, 
    gc_window_size, 
    feature, 
    number_bins, 
    bins,
    constraint_score, 
    constraint_score_alias,
    ylim_scale,
    target, 
    type_='precision',
): 
    assert feature in [f for f, _, _, _ in get_features_and_lims_and_labels(df, gc_window_size)]

    list_of_dicts = []
    list_of_dicts.extend(plot_curves_all_bins(df, type_, feature, number_bins, bins, constraint_score, constraint_score_alias, target, ylim_scale, sliced=False))

    conditional_features_and_center_lims = [
        (f, cl) for f, cl, _, _ in get_features_and_lims_and_labels(df, gc_window_size) if f != feature
    ]
    df_sliced = slice_feature_space(df, conditional_features_and_center_lims)
    list_of_dicts.extend(plot_curves_all_bins(df_sliced, type_, feature, number_bins, bins, constraint_score, constraint_score_alias, target, ylim_scale, sliced=True))

    return list_of_dicts

def compute_midpoints_and_enhancer_fraction(df, number_bins, bins, feature, target): 
    df = df.copy() 
    if number_bins is None and bins is None: 
        raise ValueError('must provide only one of number_bins and bins')    
    if number_bins is not None: 
        df[f'{feature}_bin'] = pd.cut(df[f'{feature}'], bins=number_bins)
    elif bins is not None: 
        # https://pandas.pydata.org/docs/reference/api/pandas.cut.html
        df[f'{feature}_bin'] = pd.cut(df[f'{feature}'], bins) 
    else: 
        raise ValueError('must provide either bins or number_bins')
    enhancer_fraction = df.groupby(f'{feature}_bin')[target].mean()
    bins = enhancer_fraction.index
    bins_series = pd.Series(bins)
    midpoints = bins_series.apply(lambda x: x.mid)
    return midpoints, enhancer_fraction

def compute_midpoints_and_enhancer_fraction_wrapper(conditioned_on_complementary_features, df, number_bins, bins, feature, target, gc_window_size): 
    if conditioned_on_complementary_features: 
        conditional_features_and_center_lims = [
            (f, cl) for f, cl, _, _ in get_features_and_lims_and_labels(df, gc_window_size) if f != feature
        ]
        df_sliced = slice_feature_space(df, conditional_features_and_center_lims)
        return compute_midpoints_and_enhancer_fraction(df_sliced, number_bins, bins, feature, target)    
    else: 
        return compute_midpoints_and_enhancer_fraction(df, number_bins, bins, feature, target)    

def line_plot(areas, windows, number_bins, bins, feature, target, gc_window_size):
    x2_y2 = {}
    for conditioned_on_complementary_features in [False, True]:
        x2_y2[conditioned_on_complementary_features] = compute_midpoints_and_enhancer_fraction_wrapper(
            conditioned_on_complementary_features=conditioned_on_complementary_features, 
            df=windows, 
            number_bins=number_bins, 
            bins=bins, 
            feature=feature, 
            target=target, 
            gc_window_size=gc_window_size
        )

    areas = areas[areas['feature_bin'] != 'all']
    for conditioned_on_complementary_features in [False, True]:
        for area_type in ['PRCnorm', 'ROC']:
            areas_ = areas[
                (areas['conditioned_on_complementary_features'] == conditioned_on_complementary_features) &
                (areas['area_type'] == area_type)
            ]

            fig, axs = plt.subplots(2, 1, figsize=(10, 15), sharex=True)
            fig.subplots_adjust(hspace=0.1)

            constraint_scores = areas_['constraint_score'].unique()
            for constraint_score in constraint_scores:
                areas__ = areas_[areas_['constraint_score'] == constraint_score]
                x1 = areas__['feature_bin'].apply(lambda x: x.mid).astype(float)
                y1 = areas__['area']
                axs[0].plot(x1, y1, 'o-', label=constraint_score, linewidth=3)

            x2, y2 = x2_y2[conditioned_on_complementary_features]
            axs[1].plot(x2, y2, 'o-', linewidth=3)

            if area_type == 'PRCnorm':
                axs[0].axhline(y=1, color='black', linestyle='--', linewidth=3)
                axs[0].set_ylabel('auPRC (normalized by\npositive-class fraction)')
            elif area_type == 'ROC':
                axs[0].axhline(y=0.5, color='black', linestyle='--', linewidth=3)
                axs[0].set_ylabel('auROC')
            else: 
                raise ValueError(f'invalid area type: {area_type}')
            if conditioned_on_complementary_features:
                axs[0].set_title('conditioned on complementary features')
            else: 
                axs[0].set_title('not conditioned on complementary features')
            axs[0].legend(title='constraint score')
            
            axs[1].set_ylabel('Fraction of windows\nthat overlap enhancers')
            axs[1].set_ylim(0, 1)
            axs[1].set_xlabel(feature)

            plt.show()

def plot_area_under_curve_wrapper(gc_window_size, feature, ylim_scale, number_bins=None, bins=None, target='truly constrained'): 
    list_of_dicts = []

    list_of_dicts.extend(plot_curves_with_and_without_slicing(
        GNOCCHI_WINDOWS,
        gc_window_size, 
        feature, 
        number_bins, 
        bins,
        constraint_score='gnocchi', 
        constraint_score_alias='gnocchi',
        ylim_scale=ylim_scale,
        target=target, 
    ))

    list_of_dicts.extend(plot_curves_with_and_without_slicing(
        HALLDORSSON_WINDOWS, 
        gc_window_size, 
        feature, 
        number_bins, 
        bins,
        constraint_score='depletion_rank_constraint_score_complement',
        constraint_score_alias='depletion rank',
        ylim_scale=ylim_scale,
        target=target, 
    ))

    list_of_dicts.extend(plot_curves_with_and_without_slicing(
        CDTS_WINDOWS, 
        gc_window_size, 
        feature, 
        number_bins, 
        bins,
        constraint_score='percentile_rank_of_observed_minus_expected_complement',
        constraint_score_alias='CDTS',
        ylim_scale=ylim_scale,
        target=target, 
    ))

    list_of_dicts.extend(plot_curves_with_and_without_slicing(
        LAMBDA_S_WINDOWS, 
        gc_window_size, 
        feature, 
        number_bins, 
        bins,
        constraint_score='lambda_s',
        constraint_score_alias='lambda_s',
        ylim_scale=ylim_scale,
        target=target, 
    ))

    areas = pd.DataFrame(list_of_dicts)

    line_plot( 
        areas, 
        HALLDORSSON_WINDOWS,
        number_bins, 
        bins, 
        feature, 
        target,
        gc_window_size
    )

In [None]:
# TODO: [WED Aug 7th] try fewer and wider bins
# if that works, review code from just after "Features that negatively impact a constraint-score-based classifier" to here 
plot_area_under_curve_wrapper(
    gc_window_size=1000,
    feature='GC_content_1000bp',
    ylim_scale=3,
    bins=pd.IntervalIndex.from_tuples([(0.2, 0.3), (0.3, 0.4), (0.4, 0.5), (0.50, 0.55), (0.55, 0.60), (0.60, 0.65), (0.65, 0.70), (0.7, 0.8)]), 
)

## TODO 

In [None]:
# TODO 

# Note that I've already assessed performance for Gnocchi and BGS on a stringent truth set: 
# experiments/germline-model/chen-et-al-2022/assess-impact-of-BGS-on-Gnocchi-predictions-at-labeled-enhancers.ipynb

# So compare the results obtained in this notebook, with those earlier results 
