diff --git a/src/pyhf/constraints.py b/src/pyhf/constraints.py index 30ef835e79..73b45b05c5 100644 --- a/src/pyhf/constraints.py +++ b/src/pyhf/constraints.py @@ -45,7 +45,7 @@ def __init__(self, pdfconfig, batch_size=None): if not parset.pdf_type == 'normal': continue - normal_constraint_data.append(thisauxdata) + normal_constraint_data.append(default_backend.astensor(thisauxdata)) # many constraints are defined on a unit gaussian # but we reserved the possibility that a paramset @@ -53,9 +53,11 @@ def __init__(self, pdfconfig, batch_size=None): # by the paramset associated to staterror modifiers. # Such parsets define a 'sigmas' attribute try: - normal_constraint_sigmas.append(parset.sigmas) + normal_constraint_sigmas.append(default_backend.astensor(parset.sigmas)) except AttributeError: - normal_constraint_sigmas.append([1.0] * len(thisauxdata)) + normal_constraint_sigmas.append( + default_backend.astensor([1.0] * len(thisauxdata)) + ) self._normal_data = None self._sigmas = None diff --git a/src/pyhf/modifiers/histosys.py b/src/pyhf/modifiers/histosys.py index 78b0f20074..c406112c13 100644 --- a/src/pyhf/modifiers/histosys.py +++ b/src/pyhf/modifiers/histosys.py @@ -31,11 +31,17 @@ def __init__(self, config): self.required_parsets = {} def collect(self, thismod, nom): + default_backend = pyhf.default_backend lo_data = thismod['data']['lo_data'] if thismod else nom hi_data = thismod['data']['hi_data'] if thismod else nom maskval = bool(thismod) mask = [maskval] * len(nom) - return {'lo_data': lo_data, 'hi_data': hi_data, 'mask': mask, 'nom_data': nom} + return { + 'lo_data': default_backend.astensor(lo_data), + 'hi_data': default_backend.astensor(hi_data), + 'mask': default_backend.astensor(mask, dtype='bool'), + 'nom_data': default_backend.astensor(nom), + } def append(self, key, channel, sample, thismod, defined_samp): self.builder_data.setdefault(key, {}).setdefault(sample, {}).setdefault( diff --git a/src/pyhf/modifiers/shapefactor.py b/src/pyhf/modifiers/shapefactor.py index 6c65c2e8c9..9de62c1249 100644 --- a/src/pyhf/modifiers/shapefactor.py +++ b/src/pyhf/modifiers/shapefactor.py @@ -142,9 +142,17 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None): for m in keys ] - global_concatenated_bin_indices = [ - [[j for c in pdfconfig.channels for j in range(pdfconfig.channel_nbins[c])]] - ] + global_concatenated_bin_indices = default_backend.astensor( + [ + [ + [ + j + for c in pdfconfig.channels + for j in range(pdfconfig.channel_nbins[c]) + ] + ] + ] + ) self._access_field = default_backend.tile( global_concatenated_bin_indices, @@ -174,8 +182,8 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None): for t, batch_access in enumerate(syst_access): selection = self.param_viewer.index_selection[s][t] for b, bin_access in enumerate(batch_access): - self._access_field[s, t, b] = ( - selection[bin_access] if bin_access < len(selection) else 0 + self._access_field = self._access_field.at[s, t, b].set( + selection[int(bin_access)] if bin_access < len(selection) else 0 ) self._precompute() diff --git a/src/pyhf/modifiers/shapesys.py b/src/pyhf/modifiers/shapesys.py index 3d4d52e2b6..ea3e0067f1 100644 --- a/src/pyhf/modifiers/shapesys.py +++ b/src/pyhf/modifiers/shapesys.py @@ -128,9 +128,17 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None): for m in keys ] ) - global_concatenated_bin_indices = [ - [[j for c in pdfconfig.channels for j in range(pdfconfig.channel_nbins[c])]] - ] + global_concatenated_bin_indices = default_backend.astensor( + [ + [ + [ + j + for c in pdfconfig.channels + for j in range(pdfconfig.channel_nbins[c]) + ] + ] + ] + ) self._access_field = default_backend.tile( global_concatenated_bin_indices, @@ -163,10 +171,12 @@ def _reindex_access_field(self, pdfconfig): ) sample_mask = self._shapesys_mask[syst_index][singular_sample_index][0] - access_field_for_syst_and_batch[sample_mask] = selection - self._access_field[ - syst_index, batch_index - ] = access_field_for_syst_and_batch + access_field_for_syst_and_batch = access_field_for_syst_and_batch.at[ + sample_mask + ].set(selection) + self._access_field = self._access_field.at[syst_index, batch_index].set( + access_field_for_syst_and_batch + ) def _precompute(self): tensorlib, _ = get_backend() diff --git a/src/pyhf/modifiers/staterror.py b/src/pyhf/modifiers/staterror.py index 1e7573d729..3c06ab6aa0 100644 --- a/src/pyhf/modifiers/staterror.py +++ b/src/pyhf/modifiers/staterror.py @@ -120,12 +120,13 @@ def finalize(self): # extract sigmas using this modifiers mask sigmas = relerrs[masks[modname]] - # list of bools, consistent with other modifiers (no numpy.bool_) - fixed = default_backend.tolist(sigmas == 0) + # NOT a list of bools (jax indexing requires the mask to be an array) + fixed = sigmas == 0 # FIXME: sigmas that are zero will be fixed to 1.0 arbitrarily to ensure # non-Nan constraint term, but in a future PR need to remove constraints # for these - sigmas[fixed] = 1.0 + sigmas = sigmas.at[fixed].set(1.0) + self.required_parsets.setdefault(parname, [required_parset(sigmas, fixed)]) return self.builder_data @@ -151,9 +152,18 @@ def __init__(self, modifiers, pdfconfig, builder_data, batch_size=None): [[builder_data[m][s]['data']['mask']] for s in pdfconfig.samples] for m in keys ] - global_concatenated_bin_indices = [ - [[j for c in pdfconfig.channels for j in range(pdfconfig.channel_nbins[c])]] - ] + + global_concatenated_bin_indices = default_backend.astensor( + [ + [ + [ + j + for c in pdfconfig.channels + for j in range(pdfconfig.channel_nbins[c]) + ] + ] + ] + ) self._access_field = default_backend.tile( global_concatenated_bin_indices, @@ -183,10 +193,12 @@ def _reindex_access_field(self, pdfconfig): ) sample_mask = self._staterror_mask[syst_index][singular_sample_index][0] - access_field_for_syst_and_batch[sample_mask] = selection - self._access_field[ - syst_index, batch_index - ] = access_field_for_syst_and_batch + access_field_for_syst_and_batch = access_field_for_syst_and_batch.at[ + sample_mask + ].set(selection) + self._access_field = self._access_field.at[syst_index, batch_index].set( + access_field_for_syst_and_batch + ) def _precompute(self): if not self.param_viewer.index_selection: diff --git a/src/pyhf/parameters/utils.py b/src/pyhf/parameters/utils.py index 9f8e66e647..1c6a5260b5 100644 --- a/src/pyhf/parameters/utils.py +++ b/src/pyhf/parameters/utils.py @@ -37,6 +37,8 @@ def reduce_paramsets_requirements(paramsets_requirements, paramsets_user_configs for paramset_requirement in paramset_requirements: # undefined: the modifier does not support configuring that property v = paramset_requirement.get(k, 'undefined') + # differentiable models mean that v could contain jax arrays/tracers + # neither of these are hashable, so the set-based logic here needs changing combined_paramset.setdefault(k, set()).add(v) if len(combined_paramset[k]) != 1: diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index 5e4a65bc80..ee15a5e0d2 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -230,7 +230,7 @@ def astensor(self, tensor_in, dtype="float"): return jnp.asarray(tensor_in, dtype=dtype) def sum(self, tensor_in, axis=None): - return jnp.sum(tensor_in, axis=axis) + return jnp.sum(jnp.asarray(tensor_in), axis=axis) def product(self, tensor_in, axis=None): return jnp.prod(tensor_in, axis=axis) @@ -334,7 +334,7 @@ def concatenate(self, sequence, axis=0): output: the concatenated tensor """ - return jnp.concatenate(sequence, axis=axis) + return jnp.concatenate([jnp.array(x) for x in sequence], axis=axis) def simple_broadcast(self, *args): """ diff --git a/tests/test_differentiable_model_construction.py b/tests/test_differentiable_model_construction.py new file mode 100644 index 0000000000..6922a32e76 --- /dev/null +++ b/tests/test_differentiable_model_construction.py @@ -0,0 +1,115 @@ +import pyhf +from jax import numpy as jnp +import jax + + +def test_model_building_grad(): + pyhf.set_backend('jax', default=True) + + def make_model( + nominal, + corrup_data, + corrdn_data, + stater_data, + normsys_up, + normsys_dn, + uncorr_data, + ): + return { + "channels": [ + { + "name": "achannel", + "samples": [ + { + "name": "background", + "data": nominal, + "modifiers": [ + {"name": "mu", "type": "normfactor", "data": None}, + {"name": "lumi", "type": "lumi", "data": None}, + { + "name": "mod_name", + "type": "shapefactor", + "data": None, + }, + { + "name": "corr_bkguncrt2", + "type": "histosys", + "data": { + 'hi_data': corrup_data, + 'lo_data': corrdn_data, + }, + }, + { + "name": "staterror2", + "type": "staterror", + "data": stater_data, + }, + { + "name": "norm", + "type": "normsys", + "data": {'hi': normsys_up, 'lo': normsys_dn}, + }, + ], + } + ], + }, + { + "name": "secondchannel", + "samples": [ + { + "name": "background", + "data": nominal, + "modifiers": [ + {"name": "mu", "type": "normfactor", "data": None}, + {"name": "lumi", "type": "lumi", "data": None}, + { + "name": "mod_name", + "type": "shapefactor", + "data": None, + }, + { + "name": "uncorr_bkguncrt2", + "type": "shapesys", + "data": uncorr_data, + }, + { + "name": "corr_bkguncrt2", + "type": "histosys", + "data": { + 'hi_data': corrup_data, + 'lo_data': corrdn_data, + }, + }, + { + "name": "staterror", + "type": "staterror", + "data": stater_data, + }, + { + "name": "norm", + "type": "normsys", + "data": {'hi': normsys_up, 'lo': normsys_dn}, + }, + ], + } + ], + }, + ], + } + + def pipe(x): + spec = make_model( + x * jnp.asarray([60.0, 62.0]), + x * jnp.asarray([60.0, 62.0]), + x * jnp.asarray([60.0, 62.0]), + x * jnp.asarray([5.0, 5.0]), + x * jnp.asarray(0.95), + x * jnp.asarray(1.05), + x * jnp.asarray([5.0, 5.0]), + ) + model = pyhf.Model(spec, validate=False) + nominal = jnp.array(model.config.suggested_init()) + data = model.expected_data(nominal) + return model.logpdf(nominal, data)[0] + + jax.grad(pipe)(3.4) # should work without error