Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: model building (and gradients thereof) with jax as the default backend #1912

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
8 changes: 5 additions & 3 deletions src/pyhf/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,19 @@ def __init__(self, pdfconfig, batch_size=None):
if not parset.pdf_type == 'normal':
continue

normal_constraint_data.append(thisauxdata)
normal_constraint_data.append(default_backend.astensor(thisauxdata))

# many constraints are defined on a unit gaussian
# but we reserved the possibility that a paramset
# can define non-standard uncertainties. This is used
# by the paramset associated to staterror modifiers.
# Such parsets define a 'sigmas' attribute
try:
normal_constraint_sigmas.append(parset.sigmas)
normal_constraint_sigmas.append(default_backend.astensor(parset.sigmas))
except AttributeError:
normal_constraint_sigmas.append([1.0] * len(thisauxdata))
normal_constraint_sigmas.append(
default_backend.astensor([1.0] * len(thisauxdata))
)

self._normal_data = None
self._sigmas = None
Expand Down
8 changes: 7 additions & 1 deletion src/pyhf/modifiers/histosys.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,17 @@ def __init__(self, config):
self.required_parsets = {}

def collect(self, thismod, nom):
default_backend = pyhf.default_backend
lo_data = thismod['data']['lo_data'] if thismod else nom
hi_data = thismod['data']['hi_data'] if thismod else nom
maskval = bool(thismod)
mask = [maskval] * len(nom)
return {'lo_data': lo_data, 'hi_data': hi_data, 'mask': mask, 'nom_data': nom}
return {
'lo_data': default_backend.astensor(lo_data),
'hi_data': default_backend.astensor(hi_data),
'mask': default_backend.astensor(mask, dtype='bool'),
'nom_data': default_backend.astensor(nom),
}

def append(self, key, channel, sample, thismod, defined_samp):
self.builder_data.setdefault(key, {}).setdefault(sample, {}).setdefault(
Expand Down
18 changes: 13 additions & 5 deletions src/pyhf/modifiers/shapefactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,17 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
for m in keys
]

global_concatenated_bin_indices = [
[[j for c in pdfconfig.channels for j in range(pdfconfig.channel_nbins[c])]]
]
global_concatenated_bin_indices = default_backend.astensor(
[
[
[
j
for c in pdfconfig.channels
for j in range(pdfconfig.channel_nbins[c])
]
]
]
)

self._access_field = default_backend.tile(
global_concatenated_bin_indices,
Expand Down Expand Up @@ -174,8 +182,8 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
for t, batch_access in enumerate(syst_access):
selection = self.param_viewer.index_selection[s][t]
for b, bin_access in enumerate(batch_access):
self._access_field[s, t, b] = (
selection[bin_access] if bin_access < len(selection) else 0
self._access_field = self._access_field.at[s, t, b].set(
selection[int(bin_access)] if bin_access < len(selection) else 0
)

self._precompute()
Expand Down
24 changes: 17 additions & 7 deletions src/pyhf/modifiers/shapesys.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,17 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
for m in keys
]
)
global_concatenated_bin_indices = [
[[j for c in pdfconfig.channels for j in range(pdfconfig.channel_nbins[c])]]
]
global_concatenated_bin_indices = default_backend.astensor(
[
[
[
j
for c in pdfconfig.channels
for j in range(pdfconfig.channel_nbins[c])
]
]
]
)

self._access_field = default_backend.tile(
global_concatenated_bin_indices,
Expand Down Expand Up @@ -163,10 +171,12 @@ def _reindex_access_field(self, pdfconfig):
)

sample_mask = self._shapesys_mask[syst_index][singular_sample_index][0]
access_field_for_syst_and_batch[sample_mask] = selection
self._access_field[
syst_index, batch_index
] = access_field_for_syst_and_batch
access_field_for_syst_and_batch = access_field_for_syst_and_batch.at[
sample_mask
].set(selection)
self._access_field = self._access_field.at[syst_index, batch_index].set(
access_field_for_syst_and_batch
)

def _precompute(self):
tensorlib, _ = get_backend()
Expand Down
32 changes: 22 additions & 10 deletions src/pyhf/modifiers/staterror.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,13 @@ def finalize(self):
# extract sigmas using this modifiers mask
sigmas = relerrs[masks[modname]]

# list of bools, consistent with other modifiers (no numpy.bool_)
fixed = default_backend.tolist(sigmas == 0)
# NOT a list of bools (jax indexing requires the mask to be an array)
fixed = sigmas == 0
# FIXME: sigmas that are zero will be fixed to 1.0 arbitrarily to ensure
# non-Nan constraint term, but in a future PR need to remove constraints
# for these
sigmas[fixed] = 1.0
sigmas = sigmas.at[fixed].set(1.0)

self.required_parsets.setdefault(parname, [required_parset(sigmas, fixed)])
return self.builder_data

