In [102]:
from scipy.integrate import quad, trapz, fixed_quad
import theano
import theano.tensor as T
from theano.tests.unittest_tools import verify_grad
from theano.compile.ops import as_op

import numpy as np
import pymc3 as pm

import matplotlib.pyplot as plt

import astropy.units as u
from astropy.coordinates import SkyCoord

from tqdm import tqdm

from gammapy.spectrum import CountsPredictor, CountsSpectrum
from gammapy.data import DataStore
from gammapy.maps import Map
from gammapy.background import ReflectedRegionsBackgroundEstimator
from gammapy.spectrum import SpectrumObservationList, SpectrumExtraction

from regions import CircleSkyRegion

from utils import load_spectrum_observations, plot_spectra, Log10Parabola, integrate_spectrum, wstat_profile
from forward_fold_model import model_probability, ln_profile_likelihood

%matplotlib inline

In [103]:
class Integrate(theano.Op):
    def __init__(self, expr, var, lower, upper, *inputs):
        super().__init__()
        self._expr = expr
        self._var = var
        self._extra_vars = inputs
        self.lower = lower
        self.upper = upper
        self._func = theano.function(
            [var] + list(self._extra_vars),
            self._expr,
            on_unused_input='ignore'
        )
    
    def make_node(self, *inputs):
        assert len(self._extra_vars)  == len(inputs)
        return theano.Apply(self, list(inputs), [T.dscalar().type()])
    
    def perform(self, node, inputs, out):
        x = np.linspace(self.lower, self.upper, num=3)
        y = np.array([self._func(i , *inputs) for i in x])
        val = trapz(y, x)
#         print(val)
#         val = quad(self._func, self.lower, self.upper, args=tuple(inputs))[0]
        out[0][0] = np.array(val)
        
    def grad(self, inputs, grads):
        out, = grads
        grads = T.grad(self._expr, self._extra_vars)
        dargs = []
        for grad in grads:
            integrate = Integrate(grad, self._var, self.lower, self.upper, *self._extra_vars)
            darg = out * integrate(*inputs)
            dargs.append(darg)
            
        return dargs

In [104]:
def apply_range(*arr, fit_range, bins):
    idx = np.searchsorted(bins.to_value(u.TeV), fit_range.to_value(u.TeV))
    return [a[idx[0]:idx[1]] for a in arr]

In [105]:
telescope = 'hess'

crab_position = SkyCoord(ra='83d37m59.0988s', dec='22d00m52.2s')
exclusion_map = Map.read(f"./data/exclusion_mask.fits.gz")

energy_range = {
    'fact': [0.55, 17] * u.TeV,
    'magic': [0.04, 18] * u.TeV,
    'veritas': [0.11, 20] * u.TeV,
    'hess': [0.6, 20] * u.TeV,
}

on_radius = {
    'fact': 0.17 * u.deg,
    'magic': 0.142 * u.deg,
    'veritas': 0.10 * u.deg,
    'hess': 0.11 * u.deg,
}

ds = DataStore.from_dir(f'./data/{telescope}/')
observations = ds.obs_list(ds.hdu_table['OBS_ID'].data)

fit_range = energy_range[telescope]

e_true_bins = np.logspace(*np.log10(fit_range.value), 10 + 1) * u.TeV
e_reco_bins = np.logspace(*np.log10(fit_range.value), 14 + 1) * u.TeV

on_region = CircleSkyRegion(center=crab_position, radius=on_radius[telescope])

print('Estimating Background')
bkg_estimate = ReflectedRegionsBackgroundEstimator(
    obs_list=observations, on_region=on_region, exclusion_mask=exclusion_map
)
bkg_estimate.run()

print('Extracting Count Spectra')
extract = SpectrumExtraction(
    obs_list=observations,
    bkg_estimate=bkg_estimate.result,
    e_true=e_true_bins,
    e_reco=e_reco_bins,
    containment_correction=False,
    use_recommended_erange=False,
)

extract.run()

observations = extract.observations

# observations, fit_range = load_spectrum_observations('hess', low_binning=False)
# observation = obs_list[1]
obs_alpha = observations[0].alpha
energy_bins = observations[0].on_vector.energy.bins
print(len(energy_bins))
fit_range # [obs.alpha for obs in observations]

Estimating Background
Extracting Count Spectra
15


<Quantity [ 0.6, 20. ] TeV>

