In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

#PDFs in BDT and sindec?
import os

# set env flags to catch BLAS used for scipy/numpy 
# to only use 1 cpu, n_cpus will be totally controlled by csky
if False:
    os.environ['MKL_NUM_THREADS'] = "1"
    os.environ['NUMEXPR_NUM_THREADS'] = "1"
    os.environ['OMP_NUM_THREADS'] = "1"
    os.environ['OPENBLAS_NUM_THREADS'] = "1"
    os.environ['VECLIB_MAXIMUM_THREADS'] = "1"

import matplotlib as mpl
mpl.rcParams['figure.facecolor'] = 'w'
mpl.rcParams['savefig.facecolor'] = 'w'
import matplotlib.pyplot as plt
from matplotlib import colors, cm
import csky as cy
from csky import cext
import numpy as np
import astropy
#from icecube import astro
import histlite as hl
import healpy
import healpy as hp
import socket
import pickle
import copy
healpy.disable_warnings()
plt.rc('figure', facecolor = 'w')
plt.rc('figure', dpi=100)

## Define Settings

In [None]:
selection_version = 'version-001-p00'

host_name = socket.gethostname()

if 'cobalt' in host_name:
    print('Working on Cobalts')
    data_prefix = '/data/user/ssclafani/data/cscd/final'
    ana_dir = '/data/user/ssclafani/data/analyses/'
    plot_dir = '/data/user/mhuennefeld/data/analyses/DNNCascadeCodeReview/unblinding_checks/plots/followup/template_stacking_feasibility'
    
else:
    raise ValueError('Unknown host:', host_name)

In [None]:
for dir_path in [plot_dir]:
    if not os.path.exists(dir_path):
        print('Creating directory:', dir_path)
        os.makedirs(dir_path)

## Load Data

In [None]:
repo = cy.selections.Repository()
specs = cy.selections.DNNCascadeDataSpecs.DNNC_10yr

In [None]:
%%time

ana = cy.get_analysis(
    repo, selection_version, specs, 
    #gammas=np.r_[0.1:6.01:0.125],
)

In [None]:
a = ana.anas[0]
a.sig

In [None]:
a.bg_data

## Helpers

In [None]:
from cycler import cycle
from copy import deepcopy

soft_colors = cy.plotting.soft_colors
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']


def get_bias_allt(tr, ntrials=200, n_sigs=np.r_[:101:10], quiet=False):
    trials = [
        (None if quiet else print(f'\r{n_sig:4d} ...', end='', flush=True))
        or
        tr.get_many_fits(ntrials, n_sig=n_sig, logging=False, seed=n_sig)
        for n_sig in n_sigs]
    if not quiet:
        print()
    for (n_sig, t) in zip(n_sigs, trials):
        t['ntrue'] = np.repeat(n_sig, len(t))
    allt = cy.utils.Arrays.concatenate(trials)
    return allt

def get_color_cycler():
    return cycle(colors)

def plot_bias(ax, x_fit, y_true, label=''):
    
    y_unique = np.unique(y_true)
    dy = np.mean(np.diff(y_unique))
    y_bins = np.r_[y_unique - 0.5*dy, y_unique[-1] + 0.5*dy]
    expect_kw = dict(color='C0', ls='--', lw=1, zorder=-10)

    h = hl.hist((y_true, x_fit), bins=(y_bins, 100))
    hl.plot1d(ax, h.contain_project(1), errorbands=True, 
              drawstyle='default', label=label)
    lim = y_bins[[0, -1]]
    ax.set_xlim(ax.set_ylim(lim))
    ax.plot(lim, lim, **expect_kw)
    ax.set_aspect('equal')

    ax.grid()
    return h

def plot_ns_bias(ax, tr, allt, label=''):

    n_sigs = np.unique(allt.ntrue)
    dns = np.mean(np.diff(n_sigs))
    ns_bins = np.r_[n_sigs - 0.5*dns, n_sigs[-1] + 0.5*dns]
    expect_kw = dict(color='C0', ls='--', lw=1, zorder=-10)

    h = hl.hist((allt.ntrue, allt.ns), bins=(ns_bins, 100))
    hl.plot1d(ax, h.contain_project(1),errorbands=True, 
              drawstyle='default', label=label)
    lim = ns_bins[[0, -1]]
    ax.set_xlim(ax.set_ylim(lim))
    ax.plot(lim, lim, **expect_kw)
    ax.set_aspect('equal')

    ax.set_xlabel(r'$n_{inj}$')
    ax.set_ylabel(r'$n_s$')
    ax.grid()

def plot_gamma_bias(ax, tr, allt, label=''):

    n_sigs = np.unique(allt.ntrue)
    dns = np.mean(np.diff(n_sigs))
    ns_bins = np.r_[n_sigs - 0.5*dns, n_sigs[-1] + 0.5*dns]
    expect_kw = dict(color='C0', ls='--', lw=1, zorder=-10)
    expect_gamma = tr.sig_injs[0].flux[0].gamma

    h = hl.hist((allt.ntrue, allt.gamma), bins=(ns_bins, 100))
    hl.plot1d(ax, h.contain_project(1),errorbands=True, 
              drawstyle='default', label=label)
    lim = ns_bins[[0, -1]]
    ax.set_xlim(lim)
    ax.set_ylim(1, 4)
    ax.axhline(expect_gamma, **expect_kw)

    ax.set_xlabel(r'$n_{inj}$')
    ax.set_ylabel(r'$\gamma$')
    ax.grid()

