Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: Document Channel Summary Mixin #1972

Merged
merged 9 commits into from
Sep 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Making Models from PDFs

~pdf.Model
~pdf._ModelConfig
~mixins._ChannelSummaryMixin
~workspace.Workspace
~patchset.PatchSet
~patchset.Patch
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ module = [
'pyhf.compat',
'pyhf.events',
'pyhf.utils',
'pyhf.mixins',
'pyhf.constraints',
'pyhf.pdf',
'pyhf.simplemodels',
Expand Down
76 changes: 58 additions & 18 deletions src/pyhf/mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from __future__ import annotations

import logging
from typing import Any, Sequence

from pyhf.typing import Channel

log = logging.getLogger(__name__)

Expand All @@ -13,39 +18,74 @@ class _ChannelSummaryMixin:
**channels: A list of channels to provide summary information about. Follows the `defs.json#/definitions/channel` schema.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Sequence[Channel]):
channels = kwargs.pop('channels')
super().__init__(*args, **kwargs)
self.channels = []
self.samples = []
self.modifiers = []
self._channels: list[str] = []
self._samples: list[str] = []
self._modifiers: list[tuple[str, str]] = []
# keep track of the width of each channel (how many bins)
self.channel_nbins = {}
self._channel_nbins: dict[str, int] = {}
# need to keep track in which order we added the constraints
# so that we can generate correctly-ordered data
for channel in channels:
self.channels.append(channel['name'])
self.channel_nbins[channel['name']] = len(channel['samples'][0]['data'])
self._channels.append(channel['name'])
self._channel_nbins[channel['name']] = len(channel['samples'][0]['data'])
for sample in channel['samples']:
self.samples.append(sample['name'])
self._samples.append(sample['name'])
for modifier_def in sample['modifiers']:
self.modifiers.append(
self._modifiers.append(
(
modifier_def['name'], # mod name
modifier_def['type'], # mod type
)
)

self.channels = sorted(list(set(self.channels)))
self.samples = sorted(list(set(self.samples)))
self.modifiers = sorted(list(set(self.modifiers)))
self.channel_nbins = {
channel: self.channel_nbins[channel] for channel in self.channels
self._channels = sorted(list(set(self._channels)))
self._samples = sorted(list(set(self._samples)))
self._modifiers = sorted(list(set(self._modifiers)))
self._channel_nbins = {
channel: self._channel_nbins[channel] for channel in self._channels
}

self.channel_slices = {}
self._channel_slices = {}
begin = 0
for c in self.channels:
end = begin + self.channel_nbins[c]
self.channel_slices[c] = slice(begin, end)
for c in self._channels:
end = begin + self._channel_nbins[c]
self._channel_slices[c] = slice(begin, end)
begin = end

@property
def channels(self) -> list[str]:
"""
Ordered list of channel names in the model.
"""
return self._channels

@property
def samples(self) -> list[str]:
"""
Ordered list of sample names in the model.
"""
return self._samples

@property
def modifiers(self) -> list[tuple[str, str]]:
"""
Ordered list of pairs of modifier name/type in the model.
"""
return self._modifiers

@property
def channel_nbins(self) -> dict[str, int]:
"""
Dictionary mapping channel name to number of bins in the channel.
"""
return self._channel_nbins

@property
def channel_slices(self) -> dict[str, slice]:
"""
Dictionary mapping channel name to the bin slices in the model.
"""
return self._channel_slices
73 changes: 62 additions & 11 deletions src/pyhf/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,58 @@ def __init__(self, spec, **config_kwargs):
f"Unsupported options were passed in: {list(config_kwargs.keys())}."
)

# prefixed with underscore are documented via @property
self._par_order = []
self._poi_name = None
self._poi_index = None
self._nmaindata = sum(self.channel_nbins.values())
self._auxdata = []

# these are not documented properties
self.par_map = {}
self.par_order = []
self.poi_name = None
self.poi_index = None
self.auxdata = []
self.auxdata_order = []
self.nmaindata = sum(self.channel_nbins.values())

@property
def par_order(self):
"""
Return an ordered list of paramset names in the model.
"""
return self._par_order

@property
def poi_name(self):
"""
Return the name of the POI parameter in the model.
"""
return self._poi_name

@property
def poi_index(self):
"""
Return the index of the POI parameter in the model.
"""
return self._poi_index

@property
def auxdata(self):
"""
Return the auxiliary data in the model.
"""
return self._auxdata

@property
def nmaindata(self):
"""
Return the length of data in the main model.
"""
return self._nmaindata

@property
def nauxdata(self):
"""
Return the length of data in the constraint model.
"""
return len(self._auxdata)

def set_parameters(self, _required_paramsets):
"""
Expand All @@ -240,9 +285,8 @@ def set_auxinfo(self, auxdata, auxdata_order):
"""
Sets a group of configuration data for the constraint terms.
"""
self.auxdata = auxdata
self._auxdata = auxdata
self.auxdata_order = auxdata_order
self.nauxdata = len(self.auxdata)

def suggested_init(self):
"""
Expand Down Expand Up @@ -400,8 +444,8 @@ def set_poi(self, name):
)
s = self.par_slice(name)
assert s.stop - s.start == 1
self.poi_name = name
self.poi_index = s.start
self._poi_name = name
self._poi_index = s.start

def _create_and_register_paramsets(self, required_paramsets):
next_index = 0
Expand All @@ -415,7 +459,7 @@ def _create_and_register_paramsets(self, required_paramsets):
sl = slice(next_index, next_index + paramset.n_parameters)
next_index = next_index + paramset.n_parameters

self.par_order.append(param_name)
self._par_order.append(param_name)
self.par_map[param_name] = {'slice': sl, 'paramset': paramset}


Expand Down Expand Up @@ -700,7 +744,7 @@ def __init__(
schema.validate(self.spec, self.schema, version=self.version)
# build up our representation of the specification
poi_name = config_kwargs.pop('poi_name', 'mu')
self.config = _ModelConfig(self.spec, **config_kwargs)
self._config = _ModelConfig(self.spec, **config_kwargs)

modifiers, _nominal_rates = _nominal_and_modifiers_from_spec(
modifier_set, self.config, self.spec, self.batch_size
Expand Down Expand Up @@ -733,6 +777,13 @@ def __init__(
sizes, ['main', 'aux'], self.batch_size
)

@property
def config(self):
"""
The :class:`_ModelConfig` instance for the model.
"""
return self._config

def expected_auxdata(self, pars):
"""
Compute the expected value of the auxiliary measurements.
Expand Down
11 changes: 8 additions & 3 deletions tests/test_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,20 @@ def test_get_workspace_data(workspace_factory, include_auxdata):
assert w.data(m, include_auxdata=include_auxdata)


def test_get_workspace_data_bad_model(workspace_factory, caplog):
def test_get_workspace_data_bad_model(workspace_factory, caplog, mocker):
w = workspace_factory()
m = w.model()
# the iconic fragrance of an expected failure
m.config.channels = [c.replace('channel', 'chanel') for c in m.config.channels]

mocker.patch(
"pyhf.mixins._ChannelSummaryMixin.channels",
new_callable=mocker.PropertyMock,
return_value=["fakechannel"],
)
with caplog.at_level(logging.INFO, 'pyhf.pdf'):
with pytest.raises(KeyError):
assert w.data(m)
assert 'Invalid channel' in caplog.text
assert "Invalid channel" in caplog.text


def test_json_serializable(workspace_factory):
Expand Down