In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os

# set env flags to catch BLAS used for scipy/numpy 
# to only use 1 cpu, n_cpus will be totally controlled by csky
if False:
    os.environ['MKL_NUM_THREADS'] = "1"
    os.environ['NUMEXPR_NUM_THREADS'] = "1"
    os.environ['OMP_NUM_THREADS'] = "1"
    os.environ['OPENBLAS_NUM_THREADS'] = "1"
    os.environ['VECLIB_MAXIMUM_THREADS'] = "1"

import matplotlib as mpl
mpl.rcParams['figure.facecolor'] = 'w'
mpl.rcParams['savefig.facecolor'] = 'w'
import matplotlib.pyplot as plt
from matplotlib import colors, cm
import csky as cy
from csky import cext
import numpy as np
import astropy
#from icecube import astro
from tqdm.notebook import tqdm_notebook as tqdm
import histlite as hl
import healpy
import healpy as hp
import socket
import pickle
from scipy import stats
import copy
healpy.disable_warnings()
plt.rc('figure', facecolor = 'w')
plt.rc('figure', dpi=100)

## Define Settings

In [None]:
selection_version = 'version-001-p00'

host_name = socket.gethostname()

if 'cobalt' in host_name:
    print('Working on Cobalts')
    #data_prefix = '/data/user/ssclafani/data/cscd/final'
    #ana_dir = '/data/user/ssclafani/data/analyses/'
    plot_dir = cy.utils.ensure_dir('/data/user/mhuennefeld/data/analyses/DNNCascadeCodeReview/unblinding_checks/plots/unblinding/confidence_intervals')
    
else:
    raise ValueError('Unknown host:', host_name)

In [None]:
for dir_path in [plot_dir]:
    if not os.path.exists(dir_path):
        print('Creating directory:', dir_path)
        os.makedirs(dir_path)

## Load Data

In [None]:
repo = cy.selections.Repository()
specs = cy.selections.DNNCascadeDataSpecs.DNNC_10yr

In [None]:
%%time

ana = cy.get_analysis(
    repo, selection_version, specs, 
    #gammas=np.r_[0.1:6.01:0.125],
)

In [None]:
a = ana.anas[0]
a.sig

In [None]:
a.bg_data

## Helpers

In [None]:
from cycler import cycle
from copy import deepcopy

soft_colors = cy.plotting.soft_colors
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']


def get_bias_allt(tr, n_trials=200, n_sigs=np.r_[:101:10], quiet=False):
    trials = [
        (None if quiet else print(f'\r{n_sig:4d} ...', end='', flush=True))
        or
        tr.get_many_fits(n_trials, n_sig=n_sig, logging=False, seed=n_sig)
        for n_sig in n_sigs]
    if not quiet:
        print()
    for (n_sig, t) in zip(n_sigs, trials):
        t['ntrue'] = np.repeat(n_sig, len(t))
    allt = cy.utils.Arrays.concatenate(trials)
    return allt

def get_color_cycler():
    return cycle(colors)

def plot_ns_bias(ax, tr, allt, label=''):

    n_sigs = np.unique(allt.ntrue)
    dns = np.mean(np.diff(n_sigs))
    ns_bins = np.r_[n_sigs - 0.5*dns, n_sigs[-1] + 0.5*dns]
    expect_kw = dict(color='C0', ls='--', lw=1, zorder=-10)

    h = hl.hist((allt.ntrue, allt.ns), bins=(ns_bins, 100))
    hl.plot1d(ax, h.contain_project(1),errorbands=True, 
              drawstyle='default', label=label)
    lim = ns_bins[[0, -1]]
    ax.set_xlim(ax.set_ylim(lim))
    ax.plot(lim, lim, **expect_kw)
    ax.set_aspect('equal')

    ax.set_xlabel(r'$n_{inj}$')
    ax.set_ylabel(r'$n_s$')
    ax.grid()

def plot_gamma_bias(ax, tr, allt, label=''):

    n_sigs = np.unique(allt.ntrue)
    dns = np.mean(np.diff(n_sigs))
    ns_bins = np.r_[n_sigs - 0.5*dns, n_sigs[-1] + 0.5*dns]
    expect_kw = dict(color='C0', ls='--', lw=1, zorder=-10)
    expect_gamma = tr.sig_injs[0].flux[0].gamma

    h = hl.hist((allt.ntrue, allt.gamma), bins=(ns_bins, 100))
    hl.plot1d(ax, h.contain_project(1),errorbands=True, 
              drawstyle='default', label=label)
    lim = ns_bins[[0, -1]]
    ax.set_xlim(lim)
    ax.set_ylim(1, 4)
    ax.axhline(expect_gamma, **expect_kw)

    ax.set_xlabel(r'$n_{inj}$')
    ax.set_ylabel(r'$\gamma$')
    ax.grid()

