Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Pruning with logical AND instead of OR #2391

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 77 additions & 12 deletions src/pyhf/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,16 +583,39 @@ def _prune_and_rename(
),
)
for modifier in sample['modifiers']
if modifier['name'] not in prune_modifiers
and modifier['type'] not in prune_modifier_types
# we want to remove modifiers only if channel is not in list of channels to keep,
# we want to remove modifiers only if sample is not in list of samples to keep
if (
prune_channels
and channel['name'] not in prune_channels
)
or (
prune_samples
and sample['name'] not in prune_samples
)
or (
modifier['name'] not in prune_modifiers
and modifier['type'] not in prune_modifier_types
)
# need to keep the modifier in case it is used in another measurement
or prune_measurements
],
}
for sample in channel['samples']
if sample['name'] not in prune_samples
# we want to remove samples only if channel is not in list of channels to keep,
# we want to remove samples only if no modifiers are to be pruned
if (prune_channels and channel['name'] not in prune_channels)
or sample['name'] not in prune_samples
or prune_modifiers
or prune_modifier_types
],
}
for channel in self['channels']
# we want to remove channels only if no samples or modifiers are to be pruned
if channel['name'] not in prune_channels
or prune_samples
or prune_modifiers
or prune_modifier_types
],
'measurements': [
{
Expand All @@ -607,24 +630,38 @@ def _prune_and_rename(
parameter['name'], parameter['name']
),
)
for parameter in measurement['config']['parameters']
if parameter['name'] not in prune_modifiers
for parameter in measurement['config'][
'parameters'
] # we only want to remove this parameter if measurement is in prune_measurements or if prune_measurements is empty
# we want to remove parameters from a measurement only
# if measurement is not in keep_measurements
if (
prune_measurements
and measurement['name'] not in prune_measurements
)
or parameter['name'] not in prune_modifiers
],
'poi': rename_modifiers.get(
measurement['config']['poi'], measurement['config']['poi']
),
},
}
for measurement in self['measurements']
if measurement['name'] not in prune_measurements
# we want to remove measurements only if no parameters are to be pruned
if measurement['name'] not in prune_measurements or prune_modifiers
],
'observations': [
dict(
copy.deepcopy(observation),
name=rename_channels.get(observation['name'], observation['name']),
)
for observation in self['observations']
# we want to remove this channels only
# if no samples or modifiers are to be pruned
if observation['name'] not in prune_channels
or prune_samples
or prune_modifiers
or prune_modifier_types
],
'version': self['version'],
}
Expand All @@ -637,6 +674,7 @@ def prune(
samples=None,
channels=None,
measurements=None,
mode="logical_or",
):
"""
Return a new, pruned workspace specification. This will not modify the original workspace.
Expand All @@ -649,6 +687,7 @@ def prune(
samples: A :obj:`list` of samples to prune.
channels: A :obj:`list` of channels to prune.
measurements: A :obj:`list` of measurements to prune.
mode (:obj: string): `logical_or` or `logical_and` to chain pruning with a logical OR or a logical AND, respectively. Default: `logical_or`.

Returns:
~pyhf.workspace.Workspace: A new workspace object with the specified components removed
Expand All @@ -657,19 +696,45 @@ def prune(
~pyhf.exceptions.InvalidWorkspaceOperation: An item name to prune does not exist in the workspace.

"""

if mode not in ["logical_and", "logical_or"]:
raise ValueError(
"Pruning mode must be either `logical_and` or `logical_or`."
)

# avoid mutable defaults
modifiers = [] if modifiers is None else modifiers
modifier_types = [] if modifier_types is None else modifier_types
samples = [] if samples is None else samples
channels = [] if channels is None else channels
measurements = [] if measurements is None else measurements

return self._prune_and_rename(
prune_modifiers=modifiers,
prune_modifier_types=modifier_types,
prune_samples=samples,
prune_channels=channels,
prune_measurements=measurements,
if mode == "logical_and":
if samples != [] and measurements != []:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether these sanity checks should be made here or in _prune_and_rename. Or if they are needed at all for that matter.

raise ValueError(
"Pruning of measurements and samples cannot be run with mode `logical_and`."
)
if channels != [] and measurements != []:
raise ValueError(
"Pruning of measurements and channels cannot be run with mode `logical_and`."
)
if modifier_types != [] and measurements != []:
raise ValueError(
"Pruning of measurements and modifier_types cannot be run with mode `logical_and`."
)
return self._prune_and_rename(
prune_modifiers=modifiers,
prune_modifier_types=modifier_types,
prune_samples=samples,
prune_channels=channels,
prune_measurements=measurements,
)
return (
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behaviour of _prune_and_rename was changed to logical_and, so for logical_or we can just chain separate calls for individual keywords.

self._prune_and_rename(prune_modifiers=modifiers)
._prune_and_rename(prune_modifier_types=modifier_types)
._prune_and_rename(prune_samples=samples)
._prune_and_rename(prune_channels=channels)
._prune_and_rename(prune_measurements=measurements)
)

def rename(self, modifiers=None, samples=None, channels=None, measurements=None):
Expand Down