# Stitching WFM data

Wavelength-frame-multiplication (WFM) is a technique commonly used at long-pulse facilities to improve the resolution of the results measured at the neutron detectors.
See for example the article by [Schmakat et al. (2020)](https://www.sciencedirect.com/science/article/pii/S0168900220308640) for a description of how WFM works.

In this notebook, we show how `tof` can be used to find the boundaries of the WFM frames, and apply a time correction to each frame,
in order to obtain more accurate wavelengths.

In [None]:
import numpy as np
import scipp as sc
from scippneutron.conversion.graph.beamline import beamline
from scippneutron.conversion.graph.tof import elastic
import plopp as pp
import tof

Hz = sc.Unit("Hz")
deg = sc.Unit("deg")
meter = sc.Unit("m")

## Create a source pulse

We first create a source with one pulse containing 1 million neutrons whose distribution follows the ESS time and wavelength profiles (both thermal and cold neutrons are included).

In [None]:
source = tof.Source(facility="ess", neutrons=1_000_000)
source.plot()

## Chopper set-up

We create a list of choppers that will be included in our beamline.
In our case, we make two WFM choppers, and two frame-overlap choppers.
All choppers have 6 openings.

Finally, we also add a pulse-overlap chopper with a single opening.
These choppers are copied after the [V20 ESS beamline at HZB](https://www.sciencedirect.com/science/article/pii/S0168900216309597).

In [None]:
choppers = [
    tof.Chopper(
        frequency=70.0 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[98.71, 155.49, 208.26, 257.32, 302.91, 345.3],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[109.7, 170.79, 227.56, 280.33, 329.37, 375.0],
            unit="deg",
        ),
        phase=47.10 * deg,
        distance=6.6 * meter,
        name="WFM1",
    ),
    tof.Chopper(
        frequency=70 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[80.04, 141.1, 197.88, 250.67, 299.73, 345.0],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[91.03, 156.4, 217.18, 269.97, 322.74, 375.0],
            unit="deg",
        ),
        phase=76.76 * deg,
        distance=7.1 * meter,
        name="WFM2",
    ),
    tof.Chopper(
        frequency=56 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[74.6, 139.6, 194.3, 245.3, 294.8, 347.2],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[95.2, 162.8, 216.1, 263.1, 310.5, 371.6],
            unit="deg",
        ),
        phase=62.40 * deg,
        distance=8.8 * meter,
        name="Frame-overlap 1",
    ),
    tof.Chopper(
        frequency=28 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[98.0, 154.0, 206.8, 254.0, 299.0, 344.65],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[134.6, 190.06, 237.01, 280.88, 323.56, 373.76],
            unit="deg",
        ),
        phase=12.27 * deg,
        distance=15.9 * meter,
        name="Frame-overlap 2",
    ),
    tof.Chopper(
        frequency=7 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[30.0],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[140.0],
            unit="deg",
        ),
        phase=0 * deg,
        distance=22 * meter,
        name="Pulse-overlap",
    ),
]

## Detector set-up

We add a detector 32 meters from the source.

In [None]:
detectors = [
    tof.Detector(distance=32.0 * meter, name="detector"),
]

## Find WFM frame edges

To compute the frame edges, we use one of the openings in the WFM choppers at a time,
run the `tof` simulation, and find the min and max tof of the neutrons that make it through the chopper cascade.

In [None]:
nframes = len(choppers[0].open)
results = []

print("Computing frames")

# Run the model with a single frame at a time
for i in range(nframes):
    wfm_choppers = [
        tof.Chopper(
            frequency=choppers[j].frequency,
            open=choppers[j].open[i : i + 1],
            close=choppers[j].close[i : i + 1],
            phase=choppers[j].phase,
            distance=choppers[j].distance,
            name=choppers[j].name,
        )
        for j in (0, 1)
    ]
    new_choppers = wfm_choppers + choppers[2:]
    model = tof.Model(source=source, choppers=new_choppers, detectors=detectors)
    res = model.run()
    results.append(res)

invalid_frames = True
fact = 0.01
while invalid_frames:
    print(f"Searching for bounds: threshold={fact:.2f}")
    frame_bounds = []
    for res in results:
        tofs = res.detectors["detector"].tofs.data["visible"]["pulse:0"].hist(tof=500)
        tofs.coords["tof"] = sc.midpoints(tofs.coords["tof"])
        # We need to filter out the outliers because some stray rays from other frames make it through
        filtered = tofs[tofs.data > fact * tofs.data.max()].coords["tof"]
        frame_bounds.append((filtered.min(), filtered.max()))
    if all(
        frame_bounds[k][1] < frame_bounds[k + 1][0]
        for k in range(len(frame_bounds) - 1)
    ):
        invalid_frames = False
    else:
        fact += 0.01

The edges of the frames are the following

In [None]:
frame_bounds

In [None]:
frames = sc.DataGroup(
    {
        "time_min": sc.concat([b[0] for b in frame_bounds], dim="frame"),
        "time_max": sc.concat([b[1] for b in frame_bounds], dim="frame"),
    }
)
frames

### Inspecting the frames

As a consistency check, we can run the model with all of the chopper openings, and overlay the frame bounds,
to verify that there is no overlap between the frames.

In [None]:
model = tof.Model(source=source, choppers=choppers, detectors=detectors)
res = model.run()
f = res.detectors["detector"].tofs.plot(legend=False)
for i, bound in enumerate(frame_bounds):
    col = f"C{i + 1}"
    f.ax.axvspan(bound[0].value, bound[1].value, alpha=0.1, color=col)
    f.ax.axvline(bound[0].value, color=col)
    f.ax.axvline(bound[1].value, color=col)