def plot_bkg_trials(
            bg, fig=None, ax=None, 
            label='{} bg trials', 
            label_fit=r'$\chi^2[{:.2f}\mathrm{{dof}},\ \eta={:.3f}]$', 
            color=colors[0],
            density=False,
            bins=50,
        ):
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4))
    
    if density:
        h = bg.get_hist(bins=bins).normalize()
    else:
        h = bg.get_hist(bins=bins)
    if label is not None:
        label = label.format(bg.n_total)
    hl.plot1d(ax, h, crosses=True, color=color, label=label)

    # compare with the chi2 fit:
    if hasattr(bg, 'pdf'):
        x = h.centers[0]
        norm = h.integrate().values
        if label_fit is not None:
            label_fit = label_fit.format(bg.ndof, bg.eta)
        if density:
            ax.semilogy(x, bg.pdf(x), lw=1, ls='--', label=label_fit, color=color)
        else:
            ax.semilogy(x, norm * bg.pdf(x), lw=1, ls='--', label=label_fit, color=color)

    ax.set_xlabel(r'TS')
    if density:
        ax.set_ylabel(r'Density')
    else:
        ax.set_ylabel(r'number of trials')
    ax.legend()
        
    return fig, ax

## Setup Analysis

In [None]:
import sys
sys.path.insert(0, '../..')

import config as cg

cg.base_dir = '/data/user/mhuennefeld/data/analyses/unblinding_v1.0.0/'

In [None]:
def get_gp_tr(template_str, cutoff=np.inf, gamma=None, cpus=20):
    cutoff_GeV = cutoff * 1e3
    gp_conf = cg.get_gp_conf(
        template_str=template_str, gamma=gamma, 
        cutoff_GeV=cutoff_GeV, base_dir=cg.base_dir)
    tr = cy.get_trial_runner(gp_conf, ana=ana, mp_cpus=cpus)
    return tr

def get_template_tr(template, gamma=2.7, cutoff_tev=np.inf, cpus=20):
    cutoff_gev = cutoff_tev * 1000.
    gp_conf = {
        'template': template,
        'flux': cy.hyp.PowerLawFlux(gamma, energy_cutoff=cutoff_gev),
        'randomize': ['ra'],
        'fitter_args': dict(gamma=gamma),
        'sigsub': True,
        'update_bg': True,
        'fast_weight': False,
    }
    tr = cy.get_trial_runner(gp_conf, ana=ana, mp_cpus=cpus)
    return tr

def get_gp_tr_sys(ana, template_str, cutoff=np.inf, gamma=None, cpus=20, sigmas=[0]):
    cutoff_GeV = cutoff * 1e3
    gp_conf = cg.get_gp_conf(
        template_str=template_str, gamma=gamma, 
        cutoff_GeV=cutoff_GeV, base_dir=cg.base_dir)
    gp_conf.pop('dir')
    gp_conf['sigmas'] = sigmas
    tr = cy.get_trial_runner(gp_conf, ana=ana, mp_cpus=cpus)
    return tr

#### SnowStorm Systematics

In [None]:
import pandas as pd

df_dir = '/data/ana/PointSource/DNNCascade/analysis/{}/'.format(selection_version)
df = pd.read_hdf(
    df_dir + '/systematics/SnowStorm_Spice321/MC_NuGen_snowstorm_214xx.hdf', key='df',
)
df = df[['SnowstormParameters_{:05d}'.format(i) for i in range(6)] + ['run', 'energy', 'ow']]

In [None]:
import pandas as pd
from copy import deepcopy
from IPython.utils import io

sim_ranges = {
    'Scattering': [0.9, 1.1],
    'Absorption': [0.9, 1.1],
    'AnisotropyScale': [0., 2.],
    'DOMEfficiency': [0.9, 1.1],
    'HoleIceForward_Unified_00': [-1.0, 1.0],
    'HoleIceForward_Unified_01': [-0.2, 0.2],
}

allowed_ranges = {
    'Scattering': [0.9, 1.1],
    'Absorption': [0.9, 1.1],
    'AnisotropyScale': [0., 2.],
    'DOMEfficiency': [0.9, 1.1],
    
    # slightly increase range from recommendation to not have too little stats
    'HoleIceForward_Unified_00': [-0.75, 0.45], #[-0.5, 0.3],
    'HoleIceForward_Unified_01': [-0.15, 0.075], #[-0.1, 0.05],
}


def get_snowstorm_ana(sys_ranges, sim_ranges=sim_ranges):
    
    # define SnowStorm dataset with reduced range
    class DNNCascade_10yr_sys_reduced(cy.selections.DNNCascadeDataSpecs.DNNCascade_10yr_snowstorm_fullrange):
        def dataset_modifications(self, ds):
            print('Adding SnowStorm Parameters to MC')
            path_sig_df = (
                '/data/ana/PointSource/DNNCascade/analysis/' + 
                self._path_sig.format(version=self._version).replace('dnn_cascades/', '').replace('.npy', '.hdf')
            )
            # (use global df to avoid loading multiple times)
            #if df is None:
            #    df = pd.read_hdf(path_sig_df, key='df')
            assert np.allclose(df['run'], ds.sig.run)
            assert np.allclose(df['energy'], ds.sig.energy)
            assert np.allclose(df['ow'], ds.sig.oneweight)

            # load and rename SnowStorm parameters
            parameter_names=[
                'Scattering', 'Absorption', 'AnisotropyScale', 
                'DOMEfficiency', 'HoleIceForward_Unified_00', 
                'HoleIceForward_Unified_01',
            ]
            for i, param in enumerate(parameter_names):
                ds.sig[param] = np.array(df['SnowstormParameters_{:05d}'.format(i)])
            
            print('Reducing Dataset')
            mask = np.ones(len(ds.sig), dtype=bool)
            factor = 1.
            for param, sys_range in sys_ranges.items():
                if sys_range != sim_ranges[param]:
                    assert sys_range[0] < sys_range[1], sys_range
                    factor *= (sys_range[1] - sys_range[0]) / (sim_ranges[param][1] - sim_ranges[param][0])
                    mask_i = np.logical_and(
                        ds.sig[param] >= sys_range[0],
                        ds.sig[param] < sys_range[1],
                    )
                    mask = np.logical_and(mask, mask_i)
            
            print('Reduction factor: {:3.3f}'.format(factor))
            ds.sig = ds.sig._subsample(mask)
            ds.sig.oneweight[:] = ds.sig.oneweight/factor
            
    ana_sys = cy.get_analysis(
        cy.selections.Repository(), selection_version, [DNNCascade_10yr_sys_reduced], 
        #_quiet=True,
    )
    return ana_sys