def plot_bkg_trials(
            bg, fig=None, ax=None, 
            label='{} bg trials', 
            label_fit=r'$\chi^2[{:.2f}\mathrm{{dof}},\ \eta={:.3f}]$', 
            color=colors[0],
            density=False,
            bins=50,
        ):
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4))
    
    if density:
        h = bg.get_hist(bins=bins).normalize()
    else:
        h = bg.get_hist(bins=bins)
    if label is not None:
        label = label.format(bg.n_total)
    hl.plot1d(ax, h, crosses=True, color=color, label=label)

    # compare with the chi2 fit:
    if hasattr(bg, 'pdf'):
        x = h.centers[0]
        norm = h.integrate().values
        if label_fit is not None:
            label_fit = label_fit.format(bg.ndof, bg.eta)
        if density:
            ax.semilogy(x, bg.pdf(x), lw=1, ls='--', label=label_fit, color=color)
        else:
            ax.semilogy(x, norm * bg.pdf(x), lw=1, ls='--', label=label_fit, color=color)

    ax.set_xlabel(r'TS')
    if density:
        ax.set_ylabel(r'Density')
    else:
        ax.set_ylabel(r'number of trials')
    ax.legend()
        
    return fig, ax

## Setup Analysis

In [None]:
import sys
sys.path.insert(0, '../..')

import config as cg

cg.base_dir = '/data/user/mhuennefeld/data/analyses/unblinding_v1.0.0/'

In [None]:
def get_gp_tr(template_str, cutoff=np.inf, gamma=None, cpus=20):
    cutoff_GeV = cutoff * 1e3
    gp_conf = cg.get_gp_conf(
        template_str=template_str, gamma=gamma, cutoff_GeV=cutoff_GeV, base_dir=cg.base_dir)
    tr = cy.get_trial_runner(gp_conf, ana=ana, mp_cpus=cpus)
    return tr

def get_template_tr(template, gamma=2.7, cutoff_tev=np.inf, cpus=20, **kwargs):
    cutoff_gev = cutoff_tev * 1000.
    gp_conf = {
        'template': template,
        'flux': cy.hyp.PowerLawFlux(gamma, energy_cutoff=cutoff_gev),
        'randomize': ['ra'],
        'fitter_args': dict(gamma=gamma),
        'sigsub': True,
        'update_bg': True,
        'fast_weight': False,
        **kwargs
    }
    tr = cy.get_trial_runner(gp_conf, ana=ana, mp_cpus=cpus)
    return tr


#### Get TrialRunners

In [None]:
tr_dict = {
    'fermibubbles_50TeV': get_gp_tr('fermibubbles', cutoff=50),
    'pi0': get_gp_tr('pi0'),
    'kra5': get_gp_tr('kra5'),
    'kra50': get_gp_tr('kra50'),
}

#### Get Results for each template

In [None]:
res_dict = {}
for key in tr_dict.keys():
    f_path = os.path.join(
        cg.base_dir, 
        'gp/results/{}/{}_unblinded.npy'.format(key, key), 
    )
    if os.path.exists(f_path):
        res_dict[key] = np.load(f_path)
    else:
        print('File does not exist: {}'.format(f_path))

#### Print best fit fluxes

In [None]:

dNdE = tr_dict['pi0'].to_dNdE(ns=res_dict['pi0'][1], E0=1e5)
E2dNdE = tr_dict['pi0'].to_E2dNdE(ns=res_dict['pi0'][1], E0=100, unit=1e3)
print(dNdE, E2dNdE)


#### Get bkg fits for each template

In [None]:
bkg_file_dict = {
    'pi0': '{}/gp/trials/{}/{}/trials.dict'.format(cg.base_dir, 'DNNC', 'pi0'),
    'kra5': '{}/gp/trials/{}/{}/trials.dict'.format(cg.base_dir, 'DNNC', 'kra5'),
    'kra50': '{}/gp/trials/{}/{}/trials.dict'.format(cg.base_dir, 'DNNC', 'kra50'),
}
n_bkg_trials = 20000
seed = 1337

bkg_dict = {}
for key, tr in tr_dict.items():
    if 'fermibubbles' in key: continue
    if key in bkg_file_dict:
        print('Loading background trials for template {}'.format(key))
        sig = np.load(bkg_file_dict[key], allow_pickle=True)
        bkg_dict[key] = sig['poisson']['nsig'][0.0]['ts']
    
    else:
        print('Running background trials for template {}'.format(key))
        bkg_dict[key] = tr.get_many_fits(
            n_trials=n_bkg_trials, seed=seed, mp_cpus=20)
        

In [None]:
for k, values in bkg_dict.items():
    print(k, len(values))

#### Plot ts distribution

In [None]:
for key, bg in bkg_dict.items():
    bg_tsd = cy.dists.TSD(bg)
    fig, ax = plot_bkg_trials(bg_tsd)
    ts = res_dict[key][0]
    ns = res_dict[key][1]
    ax.axvline(
        ts, color='0.8', ls='--', lw=2,
        label='TS: {:3.3f} | ns: {:3.1f}'.format(ts, ns), 
    )
    ts_5sig = bg_tsd.isf_nsigma(5)
    ax.axvline(
        ts_5sig, ls='--', lw=1,
        label='5-sigma TS: {:3.3f}'.format(ts_5sig), 
    )
    ax.set_title('Template: {}'.format(key))
    ax.set_yscale('log')
    ax.legend()
    fig.savefig('{}/ts_dist_{}.png'.format(plot_dir, key))

#### Compute Significance

In [None]:
p_val_dict = {}
sigma_dict = {}
max_n = 300000000
for key, bg in bkg_dict.items():
    print(key)
    bg_tsd = cy.dists.TSD(bg[:max_n])
    p_val_dict[key] = bg_tsd.sf(bg[:max_n])
    sigma_dict[key] = bg_tsd.sf_nsigma(bg[:max_n])


#### Plot Trial Correlation

In [None]:
mask = np.zeros_like(bkg_dict['pi0'][:max_n])
sigma_threshold = 0.5

for key, tr in sigma_dict.items():

    mask = np.logical_or(mask, sigma_dict[key] > sigma_threshold)
    

In [None]:
import matplotlib as mpl