f

### Time-distance diagram

Another way of verifying the frames that were computed is to find the fastest and slowest neutron in each frame,
and propagate those through the choppers and show the time-distance diagram.

In [None]:
frame_min = []
frame_max = []

for i in range(nframes):
    print("frame", i)
    da = results[i]["detector"].data.squeeze()
    da = da[~da.masks["blocked_by_others"]]
    ts = da.coords["tof"]
    sel = (ts > frames["time_min"][i]) & (ts < frames["time_max"][i])
    filtered = da[sel]
    frame_min.append(filtered[np.argmin(filtered.coords["tof"])])
    frame_max.append(filtered[np.argmax(filtered.coords["tof"])])

# Create a source by manually setting neutron birth times and wavelengths
birth_times = sc.concat(
    [f.coords["time"] for f in frame_min] + [f.coords["time"] for f in frame_max],
    dim="event",
)
wavelengths = sc.concat(
    [f.coords["wavelength"] for f in frame_min]
    + [f.coords["wavelength"] for f in frame_max],
    dim="event",
)
source_min_max = tof.Source.from_neutrons(birth_times=birth_times, wavelengths=wavelengths)
source_min_max

We can see that the source has 12 neutrons, which is 2 per frame.
Re-running the model with those yields

In [None]:
model = tof.Model(source=source_min_max, choppers=choppers, detectors=detectors)
model.run().plot()

What is interesting, and also a contract of WFM, is that the wavelength of the slowest neutron in one frame is very close to the wavelength of the fastest neutron in the next frame.

## Stitching the data: computing a new time-of-flight

Using WFM choppers allows us to re-define the burst time of the neutrons, and compute a more accurate wavelength.

In the following, we use the boundaries of the frames to select neutrons in each frame,
and apply a correction to the time-of-flight of those neutrons which corresponds to the time when the WFM choppers are open.

### Computing wavelengths from the naive time-of-flight

We first begin by computing the neutron wavelengths as if there were no WFM choppers.
We take the distance from the source to the detector, and use the neutron arrival time at the detector to compute a speed and hence a wavelength.

In [None]:
# Extract the tof data of the events that make it through to the detector
tofs = res["detector"].tofs.data["visible"]["pulse:0"].copy()
tofs.coords["source_position"] = sc.vector([0.0, 0.0, 0.0], unit="m")
tofs.coords["position"] = sc.vector(
    [0.0, 0.0, detectors[0].distance.value], unit=detectors[0].distance.unit
)
tofs

Converting the time-of-flight to wavelength is done using Scipp's `transform_coords`:

In [None]:
# Make a coordinate transformation graph to compute wavelength from tof
graph = {**beamline(scatter=False), **elastic("tof")}
wav_naive = tofs.transform_coords("wavelength", graph=graph)
wav_naive.hist(wavelength=300).plot()

### Computing time-of-flight from WFM choppers to detector

Instead of using the source as the departure point of the neutrons, we use the WFM choppers.
This means that the distance used for the flight is from the WFM choppers to the detector,
and the flight time is from when the choppers open to the arrival time at detector.

We first get the times when the choppers are open (mid-point between open and close times for the 2 WFM choppers):

In [None]:
times_wfm1 = choppers[0].open_close_times()
times_wfm2 = choppers[1].open_close_times()

corrections = [
    sc.concat(
        [
            times_wfm1[0][i + nframes],  # open wfm1
            times_wfm1[1][i + nframes],  # close wfm1
            times_wfm2[0][i + nframes],  # open wfm2
            times_wfm2[1][i + nframes],  # close wfm2
        ],
        dim="x",
    ).mean()
    for i in range(nframes)
]

frames["time_correction"] = sc.concat(corrections, dim="frame")
frames

We apply the correction which effectively 'stitches' the data back together:

In [None]:
def stitch(
    data: sc.DataArray,
    frames: sc.DataGroup,
    dim: str,
) -> sc.DataArray:
    edges = sc.flatten(
        sc.transpose(
            sc.concat([frames["time_min"], frames["time_max"]], "dummy"),
            dims=["frame", "dummy"],
        ),
        to=dim,
    )

    binned = data.bin({dim: edges})

    for i in range(frames.sizes["frame"]):
        binned[dim, i * 2].bins.coords[dim] -= frames["time_correction"]["frame", i]

    binned.masks["frame_gaps"] = (
        sc.arange(dim, 2 * frames.sizes["frame"] - 1) % 2
    ).astype(bool)
    binned.masks["frame_gaps"].unit = None
    return binned.bins.concat()


wfm_tofs = stitch(data=tofs, frames=frames, dim="tof")

Finally, we change the `source_position` to now be the mid-point between the WFM choppers:

In [None]:
wfm_tofs.coords["source_position"] = sc.vector(
    [0.0, 0.0, 0.5 * (choppers[0].distance.value + choppers[1].distance.value)],
    unit=choppers[0].distance.unit,
)
wfm_tofs

We can now compute wavelengths using the `transform_coords`, as before:

In [None]:
wav_wfm = wfm_tofs.transform_coords("wavelength", graph=graph)

### Comparison between naive and WFM computations

We compare the wavelengths computed using the naive approach, the WFM approach, and also with the true wavelengths of the neutrons:

In [None]:
pp.plot(
    {
        "naive": wav_naive.hist(wavelength=300),
        "wfm": wav_wfm.hist(wavelength=300),
        "truth": res["detector"]
        .wavelengths.data["visible"]["pulse:0"]
        .hist(wavelength=300),
    }
)

As we can see, the WFM approach vastly outperforms the naive approach.