Expand All @@ -151,9 +152,18 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None):
[[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples]
for m in keys
]
global_concatenated_bin_indices = [
[[j for c in pdfconfig.channels for j in range(pdfconfig.channel_nbins[c])]]
]

global_concatenated_bin_indices = default_backend.astensor(
[
[
[
j
for c in pdfconfig.channels
for j in range(pdfconfig.channel_nbins[c])
]
]
]
)

self._access_field = default_backend.tile(
global_concatenated_bin_indices,
Expand Down Expand Up @@ -183,10 +193,12 @@ def _reindex_access_field(self, pdfconfig):
)

sample_mask = self._staterror_mask[syst_index][singular_sample_index][0]
access_field_for_syst_and_batch[sample_mask] = selection
self._access_field[
syst_index, batch_index
] = access_field_for_syst_and_batch
access_field_for_syst_and_batch = access_field_for_syst_and_batch.at[
sample_mask
].set(selection)
self._access_field = self._access_field.at[syst_index, batch_index].set(
access_field_for_syst_and_batch
)

def _precompute(self):
if not self.param_viewer.index_selection:
Expand Down
2 changes: 2 additions & 0 deletions src/pyhf/parameters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def reduce_paramsets_requirements(paramsets_requirements, paramsets_user_configs
for paramset_requirement in paramset_requirements:
# undefined: the modifier does not support configuring that property
v = paramset_requirement.get(k, 'undefined')
# differentiable models mean that v could contain jax arrays/tracers
# neither of these are hashable, so the set-based logic here needs changing
combined_paramset.setdefault(k, set()).add(v)

if len(combined_paramset[k]) != 1:
Expand Down
4 changes: 2 additions & 2 deletions src/pyhf/tensor/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def astensor(self, tensor_in, dtype="float"):
return jnp.asarray(tensor_in, dtype=dtype)

def sum(self, tensor_in, axis=None):
return jnp.sum(tensor_in, axis=axis)
return jnp.sum(jnp.asarray(tensor_in), axis=axis)

def product(self, tensor_in, axis=None):
return jnp.prod(tensor_in, axis=axis)
Expand Down Expand Up @@ -334,7 +334,7 @@ def concatenate(self, sequence, axis=0):
output: the concatenated tensor

"""
return jnp.concatenate(sequence, axis=axis)
return jnp.concatenate([jnp.array(x) for x in sequence], axis=axis)

def simple_broadcast(self, *args):
"""
Expand Down
115 changes: 115 additions & 0 deletions tests/test_differentiable_model_construction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import pyhf
from jax import numpy as jnp
import jax


def test_model_building_grad():
pyhf.set_backend('jax', default=True)

def make_model(
nominal,
corrup_data,
corrdn_data,
stater_data,
normsys_up,
normsys_dn,
uncorr_data,
):
return {
"channels": [
{
"name": "achannel",
"samples": [
{
"name": "background",
"data": nominal,
"modifiers": [
{"name": "mu", "type": "normfactor", "data": None},
{"name": "lumi", "type": "lumi", "data": None},
{
"name": "mod_name",
"type": "shapefactor",
"data": None,
},
{
"name": "corr_bkguncrt2",
"type": "histosys",
"data": {
'hi_data': corrup_data,
'lo_data': corrdn_data,
},
},
{
"name": "staterror2",
"type": "staterror",
"data": stater_data,
},
{
"name": "norm",
"type": "normsys",
"data": {'hi': normsys_up, 'lo': normsys_dn},
},
],
}
],
},
{
"name": "secondchannel",
"samples": [
{
"name": "background",
"data": nominal,
"modifiers": [
{"name": "mu", "type": "normfactor", "data": None},
{"name": "lumi", "type": "lumi", "data": None},
{
"name": "mod_name",
"type": "shapefactor",
"data": None,
},
{
"name": "uncorr_bkguncrt2",
"type": "shapesys",
"data": uncorr_data,
},
{
"name": "corr_bkguncrt2",
"type": "histosys",
"data": {
'hi_data': corrup_data,
'lo_data': corrdn_data,
},
},
{
"name": "staterror",
"type": "staterror",
"data": stater_data,
},
{
"name": "norm",
"type": "normsys",
"data": {'hi': normsys_up, 'lo': normsys_dn},
},
],
}
],
},
],
}

def pipe(x):
spec = make_model(
x * jnp.asarray([60.0, 62.0]),
x * jnp.asarray([60.0, 62.0]),
x * jnp.asarray([60.0, 62.0]),
x * jnp.asarray([5.0, 5.0]),
x * jnp.asarray(0.95),
x * jnp.asarray(1.05),
x * jnp.asarray([5.0, 5.0]),
)
model = pyhf.Model(spec, validate=False)
nominal = jnp.array(model.config.suggested_init())
data = model.expected_data(nominal)
return model.logpdf(nominal, data)[0]

jax.grad(pipe)(3.4) # should work without error