def plot_corr_ax(ax, key1, key2, mask=None, norm=None):
    
    if mask is None:
        mask = np.ones_like(sigma_dict[key1], dtype=bool)
        
    ax.hist2d(
        sigma_dict[key1][mask], sigma_dict[key2][mask],
        bins=bins, norm=norm, cmin=1,
    )
    ax.plot(
        (bins[0][0], bins[0][-1]), (bins[0][0], bins[0][-1]), 
        ls='--', color='0.7', lw=3,
    )
    ax.set_xlabel('$n\cdot \sigma$ of {}'.format(key1))
    ax.set_ylabel('$n\cdot \sigma$ of {}'.format(key2))

fig, axes = plt.subplots(3, 1, figsize=(9, 9))

bins = (np.linspace(0, 6, 50), np.linspace(0, 6, 50))
norm = mpl.colors.LogNorm(vmin=1, vmax=1e5)
mask = None
plot_corr_ax(axes[0], 'pi0', 'kra5', mask=mask, norm=norm)
plot_corr_ax(axes[1], 'kra5', 'kra50', mask=mask, norm=norm)
plot_corr_ax(axes[2], 'pi0', 'kra50', mask=mask, norm=norm)
fig.tight_layout()
fig.savefig('{}/gp_trial_correlation.png'.format(plot_dir))


In [None]:
corr_keys = ['pi0', 'kra5', 'kra50']

max_nsigma = np.max(
    np.stack([sigma_dict[k] for k in corr_keys]),
    axis=0,
)
bg_max = cy.dists.TSD(max_nsigma)

In [None]:
from scipy import stats

nsigma_chosen = res_dict['pi0'][3]# 4.705
pval_chosen = stats.norm.sf(nsigma_chosen)
nsigma_corrected = bg_max.sf_nsigma(nsigma_chosen)

fig, ax = plt.subplots()
ax.hist(max_nsigma, bins=np.linspace(0, 6, 200), label='Correlated bkg trials')
ax.set_xlabel('Max n-sigma')
ax.set_ylabel('Number of trials')
ax.set_yscale('log')
ax.axvline(
    nsigma_chosen, ls='--', color='0.7', 
    label='Unblinded: {:3.3f}$\sigma$ | Corrected: {:3.3f}$\sigma$'.format(
        nsigma_chosen, nsigma_corrected),
)
ax.legend(loc='upper right')
fig.savefig('{}/gp_trial_correction_hist.png'.format(plot_dir))

pval_corrected = bg_max.sf(nsigma_chosen)
print('Correcting for: {}'.format(corr_keys))
print('Pre-trial N-sigma of: {}'.format(nsigma_chosen))
print('Pre-trial p-value of: {}'.format(pval_chosen))
print('Post-trial correlated n-sigma: {} | factor: {}'.format(nsigma_corrected, pval_corrected/pval_chosen))
print('Post-trial correlated p-value: {} | factor: {}'.format(pval_corrected, pval_corrected/pval_chosen))
print('Post-trial conservative n-sigma: {} | factor: {}'.format(stats.norm.isf(pval_chosen * len(corr_keys)), len(corr_keys)))
print('Post-trial conservative p-value: {} | factor: {}'.format(pval_chosen * len(corr_keys), pval_chosen * len(corr_keys)/pval_chosen))



In [None]:
nsigma_chosen = res_dict['pi0'][3]# 4.705
pval_chosen = stats.norm.sf(nsigma_chosen)
pval_conservative = pval_chosen * 3
nsigma_conservative = stats.norm.isf(pval_conservative)

print('Pre-trial N-sigma of: {}'.format(nsigma_chosen))
print('Post-trial conservative: {} | factor: {}'.format(nsigma_conservative, pval_conservative/pval_chosen))


## Load and Plot Skymaps

In [None]:
ss_results = np.load(
    os.path.join(
        cg.base_dir, 'skyscan/results/unblinded_skyscan.npy'),
    allow_pickle=True,
)[()]
ss_trial = ss_results['ss_trial']



In [None]:
ss_results


In [None]:
names = ['mlog10p', 'ts', 'ns', 'gamma']
for loc in ['south', 'north']:
    print('Hottest spot in {}:'.format(loc))
    for i, name in enumerate(names):
        print('  {}: {}'.format(name, ss_trial[i,ss_results['ipix_max_{}'.format(loc)]]))


In [None]:
# https://www.nasa.gov/mission_pages/GLAST/news/gammaray_best.html
# https://www.nasa.gov/images/content/317870main_Fermi_3_month_labeled_new.jpg
fermi_sources = {
    # ra, dec
    'NGC 1275': (049.9506656698585, +41.5116983765094),
    '3C 454.3': (343.49061658, +16.14821142),
    '47 Tuc': (006.022329, -72.081444),
    '0FGL J1813.5-1248': (273.349033, -12.766842),
    '0FGL J0614.3-3330': (093.5431162, -33.4983656),
    'PKS 0727-115': (112.57963530917, -11.68683347528),
    'Vela': (128.5000, -45.8333),
    'Geminga': (098.475638, +17.770253),
    'Crab': (083.63308, +22.01450),
    'LSI +61 303': (040.1319341179735, +61.2293308716971),
    'PSR J1836+5925': (279.056921, +59.424936),
    'PKS 1502+106': (226.10408242258, +10.49422183753),
    #'Cygnus X-3': (308.10742, +40.95775), # not one of the top Fermi sources
}

cat_dict = {}
for cat_str in ['pwn', 'snr', 'unid']:
    catalog_file = os.path.join(
        cg.catalog_dir, '{}_ESTES_12.pickle'.format(cat_str))
    cat_dict[cat_str] = np.load(catalog_file, allow_pickle=True)

src_list_file = os.path.join(cg.catalog_dir, 'Source_List_DNNC.pickle')
sourcelist = np.load(src_list_file, allow_pickle=True)
sourcelist['ra_deg'] = sourcelist['RA']
sourcelist['dec_deg'] = sourcelist['DEC']
cat_dict['Source List'] = sourcelist

