Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ess/reduce/time_of_flight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
PulseStrideOffset,
TimeOfFlightLookupTable,
TimeOfFlightLookupTableFilename,
ToaDetector,
TofDetector,
TofMonitor,
)
Expand All @@ -51,6 +52,7 @@
"TimeOfFlightLookupTable",
"TimeOfFlightLookupTableFilename",
"TimeResolution",
"ToaDetector",
"TofDetector",
"TofLookupTableWorkflow",
"TofMonitor",
Expand Down
107 changes: 99 additions & 8 deletions src/ess/reduce/time_of_flight/eto_to_tof.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
MonitorLtotal,
PulseStrideOffset,
TimeOfFlightLookupTable,
ToaDetector,
TofDetector,
TofMonitor,
)
Expand Down Expand Up @@ -196,12 +197,32 @@ def _guess_pulse_stride_offset(
return sorted(tofs, key=lambda x: sc.isnan(tofs[x]).sum())[0]


def _time_of_flight_data_events(
def _prepare_tof_interpolation_inputs(
da: sc.DataArray,
lookup: sc.DataArray,
ltotal: sc.Variable,
pulse_stride_offset: int,
) -> sc.DataArray:
pulse_stride_offset: int | None,
) -> dict:
"""
Prepare the inputs required for the time-of-flight interpolation.
This function is used when computing the time-of-flight for event data, and for
computing the time-of-arrival for event data (as they both require guessing the
pulse_stride_offset if not provided).

Parameters
----------
da:
Data array with event data.
lookup:
Lookup table giving time-of-flight as a function of distance and time of
arrival.
ltotal:
Total length of the flight path from the source to the detector.
pulse_stride_offset:
When pulse-skipping, the offset of the first pulse in the stride. This is
typically zero but can be a small integer < pulse_stride.
If None, a guess is made.
"""
etos = da.bins.coords["event_time_offset"].to(dtype=float, copy=False)
eto_unit = elem_unit(etos)

Expand Down Expand Up @@ -259,12 +280,34 @@ def _time_of_flight_data_events(
pulse_index += pulse_stride_offset
pulse_index %= pulse_stride

# Compute time-of-flight for all neutrons using the interpolator
tofs = interp(
return {
"eto": etos,
"pulse_index": pulse_index,
"pulse_period": pulse_period,
"interp": interp,
"ltotal": ltotal,
}


def _time_of_flight_data_events(
da: sc.DataArray,
lookup: sc.DataArray,
ltotal: sc.Variable,
pulse_stride_offset: int | None,
) -> sc.DataArray:
inputs = _prepare_tof_interpolation_inputs(
da=da,
lookup=lookup,
ltotal=ltotal,
event_time_offset=etos,
pulse_index=pulse_index,
pulse_period=pulse_period,
pulse_stride_offset=pulse_stride_offset,
)

# Compute time-of-flight for all neutrons using the interpolator
tofs = inputs["interp"](
ltotal=inputs["ltotal"],
event_time_offset=inputs["eto"],
pulse_index=inputs["pulse_index"],
pulse_period=inputs["pulse_period"],
)

parts = da.bins.constituents
Expand Down Expand Up @@ -416,6 +459,53 @@ def monitor_time_of_flight_data(
)


def detector_time_of_arrival_data(
detector_data: RawDetector[RunType],
lookup: TimeOfFlightLookupTable,
ltotal: DetectorLtotal[RunType],
pulse_stride_offset: PulseStrideOffset,
) -> ToaDetector[RunType]:
"""
Convert the time-of-flight data to time-of-arrival data using a lookup table.
The output data will have a time-of-arrival coordinate.
The time-of-arrival is the time since the neutron was emitted from the source.
It is basically equal to event_time_offset + pulse_index * pulse_period.

Parameters
----------
da:
Raw detector data loaded from a NeXus file, e.g., NXdetector containing
NXevent_data.
lookup:
Lookup table giving time-of-flight as a function of distance and time of
arrival.
ltotal:
Total length of the flight path from the source to the detector.
pulse_stride_offset:
When pulse-skipping, the offset of the first pulse in the stride. This is
typically zero but can be a small integer < pulse_stride.
"""
if detector_data.bins is None:
raise NotImplementedError(
"Computing time-of-arrival in histogram mode is not implemented yet."
)
inputs = _prepare_tof_interpolation_inputs(
da=detector_data,
lookup=lookup,
ltotal=ltotal,
pulse_stride_offset=pulse_stride_offset,
)
parts = detector_data.bins.constituents
parts["data"] = inputs["eto"]
# The pulse index is None if pulse_stride == 1 (i.e., no pulse skipping)
if inputs["pulse_index"] is not None:
parts["data"] = parts["data"] + inputs["pulse_index"] * inputs["pulse_period"]
Comment on lines +501 to +502
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would it be None? What is the meaning of that case? No pulse-skipping?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when there is no pulse-skipping, the pulse_index is None. It is this way to avoid allocating a large array of zeros (size Nevents) if it is not needed.

I added a comment to clarify.

result = detector_data.bins.assign_coords(
toa=sc.bins(**parts, validate_indices=False)
)
return result


def providers() -> tuple[Callable]:
"""
Providers of the time-of-flight workflow.
Expand All @@ -425,4 +515,5 @@ def providers() -> tuple[Callable]:
monitor_time_of_flight_data,
detector_ltotal_from_straight_line_approximation,
monitor_ltotal_from_straight_line_approximation,
detector_time_of_arrival_data,
)
12 changes: 12 additions & 0 deletions src/ess/reduce/time_of_flight/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,17 @@ class TofDetector(sl.Scope[RunType, sc.DataArray], sc.DataArray):
"""Detector data with time-of-flight coordinate."""


class ToaDetector(sl.Scope[RunType, sc.DataArray], sc.DataArray):
"""Detector data with time-of-arrival coordinate.

When the pulse stride is 1 (i.e., no pulse skipping), the time-of-arrival is the
same as the event_time_offset. When pulse skipping is used, the time-of-arrival is
the event_time_offset + pulse_offset * pulse_period.
This means that the time-of-arrival is basically the event_time_offset wrapped
over the frame period instead of the pulse period
(where frame_period = pulse_stride * pulse_period).
"""


class TofMonitor(sl.Scope[RunType, MonitorType, sc.DataArray], sc.DataArray):
"""Monitor data with time-of-flight coordinate."""
67 changes: 66 additions & 1 deletion tests/time_of_flight/unwrap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

from ess.reduce import time_of_flight
from ess.reduce.nexus.types import AnyRun, RawDetector, SampleRun
from ess.reduce.time_of_flight import GenericTofWorkflow, TofLookupTableWorkflow, fakes
from ess.reduce.time_of_flight import (
GenericTofWorkflow,
PulsePeriod,
TofLookupTableWorkflow,
fakes,
)

sl = pytest.importorskip("sciline")

Expand Down Expand Up @@ -441,3 +446,63 @@ def test_unwrap_int(dtype, lut_workflow_psc_choppers) -> None:
_validate_result_events(
tofs=tofs, ref=ref, percentile=100, diff_threshold=0.02, rtol=0.05
)


def test_compute_toa():
distance = sc.scalar(80.0, unit="m")
choppers = fakes.psc_choppers()

lut_wf = make_lut_workflow(
choppers=choppers, neutrons=500_000, seed=1234, pulse_stride=1
)

pl, _ = _make_workflow_event_mode(
distance=distance,
choppers=choppers,
lut_workflow=lut_wf,
seed=2,
pulse_stride_offset=0,
error_threshold=0.1,
)

toas = pl.compute(time_of_flight.ToaDetector[SampleRun])

assert "toa" in toas.bins.coords
raw = pl.compute(RawDetector[SampleRun])
assert sc.allclose(toas.bins.coords["toa"], raw.bins.coords["event_time_offset"])


def test_compute_toa_pulse_skipping():
distance = sc.scalar(100.0, unit="m")
choppers = fakes.pulse_skipping_choppers()

lut_wf = make_lut_workflow(
choppers=choppers, neutrons=500_000, seed=1234, pulse_stride=2
)

pl, _ = _make_workflow_event_mode(
distance=distance,
choppers=choppers,
lut_workflow=lut_wf,
seed=2,
pulse_stride_offset=1,
error_threshold=0.1,
)

raw = pl.compute(RawDetector[SampleRun])

toas = pl.compute(time_of_flight.ToaDetector[SampleRun])

assert "toa" in toas.bins.coords
pulse_period = lut_wf.compute(PulsePeriod)
hist = toas.bins.concat().hist(
toa=sc.array(
dims=["toa"],
values=[0, pulse_period.value, pulse_period.value * 2],
unit=pulse_period.unit,
).to(unit=toas.bins.coords["toa"].unit)
)
# There should be counts in both bins
n = raw.sum().value
assert hist.data[0].value > n / 5
assert hist.data[1].value > n / 5
Loading