def sample_snowstorm_ranges(
            seed=None, 
            sim_ranges=sim_ranges, 
            allowed_ranges=allowed_ranges, 
            min_red_factor=0.05,
            max_k=3,
        ):
    rng = np.random.RandomState(seed)
    
    # sample number of parameters to perturb
    k = rng.randint(1, 1 + max_k)
    
    # sample which parameters to perturb
    parameter_names=[
        'Scattering', 'Absorption', 'AnisotropyScale', 
        'DOMEfficiency', 'HoleIceForward_Unified_00', 
        'HoleIceForward_Unified_01',
    ]
    params = rng.choice(parameter_names, size=k, replace=False)
    
    # compute reduction fraction from allowed range
    fractions = []
    allowed_fraction = 1.
    for param, allowed_range in allowed_ranges.items():
        if allowed_range != sim_ranges[param]:
            fraction_i = (allowed_range[1] - allowed_range[0]) / (sim_ranges[param][1] - sim_ranges[param][0])
        else:
            fraction_i = 1.
        allowed_fraction *= fraction_i
        if param in params:
            fractions.append(fraction_i)
            
    # define relative reduction fraction of allowed range
    rel_fr = np.power(min_red_factor / allowed_fraction , 1./k)
    
    # sample intervals
    sys_range = deepcopy(allowed_ranges)

    current_factor = 1.
    for param, fraction_i in zip(params, fractions):
        allowed_range = allowed_ranges[param]
        
        interval_width = (allowed_range[1] - allowed_range[0]) * rel_fr / 2.
        sample_range = [allowed_range[0] + interval_width, allowed_range[1] - interval_width]
        
        assert sample_range[1] > sample_range[0], sample_range
        
        mid_point = rng.uniform(*sample_range)
        sys_range[str(param)] = [mid_point - interval_width, mid_point + interval_width]
    
    return sys_range, params

def get_snowstorm_tr(
            template_str,
            seed=None, 
            sim_ranges=sim_ranges, 
            allowed_ranges=allowed_ranges, 
            min_red_factor=0.05,
            max_k=3,
            sigmas=[0],
        ):
    
    # sample SnowStorm parameters
    sys_ranges, params = sample_snowstorm_ranges(seed=seed, min_red_factor=min_red_factor, max_k=max_k)
    
    # get snowstorm ana object
    with io.capture_output() as captured:
        ana_sys = get_snowstorm_ana(sys_ranges=sys_ranges)

        # get trial runner
        tr_sys = get_gp_tr_sys(ana=ana_sys, template_str=template_str, sigmas=sigmas)
    
    return tr_sys

##### Test Sampling

In [None]:
n_samples = 10000
min_red_factor = 0.02
max_k = 3

mids = {k: [] for k in sim_ranges.keys()}
mids_all = {k: [] for k in sim_ranges.keys()}
for i in tqdm(range(n_samples), total=n_samples):
    sys_ranges, params = sample_snowstorm_ranges(min_red_factor=min_red_factor, max_k=max_k)
    for k, sys_range in sys_ranges.items():
        if k in params:
            mids[k].append(np.mean(sys_range))
        mids_all[k].append(np.mean(sys_range))
    

fig, axes = plt.subplots(2, 3, figsize=(9, 6))
for i, ax in enumerate(axes.flatten()):
    key = sorted(sim_ranges.keys())[i]
    ax.set_xlabel(key)
    ax.set_ylabel('Number of samples')
    ax.hist(mids[key], bins=30)
    ax.axvline(sim_ranges[key][0], color='0.3', ls='--', label='Simulation Range')
    ax.axvline(sim_ranges[key][1], color='0.3', ls='--')
    ax.axvline(allowed_ranges[key][0], color='0.7', ls='-', label='Allowed Range')
    ax.axvline(allowed_ranges[key][1], color='0.7', ls='-')
axes[0, 0].legend()
fig.suptitle('Min Reduction: {:1.3f} | Max k: {}'.format(min_red_factor, max_k))
fig.tight_layout()
fig.savefig('{}/snowstorm_sampling_check.png'.format(plot_dir))


In [None]:
sys_ranges, params = sample_snowstorm_ranges(min_red_factor=0.02)
sys_ranges, params

In [None]:
%%time