#### Plot template contours

In [None]:
class ContourSkymap:
    def __init__(self, skymap, nside=None):
        
        # upscale skymap
        if nside is not None:
            skymap = hp.ud_grade(skymap, nside_out=nside)
        else:
            nside = hp.get_nside(skymap)
        
        # normalize such that sum of pixel values equals one
        self.prob_values = skymap / np.sum(skymap)
        self.neg_llh_values = -np.log10(self.prob_values)
        self.nside = hp.get_nside(self.prob_values)
        self.npix = hp.nside2npix(self.nside)
        
        self.theta, self.phi = self.get_healpix_grid()
        
        # sort healpix points according to neg llh
        sorted_indices = np.argsort(self.neg_llh_values)
        self.theta_s = self.theta[sorted_indices]
        self.phi_s = self.phi[sorted_indices]
        self.neg_llh_values_s = self.neg_llh_values[sorted_indices]
        self.prob_values_s = self.prob_values[sorted_indices]
    
        self.cdf_values_s = np.cumsum(self.prob_values_s)
    
    def get_healpix_grid(self):
        npix = hp.nside2npix(self.nside)
        theta, phi = hp.pix2ang(self.nside, np.r_[:npix])
        return theta, phi
    
    def quantile_to_pdf_value(self, quantile):
        """Get pixel probability value
        """
        assert quantile >= 0., quantile
        assert quantile <= 1., quantile

        index = np.searchsorted(self.cdf_values_s, quantile)
        return self.prob_values_s[index]
    
    def _get_level_indices(self, level=0.5, delta=0.01):
        """Get indices of healpix map, which belong to the specified
        contour as defined by: level +- delta.

        Parameters
        ----------
        level : float, optional
            The contour level. Example: a level of 0.7 means that 70% of events
            are within this contour.
        delta : float, optional
            The contour is provided by selecting directions from the sampled
            ones which have cdf values within [level - delta, level + delta].
            The smaller delta, the more accurate the contour will be. However,
            the number of available sample points for the contour will also
            decrease.

        Returns
        -------
        int, int
            The starting and stopping index for a slice of sampled events
            that lie within the contour [level - delta, level + delta].

        Raises
        ------
        ValueError
            If number of resulting samples is too low.
        """
        assert level >= 0., level
        assert level <= 1., level

        index_min = np.searchsorted(self.cdf_values_s, level - delta)
        index_max = min(self.npix,
                        np.searchsorted(self.cdf_values_s, level + delta))

        if index_max - index_min <= 10:
            raise ValueError('Number of samples is too low!')

        return index_min, index_max
    
    def contour(self, level=0.5, delta=0.01):
        """Get zenith/azimuth paris of points that lie with the specified
        contour [level - delta, level + delta].

        Parameters
        ----------
        level : float, optional
            The contour level. Example: a level of 0.7 means that 70% of events
            are within this contour.
        delta : float, optional
            The contour is provided by selecting directions from the sampled
            ones which have cdf values within [level - delta, level + delta].
            The smaller delta, the more accurate the contour will be. However,
            the number of available sample points for the contour will also
            decrease.

        Returns
        -------
        np.array, np.array
            The theta/phi pairs that lie within the contour
            [level - delta, level + delta].
        """
        index_min, index_max = self._get_level_indices(level, delta)
        return (self.theta_s[index_min:index_max],
                self.phi_s[index_min:index_max])

In [None]:
def get_smeared_template(key_or_tr, smearing=5):
    if isinstance(key_or_tr, str):
        tr = tr_dict[key_or_tr]
    else:
        tr = key_or_tr
    space_pdf = tr.llh_models[0].pdf_ratio_model.models[0]
    sigma_idx = np.searchsorted(np.rad2deg(space_pdf.sigmas), smearing)
    template_smeared = space_pdf.pdf_space_sig[sigma_idx]
    return np.array(template_smeared)


#### Skymap Plotting Class

In [None]:
import utils

sys.path.insert(0, '../unblinding')
import contour_compute


