diff --git a/src/ess/beer/io.py b/src/ess/beer/io.py index 20bb2e3d..53e3f0b8 100644 --- a/src/ess/beer/io.py +++ b/src/ess/beer/io.py @@ -10,7 +10,7 @@ Filename, ModulationPeriod, SampleRun, - TwoThetaMaskFunction, + TwoThetaLimits, WavelengthDefinitionChopperDelay, ) @@ -122,19 +122,27 @@ def load_beer_mcstas(f: str | Path | h5py.File) -> sc.DataGroup: ) +def _not_between(x, a, b): + return (x < a) | (b < x) + + def load_beer_mcstas_provider( - fname: Filename[SampleRun], two_theta_mask: TwoThetaMaskFunction + fname: Filename[SampleRun], two_theta_limits: TwoThetaLimits ) -> DetectorData[SampleRun]: da = load_beer_mcstas(fname) da = ( sc.DataGroup( { - k: v.assign_masks(two_theta=two_theta_mask(v.coords['two_theta'])) + k: v.assign_masks( + two_theta=_not_between(v.coords['two_theta'], *two_theta_limits) + ) for k, v in da.items() } ) if isinstance(da, sc.DataGroup) - else da.assign_masks(two_theta=two_theta_mask(da.coords['two_theta'])) + else da.assign_masks( + two_theta=_not_between(da.coords['two_theta'], *two_theta_limits) + ) ) return DetectorData[SampleRun](da) diff --git a/src/ess/beer/types.py b/src/ess/beer/types.py index 592cc00e..e729e10a 100644 --- a/src/ess/beer/types.py +++ b/src/ess/beer/types.py @@ -1,4 +1,3 @@ -from collections.abc import Callable from typing import NewType import sciline @@ -18,9 +17,7 @@ class StreakClusteredData(sciline.Scope[RunType, sc.DataArray], sc.DataArray): DetectorTofData = DetectorTofData -TwoThetaMaskFunction = NewType( - 'TwoThetaMaskFunction', Callable[[sc.Variable], sc.Variable] -) +TwoThetaLimits = NewType('TwoThetaLimits', tuple[sc.Variable, sc.Variable]) TofCoordTransformGraph = NewType("TofCoordTransformGraph", dict) diff --git a/src/ess/beer/workflow.py b/src/ess/beer/workflow.py index e9428cf8..f47fdcb0 100644 --- a/src/ess/beer/workflow.py +++ b/src/ess/beer/workflow.py @@ -11,14 +11,14 @@ PulseLength, RunType, SampleRun, - TwoThetaMaskFunction, + TwoThetaLimits, ) default_parameters = { PulseLength: sc.scalar(0.003, unit='s'), - TwoThetaMaskFunction: lambda two_theta: ( - (two_theta >= sc.scalar(105, unit='deg').to(unit='rad', dtype='float64')) - | (two_theta <= sc.scalar(75, unit='deg').to(unit='rad', dtype='float64')) + TwoThetaLimits: ( + sc.scalar(75, unit='deg').to(unit='rad', dtype='float64'), + sc.scalar(105, unit='deg').to(unit='rad', dtype='float64'), ), }