ana_sys = get_snowstorm_ana(
    #sys_ranges={
    #    'Scattering': [1.0, 1.1],
    #    'Absorption': [0.9, 1.0],
    #    'AnisotropyScale': [0., 1.],
    #},
    sys_ranges=sys_ranges,
    #sys_ranges=sample_snowstorm_ranges(),
)
tr_sys = get_gp_tr_sys(
    ana=ana_sys, template_str='pi0', 
    #sigmas=np.radians(np.r_[3:20, 20:40:2, 40:60:4, 60:91:5]),
)
print(len(ana_sys.anas[0].sig)/len(df), len(ana_sys.anas[0].sig))
print('ana', np.sum(a.sig.oneweight * a.sig.true_energy**-2.5))
print('sys', np.sum(ana_sys.anas[0].sig.oneweight * ana_sys.anas[0].sig.true_energy**-2.5))


In [None]:
len(ana_sys.anas[0].sig) / len(ana.anas[0].sig)

#### Get TrialRunners

In [None]:
tr_dict = {
    'pi0': get_gp_tr('pi0'),
    'kra5': get_gp_tr('kra5'),
    'kra50': get_gp_tr('kra50'),
}

#### Get Results for each template

In [None]:
res_dict = {}
for key in tr_dict.keys():
    f_path = os.path.join(
        cg.base_dir, 
        'gp/results/{}/{}_unblinded.npy'.format(key, key), 
    )
    res_dict[key] = np.load(f_path)

#### Get ns-bias correction

In [None]:
use_poisson = False
add_sys = False

if add_sys:
    sys_suffix = '_sys'
else:
    sys_suffix = ''
    
if use_poisson:
    ns_bias_file = os.path.join(plot_dir, 'ns_bias_poisson{}.pkl'.format(sys_suffix))
else:
    ns_bias_file = os.path.join(plot_dir, 'ns_bias{}.pkl'.format(sys_suffix))


def get_ns_bias_dict_from_sys(ns_bias_dict_sys):
    ns_bias_dict = {}
    for key, ns_trials_dict in ns_bias_dict_sys.items():

        # get ns_values
        ns_values = [sorted(ns_trials_dict[trial_i].keys()) for trial_i in ns_trials_dict.keys()]
        for ns_values_i in ns_values:
            assert np.allclose(ns_values_i, ns_values[0])
        ns_values = ns_values[0]

        trials_dict = {ns: [] for ns in ns_values}
        for trial_i in sorted(ns_trials_dict.keys()):
            for ns in ns_values:
                trials_dict[ns].append(ns_trials_dict[trial_i][ns])

        ns_bias_dict[key] = {
            ns: cy.utils.Arrays.concatenate(trials_dict[ns]) for ns in ns_values
        }
    return ns_bias_dict

if os.path.exists(ns_bias_file):
    print('Loading from file')
    if add_sys:
        with open(ns_bias_file, 'rb') as handle:
            ns_bias_dict_sys = pickle.load(handle)
        
        # restructure dict
        ns_bias_dict = get_ns_bias_dict_from_sys(ns_bias_dict_sys)
    else:
        with open(ns_bias_file, 'rb') as handle:
            ns_bias_dict = pickle.load(handle)
        
        ns_bias_dict_sys = {}
else:
    print('Creating new dict')
    ns_bias_dict = {}
    ns_bias_dict_sys = {}
    

In [None]:
from multiprocessing import pool

ns_bias_range = {
    'kra5': [0, 600],
    'kra50': [0, 600],
    'pi0': [0, 1500],
}
n_trials = 100
recalculate = False
cpus = 20

for key, ns_range in ns_bias_range.items():
    
    for dictionary in [ns_bias_dict, ns_bias_dict_sys]:
        if key not in dictionary:
            dictionary[key] = {}
        
    print('Submitting values for {} from {} to {}'.format(key, *ns_range))
    
    if add_sys:
        for trial_i in tqdm(range(n_trials), total=n_trials):
            
            if trial_i not in ns_bias_dict_sys[key] or recalculate:
                ns_bias_dict_sys[key][trial_i] = {}
                
                print('Getting trial runner for {}'.format(key))
                tr = get_snowstorm_tr(
                    key, seed=trial_i, 
                    sigmas=np.radians(np.r_[3:20, 20:40:2, 40:60:4, 60:91:5]),
                )
                
                print('Starting pool with {} cpus'.format(cpus))
                def compute_trial_i(ns):
                    trials = tr.get_many_fits(1, n_sig=ns, logging=False, seed=ns, poisson=use_poisson, cpus=1)
                    trials['ntrue'] = np.repeat(ns, len(trials))
                    return trials
                    
                arg_list = list(range(*ns_range))
                with Pool(cpus) as p:
                    trials = list(tqdm(p.imap(compute_trial_i, arg_list), total=len(arg_list)))
                
                for j, ns in enumerate(range(*ns_range)):
                    ns_bias_dict_sys[key][trial_i][ns] = trials[j]

                with open(ns_bias_file, 'wb') as f:
                    pickle.dump(ns_bias_dict_sys, f, protocol=-1)
    
    else:
        tr = tr_dict[key]
        
        for ns in tqdm(range(*ns_range), total=len(range(*ns_range))):

            if ns not in ns_bias_dict[key] or recalculate:
                trials = tr.get_many_fits(n_trials, n_sig=ns, logging=False, seed=ns, cpus=cpus, poisson=use_poisson)
                trials['ntrue'] = np.repeat(ns, len(trials))

                ns_bias_dict[key][ns] = trials

                with open(ns_bias_file, 'wb') as f:
                    pickle.dump(ns_bias_dict, f, protocol=-1)