class SkymapPlotter:
    
    def __init__(self, fermi_sources=fermi_sources, cat_dict=cat_dict, ss_results=ss_results, **kwargs):
        self.cat_dict = cat_dict
        self.fermi_sources = fermi_sources
        self.ss_results = ss_results
        self.coord = None
        
        self.fig, self.ax, self.sp, self.cb, self.pc = self.plot_skymap(**kwargs)
    
    def add_skymap_layer(self, m, ax=None, input_coord='C', **kw):
        
        if input_coord == 'C' and self.sp.coord == 'G':
            m = SkymapPlotter.equatorial_to_galactic(m)
                
        if ax is None:
            ax = self.ax
        lat, lon, Z = self.sp.map_to_latlonz(m)
        pc = ax.pcolormesh(lon, lat, Z, **kw)
        return lat, lon, Z, pc
        
    @staticmethod
    def plot_skymap(
                skymap, fig=None, ax=None, outfile=None, figsize=(9, 6),
                vmin=None, vmax=None, label=None, norm=None,
                cmap=cy.plotting.skymap_cmap,
                input_coord='C',
                n_cb_ticks=4,
                gp_kw=dict(color='.3', alpha=0.5), gp_lw=1.,
                plot_gp=True, annotate=True,
                **kwargs
            ):
        """Plot a skymap

        Parameters
        ----------
        skymap : array_like
            The skymap to plot.
        outfile : str, optional
            The output file path to which to plot if provided.
        vmin : float, optional
            The minimum value for the colorbar.
        vmax : float, optional
            The maximum value for the colorbar.
        figsize : tuple, optional
            The figure size to use.
        label : str, optional
            The label for the colorbar.

        Returns
        -------
        fig, ax
            The matplotlib figure and axis.
        """
        if fig is None:
            fig, ax = plt.subplots(
                subplot_kw=dict(projection='aitoff'), figsize=figsize)

        if 'coord' in kwargs and kwargs['coord'] == 'G':
            nohr = True
            if input_coord == 'C':
                skymap = SkymapPlotter.equatorial_to_galactic(skymap)
        else:
            nohr = False

        sp = cy.plotting.SkyPlotter(
            pc_kw=dict(cmap=cmap, vmin=vmin, vmax=vmax, norm=norm), 
            **kwargs
        )
        pc, cb = sp.plot_map(ax, skymap, n_ticks=n_cb_ticks, nohr=nohr)
    
        SkymapPlotter.annotate_skymap(
            ax=ax, sp=sp, annotate=annotate, plot_gp=plot_gp, gp_kw=gp_kw, gp_lw=gp_lw,
        )
        if False:
            if sp.coord == 'G' and annotate:
                kw = dict(xycoords='axes fraction', textcoords='offset pixels', verticalalignment='center')
                ax.annotate(r'l = -180°', xy=(1, .5), xytext=(10, 0), horizontalalignment='left', **kw)
                ax.annotate(r'l = 180°', xy=(0, .5), xytext=(-10, 0), horizontalalignment='right', **kw)

            if sp.coord != 'G' and plot_gp:
                sp.plot_gp(ax, lw=gp_lw, **gp_kw)
                sp.plot_gc(ax, **gp_kw)
            kw = dict(color='.5', alpha=.5)
            ax.grid(**kw)
        cb.set_label(label)
        fig.tight_layout()
        if outfile is not None:
            fig.savefig(outfile)

        return fig, ax, sp, cb, pc
    
    @staticmethod
    def annotate_skymap(ax, sp, annotate=True, plot_gp=True, gp_kw=dict(color='.3', alpha=0.5), gp_lw=1.):
        if sp.coord == 'G' and annotate:
            kw = dict(xycoords='axes fraction', textcoords='offset pixels', verticalalignment='center')
            ax.annotate(r'l = -180°', xy=(1, .5), xytext=(10, 0), horizontalalignment='left', **kw)
            ax.annotate(r'l = 180°', xy=(0, .5), xytext=(-10, 0), horizontalalignment='right', **kw)
        
        if sp.coord != 'G' and plot_gp:
            sp.plot_gp(ax, lw=gp_lw, **gp_kw)
            sp.plot_gc(ax, **gp_kw)
        kw = dict(color='.5', alpha=.5)
        ax.grid(**kw)
    
    @staticmethod
    def equatorial_to_galactic(m, rot=180):
        r = hp.Rotator(rot=rot, coord='CG')
        return r.rotate_map_pixel(m)

    @staticmethod
    def equatorial_to_galactic_coords(theta, phi, rot=180):
        r = hp.Rotator(rot=rot, coord='CG')
        return r(theta, phi)
    
    @staticmethod
    def galactic_to_equatorial(m, rot=180):
        r = hp.Rotator(rot=rot, coord='CG', inv=True)
        return r.rotate_map_pixel(m)

    @staticmethod
    def galactic_to_equatorial_coords(theta, phi, rot=180):
        r = hp.Rotator(rot=rot, coord='CG', inv=True)
        return r(theta, phi)
    
    def convert_theta_phi_to_mpl_coords(self, theta, phi, convert=True):
        if self.sp.coord == 'G' and convert:
            theta, phi = self.equatorial_to_galactic_coords(theta, phi)
        x, y = self.sp.thetaphi_to_mpl(theta, phi)
        return x, y
        
    def convert_ra_dec_to_mpl_coords(self, ra, dec):
        theta = np.pi/2. - dec
        phi = ra
        return self.convert_theta_phi_to_mpl_coords(theta=theta, phi=phi)

    def draw_equator(self, ax=None, color='0.6', s=1, **kwargs):
        if ax is None:
            ax = self.ax
        phi = np.linspace(0., 2*np.pi, 10000)
        dec = np.zeros_like(phi)
        x, y = self.convert_ra_dec_to_mpl_coords(ra=phi, dec=dec)
        return ax.scatter(x, y, marker='.', color=color, s=s)

    def plot_catalog(self, ax=None, sp=None, marker='x', color='red', keys=None, labels=None, cat_dict=None, **kwargs):
        
        if ax is None:
            ax = self.ax
        if sp is None:
            sp = self.sp
        
        if cat_dict is None:
            cat_dict = self.cat_dict
            
        if keys is None:
            keys = list(cat_dict.keys())
        if labels is None:
            labels = keys
            
        for cat_str, label in zip(keys, labels):
            cat = cat_dict[cat_str]
            x, y = self.convert_ra_dec_to_mpl_coords(
                ra=np.deg2rad(cat.ra_deg), dec=np.deg2rad(cat.dec_deg))
            ax.scatter(x, y, marker=marker, color=color, label=label, **kwargs)
    
    def plot_hotspots(self, ax=None, sp=None, marker='x', color='0.8', **kwargs):
        if ax is None:
            ax = self.ax
        if sp is None:
            sp = self.sp
            
        # plot hottest spots
        for res_str in ['ipix_max_north', 'ipix_max_south']:
            theta, phi = hp.pix2ang(128, self.ss_results[res_str])
            x, y = self.convert_theta_phi_to_mpl_coords(theta=theta, phi=phi)
            ax.scatter(x, y, marker=marker, color=color, **kwargs)
    
    def plot_fermi_sources(self, ax=None, sp=None, marker='+', color='1.0'):
        
        if ax is None:
            ax = self.ax
        if sp is None:
            sp = self.sp
            
        for key, (ra_deg, dec_deg) in self.fermi_sources.items():
            x, y = self.convert_ra_dec_to_mpl_coords(
                ra=np.deg2rad(ra_deg), dec=np.deg2rad(dec_deg))
            ax.scatter(x, y, marker=marker, color=color)
    
    def get_contour(self, skymap, quantiles=[0.5], geodesic='planar'):
        """Get contours

        Returns
        -------
        contours_by_level : list(list(list(point)))
            The contours for each level
            Outermost list indexes by level
            Second list indexes by contours at a particular level
            Third list indexes by points in each contour
            Points are of the same form as sample_points
        """
        nside = hp.get_nside(skymap)
        theta, phi = hp.pix2ang(nside=nside, ipix=np.arange(hp.nside2npix(nside)))
        
        if self.sp.coord == 'G':
            skymap = SkymapPlotter.equatorial_to_galactic(skymap)
        
        # compute PDF levels for provided quantiles 
        contour_map = ContourSkymap(skymap=skymap)

        levels = []
        for quantile in quantiles:
            levels.append(contour_map.quantile_to_pdf_value(quantile))
        
        # compute sample points in which to compute the contours
        if geodesic == 'spherical':
            sample_points = np.stack((theta, phi), axis=1)
        elif geodesic == 'planar':
            x, y = self.convert_theta_phi_to_mpl_coords(theta=theta, phi=phi, convert=False)
            sample_points = np.stack((x, y), axis=1)
        else:
            raise ValueError('Unknown geodesic: {}'.format(geodesic))

        contours = contour_compute.compute_contours(
            sample_points=sample_points, samples=contour_map.prob_values, levels=levels, geodesic=geodesic)
        return contours

    
    def plot_template_contour(
                self, template_str, smearing_deg=5., quantiles=[0.5], ls=['-'], color=[None], geodesic='planar',
        ):
        assert len(quantiles) == len(ls)
        assert len(quantiles) == len(color)
        
        template = get_smeared_template(template_str, smearing=smearing_deg)
        contours = self.get_contour(skymap=template, quantiles=quantiles, geodesic=geodesic)
        for ls_i, color_i, contour in zip(ls, color, contours):
            for contour_i in contour:
                if geodesic == 'spherical':
                    x, y = self.convert_theta_phi_to_mpl_coords(theta=contour_i[:, 0], phi=contour_i[:, 1], convert=False)
                elif geodesic == 'planar':
                    x, y = contour_i[:, 0], contour_i[:, 1]
                else:
                    raise ValueError('Unknown geodesic: {}'.format(geodesic))
                if len(x) > 2:
                    self.ax.plot(x, y, ls=ls_i, color=color_i)
        return contours
    
    def plot_template_contour_points(
                self, template_str, level, 
                smearing_deg=5, color='0.2', delta=0.01, 
                marker='.', s=1,
                ax=None, sp=None,
            ):
        contour_map = ContourSkymap(get_smeared_template(template_str, smearing=smearing_deg))
        theta, phi = contour_map.contour(level=level, delta=delta)
        x, y = self.convert_theta_phi_to_mpl_coords(theta=theta, phi=phi)
        self.ax.scatter(x, y, marker=marker, s=s, color=color)


