In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#PDFs in BDT and sindec?
import os

# set env flags to catch BLAS used for scipy/numpy 
# to only use 1 cpu, n_cpus will be totally controlled by csky
if False:
    os.environ['MKL_NUM_THREADS'] = "1"
    os.environ['NUMEXPR_NUM_THREADS'] = "1"
    os.environ['OMP_NUM_THREADS'] = "1"
    os.environ['OPENBLAS_NUM_THREADS'] = "1"
    os.environ['VECLIB_MAXIMUM_THREADS'] = "1"

import matplotlib as mpl
mpl.rcParams['figure.facecolor'] = 'w'
mpl.rcParams['savefig.facecolor'] = 'w'
import matplotlib.pyplot as plt
from matplotlib import colors, cm
import csky as cy
from csky import cext
import numpy as np
import astropy
#from icecube import astro
import histlite as hl
import healpy
import healpy as hp
import socket
import pickle
import copy
healpy.disable_warnings()
plt.rc('figure', facecolor = 'w')
plt.rc('figure', dpi=100)

## Define Settings

In [None]:
selection_version = 'version-001-p01'

host_name = socket.gethostname()

if 'cobalt' in host_name:
    print('Working on Cobalts')
    data_prefix = '/data/user/ssclafani/data/cscd/final'
    ana_dir = '/data/user/ssclafani/data/analyses/'
    plot_dir = '/data/user/mhuennefeld/data/analyses/DNNCascadeCodeReview/unblinding_checks/plots/unblinding/sens_comparison_ps_cascade'
    
else:
    raise ValueError('Unknown host:', host_name)

In [None]:
for dir_path in [plot_dir]:
    if not os.path.exists(dir_path):
        print('Creating directory:', dir_path)
        os.makedirs(dir_path)

## Load Data

In [None]:
repo = cy.selections.Repository()
specs = cy.selections.DNNCascadeDataSpecs.DNNC_10yr

In [None]:
%%time

ana = cy.get_analysis(
    repo, selection_version, specs, 
    #gammas=np.r_[0.1:6.01:0.125],
)

In [None]:
a = ana.anas[0]
a.sig

In [None]:
a.bg_data

#### Load PSTracksV4

In [None]:
%%time

track_version = 'version-004-p00'
ana_tracks = cy.get_analysis(
    cy.selections.Repository(), track_version, cy.selections.PSDataSpecs.ps_v4, 
)

## Helpers

In [None]:
from cycler import cycle
from copy import deepcopy

soft_colors = cy.plotting.soft_colors
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']


def get_bias_allt(tr, ntrials=200, n_sigs=np.r_[:101:10], quiet=False):
    trials = [
        (None if quiet else print(f'\r{n_sig:4d} ...', end='', flush=True))
        or
        tr.get_many_fits(ntrials, n_sig=n_sig, logging=False, seed=n_sig)
        for n_sig in n_sigs]
    if not quiet:
        print()
    for (n_sig, t) in zip(n_sigs, trials):
        t['ntrue'] = np.repeat(n_sig, len(t))
    allt = cy.utils.Arrays.concatenate(trials)
    return allt

def get_color_cycler():
    return cycle(colors)

def plot_ns_bias(ax, tr, allt, label=''):

    n_sigs = np.unique(allt.ntrue)
    dns = np.mean(np.diff(n_sigs))
    ns_bins = np.r_[n_sigs - 0.5*dns, n_sigs[-1] + 0.5*dns]
    expect_kw = dict(color='C0', ls='--', lw=1, zorder=-10)

    h = hl.hist((allt.ntrue, allt.ns), bins=(ns_bins, 100))
    hl.plot1d(ax, h.contain_project(1),errorbands=True, 
              drawstyle='default', label=label)
    lim = ns_bins[[0, -1]]
    ax.set_xlim(ax.set_ylim(lim))
    ax.plot(lim, lim, **expect_kw)
    ax.set_aspect('equal')

    ax.set_xlabel(r'$n_{inj}$')
    ax.set_ylabel(r'$n_s$')
    ax.grid()

def plot_gamma_bias(ax, tr, allt, label=''):

    n_sigs = np.unique(allt.ntrue)
    dns = np.mean(np.diff(n_sigs))
    ns_bins = np.r_[n_sigs - 0.5*dns, n_sigs[-1] + 0.5*dns]
    expect_kw = dict(color='C0', ls='--', lw=1, zorder=-10)
    expect_gamma = tr.sig_injs[0].flux[0].gamma

    h = hl.hist((allt.ntrue, allt.gamma), bins=(ns_bins, 100))
    hl.plot1d(ax, h.contain_project(1),errorbands=True, 
              drawstyle='default', label=label)
    lim = ns_bins[[0, -1]]
    ax.set_xlim(lim)
    ax.set_ylim(1, 4)
    ax.axhline(expect_gamma, **expect_kw)

    ax.set_xlabel(r'$n_{inj}$')
    ax.set_ylabel(r'$\gamma$')
    ax.grid()

def plot_bkg_trials(
            bg, fig=None, ax=None, 
            label='{} bg trials', 
            label_fit=r'$\chi^2[{:.2f}\mathrm{{dof}},\ \eta={:.3f}]$', 
            color=colors[0],
            density=False,
            bins=50,
        ):
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4))
    
    if density:
        h = bg.get_hist(bins=bins).normalize()
    else:
        h = bg.get_hist(bins=bins)
    if label is not None:
        label = label.format(bg.n_total)
    hl.plot1d(ax, h, crosses=True, color=color, label=label)

    # compare with the chi2 fit:
    if hasattr(bg, 'pdf'):
        x = h.centers[0]
        norm = h.integrate().values
        if label_fit is not None:
            label_fit = label_fit.format(bg.ndof, bg.eta)
        if density:
            ax.semilogy(x, bg.pdf(x), lw=1, ls='--', label=label_fit, color=color)
        else:
            ax.semilogy(x, norm * bg.pdf(x), lw=1, ls='--', label=label_fit, color=color)

    ax.set_xlabel(r'TS')
    if density:
        ax.set_ylabel(r'Density')
    else:
        ax.set_ylabel(r'number of trials')
    ax.legend()
        
    return fig, ax

