In [51]:
import jax
import jax.numpy as jnp

from jax.random import PRNGKey, multivariate_normal

import pyhf
pyhf.set_backend("jax")


def make_model(pars: jnp.array) -> pyhf.Model:
    bounded_pars = jnp.where(pars > 10., 10., jnp.abs(pars))
    u1, u2, d1, d2 = bounded_pars
    u = jnp.array([u1, u2, -u1 - u2])
    d = jnp.array([d1, d2, -d1 - d2])
    
    sig = jnp.array([2,5,10])
    nominal = jnp.array([50, 50, 50])
    up = jnp.array([50, 50, 50]) + u
    down = jnp.array([50, 50, 50]) + d
    
    
    m = {
        'channels': [{'name': 'singlechannel',
        'samples': [{'name': 'signal',
            'data': sig,
            'modifiers': [{'name': 'mu', 'type': 'normfactor', 'data': None}]},
            {'name': 'background',
            'data': nominal,
            'modifiers': [
                {'name': 'bkguncrt',
            'type': 'histosys',
            'data': {'hi_data': up, 'lo_data': down}
            }]}]}]
        }
    return pyhf.Model(m, validate=False)
    
def fisher_info_covariance(bestfit_pars: jnp.array, m: pyhf.Model, observed_data: jnp.array) -> jnp.array:
    return jnp.linalg.inv(jax.hessian(lambda lhood_pars: -m.logpdf(lhood_pars, observed_data)[0])(bestfit_pars))

def gaussian_logpdf(bestfit_pars: jnp.array, data: jnp.array, cov: jnp.array) -> jnp.array:
    return jax.scipy.stats.multivariate_normal.logpdf(data, bestfit_pars, cov)

def model_gaussianity(m: pyhf.Model, bestfit_pars: jnp.array, cov_approx: jnp.array, observed_data: jnp.array) -> jnp.array:
    # - compare the likelihood of the fitted model with a gaussian approximation that has the same MLE (fitted_pars)
    # - do this across a number of points in parspace (sampled from the gaussian approx) and take the mean squared diff
    # - centre the values wrt the best-fit vals to scale the differences 
    gaussian_parspace_samples = multivariate_normal(key=PRNGKey(1), mean=bestfit_pars, cov=cov_approx, shape=(100,))
    
    relative_nlls_model = jax.vmap(
        lambda pars, data: -(m.logpdf(pars, data)[0] - m.logpdf(bestfit_pars, data)[0]), # scale origin to bestfit pars
        in_axes=(0, None)
    )(gaussian_parspace_samples, observed_data)
    
    relative_nlls_gaussian = jax.vmap(
        lambda pars, data: -(gaussian_logpdf(pars, data, cov_approx) - gaussian_logpdf(bestfit_pars, data, cov_approx)), # data fixes the lhood shape
        in_axes=(0, None)
    )(gaussian_parspace_samples, bestfit_pars)
    
    diffs = relative_nlls_model-relative_nlls_gaussian
    return jnp.mean(diffs[jnp.isfinite(diffs)]**2, axis=0)

def metrics(bestfit_pars: jnp.array, m: pyhf.Model, observed_data: jnp.array) -> jnp.array:
    cov_approx = fisher_info_covariance(bestfit_pars, m, observed_data)
    mu_idx, y_idx = m.config.par_order.index('mu'), m.config.par_order.index('bkguncrt')
    mu_uncert2 = cov_approx[mu_idx, mu_idx]
    pull_width_metric2 = (1-cov_approx[y_idx, y_idx])**2
    gaussianity = model_gaussianity(m, bestfit_pars, cov_approx, observed_data)
    cls_obs = pyhf.infer.hypotest(1.0, observed_data, m, init_pars = [0.0,0.0])
    
    return dict(cls_obs = cls_obs, mu_uncert2=mu_uncert2, pull_width_metric2=pull_width_metric2, gaussianity=gaussianity)

def pipeline(pars: jnp.array, observed_data: jnp.array) -> jnp.array:
    m = make_model(pars)
    data = jnp.concatenate((observed_data, jnp.array(m.config.auxdata)))
    mle_pars = pyhf.infer.mle.fit(data, m, init_pars=[0., 0.])
    
    return metrics(mle_pars, m, data)


In [53]:
pipeline(pars = jnp.array([1.,-2., -1.,2]), observed_data = jnp.array([50,50,50]))

{'cls_obs': DeviceArray(0.12815149, dtype=float64),
 'mu_uncert2': DeviceArray(0.3875969, dtype=float64),
 'pull_width_metric2': DeviceArray(4.93038066e-32, dtype=float64),
 'gaussianity': DeviceArray(0.24232135, dtype=float64)}