## Test Template Mixture Model

In [None]:
template_pi0_raw_ = cg.template_repo.get_template('Fermi-LAT_pi0_map')
template_kra5_raw_, energy_bins = cg.template_repo.get_template(
    'KRA-gamma_5PeV_maps_energies', per_pixel_flux=True)
template_kra5_raw = np.sum(template_kra5_raw_, axis=1)

# normalize
pix_area = hp.nside2pixarea(hp.get_nside(template_pi0_raw_))
template_pi0_raw = template_pi0_raw_ / np.sum(template_pi0_raw_) / pix_area
template_kra5_raw = template_kra5_raw_ / np.sum(template_kra5_raw_) / pix_area

# create a different, rotated PDF 
template_pi0_raw_rotated_ = SkymapPlotter.equatorial_to_galactic(template_pi0_raw_)

##### Create trial runners with same flux

In [None]:
tr_pi0 = get_template_tr(template=template_pi0_raw_)
tr_pi0_rotated = get_template_tr(template=template_pi0_raw_rotated_)


In [None]:
print(np.allclose(tr_pi0.llh_models[0].pdf_ratio_model.models[0].template, template_pi0_raw_))
print(np.allclose(tr_pi0_rotated.llh_models[0].pdf_ratio_model.models[0].template, template_pi0_raw_rotated_))


In [None]:
w1 = 0.5

template_pi0_tr = tr_pi0.llh_models[0].pdf_ratio_model.models[0].template
template_pi0_rotated_tr = tr_pi0_rotated.llh_models[0].pdf_ratio_model.models[0].template

# combine before acceptance and smearing
template_combined_raw = w1 * template_pi0_tr + (1 - w1) * template_pi0_rotated_tr
tr_combined = get_template_tr(template=template_combined_raw)


In [None]:
smearing_deg = 7
template_pi0 = get_smeared_template(tr_pi0, smearing=smearing_deg)
template_pi0_rotated = get_smeared_template(tr_pi0_rotated, smearing=smearing_deg)

template_combined_post1 = w1 * template_pi0 + (1 - w1) * template_pi0_rotated
template_combined = get_smeared_template(tr_combined, smearing=smearing_deg)



In [None]:
skymap_plotter = SkymapPlotter(
    skymap=template_pi0, cmap='viridis', 
)
skymap_plotter.ax.set_title('Map 1')
skymap_plotter.fig.savefig('{}/template_stacking_map1.png'.format(plot_dir))

skymap_plotter = SkymapPlotter(
    skymap=template_pi0_rotated, cmap='viridis', 
)
skymap_plotter.ax.set_title('Map 2')
skymap_plotter.fig.savefig('{}/template_stacking_map2.png'.format(plot_dir))


In [None]:
skymap_plotter = SkymapPlotter(
    skymap=template_combined_post1, cmap='viridis', 
)
skymap_plotter.ax.set_title('Combination after AC')
skymap_plotter.fig.savefig('{}/combination_after_accceptance.png'.format(plot_dir))