if add_sys:
    ns_bias_dict = get_ns_bias_dict_from_sys(ns_bias_dict_sys)

In [None]:
from scipy.interpolate import UnivariateSpline

bias_corr_funcs = {}

for key, ns_bias_dict_i in ns_bias_dict.items():
    
    # create bias correction function
    x = []
    y = []
    for n_inj in sorted(ns_bias_dict_i.keys()):
        x.append(n_inj)
        y.append(np.median(ns_bias_dict_i[n_inj].ns))
    bias_corr_funcs[key] = UnivariateSpline(x=x, y=y, s=len(ns_bias_dict_i)*500)
    
    fig, ax = plt.subplots(figsize=(9, 6))
    tr = tr_dict[key]
    
    allt = cy.utils.Arrays.concatenate([t for t in ns_bias_dict_i.values()])
    plot_ns_bias(ax=ax, tr=tr, allt=allt)
    ax.plot(x, bias_corr_funcs[key](x), label='Spline Fit')
    ax.set_title('Model: {}'.format(key))
    ax.set_xlabel('$n_\mathrm{inj}$')
    ax.set_ylabel('$n_s$')
    ax.legend()
    fig.savefig('{}/ns_bias_{}{}.png'.format(plot_dir, key, sys_suffix))
    

In [None]:
fig, ax = plt.subplots()
x = np.linspace(0, 1500, 1000)
for key, func in bias_corr_funcs.items():
    ax.plot(x, func(x), label=key)
ax.plot(x, x, color='0.7', ls='--')
ax.set_xlabel('$n_\mathrm{inj}$')
ax.set_ylabel('$n_s$')
ax.legend()
fig.savefig('{}/ns_bias_comparison{}.png'.format(plot_dir, sys_suffix))


#### Get Critical Values

In [None]:
from multiprocessing import Pool
from tqdm.notebook import tqdm_notebook as tqdm


def model_norm_to_ns(tr, model_norm, correction_factor=1.5):
    return model_norm * tr.sig_inj_acc_total / correction_factor
    
def get_critical_value_trial(
            E2dNdE_or_modelnorm, tr, tr_inj, bias_corr_func=None,
            E0=100, unit=1e3, seed=None, TRUTH=False, is_model_norm=False,
        ):
    
    # get number of ns corresponding to flux
    if is_model_norm:
        n_sig = model_norm_to_ns(tr=tr_inj, model_norm=E2dNdE_or_modelnorm)
    else:
        n_sig = tr_inj.to_ns(E2dNdE_or_modelnorm, E0=E0, unit=unit)
    
    if TRUTH:
        n_inj = 0
    else:
        n_inj = n_sig
    
    # get trial
    trial = tr_inj.get_one_trial(n_sig=n_inj, poisson=True, seed=seed, TRUTH=TRUTH)
    
    # get best fit ts and ns for this trial
    fit = tr.get_one_fit_from_trial(trial)
    ts_fit, ns_fit = fit
    
    # apply bias correction for tested nsig?
    if bias_corr_func is not None:
        n_sig = bias_corr_func(n_sig)
        
    # get Likelihood object
    L = tr.get_one_llh_from_trial(trial)
    ts_test = L.get_ts(ns=n_sig, **tr.fitter_args)
    
    # compute test-statistic tau for critical value definition
    # tau = -2 log llh-ratio = - 2 log {L_0(ns_test) / L_1(ns=n_fit)}
    # In this case, we want to test against ns_test = ns(E2dNdE)
    # tau = (-2 log LR(ns=0) - (-2 log LR(ns=ns_test))
    #     = -2 log L(ns=0) + 2 log L(ns=n_fit) + 2 log L(ns=0) - 2 log L(ns=ns_test)
    #     = -2 log L(ns=ns_test) + 2 log L(ns=ns_fit)
    #     = -2 log {L(ns_ns_test) / L(ns=ns_fit)}
    tau = ts_fit - ts_test
    
    return tau

# ------------------------------------------------------------------------------------
# define global functions for multiprocessing (pickle has issues with local functions)
# ------------------------------------------------------------------------------------
def run_trial_pi0(args):
    i, E2dNdEs, E0, unit, bias_corr_func, sys_seed, min_red_factor, max_k = args
    if sys_seed is None:
        tr_inj = tr_dict['pi0']
    else:
        tr_inj = get_snowstorm_tr('pi0', seed=sys_seed, min_red_factor=min_red_factor, max_k=max_k)
    
    tau_values = [] 
    for E2dNdE in E2dNdEs:
        tau_values_i = get_critical_value_trial(E2dNdE_or_modelnorm=E2dNdE, tr=tr_dict['pi0'], tr_inj=tr_inj, bias_corr_func=bias_corr_func, E0=E0, unit=unit, seed=i)
        tau_values.append(tau_values_i)
    return tau_values

def run_trial_kra5(args):
    i, modelnorms, E0, unit, bias_corr_func, sys_seed, min_red_factor, max_k = args
    if sys_seed is None:
        tr_inj = tr_dict['kra5']
    else:
        tr_inj = get_snowstorm_tr('kra5', seed=sys_seed, min_red_factor=min_red_factor, max_k=max_k)
    
    tau_values = [] 
    for modelnorm in modelnorms:
        tau_values_i = get_critical_value_trial(E2dNdE_or_modelnorm=modelnorm, tr=tr_dict['kra5'], tr_inj=tr_inj, bias_corr_func=bias_corr_func, E0=E0, unit=unit, seed=i, is_model_norm=True)
        tau_values.append(tau_values_i)
    return tau_values

