Skip to content

Commit

Permalink
docs: Document Channel Summary Mixin (#1972)
Browse files Browse the repository at this point in the history
* Make _ChannelSummaryMixin parameter related attributes properties.
* Add documentation for parameter related properties of _ModelConfig by documenting
  the _ChannelSummaryMixin.
* Add typehints to mixins.
* Update tests to mock the properties.
  • Loading branch information
kratsg committed Sep 4, 2022
1 parent b6e02ee commit 0fe3434
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 33 deletions.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Making Models from PDFs

~pdf.Model
~pdf._ModelConfig
~mixins._ChannelSummaryMixin
~workspace.Workspace
~patchset.PatchSet
~patchset.Patch
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ module = [
'pyhf.compat',
'pyhf.events',
'pyhf.utils',
'pyhf.mixins',
'pyhf.constraints',
'pyhf.pdf',
'pyhf.simplemodels',
Expand Down
76 changes: 58 additions & 18 deletions src/pyhf/mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from __future__ import annotations

import logging
from typing import Any, Sequence

from pyhf.typing import Channel

log = logging.getLogger(__name__)

Expand All @@ -13,39 +18,74 @@ class _ChannelSummaryMixin:
**channels: A list of channels to provide summary information about. Follows the `defs.json#/definitions/channel` schema.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Sequence[Channel]):
channels = kwargs.pop('channels')
super().__init__(*args, **kwargs)
self.channels = []
self.samples = []
self.modifiers = []
self._channels: list[str] = []
self._samples: list[str] = []
self._modifiers: list[tuple[str, str]] = []
# keep track of the width of each channel (how many bins)
self.channel_nbins = {}
self._channel_nbins: dict[str, int] = {}
# need to keep track in which order we added the constraints
# so that we can generate correctly-ordered data
for channel in channels:
self.channels.append(channel['name'])
self.channel_nbins[channel['name']] = len(channel['samples'][0]['data'])
self._channels.append(channel['name'])
self._channel_nbins[channel['name']] = len(channel['samples'][0]['data'])
for sample in channel['samples']:
self.samples.append(sample['name'])
self._samples.append(sample['name'])
for modifier_def in sample['modifiers']:
self.modifiers.append(
self._modifiers.append(
(
modifier_def['name'], # mod name
modifier_def['type'], # mod type
)
)

self.channels = sorted(list(set(self.channels)))
self.samples = sorted(list(set(self.samples)))
self.modifiers = sorted(list(set(self.modifiers)))
self.channel_nbins = {
channel: self.channel_nbins[channel] for channel in self.channels
self._channels = sorted(list(set(self._channels)))
self._samples = sorted(list(set(self._samples)))
self._modifiers = sorted(list(set(self._modifiers)))
self._channel_nbins = {
channel: self._channel_nbins[channel] for channel in self._channels
}

self.channel_slices = {}
self._channel_slices = {}
begin = 0
for c in self.channels:
end = begin + self.channel_nbins[c]
self.channel_slices[c] = slice(begin, end)
for c in self._channels:
end = begin + self._channel_nbins[c]
self._channel_slices[c] = slice(begin, end)
begin = end

@property
def channels(self) -> list[str]:
"""
Ordered list of channel names in the model.
"""
return self._channels

@property
def samples(self) -> list[str]:
"""
Ordered list of sample names in the model.
"""
return self._samples

@property
def modifiers(self) -> list[tuple[str, str]]:
"""
Ordered list of pairs of modifier name/type in the model.
"""
return self._modifiers

@property
def channel_nbins(self) -> dict[str, int]:
"""
Dictionary mapping channel name to number of bins in the channel.
"""
return self._channel_nbins

@property
def channel_slices(self) -> dict[str, slice]:
"""
Dictionary mapping channel name to the bin slices in the model.
"""
return self._channel_slices
73 changes: 62 additions & 11 deletions src/pyhf/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,58 @@ def __init__(self, spec, **config_kwargs):
f"Unsupported options were passed in: {list(config_kwargs.keys())}."
)

# prefixed with underscore are documented via @property
self._par_order = []
self._poi_name = None
self._poi_index = None
self._nmaindata = sum(self.channel_nbins.values())
self._auxdata = []

# these are not documented properties
self.par_map = {}
self.par_order = []
self.poi_name = None
self.poi_index = None
self.auxdata = []
self.auxdata_order = []
self.nmaindata = sum(self.channel_nbins.values())

@property
def par_order(self):
"""
Return an ordered list of paramset names in the model.
"""
return self._par_order

@property
def poi_name(self):
"""
Return the name of the POI parameter in the model.
"""
return self._poi_name

@property
def poi_index(self):
"""
Return the index of the POI parameter in the model.
"""
return self._poi_index

@property
def auxdata(self):
"""
Return the auxiliary data in the model.
"""
return self._auxdata

@property
def nmaindata(self):
"""
Return the length of data in the main model.
"""
return self._nmaindata

@property
def nauxdata(self):
"""
Return the length of data in the constraint model.
"""
return len(self._auxdata)

def set_parameters(self, _required_paramsets):
"""
Expand All @@ -240,9 +285,8 @@ def set_auxinfo(self, auxdata, auxdata_order):
"""
Sets a group of configuration data for the constraint terms.
"""
self.auxdata = auxdata
self._auxdata = auxdata
self.auxdata_order = auxdata_order
self.nauxdata = len(self.auxdata)

def suggested_init(self):
"""
Expand Down Expand Up @@ -400,8 +444,8 @@ def set_poi(self, name):
)
s = self.par_slice(name)
assert s.stop - s.start == 1
self.poi_name = name
self.poi_index = s.start
self._poi_name = name
self._poi_index = s.start

def _create_and_register_paramsets(self, required_paramsets):
next_index = 0
Expand All @@ -415,7 +459,7 @@ def _create_and_register_paramsets(self, required_paramsets):
sl = slice(next_index, next_index + paramset.n_parameters)
next_index = next_index + paramset.n_parameters

self.par_order.append(param_name)
self._par_order.append(param_name)
self.par_map[param_name] = {'slice': sl, 'paramset': paramset}


Expand Down Expand Up @@ -700,7 +744,7 @@ def __init__(
schema.validate(self.spec, self.schema, version=self.version)
# build up our representation of the specification
poi_name = config_kwargs.pop('poi_name', 'mu')
self.config = _ModelConfig(self.spec, **config_kwargs)
self._config = _ModelConfig(self.spec, **config_kwargs)

modifiers, _nominal_rates = _nominal_and_modifiers_from_spec(
modifier_set, self.config, self.spec, self.batch_size
Expand Down Expand Up @@ -733,6 +777,13 @@ def __init__(
sizes, ['main', 'aux'], self.batch_size
)

@property
def config(self):
"""
The :class:`_ModelConfig` instance for the model.
"""
return self._config

def expected_auxdata(self, pars):
"""
Compute the expected value of the auxiliary measurements.
Expand Down
11 changes: 8 additions & 3 deletions tests/test_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,20 @@ def test_get_workspace_data(workspace_factory, include_auxdata):
assert w.data(m, include_auxdata=include_auxdata)


def test_get_workspace_data_bad_model(workspace_factory, caplog):
def test_get_workspace_data_bad_model(workspace_factory, caplog, mocker):
w = workspace_factory()
m = w.model()
# the iconic fragrance of an expected failure
m.config.channels = [c.replace('channel', 'chanel') for c in m.config.channels]

mocker.patch(
"pyhf.mixins._ChannelSummaryMixin.channels",
new_callable=mocker.PropertyMock,
return_value=["fakechannel"],
)
with caplog.at_level(logging.INFO, 'pyhf.pdf'):
with pytest.raises(KeyError):
assert w.data(m)
assert 'Invalid channel' in caplog.text
assert "Invalid channel" in caplog.text


def test_json_serializable(workspace_factory):
Expand Down

0 comments on commit 0fe3434

Please sign in to comment.