In [29]:
import jax
import jax.numpy as jnp

from jax.random import PRNGKey, multivariate_normal

import pyhf
pyhf.set_backend("jax")


def make_model(pars: jnp.array) -> pyhf.Model:
    bounded_pars = jnp.where(pars > 10., 10., jnp.abs(pars))
    u1, u2, d1, d2 = bounded_pars
    u = jnp.array([u1, u2, -u1 - u2])
    d = jnp.array([d1, d2, -d1 - d2])
    
    sig = jnp.array([2,5,10])
    nominal = jnp.array([50, 50, 50])
    up = jnp.array([50, 50, 50]) + u
    down = jnp.array([50, 50, 50]) + d
    
    
    m = {
        'channels': [{'name': 'singlechannel',
        'samples': [{'name': 'signal',
            'data': sig,
            'modifiers': [{'name': 'mu', 'type': 'normfactor', 'data': None}]},
            {'name': 'background',
            'data': nominal,
            'modifiers': [
                {'name': 'bkguncrt',
            'type': 'histosys',
            'data': {'hi_data': up, 'lo_data': down}
            }]}]}]
        }
    return pyhf.Model(m, validate=False)
    
def fisher_info_covariance(bestfit_pars: jnp.array, m: pyhf.Model, fitted_data: jnp.array) -> jnp.array:
    return jnp.linalg.inv(jax.hessian(lambda lhood_pars: -m.logpdf(lhood_pars, fitted_data)[0])(bestfit_pars))

def gaussian_logpdf(data: jnp.array, bestfit_pars: jnp.array, cov: jnp.array) -> jnp.array:
    return jax.scipy.stats.multivariate_normal.logpdf(data, bestfit_pars, cov)

def gaussianity():
    gaussian_parspace_samples = multivariate_normal(key=PRNGKey(1), mean=bestfit_pars, cov=cov_approx, shape=(100,))
    nll_model = jax.vmap(
        lambda data, pars: -(m.logpdf(pars, data)[0] - m.logpdf(bestfit_pars, data)[0]), # scale origin to bestfit pars
        in_axes=(0, None)
    )(gaussian_samples, )

def metrics(bestfit_pars: jnp.array, m: pyhf.Model, fitted_data: jnp.array) -> jnp.array:
    cov_approx = fisher_info_covariance(bestfit_pars, m, fitted_data)
    mu_uncert2 = cov_approx[m.config.par_order.index('mu')]
    pull_width_metric2 = (1-cov_approx[m.config.par_order.index('bkguncrt')])**2
    
    #
    
    
    return dict(mu_uncert2=mu_uncert2, pull_width_metric2=pull_width_metric2)


In [27]:
make_model(jnp.array([3,4,5,6])).

1

In [10]:
pyhf.__version__

'0.6.4.dev30'

In [6]:
x = jnp.array([-1,11])
jnp.where(x>10, 10, jnp.abs(x))

DeviceArray([ 1, 10], dtype=int64)

In [3]:
from jax import value_and_grad

value_and_grad(t)(jnp.ones(4))

Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray([51. 51. 48.])>with<JVPTrace(level=1/0)>
  with primal = DeviceArray([51., 51., 48.], dtype=float64)
       tangent = Traced<ShapedArray(float64[3]):JaxprTrace(level=0/0)>.

This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.

In [13]:
pyhf.__version__

'0.5.0'

In [21]:
import numpy as np
import jax

def op(tensor):
    t = jnp.asarray(tensor)
    return 2*t

jax.value_and_grad(op)(1.)

(DeviceArray(2., dtype=float64), DeviceArray(2., dtype=float64))