skymap_plotter = SkymapPlotter(
    skymap=template_combined, cmap='viridis', 
)
skymap_plotter.ax.set_title('Combination before AC')
skymap_plotter.fig.savefig('{}/combination_before_accceptance.png'.format(plot_dir))

skymap_plotter = SkymapPlotter(
    skymap=np.log10(template_combined/template_combined_post1), cmap='viridis', 
)
skymap_plotter.ax.set_title('Acceptance Comparison Ratio')
skymap_plotter.fig.savefig('{}/acceptance_order.png'.format(plot_dir))

## Check effect of smearing before/after combination

As demonstrated below, smearing can be applied before or after building the mixture model. This is essentially the same appart from numerical differences.

In [None]:
sigma = np.deg2rad(20)
t1 = get_smeared_template('pi0', smearing=0)
t2 = get_smeared_template('kra5', smearing=0)

t1_s = hp.smoothing(t1, sigma=sigma)
t2_s = hp.smoothing(t2, sigma=sigma)

t_combined_s_pre = t1_s * w1 + (1. - w1) *t2_s
t_combined_s_post = hp.smoothing(t1 * w1 + (1. - w1) *t2, sigma=sigma)


In [None]:
skymap_plotter = SkymapPlotter(
    skymap=t_combined_s_pre, cmap='viridis', 
)
skymap_plotter = SkymapPlotter(
    skymap=t_combined_s_post, cmap='viridis', 
)
skymap_plotter = SkymapPlotter(
    skymap=np.log10(t_combined_s_pre/t_combined_s_post), cmap='viridis', 
)
skymap_plotter.ax.set_title('Smearing Comparison')
skymap_plotter.fig.savefig('{}/smearing_order.png'.format(plot_dir))

## Test StackedTemplatePDFRatioModel

#### Define test injection template

In [None]:
def catalog_as_template(catalog, cat_dict=cat_dict, nside=64):
    template = np.zeros(hp.nside2npix(nside))
    pix_area = hp.nside2pixarea(nside)
    
    # set pixels corresponding to source locations to 1
    for ra_deg, dec_deg in zip(cat_dict[catalog].ra_deg, cat_dict[catalog].dec_deg):

        # transform to other coordinates and set to 1
        theta = np.pi/2. - np.deg2rad(dec_deg)
        phi = np.deg2rad(ra_deg)
        
        ipix = hp.ang2pix(nside=nside, theta=theta, phi=phi)
        template[ipix] = 1
        
    # normalize
    template = template / np.sum(template) / pix_area
    return template

def get_template_component(component, sigma=None):
    
    if component == 'pi0':
        template = cg.template_repo.get_template('Fermi-LAT_pi0_map')
    
    elif component in ['snr', 'pwn', 'unid']:
        template = catalog_as_template(component)
    
    elif component == 'random':
        template_pi0_raw_ = cg.template_repo.get_template('Fermi-LAT_pi0_map')
        
        def equatorial_to_galactic(m, rot=180):
            r = hp.Rotator(rot=rot, coord='CG')
            return r.rotate_map_pixel(m)
        
        template = equatorial_to_galactic(template_pi0_raw_)
    else:
        raise ValueError()
    
    # smear template 
    if sigma is not None:
        template = hp.smoothing(template, sigma=sigma)
        
    # normalize template
    pix_area = hp.nside2pixarea(hp.get_nside(template))
    template = template / np.sum(template) / pix_area
    
    return template

def get_test_injection_template(w, sigma=None):
    template1 = get_template_component('unid', sigma=sigma)
    template2 = get_template_component('snr', sigma=sigma)
    return template1 * w + (1. - w) * template2


In [None]:
sigma = np.deg2rad(3)
inj_weight = 0.8

templates = [get_template_component(c, sigma=sigma) for c in ['snr', 'unid']]
injection_template = get_test_injection_template(inj_weight, sigma=sigma)

skymap_plotter = SkymapPlotter(
    skymap=injection_template, cmap='viridis', 
)

skymap_plotter.ax.set_title('Injected | w = {:3.2f}'.format(inj_weight))
skymap_plotter.fig.savefig('{}/injected_template.png'.format(plot_dir))


In [None]:
for i, template_i in enumerate(templates):
    skymap_plotter = SkymapPlotter(
        skymap=template_i, cmap='viridis', 
    )

    skymap_plotter.ax.set_title('Component {}'.format(i))
    skymap_plotter.fig.savefig('{}/component_{:03d}.png'.format(plot_dir, i))


#### Define trial runner

In [None]:

def get_stacked_template_tr(templates, injection_template, gamma=2.7, cutoff_tev=np.inf, cpus=20, **kwargs):
    
    cutoff_gev = cutoff_tev * 1000.
    flux = cy.hyp.PowerLawFlux(gamma, energy_cutoff=cutoff_gev)
    
    # create list of template model kwargs
    template_kwarg_list = []
    for template in templates:
        template_kwarg_list.append({
            'template': template,
            'flux': flux,
        })
    
    injection_template_kwargs = {
        'template': injection_template,
        'flux': flux,
        'sigmas': [0],
    }
    
    gp_conf = {
        'template_kwarg_list': template_kwarg_list,
        'flux': flux,
        'randomize': ['ra'],
        'fitter_args': dict(gamma=gamma),
        'sigsub': True,
        'update_bg': True,
        'fast_weight': False,
        'injection_template_kwargs': injection_template_kwargs,
        **kwargs
    }
    tr = cy.get_trial_runner(gp_conf, ana=ana, mp_cpus=cpus)
    return tr

#### Get Trial Runner

In [None]:
tr_stacked = get_stacked_template_tr(
    templates=templates, injection_template=injection_template)

#### Test one fit