def run_trial_kra50(args):
    i, modelnorms, E0, unit, bias_corr_func, sys_seed, min_red_factor, max_k = args
    if sys_seed is None:
        tr_inj = tr_dict['kra50']
    else:
        tr_inj = get_snowstorm_tr('kra50', seed=sys_seed, min_red_factor=min_red_factor, max_k=max_k)
    
    tau_values = [] 
    for modelnorm in modelnorms:
        tau_values_i = get_critical_value_trial(E2dNdE_or_modelnorm=modelnorm, tr=tr_dict['kra50'], tr_inj=tr_inj, bias_corr_func=bias_corr_func, E0=E0, unit=unit, seed=i, is_model_norm=True)
        tau_values.append(tau_values_i)
    return tau_values

function_dict = {
    'pi0': run_trial_pi0,
    'kra5': run_trial_kra5,
    'kra50': run_trial_kra50,
}
# ------------------------------------------------------------------------------------

def run_critical_value_trials(
            n_trials, E2dNdE_or_modelnorm_list, key, add_systematics=False, 
            bias_corr_funcs=None, min_red_factor=0.02, max_k=3,
            E0=100, unit=1e3, seed=0, cpus=20,
        ):
    
    if bias_corr_funcs is not None:
        bias_corr_func = bias_corr_funcs[key]
    else:
        bias_corr_func = None
        
    tau_values = [[] for norm in E2dNdE_or_modelnorm_list]
    seed_values = list(range(seed, seed + n_trials))
    
    if add_systematics:
        arg_list = [(i, E2dNdE_or_modelnorm_list, E0, unit, bias_corr_func, i, min_red_factor, max_k) 
                    for i in seed_values]
    else:
        arg_list = [(i, E2dNdE_or_modelnorm_list, E0, unit, bias_corr_func, None, min_red_factor, max_k) 
                    for i in seed_values]

    compute_trial_i = function_dict[key]
    
    if cpus > 1:
        print('Running pool with {} cpus'.format(cpus))
        
        with Pool(cpus) as p:
            tau_values_map = list(tqdm(p.imap(compute_trial_i, arg_list), total=n_trials))
        print('tau_values_map.shape', np.array(tau_values_map).shape)
        for j, tau_values_i in enumerate(tau_values_map):
            for i, values in enumerate(tau_values_i):
                tau_values[i].append(values)
        p.close()
    else:
        for args in tqdm(arg_list, total=n_trials):
            tau_values_i = compute_trial_i(args)
            for i, values in enumerate(tau_values_i):
                tau_values[i].append(values)
            
    return np.array(tau_values)



In [None]:
%%time

seed = 4000
n_trials = 1000
apply_correction = True
add_systematics = False
min_red_factor = 0.02
max_k = 3

cpus = 20
recalculate = False

E2dNdE_or_modelnorm_dict = {
    'pi0': np.linspace(1e-11, 4e-11, 25),
    'kra5': np.linspace(0.3, 1.7, 25),
    'kra50': np.linspace(0.2, 1.2, 25),
}

    
tau_dict = {}
for key in ['pi0', 'kra5', 'kra50']:
#for key in ['pi0']:
        
    print('Running {} trials for {} with {} different normalizations'.format(n_trials, key, len(E2dNdE_or_modelnorm_dict[key])))
    if apply_correction:
        print('Applying correction')
        bias_corr_funcs_kw = bias_corr_funcs
    else:
        bias_corr_funcs_kw = None
    print('Adding Systematic:', add_systematics)
    
    if add_systematics:
        sys_str = '{}_red_{:0.3f}_k_{}'.format(add_systematics, min_red_factor, max_k)
    else:
        sys_str = '{}'.format(add_systematics)
        
    file_path = os.path.join(plot_dir, 'trials_{}_corr_{}_sys_{}_seeds_{}_{}.pkl'.format(
        key, apply_correction, sys_str, seed, seed+n_trials))
    
    if not os.path.exists(file_path) or recalculate:
        tau_values = run_critical_value_trials(
            n_trials, seed=seed, E2dNdE_or_modelnorm_list=E2dNdE_or_modelnorm_dict[key], key=key, add_systematics=add_systematics, cpus=cpus, bias_corr_funcs=bias_corr_funcs_kw)
        
        # save trials
        with open(file_path, 'wb') as f:
            seeds = list(range(seed, seed+n_trials))
            pickle.dump((E2dNdE_or_modelnorm_dict, tau_values, seeds), f, protocol=-1)
    else:
        print('Skipping because file already exists...')
        


#### Load trials

In [None]:
import glob