In [106]:
def forward_fold_log_parabola(amplitude, alpha, beta, observations, fit_range=None):
    
    amplitude *= 1e-11
    
    predicted_signal_per_observation = []
    for observation in observations:
        obs_bins = observation.on_vector.energy.bins.to_value(u.TeV)


        aeff_bins = observation.aeff.energy
        e_reco_bins = observation.edisp.e_reco
        e_true_bins = observation.edisp.e_true

        lower =  e_true_bins.lo.to_value(u.TeV)
        upper = e_true_bins.hi.to_value(u.TeV)

        func = lambda energy: amplitude * energy **(-alpha - beta * np.log10(energy))

        counts = []
        for a, b in zip(lower, upper):
            x = np.linspace(a, b, num=3)
            y = np.array([func(i) for i in x])
            val = trapz(y, x)
            counts.append(val)

        counts = np.array(counts)
        aeff = observation.aeff.data.data.to_value(u.cm**2).astype(np.float32)

        counts *= aeff
        counts *= observation.livetime.to_value(u.s)
        edisp = observation.edisp.pdf_matrix
        edisp = edisp

        predicted_signal_per_observation.append(np.dot(counts, edisp))

    predicted_counts = np.sum(predicted_signal_per_observation, axis=0)
    if fit_range is not None:
        idx = np.searchsorted(obs_bins, fit_range.to_value(u.TeV))
        predicted_counts = predicted_counts[idx[0]:idx[1]]

    return predicted_counts

In [107]:
def forward_fold_log_parabola_symbolic(amplitude, alpha, beta, observations, fit_range=None, efficiency=1):
    
    amplitude *= 1e-11
    
    predicted_signal_per_observation = []
    for observation in observations:
        obs_bins = observation.on_vector.energy.bins.to_value(u.TeV)
        
        if fit_range is not None:
            idx = np.searchsorted(obs_bins, fit_range.to_value(u.TeV))
            lowest_bin = idx[0]
            highest_bin = idx[1]
        else:
            lowest_bin = 0
            highest_bin = len(obs_bins)


        aeff_bins = observation.aeff.energy
        e_reco_bins = observation.edisp.e_reco
        e_true_bins = observation.edisp.e_true

        lower =  e_true_bins.lo.to_value(u.TeV)
        upper = e_true_bins.hi.to_value(u.TeV)

        energy = T.dscalar('energy')
        amplitude_ = T.dscalar('amplitude_')
        alpha_ = T.dscalar('alpha_')
        beta_ = T.dscalar('beta_')

        func = amplitude_ * energy **(-alpha_ - beta_ * T.log10(energy))

        counts = []
        for a, b in zip(lower, upper):
            c = Integrate(func, energy, a, b, amplitude_, alpha_, beta_)(amplitude, alpha, beta)
            counts.append(c)

        counts = T.stack(counts)
        aeff = observation.aeff.data.data.to_value(u.cm**2).astype(np.float32)

        counts *= efficiency * aeff
        counts *= observation.livetime.to_value(u.s)
        edisp = observation.edisp.pdf_matrix
        edisp = edisp

        predicted_signal_per_observation.append(T.dot(counts, edisp))

    predicted_counts = T.sum(predicted_signal_per_observation, axis=0)

    predicted_counts = predicted_counts[lowest_bin:highest_bin]

    return predicted_counts

In [108]:
amplitude = T.dscalar('amplitude')
alpha = T.dscalar('alpha')
beta = T.dscalar('beta')

cf_fast = forward_fold_log_parabola_symbolic(amplitude, alpha, beta, observations, fit_range=fit_range)
counts_symbolic = cf_fast.eval({amplitude: 4.0, alpha: 2.5, beta: 0.4})
counts_symbolic, counts_symbolic.shape

(array([593.90547298, 737.06338362, 549.14884918, 436.24041896,
        352.11629696, 255.01706264, 193.67177852, 137.25346841,
         88.45797208,  56.57436225,  37.78630591,  21.14003781,
         11.47621112,   6.84679695]), (14,))

In [109]:
counts = forward_fold_log_parabola(4, 2.5, 0.4, observations, fit_range=fit_range)
counts, counts.shape

(array([593.90547298, 737.06338362, 549.14884918, 436.24041896,
        352.11629696, 255.01706264, 193.67177852, 137.25346841,
         88.45797208,  56.57436225,  37.78630591,  21.14003781,
         11.47621112,   6.84679695]), (14,))

In [110]:
%timeit cf_fast.eval({amplitude: 4.0, alpha: 2.5, beta: 0.4})

27.2 ms ± 399 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [111]:
%timeit forward_fold_log_parabola(4, 2.5, 0.4, observations, fit_range=fit_range)

9.26 ms ± 218 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