## Setup Analysis

In [None]:
import sys
sys.path.insert(0, '../..')

import config as cg

cg.base_dir = '/data/user/mhuennefeld/data/analyses/unblinding_v1.0.1/'

In [None]:
def get_tr(sindec, gamma, extension=0., cutoff=np.inf, cpus=20, sigsub=True):
    src = cy.utils.sources(0, np.arcsin(sindec), extension=extension, deg=False)
    cutoff_GeV = cutoff * 1e3
    conf = cg.get_ps_conf(
        src=src, gamma=gamma, cutoff_GeV=cutoff_GeV, sigsub=sigsub)

    tr = cy.get_trial_runner(ana=ana, conf=conf, mp_cpus=cpus)
    return tr, src

def get_tr_tracks(sindec, gamma, extension=0., cutoff=np.inf, cpus=20, sigsub=False):
    src = cy.utils.sources(0, np.arcsin(sindec), extension=extension, deg=False)
    cutoff_GeV = cutoff * 1e3
    conf = {
        'src': src,
        'flux': cy.hyp.PowerLawFlux(gamma, energy_cutoff=cutoff_GeV),
        'update_bg': True,
        'sigsub':  sigsub,
        'randomize': ['ra'],
    }
    
    tr = cy.get_trial_runner(ana=ana_tracks, conf=conf, mp_cpus=cpus)
    return tr, src


#### Get TrialRunners

In [None]:
sindec = np.sin(np.deg2rad(22))
tr_dict = {
    'dnnc': {},
    'psv4': {},
}

for gamma in [2.7, 3.0]:
    for cutoff in [15, np.inf]:
        for extension in np.deg2rad([0., 3.]):
            tr_dict['dnnc'][(gamma, cutoff, extension)] = get_tr(
                sindec=sindec, gamma=gamma, cutoff=cutoff, extension=extension)[0]
            tr_dict['psv4'][(gamma, cutoff, extension)] = get_tr_tracks(
                sindec=sindec, gamma=gamma, cutoff=cutoff, extension=extension)[0]


#### Get bkg fits for each trial runner

In [None]:
bkg_file_dict = {}
n_bkg_trials = 20000
seed = 1337
recalculate = False

bkg_file = os.path.join(plot_dir, 'trials_bkg.pkl')


if os.path.exists(bkg_file) and not recalculate:
    print('Reloading background trials')
    with open(os.path.join(plot_dir, 'trials_bkg.pkl'), 'rb') as f:
        bkg_dict = pickle.load(f)
else:
    bkg_dict = {}
    for key, tr_dict_i in tr_dict.items():
        tr = list(tr_dict_i.values())[0]
        print('Running background trials for trial runner {}'.format(key))
        bkg_dict[key] = tr.get_many_fits(
            n_trials=n_bkg_trials, seed=seed, mp_cpus=20)

    with open(bkg_file, 'wb') as f:
        pickle.dump(bkg_dict, f, protocol=2)


#### Plot ts distribution

In [None]:
for key, bg in bkg_dict.items():
    bg_tsd = cy.dists.Chi2TSD(bg)
    fig, ax = plot_bkg_trials(bg_tsd)
    ts_3sig = bg_tsd.isf_nsigma(3)
    ax.axvline(
        ts_3sig, ls='--', lw=1,
        label='3-sigma TS: {:3.3f}'.format(ts_3sig), 
    )
    ax.set_title('Trial Runner: {}'.format(key))
    ax.set_yscale('log')
    ax.legend()
    #fig.savefig('{}/ts_dist_{}.png'.format(plot_dir, key))

#### Compute Sensitivity

In [None]:
sens_dict = {k: {} for k in tr_dict.keys()}

for key1, tr_dict_i  in tr_dict.items():
    bg = cy.dists.Chi2TSD(bkg_dict[key1])
    
    for key2, tr in tr_dict_i.items():
        print('Computing Sensitivity for {}: {}'.format(key1, key2))
        sens = tr.find_n_sig(
                # ts, threshold
                ts=bg.median(),
                # beta, fraction of trials which should exceed the threshold
                beta=0.9,
                # n_inj step size for initial scan
                n_sig_step=20,
                # this many trials at a time
                batch_size=500,
                # tolerance, as estimated relative error
                tol=.05,
                first_batch_size = 250,
                mp_cpus=20,
                seed=seed
            )
        sens['flux'] = tr.to_E2dNdE(sens['n_sig'], E0=1, unit=1e3)
        print('flux:', sens['flux'])
        sens_dict[key1][key2] = sens

In [None]:
for key2 in sens_dict['dnnc'].keys():
    ps_sens = sens_dict['psv4'][key2]['flux']
    dnnc_sens = sens_dict['dnnc'][key2]['flux']
    print('Gamma: {:3.2f} | Cutoff: {:3.2f} TeV | Extension: {:3.3f}'.format(
        key2[0], key2[1], np.rad2deg(key2[2]),
    ))
    print('  PS: {:3.3e} | DNNC: {:3.3e} | DNNC/PS: {:3.3f}'.format(
        ps_sens, dnnc_sens, dnnc_sens/ps_sens,
    ))