tau_dict = {}
for key in ['pi0', 'kra5', 'kra50']:
#for key in ['pi0']:
    
    print('Loading trials for {} with {} different normalizations'.format(key, len(E2dNdE_or_modelnorm_dict[key])))
    key_s = (key, apply_correction)
    
    if key_s not in tau_dict:
        tau_dict[key_s] = {norm: [] for norm in E2dNdE_or_modelnorm_dict[key]}
    
    # find a list of files
    if add_systematics:
        sys_str = '{}_red_{:0.3f}_k_{}'.format(add_systematics, min_red_factor, max_k)
    else:
        sys_str = '{}'.format(add_systematics)
        
    file_pattern = os.path.join(plot_dir, 'trials_{}_corr_{}_sys_{}_seeds_*_*.pkl'.format(
        key, apply_correction, sys_str))
    file_list = sorted(glob.glob(file_pattern))
    print('Found {} files...'.format(len(file_list)))
    
    # load files and check for overlapping seeds
    seed_values = set([])
    for file_i in file_list:
        with open(file_i, 'rb') as handle:
            E2dNdE_or_modelnorm_dict_loaded, tau_values, seeds = pickle.load(handle)
        
        # make sure model norms match
        assert sorted(E2dNdE_or_modelnorm_dict_loaded.keys()) == sorted(E2dNdE_or_modelnorm_dict.keys())
        for k, norms in E2dNdE_or_modelnorm_dict_loaded.items():
            assert np.allclose(norms, E2dNdE_or_modelnorm_dict[k]), (norms, E2dNdE_or_modelnorm_dict[k])

        # make sure seeds do not overlap
        overlapping_seeds = seed_values.intersection(set(seeds))
        if overlapping_seeds:
            raise ValueError('Found overlapping seeds: {}!'.format(overlapping_seeds))
        seed_values = seed_values.union(set(seeds))
        
        # append tau values from this file
        for i, E2dNdE_or_modelnorm in enumerate(E2dNdE_or_modelnorm_dict[key]):
            tau_dict[key_s][E2dNdE_or_modelnorm].append(tau_values[i])
    
    # concatenate into single array
    for i, E2dNdE_or_modelnorm in enumerate(E2dNdE_or_modelnorm_dict[key]):
        tau_dict[key_s][E2dNdE_or_modelnorm] = np.concatenate(tau_dict[key_s][E2dNdE_or_modelnorm])

#### Make critical value plot