In [None]:
def prior_func(**params):
    n_models = 2
    weights = np.empty(n_models)
    for i in range(n_models):
        weights[i] = params['weight_{:04d}'.format(i)]

    exp_weights = np.exp(weights)
    weight_norm = np.sum(exp_weights)
    print('weight_norm', weight_norm)
    return (weight_norm - 1)**2

#prior_func(**fit[1])

In [None]:
trial = tr_stacked.get_one_trial(n_sig=1000, seed=None)
fit = tr_stacked.get_one_fit_from_trial(trial, flat=False, )
#fit = tr_stacked.get_one_fit_from_trial(trial, flat=False, weight_0000=0.01, weight_0001=0.5)
fit

weights = cy.pdf.StackedTemplateSpacePDFRatioModel.compute_weights(2, **fit[1])
#weights = np.array([fit[1]['weight_0000'], fit[1]['weight_0001']])
#exp_weights = np.exp(weights)
#weights = exp_weights / np.sum(exp_weights)
#weights /= np.sum(weights)
weights, fit

## Test bias in fitted weights

In [None]:
from tqdm.notebook import tqdm_notebook as tqdm
from multiprocessing import Pool

n_trials = 100
n_weights = 10
n_sigs = np.linspace(0, 300, 7)
seed = 1337
n_cpus = 20
inj_weights = np.linspace(0, 1, n_weights)

results = np.zeros((n_weights, len(n_sigs), n_trials))

for i, inj_weight_i in tqdm(enumerate(inj_weights), total=n_weights):
    
    # get injection weight
    injection_template_i = get_test_injection_template(w=inj_weight_i, sigma=sigma)
    
    # get injection trial runner
    tr_inj = get_stacked_template_tr(
        templates=templates, injection_template=injection_template_i, sigmas=[0])
    
    # get trials
    print('Running pool with {} cpus'.format(n_cpus))
    for j, n_sig in enumerate(n_sigs):
        def compute_trial(j):
            inj_trial = tr_inj.get_one_trial(n_sig=n_sig, seed=seed + j)
            fit = tr_stacked.get_one_fit_from_trial(inj_trial, flat=False, )
            weights = cy.pdf.StackedTemplateSpacePDFRatioModel.compute_weights(2, **fit[1])
            return weights[1]

        with Pool(n_cpus) as p:
            weights_i = list(tqdm(p.imap(compute_trial, range(n_trials)), total=n_trials))
        results[i, j, :] = weights_i
        p.close()


In [None]:

for j, nsig in enumerate(n_sigs):
    fig, ax = plt.subplots(figsize=(9, 6))
    inj_weights_ext = np.ones_like(results[:, j]) * inj_weights[:, np.newaxis]

    plot_bias(
        ax=ax, x_fit=results[:, j].flatten(), y_true=inj_weights_ext.flatten(), 
        label=r'$n_\mathrm{inj}$ ' + '= {}'.format(nsig),
    )
    ax.set_xlabel('True UNID/SNR Ratio')
    ax.set_ylabel('Injected UNID/SNR Ratio')
    ax.legend()
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    fig.savefig('{}/template_weight_bias_nsig_{}.png'.format(plot_dir, nsig))


In [None]:
fig, ax = plt.subplots(figsize=(9, 6))

for j, nsig in enumerate(n_sigs):
    if j % 2 == 0: continue
    inj_weights_ext = np.ones_like(results[:, j]) * inj_weights[:, np.newaxis]

    plot_bias(
        ax=ax, x_fit=results[:, j].flatten(), y_true=inj_weights_ext.flatten(), 
        label=r'$n_\mathrm{inj}$ ' + '= {}'.format(nsig),
    )
ax.set_xlabel('Injected UNID/SNR Ratio')
ax.set_ylabel('Fitted UNID/SNR Ratio')
ax.legend()
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
fig.savefig('{}/template_weight_bias.png'.format(plot_dir))


## Can Stacking Catalogs explain GP results?

##### Create stacking catalog templates

In [None]:
nside = hp.get_nside(template_pi0_raw_)
pix_area = hp.nside2pixarea(nside)

stacking_templates_dict = {}
for catalog in ['snr', 'pwn', 'unid']:
    template_i = np.zeros(hp.nside2npix(nside))
    
    # set pixels corresponding to source locations to 1
    for ra_deg, dec_deg in zip(cat_dict[catalog].ra_deg, cat_dict[catalog].dec_deg):

        # transform to other coordinates and set to 1
        theta = np.pi/2. - np.deg2rad(dec_deg)
        phi = np.deg2rad(ra_deg)
        
        ipix = hp.ang2pix(nside=nside, theta=theta, phi=phi)
        template_i[ipix] = 1
        
    # normalize
    template_i = template_i / np.sum(template_i) / pix_area
    
    stacking_templates_dict[catalog] = template_i


In [None]:
smoothing = 7.
for catalog in ['snr', 'pwn', 'unid']:
    skymap_plotter = SkymapPlotter(
        skymap=hp.smoothing(stacking_templates_dict[catalog], sigma=np.deg2rad(smoothing)), cmap='viridis', 
    )
    skymap_plotter.ax.set_title('Catalog {} | {:2.0f}°'.format(catalog, smoothing))
    skymap_plotter.fig.savefig('{}/catalog_{}_{:2.0f}deg.png'.format(plot_dir, catalog, smoothing))

skymap_plotter = SkymapPlotter(
    skymap=hp.smoothing(template_pi0_raw, sigma=np.deg2rad(smoothing)), cmap='viridis', 
)
skymap_plotter.ax.set_title('$\pi^0$ Template | {:2.0f}°'.format(smoothing))
skymap_plotter.fig.savefig('{}/catalog_pi0_{:1.0f}deg.png'.format(plot_dir, smoothing))
    

## Scratch Space

In [None]:
space_pdf = tr_dict['pi0'].llh_models[0].pdf_ratio_model.models[0]
space_pdf.template / template_pi0_0deg

In [None]:
space_pdf._acc_model(space_pdf.template)