diff --git a/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb b/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb new file mode 100644 index 000000000..c5b37095e --- /dev/null +++ b/packages/essreduce/docs/user-guide/unwrap/analytical-unwrap.ipynb @@ -0,0 +1,1083 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "# Chopper cascade acceptance for unwrapping and wavelength frame multiplication\n", + "\n", + "In this notebook, we show how to use a new fast workflow from `essreduce`'s `unwrap` module to compute neutron wavelengths based on a chopper acceptance diagram for a pulse of neutrons travelling through two WFM beamlines:\n", + "the DREAM and ODIN instruments." + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## The DREAM chopper cascade\n", + "\n", + "The case of DREAM is interesting because the pulse-shaping choppers can be used in a number of different modes,\n", + "and the number of cutouts the choppers have typically does not equal the number of frames observed at the detectors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "import plopp as pp\n", + "import scipp as sc\n", + "import scippnexus as snx\n", + "from scippneutron.chopper import DiskChopper\n", + "from ess.reduce.nexus.types import AnyRun, RawDetector, SampleRun, NeXusDetectorName\n", + "from ess.reduce.unwrap import *" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "### Creating the beamline choppers\n", + "\n", + "We begin by defining the chopper settings for our beamline.\n", + "In principle, the chopper setting could simply be read from a NeXus file.\n", + "\n", + "The DREAM instrument has\n", + "\n", + "- 2 pulse-shaping choppers (PSC)\n", + "- 1 overlap chopper (OC)\n", + "- 1 band-control chopper (BCC)\n", + "- 1 T0 chopper" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "psc1 = DiskChopper(\n", + " frequency=sc.scalar(14.0, unit=\"Hz\"),\n", + " beam_position=sc.scalar(0.0, unit=\"deg\"),\n", + " phase=sc.scalar(286 - 180, unit=\"deg\"),\n", + " axle_position=sc.vector(value=[0, 0, -70.405], unit=\"m\"),\n", + " slit_begin=sc.array(\n", + " dims=[\"cutout\"],\n", + " values=[-1.23, 70.49, 84.765, 113.565, 170.29, 271.635, 286.035, 301.17],\n", + " unit=\"deg\",\n", + " ),\n", + " slit_end=sc.array(\n", + " dims=[\"cutout\"],\n", + " values=[1.23, 73.51, 88.035, 116.835, 175.31, 275.565, 289.965, 303.63],\n", + " unit=\"deg\",\n", + " ),\n", + " slit_height=sc.scalar(10.0, unit=\"cm\"),\n", + " radius=sc.scalar(30.0, unit=\"cm\"),\n", + ")\n", + "\n", + "psc2 = DiskChopper(\n", + " frequency=sc.scalar(-14.0, unit=\"Hz\"),\n", + " beam_position=sc.scalar(0.0, unit=\"deg\"),\n", + " phase=sc.scalar(-236, unit=\"deg\"),\n", + " axle_position=sc.vector(value=[0, 0, -70.395], unit=\"m\"),\n", + " slit_begin=sc.array(\n", + " dims=[\"cutout\"],\n", + " values=[-1.23, 27.0, 55.8, 142.385, 156.765, 214.115, 257.23, 315.49],\n", + " unit=\"deg\",\n", + " ),\n", + " slit_end=sc.array(\n", + " dims=[\"cutout\"],\n", + " values=[1.23, 30.6, 59.4, 145.615, 160.035, 217.885, 261.17, 318.11],\n", + " unit=\"deg\",\n", + " ),\n", + " slit_height=sc.scalar(10.0, unit=\"cm\"),\n", + " radius=sc.scalar(30.0, unit=\"cm\"),\n", + ")\n", + "\n", + "oc = DiskChopper(\n", + " frequency=sc.scalar(14.0, unit=\"Hz\"),\n", + " beam_position=sc.scalar(0.0, unit=\"deg\"),\n", + " phase=sc.scalar(297 - 180 - 90, unit=\"deg\"),\n", + " axle_position=sc.vector(value=[0, 0, -70.376], unit=\"m\"),\n", + " slit_begin=sc.array(dims=[\"cutout\"], values=[-27.6 * 0.5], unit=\"deg\"),\n", + " slit_end=sc.array(dims=[\"cutout\"], values=[27.6 * 0.5], unit=\"deg\"),\n", + " slit_height=sc.scalar(10.0, unit=\"cm\"),\n", + " radius=sc.scalar(30.0, unit=\"cm\"),\n", + ")\n", + "\n", + "bcc = DiskChopper(\n", + " frequency=sc.scalar(112.0, unit=\"Hz\"),\n", + " beam_position=sc.scalar(0.0, unit=\"deg\"),\n", + " phase=sc.scalar(240 - 180, unit=\"deg\"),\n", + " axle_position=sc.vector(value=[0, 0, -66.77], unit=\"m\"),\n", + " slit_begin=sc.array(dims=[\"cutout\"], values=[-36.875, 143.125], unit=\"deg\"),\n", + " slit_end=sc.array(dims=[\"cutout\"], values=[36.875, 216.875], unit=\"deg\"),\n", + " slit_height=sc.scalar(10.0, unit=\"cm\"),\n", + " radius=sc.scalar(30.0, unit=\"cm\"),\n", + ")\n", + "\n", + "t0 = DiskChopper(\n", + " frequency=sc.scalar(28.0, unit=\"Hz\"),\n", + " beam_position=sc.scalar(0.0, unit=\"deg\"),\n", + " phase=sc.scalar(280 - 180, unit=\"deg\"),\n", + " axle_position=sc.vector(value=[0, 0, -63.5], unit=\"m\"),\n", + " slit_begin=sc.array(dims=[\"cutout\"], values=[-314.9 * 0.5], unit=\"deg\"),\n", + " slit_end=sc.array(dims=[\"cutout\"], values=[314.9 * 0.5], unit=\"deg\"),\n", + " slit_height=sc.scalar(10.0, unit=\"cm\"),\n", + " radius=sc.scalar(30.0, unit=\"cm\"),\n", + ")\n", + "\n", + "disk_choppers = {\"psc1\": psc1, \"psc2\": psc2, \"oc\": oc, \"bcc\": bcc, \"t0\": t0}" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "It is possible to visualize the properties of the choppers by inspecting their `repr`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "psc2" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "Define the source position which is required to compute the distance that neutrons travelled." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "source_position = sc.vector([0, 0, -76.55], unit=\"m\")" + ] + }, + { + "cell_type": "markdown", + "id": "9", + "metadata": {}, + "source": [ + "### Adding a detector" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [ + "Ltotal = sc.scalar(76.55 + 1.125, unit=\"m\")" + ] + }, + { + "cell_type": "markdown", + "id": "11", + "metadata": {}, + "source": [ + "### Creating some neutron events\n", + "\n", + "We create a semi-realistic set of neutron events based on the ESS pulse." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12", + "metadata": {}, + "outputs": [], + "source": [ + "from ess.reduce.unwrap.fakes import FakeBeamline\n", + "\n", + "ess_beamline = FakeBeamline(\n", + " choppers=disk_choppers,\n", + " source_position=source_position,\n", + " monitors={\"detector\": Ltotal},\n", + " run_length=sc.scalar(1 / 14, unit=\"s\") * 4,\n", + " events_per_pulse=200_000,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "13", + "metadata": {}, + "source": [ + "The initial birth times and wavelengths of the generated neutrons can be visualized (for a single pulse):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14", + "metadata": {}, + "outputs": [], + "source": [ + "one_pulse = ess_beamline.source.data[\"pulse\", 0]\n", + "one_pulse.hist(birth_time=300).plot() + one_pulse.hist(wavelength=300).plot()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "ess_beamline.model_result.plot()" + ] + }, + { + "cell_type": "markdown", + "id": "16", + "metadata": {}, + "source": [ + "From this fake beamline, we extract the raw neutron signal at our detector:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "raw_data = ess_beamline.get_monitor(\"detector\")[0]\n", + "\n", + "# Visualize\n", + "raw_data.hist(event_time_offset=300).squeeze().plot()" + ] + }, + { + "cell_type": "markdown", + "id": "18", + "metadata": {}, + "source": [ + "The total number of neutrons in our sample data that make it through to the detector is:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "raw_data.sum().value" + ] + }, + { + "cell_type": "markdown", + "id": "20", + "metadata": {}, + "source": [ + "### Computing neutron wavelengths\n", + "\n", + "Next, we use a workflow that provides an estimate of the neutron wavelength as a function of neutron time-of-arrival.\n", + "\n", + "#### Setting up the workflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "wf = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[])\n", + "\n", + "wf[RawDetector[SampleRun]] = raw_data\n", + "wf[DetectorLtotal[SampleRun]] = Ltotal\n", + "wf[NeXusDetectorName] = 'dream_detector'\n", + "wf[LookupTableRelativeErrorThreshold] = {'dream_detector': float(\"inf\")}\n", + "\n", + "wf.visualize(WavelengthDetector[SampleRun])" + ] + }, + { + "cell_type": "markdown", + "id": "22", + "metadata": {}, + "source": [ + "By default, the workflow tries to load a `LookupTable` from a file.\n", + "\n", + "In this notebook, instead of using such a pre-made file,\n", + "we will build our own lookup table from the chopper information and apply it to the workflow." + ] + }, + { + "cell_type": "markdown", + "id": "23", + "metadata": {}, + "source": [ + "#### Building the wavelength lookup table\n", + "\n", + "We use [`scippneutron.tof.chopper_cascade`](https://scipp.github.io/scippneutron/user-guide/chopper/chopper-cascade.html) module to propagate a pulse of neutrons through the chopper system to the detectors,\n", + "and predict the most likely neutron wavelength for a given time-of-arrival and distance from source.\n", + "\n", + "From this,\n", + "we build a lookup table on which bilinear interpolation is used to compute a wavelength for every neutron event." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], + "source": [ + "lut_wf = FastLookupTableWorkflow()\n", + "lut_wf[DiskChoppers[AnyRun]] = disk_choppers\n", + "lut_wf[SourcePosition] = source_position\n", + "lut_wf[LtotalRange] = (\n", + " sc.scalar(25.0, unit=\"m\"),\n", + " sc.scalar(80.0, unit=\"m\"),\n", + ")\n", + "lut_wf.visualize(LookupTable)" + ] + }, + { + "cell_type": "markdown", + "id": "25", + "metadata": {}, + "source": [ + "#### Inspecting the lookup table\n", + "\n", + "The workflow first runs a calculation propagating a pulse of neutrons (represented by a polygon in time and wavelength space),\n", + "through a chopper cascade defined by the chopper parameters above.\n", + "\n", + "This can be used to create a figure displaying the neutron wavelengths,\n", + "as a function of arrival time at the detector.\n", + "\n", + "This is the basis for creating our lookup table." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26", + "metadata": {}, + "outputs": [], + "source": [ + "dist = sc.scalar(60.0, unit='m')\n", + "\n", + "frames = lut_wf.compute(ChopperFrameSequence)\n", + "at_detector = frames.propagate_to(dist)\n", + "fig, ax = at_detector.draw()" + ] + }, + { + "cell_type": "markdown", + "id": "27", + "metadata": {}, + "source": [ + "The source pulse is defined as spanning 0-5 ms in time, and 0-15 Å in wavelength,\n", + "and is represented by the blue rectangle on the left hand side of the diagram.\n", + "\n", + "As the pulse propagates through the system,\n", + "it stretches (slow neutrons take longer to reach the same distance) and gets chopped by the chopper openings,\n", + "creating polygons from the rectangular pulse.\n", + "\n", + "Finally, at the detector distance of 60 m, we are left with two (pink) very thin polygons,\n", + "representing the two packets of neutrons that are allowed through the instrument.\n", + "\n", + "The idea is to approximate these thin polygons as a single line,\n", + "effectively giving us a function relating neutron wavelength as a function of arrival time.\n", + "This is precisely how the loop table is built,\n", + "and if we overlay the wavelength values given by the table at a distance of 60 m,\n", + "we see that the black lines pass right in the middle of the polygons." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28", + "metadata": {}, + "outputs": [], + "source": [ + "table = lut_wf.compute(LookupTable)\n", + "\n", + "# Overlay LUT prediction on the polygons figure\n", + "da = table.array[\"distance\", 352]\n", + "ax.plot(\n", + " da.coords['event_time_offset'].values / 1000,\n", + " da.values,\n", + " color=\"k\",\n", + " ls=\"-\",\n", + " marker=None,\n", + ")\n", + "fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29", + "metadata": {}, + "outputs": [], + "source": [ + "ax.set(xlim=[39.5, 41.5], ylim=[2.6, 2.72])\n", + "fig" + ] + }, + { + "cell_type": "markdown", + "id": "30", + "metadata": {}, + "source": [ + "The full table covers a range of distances, and looks like" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31", + "metadata": {}, + "outputs": [], + "source": [ + "table.plot()" + ] + }, + { + "cell_type": "markdown", + "id": "32", + "metadata": {}, + "source": [ + "#### Computing a wavelength coordinate\n", + "\n", + "We will now update our workflow, and use it to obtain our event data with a wavelength coordinate:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33", + "metadata": {}, + "outputs": [], + "source": [ + "# Set the computed lookup table onto the original workflow\n", + "wf[LookupTable] = table\n", + "\n", + "# Compute wavelength of neutron events\n", + "wavs = wf.compute(WavelengthDetector[SampleRun])\n", + "edges = sc.linspace(\"wavelength\", 0.8, 4.6, 201, unit=\"angstrom\")\n", + "\n", + "histogrammed = wavs.hist(wavelength=edges).squeeze()\n", + "histogrammed.plot()" + ] + }, + { + "cell_type": "markdown", + "id": "34", + "metadata": {}, + "source": [ + "#### Comparing to the ground truth\n", + "\n", + "As a consistency check, because we actually know the wavelengths of the neutrons we created,\n", + "we can compare the true neutron wavelengths to those we computed above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35", + "metadata": {}, + "outputs": [], + "source": [ + "ground_truth = ess_beamline.model_result[\"detector\"].data.flatten(to=\"event\")\n", + "ground_truth = ground_truth[~ground_truth.masks[\"blocked_by_others\"]]\n", + "\n", + "pp.plot(\n", + " {\n", + " \"wfm\": histogrammed,\n", + " \"ground_truth\": ground_truth.hist(wavelength=edges),\n", + " },\n", + " color={\"ground_truth\": \"k\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "36", + "metadata": {}, + "source": [ + "### Multiple detector pixels\n", + "\n", + "It is also possible to compute the neutron wavelength for multiple detector pixels at once,\n", + "where every pixel has different frame bounds\n", + "(because every pixel is at a different distance from the source).\n", + "\n", + "In our setup, we simply propagate the same neutrons to multiple detector pixels,\n", + "as if they were not absorbed by the first pixel they meet." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37", + "metadata": {}, + "outputs": [], + "source": [ + "Ltotal = sc.array(dims=[\"detector_number\"], values=[77.675, 76.0], unit=\"m\")\n", + "monitors = {f\"detector{i}\": ltot for i, ltot in enumerate(Ltotal)}\n", + "\n", + "ess_beamline = FakeBeamline(\n", + " choppers=disk_choppers,\n", + " source_position=source_position,\n", + " monitors=monitors,\n", + " run_length=sc.scalar(1 / 14, unit=\"s\") * 4,\n", + " events_per_pulse=200_000,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "38", + "metadata": {}, + "source": [ + "Our raw data has now a `detector_number` dimension of length 2.\n", + "\n", + "We can plot the neutron `event_time_offset` for the two detector pixels and see that the offsets are shifted to the left for the pixel that is closest to the source." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39", + "metadata": {}, + "outputs": [], + "source": [ + "raw_data = sc.concat(\n", + " [ess_beamline.get_monitor(key)[0].squeeze() for key in monitors.keys()],\n", + " dim=\"detector_number\",\n", + ")\n", + "\n", + "# Visualize\n", + "pp.plot(sc.collapse(raw_data.hist(event_time_offset=300), keep=\"event_time_offset\"))" + ] + }, + { + "cell_type": "markdown", + "id": "40", + "metadata": {}, + "source": [ + "Computing wavelength is done in the same way as above.\n", + "We need to remember to update our workflow:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41", + "metadata": {}, + "outputs": [], + "source": [ + "# Update workflow\n", + "wf[RawDetector[SampleRun]] = raw_data\n", + "wf[DetectorLtotal[SampleRun]] = Ltotal\n", + "\n", + "# Compute tofs and wavelengths\n", + "wav_wfm = wf.compute(WavelengthDetector[SampleRun])\n", + "\n", + "# Compare in plot\n", + "ground_truth = []\n", + "for det in ess_beamline.monitors:\n", + " data = ess_beamline.model_result[det.name].data.flatten(to=\"event\")\n", + " ground_truth.append(data[~data.masks[\"blocked_by_others\"]])\n", + "\n", + "figs = [\n", + " pp.plot(\n", + " {\n", + " \"wfm\": wav_wfm[\"detector_number\", i].bins.concat().hist(wavelength=edges),\n", + " \"ground_truth\": ground_truth[i].hist(wavelength=edges),\n", + " },\n", + " title=f\"detector_number {i}\",\n", + " color={\"ground_truth\": \"k\", \"wfm\": f\"C{i}\"},\n", + " )\n", + " for i in range(len(Ltotal))\n", + "]\n", + "\n", + "figs[0] + figs[1]" + ] + }, + { + "cell_type": "markdown", + "id": "42", + "metadata": {}, + "source": [ + "### Handling time overlap between subframes\n", + "\n", + "In some (relatively rare) cases, where a chopper cascade is slightly ill-defined,\n", + "it is sometimes possible for some subframes to overlap in time with other subframes.\n", + "\n", + "This is basically when neutrons passed through different pulse-shaping chopper openings,\n", + "but arrive at the same time at the detector.\n", + "\n", + "In this case, it is actually not possible to accurately determine the wavelength of the neutrons.\n", + "We handle this by masking the overlapping regions and throwing away any neutrons that lie within it.\n", + "\n", + "To simulate this, we modify slightly the phase and the cutouts of the band-control chopper:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43", + "metadata": {}, + "outputs": [], + "source": [ + "disk_choppers[\"bcc\"] = DiskChopper(\n", + " frequency=sc.scalar(112.0, unit=\"Hz\"),\n", + " beam_position=sc.scalar(0.0, unit=\"deg\"),\n", + " phase=sc.scalar(240 - 180, unit=\"deg\"),\n", + " axle_position=sc.vector(value=[0, 0, -66.77], unit=\"m\"),\n", + " slit_begin=sc.array(dims=[\"cutout\"], values=[-36.875, 143.125], unit=\"deg\"),\n", + " slit_end=sc.array(dims=[\"cutout\"], values=[46.875, 216.875], unit=\"deg\"),\n", + " slit_height=sc.scalar(10.0, unit=\"cm\"),\n", + " radius=sc.scalar(30.0, unit=\"cm\"),\n", + ")\n", + "\n", + "# Go back to a single detector pixel\n", + "Ltotal = sc.scalar(76.55 + 1.125, unit=\"m\")\n", + "\n", + "ess_beamline = FakeBeamline(\n", + " choppers=disk_choppers,\n", + " source_position=source_position,\n", + " monitors={\"detector\": Ltotal},\n", + " run_length=sc.scalar(1 / 14, unit=\"s\") * 4,\n", + " events_per_pulse=200_000,\n", + ")\n", + "\n", + "ess_beamline.model_result.plot()" + ] + }, + { + "cell_type": "markdown", + "id": "44", + "metadata": {}, + "source": [ + "We can now see that there is no longer a gap between the two frames at the center of each pulse (green region).\n", + "\n", + "Another way of looking at this is looking at the wavelength vs time-of-arrival plot,\n", + "which also shows overlap in time at the junction between the two frames:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45", + "metadata": {}, + "outputs": [], + "source": [ + "# Update workflow\n", + "lut_wf[DiskChoppers[AnyRun]] = disk_choppers\n", + "\n", + "frames = lut_wf.compute(ChopperFrameSequence)\n", + "at_detector = frames.propagate_to(dist)\n", + "fig, ax = at_detector.draw()\n", + "ax.set(xlim=(36, 44), ylim=(2, 3))" + ] + }, + { + "cell_type": "markdown", + "id": "46", + "metadata": {}, + "source": [ + "The data in the lookup table contains both the mean wavelength for each distance and time-of-arrival bin,\n", + "but also the variance inside each bin.\n", + "\n", + "In the regions where there is no time overlap,\n", + "the variance is small (the regions are close to a thin line).\n", + "However, in the central region where overlap occurs,\n", + "we are computing a mean between two regions which have similar 'brightness'.\n", + "\n", + "This leads to a large variance, and this is visible when plotting the relative standard deviations on a 2D figure\n", + "(we zoom in on the distances corresponding to the detector banks around 75m from the source)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47", + "metadata": {}, + "outputs": [], + "source": [ + "table = lut_wf.compute(LookupTable)\n", + "table.plot(ymin=65) / (sc.stddevs(table.array) / sc.values(table.array)).plot(\n", + " norm=\"linear\", ymin=55, vmax=0.05\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "48", + "metadata": {}, + "source": [ + "The workflow has a parameter which is used to mask out regions where the standard deviation is above a certain threshold.\n", + "\n", + "It is difficult to automatically detector this threshold,\n", + "as it can vary a lot depending on how much signal is received by the detectors,\n", + "and how far the detectors are from the source.\n", + "It is thus more robust to simply have a user tunable parameter on the workflow." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49", + "metadata": {}, + "outputs": [], + "source": [ + "wf[LookupTable] = table\n", + "\n", + "wf[LookupTableRelativeErrorThreshold] = {'dream_detector': 0.02}\n", + "\n", + "masked_table = wf.compute(ErrorLimitedLookupTable[snx.NXdetector])\n", + "masked_table.plot(ymin=65)" + ] + }, + { + "cell_type": "markdown", + "id": "50", + "metadata": {}, + "source": [ + "We can now see that the central region is masked out.\n", + "\n", + "The neutrons in that region will be discarded in the wavelength calculation\n", + "(in practice, they are given a NaN value as a wavelength).\n", + "\n", + "This is visible when comparing to the true neutron wavelengths,\n", + "where we see that some counts were lost between the two frames." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51", + "metadata": {}, + "outputs": [], + "source": [ + "wf[RawDetector[SampleRun]] = ess_beamline.get_monitor(\"detector\")[0]\n", + "wf[DetectorLtotal[SampleRun]] = Ltotal\n", + "\n", + "# Compute wavelength\n", + "wav_wfm = wf.compute(WavelengthDetector[SampleRun])\n", + "\n", + "# Compare to the true wavelengths\n", + "ground_truth = ess_beamline.model_result[\"detector\"].data.flatten(to=\"event\")\n", + "ground_truth = ground_truth[~ground_truth.masks[\"blocked_by_others\"]]\n", + "\n", + "pp.plot(\n", + " {\n", + " \"wfm\": wav_wfm.hist(wavelength=edges).squeeze(),\n", + " \"ground_truth\": ground_truth.hist(wavelength=edges),\n", + " },\n", + " color={\"ground_truth\": \"k\"},\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "52", + "metadata": {}, + "source": [ + "## The ODIN instrument\n", + "\n", + "The second example is the ODIN instrument;\n", + "it is a more classical case in the sense that the WFM choppers have 6 openings and produce 6 neutron frames at the detector.\n", + "\n", + "However, it uses a technique called 'pulse-skipping' where a chopper rotating at half the source frequency blocks (or 'skips') every other pulse.\n", + "This allows the range of wavelengths recorded at the detector to be much wider,\n", + "because overlap between consecutive pulses has been reduced by a factor of 2.\n", + "\n", + "### Setting up the beamline and data\n", + "\n", + "We begin by defining the chopper parameters and creating fake events, as we did previously for DREAM." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53", + "metadata": {}, + "outputs": [], + "source": [ + "parameters = {\n", + " \"WFMC_1\": {\n", + " \"frequency\": 56.0,\n", + " \"phase\": 93.244,\n", + " \"distance\": 6.85,\n", + " \"open\": [-1.9419, 49.5756, 98.9315, 146.2165, 191.5176, 234.9179],\n", + " \"close\": [1.9419, 55.7157, 107.2332, 156.5891, 203.8741, 249.1752],\n", + " },\n", + " \"WFMC_2\": {\n", + " \"frequency\": 56.0,\n", + " \"phase\": 97.128,\n", + " \"distance\": 7.15,\n", + " \"open\": [-1.9419, 51.8318, 103.3493, 152.7052, 199.9903, 245.2914],\n", + " \"close\": [1.9419, 57.9719, 111.6510, 163.0778, 212.3468, 259.5486],\n", + " },\n", + " \"FOC_1\": {\n", + " \"frequency\": 42.0,\n", + " \"phase\": 81.303297,\n", + " \"distance\": 8.4,\n", + " \"open\": [-5.1362, 42.5536, 88.2425, 132.0144, 173.9497, 216.7867],\n", + " \"close\": [5.1362, 54.2095, 101.2237, 146.2653, 189.417, 230.7582],\n", + " },\n", + " \"BP_1\": {\n", + " \"frequency\": 7.0,\n", + " \"phase\": 31.080,\n", + " \"distance\": 8.45,\n", + " \"open\": [-23.6029],\n", + " \"close\": [23.6029],\n", + " },\n", + " \"FOC_2\": {\n", + " \"frequency\": 42.0,\n", + " \"phase\": 107.013442,\n", + " \"distance\": 12.2,\n", + " \"open\": [-16.3227, 53.7401, 120.8633, 185.1701, 246.7787, 307.0165],\n", + " \"close\": [16.3227, 86.8303, 154.3794, 218.7551, 280.7508, 340.3188],\n", + " },\n", + " \"BP_2\": {\n", + " \"frequency\": 7.0,\n", + " \"phase\": 44.224,\n", + " \"distance\": 12.25,\n", + " \"open\": [-34.4663],\n", + " \"close\": [34.4663],\n", + " },\n", + " \"T0_alpha\": {\n", + " \"frequency\": 14.0,\n", + " \"phase\": 179.672,\n", + " \"distance\": 13.5,\n", + " \"open\": [-167.8986],\n", + " \"close\": [167.8986],\n", + " },\n", + " \"T0_beta\": {\n", + " \"frequency\": 14.0,\n", + " \"phase\": 179.672,\n", + " \"distance\": 13.7,\n", + " \"open\": [-167.8986],\n", + " \"close\": [167.8986],\n", + " },\n", + " \"FOC_3\": {\n", + " \"frequency\": 28.0,\n", + " \"phase\": 92.993,\n", + " \"distance\": 17.0,\n", + " \"open\": [-20.302, 45.247, 108.0457, 168.2095, 225.8489, 282.2199],\n", + " \"close\": [20.302, 85.357, 147.6824, 207.3927, 264.5977, 319.4024],\n", + " },\n", + " \"FOC_4\": {\n", + " \"frequency\": 14.0,\n", + " \"phase\": 61.584,\n", + " \"distance\": 23.69,\n", + " \"open\": [-16.7157, 29.1882, 73.1661, 115.2988, 155.6636, 195.5254],\n", + " \"close\": [16.7157, 61.8217, 105.0352, 146.4355, 186.0987, 224.0978],\n", + " },\n", + " \"FOC_5\": {\n", + " \"frequency\": 14.0,\n", + " \"phase\": 82.581,\n", + " \"distance\": 33.0,\n", + " \"open\": [-25.8514, 38.3239, 99.8064, 160.1254, 217.4321, 272.5426],\n", + " \"close\": [25.8514, 88.4621, 147.4729, 204.0245, 257.7603, 313.7139],\n", + " },\n", + "}\n", + "\n", + "odin_choppers = {\n", + " key: DiskChopper(\n", + " frequency=-ch[\"frequency\"] * sc.Unit(\"Hz\"),\n", + " beam_position=sc.scalar(0.0, unit=\"deg\"),\n", + " phase=-ch[\"phase\"] * sc.Unit(\"deg\"),\n", + " axle_position=sc.vector(value=[0, 0, ch[\"distance\"]], unit=\"m\"),\n", + " slit_begin=sc.array(dims=[\"cutout\"], values=ch[\"open\"], unit=\"deg\"),\n", + " slit_end=sc.array(dims=[\"cutout\"], values=ch[\"close\"], unit=\"deg\"),\n", + " )\n", + " for key, ch in parameters.items()\n", + "}\n", + "\n", + "Ltotal = sc.scalar(60.0, unit=\"m\")\n", + "source_position = sc.vector([0, 0, 0], unit='m')\n", + "\n", + "ess_beamline = FakeBeamline(\n", + " choppers=odin_choppers,\n", + " source_position=source_position,\n", + " monitors={\"detector\": Ltotal},\n", + " run_length=sc.scalar(1 / 14, unit=\"s\") * 4,\n", + " events_per_pulse=400_000,\n", + ")\n", + "\n", + "ess_beamline.model_result.plot()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54", + "metadata": {}, + "outputs": [], + "source": [ + "raw_data = ess_beamline.get_monitor(\"detector\")[0]\n", + "\n", + "# Visualize\n", + "raw_data.hist(event_time_offset=300).squeeze().plot()" + ] + }, + { + "cell_type": "markdown", + "id": "55", + "metadata": {}, + "source": [ + "### Creating the lookup table for ODIN\n", + "\n", + "We use once again the `LookupTableWorkflow` to compute the wavelength lookup table.\n", + "\n", + "Because ODIN uses a pulse-skipping chopper, we need to set `PulseStride = 2` on the workflow." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56", + "metadata": {}, + "outputs": [], + "source": [ + "lut_wf = FastLookupTableWorkflow()\n", + "lut_wf[DiskChoppers[AnyRun]] = odin_choppers\n", + "lut_wf[SourcePosition] = source_position\n", + "lut_wf[LtotalRange] = (\n", + " sc.scalar(25.0, unit=\"m\"),\n", + " sc.scalar(65.0, unit=\"m\"),\n", + ")\n", + "lut_wf[PulseStride] = 2\n", + "\n", + "frames = lut_wf.compute(ChopperFrameSequence)\n", + "at_detector = frames.propagate_to(Ltotal)\n", + "fig, ax = at_detector.draw()\n", + "\n", + "table = lut_wf.compute(LookupTable)\n", + "\n", + "# Overlay LUT prediction on the polygons figure\n", + "da = table.array[\"distance\", 352]\n", + "ax.plot(\n", + " da.coords['event_time_offset'].values / 1000,\n", + " da.values,\n", + " color=\"k\",\n", + " ls=\"-\",\n", + " marker=None,\n", + ")\n", + "ax.legend(loc=(1.01, 0.25))" + ] + }, + { + "cell_type": "markdown", + "id": "57", + "metadata": {}, + "source": [ + "The final relation between time-of-arrival and wavelength at the detector is represented by the black lines that accurately trace the green polygons\n", + "(zooming in on the figure may be required to even see the polygons at 60 m).\n", + "\n", + "Also note that because of the pulse skipping, we consider two source pulses (blue rectangles) instead of one in the DREAM case.\n", + "Both pulses generate sets of polygons up to 8.4 m, but beyond that only the first pulses continues to travel down the beamline,\n", + "while the second pulse got blocked by the 7 Hz chopper.\n", + "\n", + "The full wavelength lookup table is plotted below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58", + "metadata": {}, + "outputs": [], + "source": [ + "table.plot()" + ] + }, + { + "cell_type": "markdown", + "id": "59", + "metadata": {}, + "source": [ + "### Computing wavelengths for ODIN\n", + "\n", + "Computing wavelengths is done in exactly the same way as for DREAM above." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60", + "metadata": {}, + "outputs": [], + "source": [ + "wf = GenericUnwrapWorkflow(run_types=[SampleRun], monitor_types=[])\n", + "\n", + "wf[RawDetector[SampleRun]] = raw_data\n", + "wf[DetectorLtotal[SampleRun]] = Ltotal\n", + "wf[NeXusDetectorName] = 'odin_detector'\n", + "wf[LookupTableRelativeErrorThreshold] = {'odin_detector': float(\"inf\")}\n", + "\n", + "wf.visualize(WavelengthDetector[SampleRun])\n", + "wf[LookupTable] = table\n", + "\n", + "# Compute wavelength of neutron events\n", + "wavs = wf.compute(WavelengthDetector[SampleRun])\n", + "edges = sc.linspace(\"wavelength\", 0.8, 10.0, 401, unit=\"angstrom\")\n", + "\n", + "histogrammed = wavs.hist(wavelength=edges).squeeze()\n", + "\n", + "ground_truth = ess_beamline.model_result[\"detector\"].data.flatten(to=\"event\")\n", + "ground_truth = ground_truth[~ground_truth.masks[\"blocked_by_others\"]]\n", + "\n", + "pp.plot(\n", + " {\n", + " \"wfm\": histogrammed,\n", + " \"ground_truth\": ground_truth.hist(wavelength=edges),\n", + " },\n", + " color={\"ground_truth\": \"k\"},\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/packages/essreduce/docs/user-guide/unwrap/index.md b/packages/essreduce/docs/user-guide/unwrap/index.md index 137d5dbda..9b37adf30 100644 --- a/packages/essreduce/docs/user-guide/unwrap/index.md +++ b/packages/essreduce/docs/user-guide/unwrap/index.md @@ -8,4 +8,5 @@ maxdepth: 1 frame-unwrapping wfm dream +analytical-unwrap ``` diff --git a/packages/essreduce/pyproject.toml b/packages/essreduce/pyproject.toml index 483d66ac7..e36a5cf3f 100644 --- a/packages/essreduce/pyproject.toml +++ b/packages/essreduce/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "graphviz>=0.20", "sciline>=25.11.0", "scipp>=26.3.1", - "scippneutron>=25.11.1", + "scippneutron>=26.5.0", "scippnexus>=25.06.0", "scipy>=1.14", ] diff --git a/packages/essreduce/src/ess/reduce/unwrap/__init__.py b/packages/essreduce/src/ess/reduce/unwrap/__init__.py index c0d21027c..27bbbcabf 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/__init__.py +++ b/packages/essreduce/src/ess/reduce/unwrap/__init__.py @@ -9,7 +9,9 @@ from ..nexus.types import DiskChoppers from .lut import ( BeamlineComponentReading, + ChopperFrameSequence, DistanceResolution, + FastLookupTableWorkflow, LookupTableWorkflow, LtotalRange, NumberOfSimulatedNeutrons, @@ -17,6 +19,7 @@ PulseStride, SimulationResults, SimulationSeed, + SourceBounds, SourcePosition, TimeResolution, simulate_chopper_cascade_using_tof, @@ -37,10 +40,12 @@ __all__ = [ "BeamlineComponentReading", + "ChopperFrameSequence", "DetectorLtotal", "DiskChoppers", "DistanceResolution", "ErrorLimitedLookupTable", + "FastLookupTableWorkflow", "GenericUnwrapWorkflow", "LookupTable", "LookupTableFilename", @@ -54,6 +59,7 @@ "PulseStrideOffset", "SimulationResults", "SimulationSeed", + "SourceBounds", "SourcePosition", "TimeResolution", "WavelengthDetector", diff --git a/packages/essreduce/src/ess/reduce/unwrap/fakes.py b/packages/essreduce/src/ess/reduce/unwrap/fakes.py index cc93023ff..117a686dc 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/fakes.py +++ b/packages/essreduce/src/ess/reduce/unwrap/fakes.py @@ -80,7 +80,7 @@ def psc_choppers(): "chopper": DiskChopper( frequency=sc.scalar(-14.0, unit="Hz"), beam_position=sc.scalar(0.0, unit="deg"), - phase=sc.scalar(-85.0, unit="deg"), + phase=sc.scalar(-105.0, unit="deg"), axle_position=sc.vector(value=[0, 0, 8.0], unit="m"), slit_begin=sc.array(dims=["cutout"], values=[0.0], unit="deg"), slit_end=sc.array(dims=["cutout"], values=[3.0], unit="deg"), diff --git a/packages/essreduce/src/ess/reduce/unwrap/lut.py b/packages/essreduce/src/ess/reduce/unwrap/lut.py index c76e031fb..c3a4a14aa 100644 --- a/packages/essreduce/src/ess/reduce/unwrap/lut.py +++ b/packages/essreduce/src/ess/reduce/unwrap/lut.py @@ -1,14 +1,17 @@ # SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2025 Scipp contributors (https://github.com/scipp) +# Copyright (c) 2026 Scipp contributors (https://github.com/scipp) """ -Utilities for computing time-of-flight lookup tables from neutron simulations. +Utilities for computing wavelength lookup tables. """ +import warnings from dataclasses import dataclass from typing import NewType +import numpy as np import sciline as sl import scipp as sc +from scippneutron.tof import chopper_cascade from ..nexus.types import AnyRun, DiskChoppers from .types import LookupTable @@ -136,6 +139,25 @@ class SimulationResults: """ +@dataclass +class SourceBounds: + """Time and wavelength bounds of the neutrons in the source pulse that encompass + all possible neutrons that can be generated by the source. + """ + + time: tuple[sc.Variable, sc.Variable] + """Time range (start, end) of the source pulse.""" + wavelength: tuple[sc.Variable, sc.Variable] + """Wavelength range (min, max) of the neutrons in the source pulse.""" + + +ChopperFrameSequence = NewType("ChopperFrameSequence", chopper_cascade.FrameSequence) +""" +Sequence of chopper frames used to compute the wavelength as a function of distance and +event_time_offset in the lookup table. +""" + + def _compute_mean_wavelength( simulation: BeamlineComponentReading, distance: sc.Variable, @@ -461,3 +483,329 @@ def LookupTableWorkflow(): }, ) return wf + + +def _polygon_edges(polygons: list[np.ndarray]) -> np.ndarray: + """ + Convert a list of polygons (N_i, 2) arrays to a single array of edges (E, 2, 2). + """ + # polygons: list of (N_i, 2) arrays + edges = [] + for poly in polygons: + p1 = poly + p2 = np.roll(poly, -1, axis=0) + edges.append(np.stack([p1, p2], axis=1)) # (N, 2, 2) + return np.concatenate(edges, axis=0) # (E, 2, 2) + + +def _polygon_intersections(polygons: list[np.ndarray], xs: np.ndarray) -> np.ndarray: + """ + Find the intersections of a list of polygons with vertical lines at specified x + coordinates. + We then take the mean of the minimum and maximum intersection points as an estimate + of the mean wavelength in each bin. This handles the case where there are multiple + subframes overlapping in a single time bin. + + Parameters + ---------- + polygons: + List of polygons, each represented as an (N_i, 2) array of vertices. + xs: + Array of x coordinates where intersections should be computed. + + Returns + ------- + Array of intersection y coordinates, one for each x in `xs`. + """ + edges = _polygon_edges(polygons) + + x1 = edges[:, 0, 0][:, None] # (E, 1) + y1 = edges[:, 0, 1][:, None] + x2 = edges[:, 1, 0][:, None] + y2 = edges[:, 1, 1][:, None] + + xs = xs[None, :] # (1, N) + + # mask: edge crosses vertical line at x + mask = ((x1 <= xs) & (x2 > xs)) | ((x2 <= xs) & (x1 > xs)) + + # avoid division by zero (vertical edges won't pass mask anyway) + denom = x2 - x1 + denom = np.where(denom == 0, np.nan, denom) + + t = (xs - x1) / denom + y = y1 + t * (y2 - y1) + + # keep only valid intersections + y = np.where(mask, y, np.nan) + + # now reduce along edges axis + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", category=RuntimeWarning, message="All-NaN slice encountered" + ) + y_min = np.nanmin(y, axis=0) + y_max = np.nanmax(y, axis=0) + + # Median value and spread estimate + return 0.5 * (y_min + y_max), 0.5 * (y_max - y_min) + + +def _estimate_wavelength_by_polygon_centers( + subframes: list[chopper_cascade.Subframe], + time_edges: sc.Variable, + time_unit: str, + frame_period: sc.Variable, +) -> sc.DataArray: + """ + Compute the mean wavelength inside event_time_offset bins for a given range of + distances. + + This is done by finding the intersection of the edges of the subframe polygons + (generated by the ``chopper_cascade`` module) with vertical lines at the specified + time edges. + We then take the mean of the minimum and maximum intersection points as an estimate + of the mean wavelength in each bin. This handles the case where there are multiple + subframes overlapping in a single time bin. + + Parameters + ---------- + subframes: + List of subframes to consider. These should already be propagated to the + correct distance. + time_edges: + Edges of the time bins for which to compute the mean wavelength. Should be a + 1D variable with a unit of time. + time_unit: + Unit to use for all time quantities. + frame_period: + Period of the source pulses, used to handle the periodicity of the subframes. + """ + + # Here, the frame could be offset by more than one frame period (if the neutron + # flight path is very long). So we shift the frame back enough times so that + # the minimum time is between 0 and the frame period. + min_time = sc.reduce([f.time.min() for f in subframes]).min() + noffset = int(min_time.to(unit=time_unit).value / frame_period.value) + + # To handle the periodicity of the subframes, we need to consider not only the + # original subframes, but also copies of the subframes shifted by the frame period. + # This is because neutrons that arrive after the frame period will wrap around and + # appear in the next pulse, which is equivalent to the original pulse but shifted + # by the frame period. + polygons = [ + np.stack( + [ + (f.time.to(unit=time_unit) - (noffset + i) * frame_period).values, + f.wavelength.values, + ], + axis=1, + ) + for f in subframes + for i in (0, 1) + ] + + wavs, stddevs = _polygon_intersections(polygons, time_edges.values) + + return sc.array( + dims=time_edges.dims, + values=wavs, + variances=stddevs**2, + unit=subframes[0].wavelength.unit, + ) + + +def compute_frame_sequence( + pulse_period: PulsePeriod, + disk_choppers: DiskChoppers[AnyRun], + source_position: SourcePosition, + source_bounds: SourceBounds, + pulse_stride: PulseStride, +) -> ChopperFrameSequence: + """ + Compute the chopper frame sequence for a given set of disk choppers and source pulse + parameters. + + Parameters + ---------- + pulse_period: + Period of the source pulses, i.e., time between consecutive pulse starts. + disk_choppers: + Disk chopper parameters. + source_position: + Position of the neutron source. + source_bounds: + Time and wavelength range of the source pulse. + pulse_stride: + Stride of used pulses. Usually 1, but may be a small integer when + pulse-skipping. + """ + + # The `pulse_frequency` parameter in time_offset_open and time_offset_close below + # decides how many rotations the chopper will perform when computing the open and + # close times. Because we want to cover a number of pulses equal to `pulse_stride`, + # we need to set the pulse frequency to be `pulse_stride` times smaller than the + # actual pulse frequency. + # + # In addition, the time_offset_open and time_offset_close below require the + # pulse_frequency to be an integer multiple of the pulse frequency or vice versa. + # A simple trick is to make sure that the requested pulse frequency is divided by + # an even number. We need to rotate the chopper for long enough to cover wrapping + # around the frame period, so we cover two pulses strides. + frequency_for_chopper_rotation = (1.0 / pulse_period.to(unit='s')) / ( + pulse_stride * 2 + ) + + chops = { + key: chopper_cascade.Chopper( + distance=sc.norm( + ch.axle_position - source_position.to(unit=ch.axle_position.unit) + ), + time_open=ch.time_offset_open( + pulse_frequency=frequency_for_chopper_rotation + ), + time_close=ch.time_offset_close( + pulse_frequency=frequency_for_chopper_rotation + ), + ) + for key, ch in disk_choppers.items() + } + + frames = chopper_cascade.FrameSequence.from_source_pulse( + time_min=source_bounds.time[0], + time_max=source_bounds.time[1], + wavelength_min=source_bounds.wavelength[0], + wavelength_max=source_bounds.wavelength[1], + pulse_period=pulse_period, + npulses=pulse_stride, + ) + frames = frames.chop(chops.values()) + return ChopperFrameSequence(frames) + + +def make_wavelength_lut_from_polygons( + ltotal_range: LtotalRange, + distance_resolution: DistanceResolution, + time_resolution: TimeResolution, + pulse_period: PulsePeriod, + pulse_stride: PulseStride, + frames: ChopperFrameSequence, +) -> LookupTable: + """ + Compute a lookup table for wavelength as a function of distance and + time-of-arrival. + + Parameters + ---------- + ltotal_range: + Range of total flight path lengths from the source to the detector. + distance_resolution: + Resolution of the distance axis in the lookup table. + time_resolution: + Resolution of the time-of-arrival axis in the lookup table. Must be an integer. + pulse_period: + Period of the source pulses, i.e., time between consecutive pulse starts. + pulse_stride: + Stride of used pulses. Usually 1, but may be a small integer when + pulse-skipping. + frames: + Chopper frame sequence used to compute the wavelength as a function of distance + and event_time_offset in the lookup table. + """ + distance_unit = "m" + time_unit = "us" + res = distance_resolution.to(unit=distance_unit) + pulse_period = pulse_period.to(unit=time_unit) + frame_period = pulse_period * pulse_stride + + min_dist = ltotal_range[0].to(unit=distance_unit) + max_dist = ltotal_range[1].to(unit=distance_unit) + + # We want to give the 2d interpolator a table that covers the requested range, + # hence we need to extend the range by at least half a resolution in each direction. + # Then, we make the choice that the resolution in distance is the quantity that + # should be preserved. Because the difference between min and max distance is + # not necessarily an integer multiple of the resolution, we need to add a pad to + # ensure that the last bin is not cut off. We want the upper edge to be higher than + # the maximum distance, hence we pad with an additional 1.5 x resolution. + pad = 2.0 * res + distances = sc.arange('distance', min_dist - pad, max_dist + pad, res) + + # Create some time bins for event_time_offset. + # We want our final table to strictly cover the range [0, frame_period]. + nbins = int(frame_period / time_resolution.to(unit=time_unit)) + 1 + time_edges = sc.linspace( + 'event_time_offset', 0.0, frame_period.value, nbins + 1, unit=pulse_period.unit + ) + + # Sort frames by reverse distance + sorted_frames = sorted(frames, key=lambda x: x.distance.value, reverse=True) + + pieces = [] + # To avoid large RAM usage, and having to split the distances into chunks + # according to which frame to use, we simply loop over distances one + # by one here. + for dist in distances: + # Find the correct simulation reading + selected_frame = None + for frame in sorted_frames: + if dist.value >= frame.distance.to(unit=dist.unit).value: + selected_frame = frame + break + if selected_frame is None: + raise ValueError( + "Building the lookup table failed: the requested position " + f"{dist:c} is before the component with the lowest " + "distance in the simulation. The first component in the beamline " + f"has distance {sorted_frames[0].distance:c}." + ) + + subframes = selected_frame.propagate_to(dist).subframes + + pieces.append( + _estimate_wavelength_by_polygon_centers( + subframes=subframes, + time_edges=time_edges, + time_unit=time_unit, + frame_period=frame_period, + ) + ) + + table = sc.DataArray( + data=sc.concat(pieces, 'distance'), + coords={"distance": distances, "event_time_offset": time_edges}, + ) + + return LookupTable( + array=table, + pulse_period=pulse_period, + pulse_stride=pulse_stride, + distance_resolution=table.coords["distance"][1] - table.coords["distance"][0], + time_resolution=table.coords["event_time_offset"][1] + - table.coords["event_time_offset"][0], + # TODO: Do we still want to store the chopper information in the lookup table? + ) + + +def FastLookupTableWorkflow(): + """ + Create a workflow for computing a wavelength lookup table from computing an + acceptance diagram for a pulse propagating through a chopper cascade. + """ + wf = sl.Pipeline( + (make_wavelength_lut_from_polygons, compute_frame_sequence), + params={ + PulsePeriod: 1.0 / sc.scalar(14.0, unit="Hz"), + PulseStride: 1, + DistanceResolution: sc.scalar(0.1, unit="m"), + TimeResolution: sc.scalar(250.0, unit='us'), + SourceBounds: SourceBounds( + time=(sc.scalar(0.0, unit='ms'), sc.scalar(5.0, unit='ms')), + wavelength=( + sc.scalar(0.0, unit='angstrom'), + sc.scalar(15.0, unit='angstrom'), + ), + ), + }, + ) + return wf diff --git a/packages/essreduce/tests/unwrap/lut_test.py b/packages/essreduce/tests/unwrap/lut_test.py index 495e57dd2..a2f73d45d 100644 --- a/packages/essreduce/tests/unwrap/lut_test.py +++ b/packages/essreduce/tests/unwrap/lut_test.py @@ -6,19 +6,22 @@ from ess.reduce import unwrap from ess.reduce.nexus.types import AnyRun -from ess.reduce.unwrap import LookupTableWorkflow +from ess.reduce.unwrap import FastLookupTableWorkflow, LookupTableWorkflow sl = pytest.importorskip("sciline") -def test_lut_workflow_computes_table(): - wf = LookupTableWorkflow() +@pytest.mark.parametrize("engine", ["analytical", "tof"]) +def test_lut_workflow_computes_table(engine): + wf = FastLookupTableWorkflow() if engine == "analytical" else LookupTableWorkflow() wf[unwrap.DiskChoppers[AnyRun]] = {} wf[unwrap.SourcePosition] = sc.vector([0, 0, 0], unit='m') - wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 - wf[unwrap.SimulationSeed] = 60 wf[unwrap.PulseStride] = 1 + if engine == "tof": + wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + wf[unwrap.SimulationSeed] = 60 + lmin, lmax = sc.scalar(25.0, unit='m'), sc.scalar(35.0, unit='m') dres = sc.scalar(0.1, unit='m') tres = sc.scalar(333.0, unit='us') @@ -40,12 +43,14 @@ def test_lut_workflow_computes_table(): assert sc.isclose(table.time_resolution, tres, rtol=sc.scalar(0.01)) -def test_lut_workflow_pulse_skipping(): - wf = LookupTableWorkflow() +@pytest.mark.parametrize("engine", ["analytical", "tof"]) +def test_lut_workflow_pulse_skipping(engine): + wf = FastLookupTableWorkflow() if engine == "analytical" else LookupTableWorkflow() wf[unwrap.DiskChoppers[AnyRun]] = {} wf[unwrap.SourcePosition] = sc.vector([0, 0, 0], unit='m') - wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 - wf[unwrap.SimulationSeed] = 62 + if engine == "tof": + wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + wf[unwrap.SimulationSeed] = 62 wf[unwrap.PulseStride] = 2 lmin, lmax = sc.scalar(55.0, unit='m'), sc.scalar(65.0, unit='m') @@ -63,12 +68,14 @@ def test_lut_workflow_pulse_skipping(): ).to(unit=table.array.coords['event_time_offset'].unit) -def test_lut_workflow_non_exact_distance_range(): - wf = LookupTableWorkflow() +@pytest.mark.parametrize("engine", ["analytical", "tof"]) +def test_lut_workflow_non_exact_distance_range(engine): + wf = FastLookupTableWorkflow() if engine == "analytical" else LookupTableWorkflow() wf[unwrap.DiskChoppers[AnyRun]] = {} wf[unwrap.SourcePosition] = sc.vector([0, 0, 0], unit='m') - wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 - wf[unwrap.SimulationSeed] = 63 + if engine == "tof": + wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + wf[unwrap.SimulationSeed] = 63 wf[unwrap.PulseStride] = 1 lmin, lmax = sc.scalar(25.0, unit='m'), sc.scalar(35.0, unit='m') @@ -145,12 +152,14 @@ def _make_choppers(): } -def test_lut_workflow_computes_table_with_choppers(): - wf = LookupTableWorkflow() +@pytest.mark.parametrize("engine", ["analytical", "tof"]) +def test_lut_workflow_computes_table_with_choppers(engine): + wf = FastLookupTableWorkflow() if engine == "analytical" else LookupTableWorkflow() wf[unwrap.DiskChoppers[AnyRun]] = _make_choppers() wf[unwrap.SourcePosition] = sc.vector([0, 0, 0], unit='m') - wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 - wf[unwrap.SimulationSeed] = 64 + if engine == "tof": + wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + wf[unwrap.SimulationSeed] = 64 wf[unwrap.PulseStride] = 1 wf[unwrap.LtotalRange] = ( @@ -179,12 +188,14 @@ def test_lut_workflow_computes_table_with_choppers(): assert eto.max() < sc.scalar(6.9e4, unit="us").to(unit=eto.unit) -def test_lut_workflow_computes_table_with_choppers_full_beamline_range(): - wf = LookupTableWorkflow() +@pytest.mark.parametrize("engine", ["analytical", "tof"]) +def test_lut_workflow_computes_table_with_choppers_full_beamline_range(engine): + wf = FastLookupTableWorkflow() if engine == "analytical" else LookupTableWorkflow() wf[unwrap.DiskChoppers[AnyRun]] = _make_choppers() wf[unwrap.SourcePosition] = sc.vector([0, 0, 0], unit='m') - wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 - wf[unwrap.SimulationSeed] = 64 + if engine == "tof": + wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + wf[unwrap.SimulationSeed] = 64 wf[unwrap.PulseStride] = 1 wf[unwrap.LtotalRange] = ( @@ -199,9 +210,9 @@ def test_lut_workflow_computes_table_with_choppers_full_beamline_range(): # Close to source: early times and large spread da = table.array['distance', 2] eto = da.coords['event_time_offset'][sc.isfinite(da.data)] - assert eto.min() > sc.scalar(0.0, unit="us").to(unit=eto.unit) + assert eto.min() >= sc.scalar(0.0, unit="us").to(unit=eto.unit) assert eto.min() < sc.scalar(1.0e3, unit="us").to(unit=eto.unit) - assert eto.max() > sc.scalar(2.5e4, unit="us").to(unit=eto.unit) + assert eto.max() > sc.scalar(2.0e4, unit="us").to(unit=eto.unit) assert eto.max() < sc.scalar(3.0e4, unit="us").to(unit=eto.unit) # Just after WFM choppers, very small range @@ -229,12 +240,14 @@ def test_lut_workflow_computes_table_with_choppers_full_beamline_range(): assert eto.max() < sc.scalar(6.9e4, unit="us").to(unit=eto.unit) -def test_lut_workflow_raises_for_distance_before_source(): - wf = LookupTableWorkflow() +@pytest.mark.parametrize("engine", ["analytical", "tof"]) +def test_lut_workflow_raises_for_distance_before_source(engine): + wf = FastLookupTableWorkflow() if engine == "analytical" else LookupTableWorkflow() wf[unwrap.DiskChoppers[AnyRun]] = {} wf[unwrap.SourcePosition] = sc.vector([0, 0, 10], unit='m') - wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 - wf[unwrap.SimulationSeed] = 65 + if engine == "tof": + wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 + wf[unwrap.SimulationSeed] = 65 wf[unwrap.PulseStride] = 1 # Setting the starting point at zero will make a table that would cover a range diff --git a/packages/essreduce/tests/unwrap/unwrap_test.py b/packages/essreduce/tests/unwrap/unwrap_test.py index 708d3dc35..4fded09bc 100644 --- a/packages/essreduce/tests/unwrap/unwrap_test.py +++ b/packages/essreduce/tests/unwrap/unwrap_test.py @@ -15,37 +15,52 @@ RawMonitor, SampleRun, ) -from ess.reduce.unwrap import GenericUnwrapWorkflow, LookupTableWorkflow, fakes +from ess.reduce.unwrap import ( + FastLookupTableWorkflow, + GenericUnwrapWorkflow, + LookupTableWorkflow, + fakes, +) sl = pytest.importorskip("sciline") -def make_lut_workflow(choppers, neutrons, seed, pulse_stride): - lut_wf = LookupTableWorkflow() +def make_lut_workflow(engine, choppers, pulse_stride, neutrons=None, seed=None): + lut_wf = LookupTableWorkflow() if engine == "tof" else FastLookupTableWorkflow() lut_wf[unwrap.DiskChoppers[AnyRun]] = choppers lut_wf[unwrap.SourcePosition] = fakes.source_position() lut_wf[unwrap.NumberOfSimulatedNeutrons] = neutrons - lut_wf[unwrap.SimulationSeed] = seed lut_wf[unwrap.PulseStride] = pulse_stride - lut_wf[unwrap.SimulationResults] = lut_wf.compute(unwrap.SimulationResults) + if engine == "tof": + lut_wf[unwrap.SimulationSeed] = seed + lut_wf[unwrap.SimulationResults] = lut_wf.compute(unwrap.SimulationResults) return lut_wf @pytest.fixture(scope="module") def lut_workflow_psc_choppers(): - return make_lut_workflow( - choppers=fakes.psc_choppers(), neutrons=500_000, seed=1234, pulse_stride=1 - ) + choppers = fakes.psc_choppers() + return { + 'tof': make_lut_workflow( + engine='tof', choppers=choppers, neutrons=1e6, seed=1234, pulse_stride=1 + ), + 'analytical': make_lut_workflow( + engine='analytical', choppers=choppers, pulse_stride=1 + ), + } @pytest.fixture(scope="module") def lut_workflow_pulse_skipping(): - return make_lut_workflow( - choppers=fakes.pulse_skipping_choppers(), - neutrons=500_000, - seed=112, - pulse_stride=2, - ) + choppers = fakes.pulse_skipping_choppers() + return { + 'tof': make_lut_workflow( + engine='tof', choppers=choppers, neutrons=1e6, seed=112, pulse_stride=2 + ), + 'analytical': make_lut_workflow( + engine='analytical', choppers=choppers, pulse_stride=2 + ), + } def _make_workflow_event_mode( @@ -154,8 +169,6 @@ def _validate_result_histogram_mode(wavs, ref, percentile, diff_threshold, rtol) assert "time_of_flight" not in wavs.coords assert "frame_time" not in wavs.coords - # graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} - # wavs = tofs.transform_coords("wavelength", graph=graph) ref = ref.hist(wavelength=wavs.coords["wavelength"]) # We divide by the maximum to avoid large relative differences at the edges of the # frames where the counts are low. @@ -166,15 +179,16 @@ def _validate_result_histogram_mode(wavs, ref, percentile, diff_threshold, rtol) assert sc.isclose(ref.data.nansum(), wavs.data.nansum(), rtol=sc.scalar(rtol)) +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize("detector_or_monitor", ["detector", "monitor"]) -def test_unwrap_with_no_choppers(detector_or_monitor) -> None: +def test_unwrap_with_no_choppers(engine, detector_or_monitor) -> None: # At this small distance the frames are not overlapping (with the given wavelength # range), despite not using any choppers. distance = sc.scalar(10.0, unit="m") choppers = {} lut_wf = make_lut_workflow( - choppers=choppers, neutrons=300_000, seed=1234, pulse_stride=1 + engine=engine, choppers=choppers, neutrons=300_000, seed=1234, pulse_stride=1 ) pl, ref = _make_workflow_event_mode( @@ -197,19 +211,21 @@ def test_unwrap_with_no_choppers(detector_or_monitor) -> None: ) -# At 30m, event_time_offset does not wrap around (all events within the first pulse). -# At 60m, all events are within the second pulse. -# At 80m, events are split between the second and third pulse. -# At 108m, events are split between the third and fourth pulse. -@pytest.mark.parametrize("dist", [30.0, 60.0, 80.0, 108.0]) +# At 25m, event_time_offset does not wrap around (all events within the first pulse). +# At 50m, all events are within the second pulse. +# At 62m, events are split between the second and third pulse. +# At 90m, events are split between the third and fourth pulse. +@pytest.mark.parametrize("dist", [25.0, 50.0, 62.0, 90.0]) +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize("detector_or_monitor", ["detector", "monitor"]) -def test_standard_unwrap(dist, detector_or_monitor, lut_workflow_psc_choppers) -> None: +def test_standard_unwrap( + dist, engine, detector_or_monitor, lut_workflow_psc_choppers +) -> None: pl, ref = _make_workflow_event_mode( distance=sc.scalar(dist, unit="m"), choppers=fakes.psc_choppers(), - lut_workflow=lut_workflow_psc_choppers, - seed=2, - # pulse_stride=1, + lut_workflow=lut_workflow_psc_choppers[engine], + seed=7, pulse_stride_offset=0, error_threshold=0.1, detector_or_monitor=detector_or_monitor, @@ -221,25 +237,30 @@ def test_standard_unwrap(dist, detector_or_monitor, lut_workflow_psc_choppers) - wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - wavs=wavs, ref=ref, percentile=100, diff_threshold=0.02, rtol=0.05 + wavs=wavs, + ref=ref, + percentile=100, + diff_threshold=0.02, + rtol=0.06 if engine == "tof" else 0.01, ) -# At 30m, event_time_offset does not wrap around (all events within the first pulse). -# At 60m, all events are within the second pulse. -# At 80m, events are split between the second and third pulse. -# At 108m, events are split between the third and fourth pulse. -@pytest.mark.parametrize("dist", [30.0, 60.0, 80.0, 108.0]) +# At 25m, event_time_offset does not wrap around (all events within the first pulse). +# At 50m, all events are within the second pulse. +# At 62m, events are split between the second and third pulse. +# At 90m, events are split between the third and fourth pulse. +@pytest.mark.parametrize("dist", [25.0, 50.0, 62.0, 90.0]) +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize("dim", ["time_of_flight", "tof", "frame_time"]) @pytest.mark.parametrize("detector_or_monitor", ["detector", "monitor"]) def test_standard_unwrap_histogram_mode( - dist, dim, detector_or_monitor, lut_workflow_psc_choppers + dist, engine, dim, detector_or_monitor, lut_workflow_psc_choppers ) -> None: pl, ref = _make_workflow_histogram_mode( dim=dim, distance=sc.scalar(dist, unit="m"), choppers=fakes.psc_choppers(), - lut_workflow=lut_workflow_psc_choppers, + lut_workflow=lut_workflow_psc_choppers[engine], seed=37, error_threshold=np.inf, detector_or_monitor=detector_or_monitor, @@ -251,21 +272,25 @@ def test_standard_unwrap_histogram_mode( wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_histogram_mode( - wavs=wavs, ref=ref, percentile=96, diff_threshold=0.4, rtol=0.05 + wavs=wavs, + ref=ref, + percentile=96, + diff_threshold=0.4, + rtol=0.06 if engine == "tof" else 0.01, ) @pytest.mark.parametrize("dist", [60.0, 100.0]) +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize("detector_or_monitor", ["detector", "monitor"]) def test_pulse_skipping_unwrap( - dist, detector_or_monitor, lut_workflow_pulse_skipping + dist, engine, detector_or_monitor, lut_workflow_pulse_skipping ) -> None: pl, ref = _make_workflow_event_mode( distance=sc.scalar(dist, unit="m"), choppers=fakes.pulse_skipping_choppers(), - lut_workflow=lut_workflow_pulse_skipping, + lut_workflow=lut_workflow_pulse_skipping[engine], seed=432, - # pulse_stride=2, pulse_stride_offset=1, error_threshold=0.1, detector_or_monitor=detector_or_monitor, @@ -277,17 +302,22 @@ def test_pulse_skipping_unwrap( wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - wavs=wavs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 + wavs=wavs, + ref=ref, + percentile=100, + diff_threshold=0.1, + rtol=0.05 if engine == "tof" else 0.01, ) @pytest.mark.parametrize("detector_or_monitor", ["detector", "monitor"]) -def test_pulse_skipping_unwrap_180_phase_shift(detector_or_monitor) -> None: +@pytest.mark.parametrize("engine", ["tof", "analytical"]) +def test_pulse_skipping_unwrap_180_phase_shift(engine, detector_or_monitor) -> None: choppers = fakes.pulse_skipping_choppers() choppers["pulse_skipping"].phase.value += 180.0 lut_wf = make_lut_workflow( - choppers=choppers, neutrons=500_000, seed=111, pulse_stride=2 + engine=engine, choppers=choppers, neutrons=500_000, seed=111, pulse_stride=2 ) pl, ref = _make_workflow_event_mode( @@ -306,19 +336,24 @@ def test_pulse_skipping_unwrap_180_phase_shift(detector_or_monitor) -> None: wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - wavs=wavs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 + wavs=wavs, + ref=ref, + percentile=100, + diff_threshold=0.1, + rtol=0.05 if engine == "tof" else 0.01, ) @pytest.mark.parametrize("dist", [60.0, 100.0]) +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize("detector_or_monitor", ["detector", "monitor"]) def test_pulse_skipping_stride_offset_guess_gives_expected_result( - dist, detector_or_monitor, lut_workflow_pulse_skipping + dist, engine, detector_or_monitor, lut_workflow_pulse_skipping ) -> None: pl, ref = _make_workflow_event_mode( distance=sc.scalar(dist, unit="m"), choppers=fakes.pulse_skipping_choppers(), - lut_workflow=lut_workflow_pulse_skipping, + lut_workflow=lut_workflow_pulse_skipping[engine], seed=97, pulse_stride_offset=None, error_threshold=0.1, @@ -331,13 +366,18 @@ def test_pulse_skipping_stride_offset_guess_gives_expected_result( wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - wavs=wavs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 + wavs=wavs, + ref=ref, + percentile=100, + diff_threshold=0.1, + rtol=0.05 if engine == "tof" else 0.01, ) +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize("detector_or_monitor", ["detector", "monitor"]) def test_pulse_skipping_unwrap_when_all_neutrons_arrive_after_second_pulse( - detector_or_monitor, + engine, detector_or_monitor ) -> None: choppers = fakes.pulse_skipping_choppers() choppers['chopper'] = DiskChopper( @@ -352,11 +392,11 @@ def test_pulse_skipping_unwrap_when_all_neutrons_arrive_after_second_pulse( ) lut_wf = make_lut_workflow( - choppers=choppers, neutrons=500_000, seed=222, pulse_stride=2 + engine=engine, choppers=choppers, neutrons=500_000, seed=222, pulse_stride=2 ) pl, ref = _make_workflow_event_mode( - distance=sc.scalar(150.0, unit="m"), + distance=sc.scalar(130.0, unit="m"), choppers=choppers, lut_workflow=lut_wf, seed=6, @@ -371,13 +411,18 @@ def test_pulse_skipping_unwrap_when_all_neutrons_arrive_after_second_pulse( wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - wavs=wavs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 + wavs=wavs, + ref=ref, + percentile=100, + diff_threshold=0.1, + rtol=0.05 if engine == "tof" else 0.01, ) +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize("detector_or_monitor", ["detector", "monitor"]) def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing( - detector_or_monitor, + engine, detector_or_monitor ) -> None: distance = sc.scalar(100.0, unit="m") choppers = fakes.pulse_skipping_choppers() @@ -392,7 +437,7 @@ def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing( mon, ref = beamline.get_monitor("detector") lut_wf = make_lut_workflow( - choppers=choppers, neutrons=300_000, seed=1234, pulse_stride=2 + engine=engine, choppers=choppers, neutrons=300_000, seed=1234, pulse_stride=2 ) lut_wf[unwrap.LtotalRange] = distance, distance @@ -421,8 +466,6 @@ def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing( pl[unwrap.MonitorLtotal[SampleRun, FrameMonitor0]] = distance wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) - # Convert to wavelength - # graph = {**beamline_graph(scatter=False), **elastic_graph("tof")} wavs = wavs.bins.concat().value # Bin the events in toa starting from the pulse period to skip the first pulse. ref = ( @@ -448,7 +491,7 @@ def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing( / ref.coords["wavelength"] ) # All errors should be small - assert np.nanpercentile(diff.values, 100) < 0.05 + assert np.nanpercentile(diff.values, 100) < 0.06 # Make sure that we have not lost too many events (we lose some because they may be # given a NaN wavelength from the lookup). if detector_or_monitor == "detector": @@ -462,13 +505,14 @@ def test_pulse_skipping_unwrap_when_first_half_of_first_pulse_is_missing( ) +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize("detector_or_monitor", ["detector", "monitor"]) -def test_pulse_skipping_stride_3(detector_or_monitor) -> None: +def test_pulse_skipping_stride_3(engine, detector_or_monitor) -> None: choppers = fakes.pulse_skipping_choppers() choppers["pulse_skipping"].frequency.value = -14.0 / 3.0 lut_wf = make_lut_workflow( - choppers=choppers, neutrons=500_000, seed=111, pulse_stride=3 + engine=engine, choppers=choppers, neutrons=500_000, seed=111, pulse_stride=3 ) pl, ref = _make_workflow_event_mode( @@ -487,19 +531,24 @@ def test_pulse_skipping_stride_3(detector_or_monitor) -> None: wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - wavs=wavs, ref=ref, percentile=100, diff_threshold=0.1, rtol=0.05 + wavs=wavs, + ref=ref, + percentile=100, + diff_threshold=0.1, + rtol=0.05 if engine == "tof" else 0.01, ) +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize("detector_or_monitor", ["detector", "monitor"]) def test_pulse_skipping_unwrap_histogram_mode( - detector_or_monitor, lut_workflow_pulse_skipping + engine, detector_or_monitor, lut_workflow_pulse_skipping ) -> None: pl, ref = _make_workflow_histogram_mode( dim='time_of_flight', distance=sc.scalar(50.0, unit="m"), choppers=fakes.pulse_skipping_choppers(), - lut_workflow=lut_workflow_pulse_skipping, + lut_workflow=lut_workflow_pulse_skipping[engine], seed=9, error_threshold=np.inf, detector_or_monitor=detector_or_monitor, @@ -511,17 +560,24 @@ def test_pulse_skipping_unwrap_histogram_mode( wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_histogram_mode( - wavs=wavs, ref=ref, percentile=96, diff_threshold=0.4, rtol=0.05 + wavs=wavs, + ref=ref, + percentile=96, + diff_threshold=0.4, + rtol=0.05 if engine == "tof" else 0.01, ) @pytest.mark.parametrize("dtype", ["int32", "int64"]) +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize("detector_or_monitor", ["detector", "monitor"]) -def test_unwrap_int(dtype, detector_or_monitor, lut_workflow_psc_choppers) -> None: +def test_unwrap_int( + dtype, engine, detector_or_monitor, lut_workflow_psc_choppers +) -> None: pl, ref = _make_workflow_event_mode( - distance=sc.scalar(80.0, unit="m"), + distance=sc.scalar(62.0, unit="m"), choppers=fakes.psc_choppers(), - lut_workflow=lut_workflow_psc_choppers, + lut_workflow=lut_workflow_psc_choppers[engine], seed=2, pulse_stride_offset=0, error_threshold=0.1, @@ -544,5 +600,9 @@ def test_unwrap_int(dtype, detector_or_monitor, lut_workflow_psc_choppers) -> No wavs = pl.compute(unwrap.WavelengthMonitor[SampleRun, FrameMonitor0]) _validate_result_events( - wavs=wavs, ref=ref, percentile=100, diff_threshold=0.02, rtol=0.05 + wavs=wavs, + ref=ref, + percentile=100, + diff_threshold=0.02, + rtol=0.05 if engine == "tof" else 0.01, ) diff --git a/packages/essreduce/tests/unwrap/wfm_test.py b/packages/essreduce/tests/unwrap/wfm_test.py index d150640f5..ef4728dc6 100644 --- a/packages/essreduce/tests/unwrap/wfm_test.py +++ b/packages/essreduce/tests/unwrap/wfm_test.py @@ -8,7 +8,12 @@ from ess.reduce import unwrap from ess.reduce.nexus.types import AnyRun, NeXusDetectorName, RawDetector, SampleRun -from ess.reduce.unwrap import GenericUnwrapWorkflow, LookupTableWorkflow, fakes +from ess.reduce.unwrap import ( + FastLookupTableWorkflow, + GenericUnwrapWorkflow, + LookupTableWorkflow, + fakes, +) sl = pytest.importorskip("sciline") @@ -107,16 +112,28 @@ def dream_source_position() -> sc.Variable: return sc.vector(value=[0, 0, -76.55], unit="m") -@pytest.fixture(scope="module") -def lut_workflow_dream_choppers() -> sl.Pipeline: - lut_wf = LookupTableWorkflow() - lut_wf[unwrap.DiskChoppers[AnyRun]] = dream_choppers() - lut_wf[unwrap.SourcePosition] = dream_source_position() - lut_wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 - lut_wf[unwrap.SimulationSeed] = 432 +def make_workflows(choppers, source_position) -> dict[str, sl.Pipeline]: + lut_wf = FastLookupTableWorkflow() + lut_wf[unwrap.DiskChoppers[AnyRun]] = choppers + lut_wf[unwrap.SourcePosition] = source_position lut_wf[unwrap.PulseStride] = 1 - lut_wf[unwrap.SimulationResults] = lut_wf.compute(unwrap.SimulationResults) - return lut_wf + + tof_wf = LookupTableWorkflow() + tof_wf[unwrap.DiskChoppers[AnyRun]] = choppers + tof_wf[unwrap.SourcePosition] = source_position + tof_wf[unwrap.NumberOfSimulatedNeutrons] = 300_000 + tof_wf[unwrap.SimulationSeed] = 432 + tof_wf[unwrap.PulseStride] = 1 + tof_wf[unwrap.SimulationResults] = tof_wf.compute(unwrap.SimulationResults) + return {'analytical': lut_wf, 'tof': tof_wf} + + +@pytest.fixture(scope="module") +def lut_workflow_dream_choppers() -> dict[str, sl.Pipeline]: + return make_workflows( + choppers=dream_choppers(), + source_position=dream_source_position(), + ) def setup_workflow( @@ -138,6 +155,7 @@ def setup_workflow( return pl +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize( "ltotal", [ @@ -153,7 +171,7 @@ def setup_workflow( @pytest.mark.parametrize("time_offset_unit", ["s", "ms", "us", "ns"]) @pytest.mark.parametrize("distance_unit", ["m", "mm"]) def test_dream_wfm( - lut_workflow_dream_choppers, ltotal, time_offset_unit, distance_unit + lut_workflow_dream_choppers, engine, ltotal, time_offset_unit, distance_unit ): monitors = { f"detector{i}": ltot for i, ltot in enumerate(ltotal.flatten(to="detector")) @@ -186,7 +204,7 @@ def test_dream_wfm( ref = sc.sort(ref, key='id') pl = setup_workflow( - raw_data=raw, ltotal=ltotal, lut_workflow=lut_workflow_dream_choppers + raw_data=raw, ltotal=ltotal, lut_workflow=lut_workflow_dream_choppers[engine] ) wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) @@ -202,17 +220,14 @@ def test_dream_wfm( @pytest.fixture(scope="module") -def lut_workflow_dream_choppers_time_overlap(): - lut_wf = LookupTableWorkflow() - lut_wf[unwrap.DiskChoppers[AnyRun]] = dream_choppers_with_frame_overlap() - lut_wf[unwrap.SourcePosition] = dream_source_position() - lut_wf[unwrap.NumberOfSimulatedNeutrons] = 100_000 - lut_wf[unwrap.SimulationSeed] = 432 - lut_wf[unwrap.PulseStride] = 1 - lut_wf[unwrap.SimulationResults] = lut_wf.compute(unwrap.SimulationResults) - return lut_wf +def lut_workflow_dream_choppers_time_overlap() -> dict[str, sl.Pipeline]: + return make_workflows( + choppers=dream_choppers_with_frame_overlap(), + source_position=dream_source_position(), + ) +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize( "ltotal", [ @@ -229,6 +244,7 @@ def lut_workflow_dream_choppers_time_overlap(): @pytest.mark.parametrize("distance_unit", ["m", "mm"]) def test_dream_wfm_with_subframe_time_overlap( lut_workflow_dream_choppers_time_overlap, + engine, ltotal, time_offset_unit, distance_unit, @@ -266,7 +282,7 @@ def test_dream_wfm_with_subframe_time_overlap( pl = setup_workflow( raw_data=raw, ltotal=ltotal, - lut_workflow=lut_workflow_dream_choppers_time_overlap, + lut_workflow=lut_workflow_dream_choppers_time_overlap[engine], error_threshold=0.01, ) @@ -282,7 +298,7 @@ def test_dream_wfm_with_subframe_time_overlap( sum_ref = ref.hist(wavelength=100).data.sum() # Verify that we lost some neutrons that were in the overlapping region assert sum_wfm < sum_ref - assert sum_wfm > sum_ref * 0.9 + assert sum_wfm > sum_ref * 0.8 def v20_choppers(): @@ -389,16 +405,12 @@ def v20_source_position(): @pytest.fixture(scope="module") def lut_workflow_v20_choppers(): - lut_wf = LookupTableWorkflow() - lut_wf[unwrap.DiskChoppers[AnyRun]] = v20_choppers() - lut_wf[unwrap.SourcePosition] = v20_source_position() - lut_wf[unwrap.NumberOfSimulatedNeutrons] = 300_000 - lut_wf[unwrap.SimulationSeed] = 431 - lut_wf[unwrap.PulseStride] = 1 - lut_wf[unwrap.SimulationResults] = lut_wf.compute(unwrap.SimulationResults) - return lut_wf + return make_workflows( + choppers=v20_choppers(), source_position=v20_source_position() + ) +@pytest.mark.parametrize("engine", ["tof", "analytical"]) @pytest.mark.parametrize( "ltotal", [ @@ -412,7 +424,7 @@ def lut_workflow_v20_choppers(): @pytest.mark.parametrize("time_offset_unit", ["s", "ms", "us", "ns"]) @pytest.mark.parametrize("distance_unit", ["m", "mm"]) def test_v20_compute_wavelengths_from_wfm( - lut_workflow_v20_choppers, ltotal, time_offset_unit, distance_unit + lut_workflow_v20_choppers, engine, ltotal, time_offset_unit, distance_unit ): monitors = { f"detector{i}": ltot for i, ltot in enumerate(ltotal.flatten(to="detector")) @@ -444,7 +456,7 @@ def test_v20_compute_wavelengths_from_wfm( ref = sc.sort(ref, key='id') pl = setup_workflow( - raw_data=raw, ltotal=ltotal, lut_workflow=lut_workflow_v20_choppers + raw_data=raw, ltotal=ltotal, lut_workflow=lut_workflow_v20_choppers[engine] ) wavs = pl.compute(unwrap.WavelengthDetector[SampleRun]) @@ -455,5 +467,8 @@ def test_v20_compute_wavelengths_from_wfm( (x.coords["wavelength"] - ref.coords["wavelength"]) / ref.coords["wavelength"] ) - assert np.nanpercentile(diff.values, 99) < 0.02 + if engine == "tof": + assert np.nanpercentile(diff.values, 99) < 0.02 + else: + assert np.nanpercentile(diff.values, 90) < 0.05 assert sc.isclose(ref.data.sum(), da.data.sum(), rtol=sc.scalar(1.0e-3)) diff --git a/pixi.lock b/pixi.lock index 4b54394e8..1119025ac 100644 --- a/pixi.lock +++ b/pixi.lock @@ -21753,7 +21753,7 @@ packages: - graphviz>=0.20 - sciline>=25.11.0 - scipp>=26.3.1 - - scippneutron>=25.11.1 + - scippneutron>=26.5.0 - scippnexus>=25.6.0 - scipy>=1.14 - ipywidgets>=8.1 ; extra == 'test'