In [None]:
import matplotlib
from itertools import cycle
from scipy.interpolate import UnivariateSpline
from scipy import optimize

    
def make_critical_value_plot(
            key, tau_dict, confidence_levels=[0.68, 0.9, 0.95], 
            norm_bins=np.linspace(1e-11, 4e-11, 25), 
            tau_bins=np.linspace(0., 7, 30),
            bias_corr_funcs=None,
            E0=100, unit=1e3,
            ls_list=['-', '--', '-.'],
            plot_splines=False,
            n_eval_points=100,  
        ):
    if bias_corr_funcs is None:
        print('Not applying correction')
        apply_correction = False
        bias_corr_func_kw = None
    else:
        print('Applying correction')
        apply_correction = True
        bias_corr_func_kw = bias_corr_funcs[key]
        
    key_s = (key, apply_correction)
            
    tau_values = []
    norm_values = []
    for norm in sorted(tau_dict[key_s].keys()):
        tau_values.append(tau_dict[key_s][norm])
        norm_values.append(norm)
    
    tau_values = np.array(tau_values)
    norm_values = np.array(norm_values)
    
    if norm_bins is None:
        assert np.allclose(np.diff(norm_values), np.diff(norm_values)[0]) 
        width = np.diff(norm_values)[0] / 2.
        norm_bins = np.r_[norm_values - width, np.max(norm_values) + width]
    else:
        norm_bins = np.array(norm_bins)
    
    eps = 0.01 * np.diff(norm_bins)[0]
    norm_bins[0] -= eps
    norm_bins[-1] += eps
    
    n_bins = len(norm_bins) - 1
    n_taus = len(tau_bins) - 1
    
    if key in ['kra5', 'kra50']:
        is_model_norm = True
    else:
        is_model_norm = False
            
    # get tau values for observed data
    tau_observed = []
    norm_values_obs = np.linspace(norm_bins[0], norm_bins[-1], n_eval_points)
    for norm in tqdm(norm_values_obs, total=len(norm_values_obs)):
        tau_observed.append(get_critical_value_trial(
            E2dNdE_or_modelnorm=norm, tr=tr_dict[key], tr_inj=tr_dict[key], E0=E0, unit=unit, seed=42, TRUTH=True, 
            is_model_norm=is_model_norm, bias_corr_func=bias_corr_func_kw,
        ))
    
    # create histogram
    hist = np.zeros((n_bins, n_taus)) * np.nan
    critical_value_dict = {c: [] for c in confidence_levels}
    for norm, taus in zip(norm_values, tau_values):
        idx = np.searchsorted(norm_bins, norm) - 1
        assert idx >= 0 and idx <= n_bins - 1, idx
        
        hist_col, _ = np.histogram(taus, bins=tau_bins, density=True)
        hist[idx] = hist_col
        
        for c in confidence_levels:
            critical_value_dict[c].append(np.quantile(taus, q=c))
            
    # ---------------------
    # Compute best fit flux
    # ---------------------
    spl_observed = UnivariateSpline(norm_values_obs, tau_observed)
    norm_obs_best = optimize.minimize(
        spl_observed, x0=norm_values_obs[np.argmin(tau_observed)],
        bounds=[(np.min(norm_values_obs), np.max(norm_values_obs))],
    ).x[0]
    
    # ---------------------------
    # compute intersection points
    # ---------------------------
    
    tau_observed_tested = []
    for norm in tqdm(norm_values, total=len(norm_values)):
        tau_observed_tested.append(get_critical_value_trial(
            E2dNdE_or_modelnorm=norm, tr=tr_dict[key], tr_inj=tr_dict[key], E0=E0, unit=unit, seed=42, TRUTH=True,
            is_model_norm=is_model_norm, bias_corr_func=bias_corr_func_kw,
        ))
    tau_observed_tested = np.array(tau_observed_tested)
    
    cf_mask_dict = {}
    intervalls_binned = {}
    for c, critical_values in critical_value_dict.items():
        mask = np.array(critical_values) > tau_observed_tested
        intervalls_binned[c] = [np.min(norm_values[mask]), np.max(norm_values[mask])]
    
    # compute intersection points based on fitted splines      
    c_splines_dict = {}
    for c, c_values in critical_value_dict.items():
        # fit spline to critical values
        spl = UnivariateSpline(norm_values, c_values)
        c_splines_dict[c] = spl
    
    def find_intersection(c, x0):
        """Find intersection points based on splines"""
        def fun(x):
            return spl_observed(x) - c_splines_dict[c](x)
        sol = optimize.root(fun, x0=x0)
        return np.sort(sol.x)
    
    intervalls_fitted = {}
    for c, x0 in intervalls_binned.items():
        interval = find_intersection(c, x0=x0)
        assert len(interval) == 2, interval
        intervalls_fitted[c] = interval
    # ---------------------------
        
    # make plot
    if is_model_norm:
        units = 1
    else:
        units = 1e-11
    fig, ax = plt.subplots(figsize=(9, 6))
    X, Y = np.meshgrid(norm_bins / units, tau_bins)
    im = ax.pcolor(X, Y, hist.T, norm=matplotlib.colors.LogNorm())
    
    ax.plot(
        norm_values_obs / units, tau_observed, color='1.0', 
        label=r'Observed $\tau(\mathrm{flux})$' + '[Minimum at: {:3.2e}]'.format(norm_obs_best),
    )
    
    if plot_splines:
        ax.plot(
            norm_values_obs / units, spl_observed(norm_values_obs), ls='--', lw=2., color='0.3',
            label=r'Observed $\tau(\mathrm{flux})$ [Spline-Fit]',
        )
        ls_cycler = cycle(ls_list)
        for c, critical_values in critical_value_dict.items():
            ax.plot(
                norm_values_obs / units, c_splines_dict[c](norm_values_obs), 
                color='1.0', lw=2., ls=next(ls_cycler),
                label='Critical values ({:3.1f}% [Spline-Fit]'.format(c * 100),
            )
            
    ls_cycler = cycle(ls_list)
    for c, critical_values in critical_value_dict.items():
        ax.plot(
            norm_values / units, critical_values, color='r', ls=next(ls_cycler),
            label='Critical values ({:3.1f}% CL [{:3.2e}, {:3.2e}])'.format(c * 100, *intervalls_fitted[c]),
        )
    
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label(r'$\log_{10} \, P(\tau | \mathrm{flux})$')
    if units == 1e-11:
        units_label = ' $\cdot 10^{-11}$'
    elif units == 1.:
        units_label = ''
    else:
        units_label = ' $\cdot {:.0e}$'.format(units)
    
    if is_model_norm:
        ax.set_xlabel('Model Normalization')
    else:
        ax.set_xlabel(
            '$\mathrm{E}^2 \cdot \mathrm{dN/dE}$'+ units_label + ' at {:.0f} TeV'.format(E0)  + 
            ' [$\mathrm{TeV} \, \mathrm{s}^{-1} \, \mathrm{cm}^{-2}$]')
    ax.set_ylabel(r'$P(\tau | \mathrm{flux})$ (test-statistic $\tau$)')
    ax.legend()
    ax.set_ylim(np.min(tau_bins), np.max(tau_bins))
    
    return fig, ax, norm_obs_best, intervalls_binned, intervalls_fitted
    


In [None]:
if apply_correction:
    bias_corr_funcs_kw = bias_corr_funcs
    file_suffix = '_corrected'
else:
    bias_corr_funcs_kw = None
    file_suffix = ''

if add_systematics:
    file_suffix += '_sys_red_{:1.3f}_k_{}'.format(min_red_factor, max_k)

for key, _ in tau_dict.keys():
    fig, ax, norm_obs_best, intervalls_binned, intervalls_fitted = make_critical_value_plot(
        key, tau_dict=tau_dict, plot_splines=False, norm_bins=E2dNdE_or_modelnorm_dict[key], 
        bias_corr_funcs=bias_corr_funcs_kw)
    ax.set_title('Model: {}'.format(key))
    fig.savefig('{}/confidence_intervals_{}{}.png'.format(plot_dir, key, file_suffix))

    fig, ax, norm_obs_best, intervalls_binned, intervalls_fitted = make_critical_value_plot(
        key, tau_dict=tau_dict, plot_splines=True, norm_bins=E2dNdE_or_modelnorm_dict[key], 
        bias_corr_funcs=bias_corr_funcs_kw)
    ax.set_title('Model: {}'.format(key))
    fig.savefig('{}/confidence_intervals_{}{}_splines.png'.format(plot_dir, key, file_suffix))

    print(key, norm_obs_best) 

In [None]:
E2dNdE_obs_best

In [None]:
tr.to_E2dNdE(748.113, E0=100, unit=1e3)

In [None]:
tr.to_E2dNdE(678, E0=100, unit=1e3)