# Stitching WFM data

Wavelength-frame-multiplication (WFM) is a technique commonly used at long-pulse facilities to improve the resolution of the results measured at the neutron detectors.
See for example the article by [Schmakat et al. (2020)](https://www.sciencedirect.com/science/article/pii/S0168900220308640) for a description of how WFM works.

In this notebook, we show how `tof` can be used to find the boundaries of the WFM frames, and apply a time correction to each frame,
in order to obtain more accurate wavelengths.

In [None]:
import numpy as np
import scipp as sc
from scippneutron.conversion.graph.beamline import beamline
from scippneutron.conversion.graph.tof import elastic
import plopp as pp
import tof

Hz = sc.Unit("Hz")
deg = sc.Unit("deg")
meter = sc.Unit("m")

## Create a source pulse

We first create a source with one pulse containing 500,000 neutrons whose distribution follows the ESS time and wavelength profiles (both thermal and cold neutrons are included).

In [None]:
source = tof.Source(facility="ess", neutrons=500_000)
source.plot()

In [None]:
source.data

## Chopper set-up

We create a list of choppers that will be included in our beamline.
In our case, we make two WFM choppers, and two frame-overlap choppers.
All choppers have 6 openings.

Finally, we also add a pulse-overlap chopper with a single opening.
These choppers are copied after the [V20 ESS beamline at HZB](https://www.sciencedirect.com/science/article/pii/S0168900216309597).

In [None]:
choppers = [
    tof.Chopper(
        frequency=70.0 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[98.71, 155.49, 208.26, 257.32, 302.91, 345.3],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[109.7, 170.79, 227.56, 280.33, 329.37, 375.0],
            unit="deg",
        ),
        phase=47.10 * deg,
        distance=6.6 * meter,
        name="WFM1",
    ),
    tof.Chopper(
        frequency=70 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[80.04, 141.1, 197.88, 250.67, 299.73, 345.0],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[91.03, 156.4, 217.18, 269.97, 322.74, 375.0],
            unit="deg",
        ),
        phase=76.76 * deg,
        distance=7.1 * meter,
        name="WFM2",
    ),
    tof.Chopper(
        frequency=56 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[74.6, 139.6, 194.3, 245.3, 294.8, 347.2],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[95.2, 162.8, 216.1, 263.1, 310.5, 371.6],
            unit="deg",
        ),
        phase=62.40 * deg,
        distance=8.8 * meter,
        name="Frame-overlap 1",
    ),
    tof.Chopper(
        frequency=28 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[98.0, 154.0, 206.8, 254.0, 299.0, 344.65],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[134.6, 190.06, 237.01, 280.88, 323.56, 373.76],
            unit="deg",
        ),
        phase=12.27 * deg,
        distance=15.9 * meter,
        name="Frame-overlap 2",
    ),
    tof.Chopper(
        frequency=7 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[30.0],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[140.0],
            unit="deg",
        ),
        phase=0 * deg,
        distance=22 * meter,
        name="Pulse-overlap",
    ),
]

## Detector set-up

We add a detector 32 meters from the source.

In [None]:
detectors = [
    tof.Detector(distance=L * meter, name=f"detector{i}")
    for i, L in enumerate(np.linspace(33, 37, 100))
]

## Run the simulation

We propagate our pulse of neutrons through the chopper cascade and inspect the results.

In [None]:
model = tof.Model(source=source, choppers=choppers, detectors=detectors)
full_beamline_results = model.run()
full_beamline_results.plot(blocked_rays=5000)

In [None]:
full_beamline_results['detector'].plot()

## Find WFM frame edges

To compute the frame edges, we use one of the openings in the WFM choppers at a time,
run the `tof` simulation, and find the min and max tof of the neutrons that make it through the chopper cascade.

### Finding the fastest and slowest neutrons

To find the range of chopper open and close times relevant to this pulse,
we find the times at which the slowest and fastest neutrons went through the two WFM choppers.

This will then be used to select the correct chopper openings to compute the WFM time corrections before converting flight times to wavelengths.

In [None]:
# Get the neutrons that make it to the detector
da = full_beamline_results['detector'].data.squeeze()
visible = da[~da.masks['blocked_by_others']]

# Find fastest and slowest neutrons
id_fast = visible[np.argmin(visible.coords['tof'])].coords['id']
id_slow = visible[np.argmax(visible.coords['tof'])].coords['id']

# Compute the times when those neutrons crossed the WFM choppers
tfast = {}
tslow = {}
for key in ("WFM1", "WFM2"):
    data = full_beamline_results[key].data.squeeze()
    tfast[key] = data['id', id_fast].coords['tof']
    tslow[key] = data['id', id_slow].coords['tof']

tfast, tslow

### Using one choppers opening at a time

For each WFM choppers, we propagate the neutrons using one openings at a time,
to find which combinations of openings lead to a frame on the detector.

In [None]:
ncutouts = len(choppers[0].open)
results = []
time_offsets = []

print("Computing frames")
new_choppers = {c.name: c for c in choppers}

# Run the model with a single frame at a time
for i in range(ncutouts):
    new_choppers['WFM1'] = tof.Chopper(
        **choppers[0].as_dict()
    )  # make a copy of the chopper
    new_choppers['WFM1'].open = choppers[0].open[i : i + 1]
    new_choppers['WFM1'].close = choppers[0].close[i : i + 1]
    for j in range(ncutouts):
        new_choppers['WFM2'] = tof.Chopper(
            **choppers[1].as_dict()
        )  # make a copy of the chopper
        new_choppers['WFM2'].open = choppers[1].open[j : j + 1]
        new_choppers['WFM2'].close = choppers[1].close[j : j + 1]
        # Run the simulation using one opening from each WFM chopper
        res = tof.Model(
            source=source, choppers=list(new_choppers.values()), detectors=detectors
        ).run()
        if res.detectors["detector"].tofs.data["visible"].sizes['event'] != 0:
            print(f'Frame found for chopper WFM1:cutout{i}, WFM2:cutout{j}')
            # Append to result list
            results.append(res)
            # Record the chopper opening and close time for this frame
            times = []
            for key in ("WFM1", "WFM2"):
                topen, tclose = new_choppers[key].open_close_times()
                # Chopper open/close times main contain multiple rotations, and we
                # thus need to select the rotation that is relevant for the pulse of interest
                sel = (tclose > tfast[key]) & (topen < tslow[key])
                times.extend([topen[sel], tclose[sel]])
            time_offsets.append(sc.concat(times, dim='x').mean())

In [None]:
# The offsets are not guaranteed to be sorted, but some code below needs them to be
time_offsets, results = zip(*sorted(zip(time_offsets, results)))
time_offsets

## Reducing overlap between frames

There is a small amount of time overlap between some of the frames, as illustrated by this figure
(for example, some of the neutrons that belong to <span style="color: #2ca02c;">frame 2</span> are bleeding into the range that belongs to <span style="color: #ff7f0e;">frame 1</span>).

In [None]:
# Define a common binning that covers the entire range
full_tofs = (
    full_beamline_results.detectors["detector"]
    .tofs.visible.data['pulse:0']
    .coords['tof']
)
bins = sc.linspace('tof', full_tofs.min(), full_tofs.max(), num=301)

pp.plot(
    {
        f"Frame-{i}": res.detectors["detector"]
        .tofs.visible.data['pulse:0']
        .hist(tof=bins)
        for i, res in enumerate(results)
    }
)

Regions where time of flight overlaps between frames are regions where the wavelength of a neutron cannot accurately be determined,
as neutrons have the same arrival time at the detector,
but different start times at the source end of the beamline.

They are usually a sign of a faulty design in the chopper cascade,
and we thus need to discard neutrons from any overlapping regions in our final result.

One way to achieve this is to remove the outliers from the edges of the frames, until the frames no longer overlap.
This is done below by gradually increasing percentile threshold on the `tof` reading at the detector until overlap is gone.

In [None]:
percentile_step = 0.1  # Step by which to increase percentile thresholds

non_overlapping = []
for res in results:
    data = res.detectors["detector"].data.squeeze()
    non_overlapping.append(data[~data.masks['blocked_by_others']])

for i in range(len(results) - 1):
    # Find mean wavelength between frames
    tofs1 = non_overlapping[i].coords['tof']
    tofs2 = non_overlapping[i + 1].coords['tof']
    p = 0.0
    overlap = True
    while overlap:
        tofmin = np.percentile(tofs2.values, p)
        tofmax = np.percentile(tofs1.values, 100 - p)
        overlap = tofmin < tofmax
        p += percentile_step
    # Filter on tofs
    non_overlapping[i] = non_overlapping[i][tofs1 <= sc.scalar(tofmax, unit=tofs1.unit)]
    non_overlapping[i + 1] = non_overlapping[i + 1][
        tofs2 >= sc.scalar(tofmin, unit=tofs2.unit)
    ]

for frame in non_overlapping:
    w = frame.coords["wavelength"]
    t = frame.coords["tof"]
    print(
        f"Neutrons={len(frame)}, wmin={w.min():c}, wmax={w.max():c}, tofmin={t.min():c}, tofmax={t.max():c}"
    )

In [None]:
# Store the frame bounds to a format used by the stitching
frames = sc.DataGroup(
    {
        "time_min": sc.concat(
            [data.coords['tof'].min() for data in non_overlapping], dim="frame"
        ),
        "time_max": sc.concat(
            [data.coords['tof'].max() for data in non_overlapping], dim="frame"
        ),
        "time_correction": sc.concat(time_offsets, dim='frame'),
    }
)
frames

## Inspecting the frames

### Tof frame boundaries

As a consistency check,
we can overlay the frame bounds onto the detector reading to verify that each frame spans the expected `tof` range,
and that there is no overlap between the frames.

In [None]:
f = full_beamline_results.detectors["detector"].tofs.plot(legend=False, color='k')
for i, (left, right) in enumerate(zip(frames["time_min"], frames["time_max"])):
    col = f"C{i}"
    f.ax.axvspan(left.value, right.value, alpha=0.1, color=col)
    f.ax.axvline(left.value, color=col)
    f.ax.axvline(right.value, color=col)
f

### Time-distance diagram

Another way of verifying the frames that were computed is to find the fastest and slowest neutron in each frame,
and propagate those through the choppers and show the time-distance diagram.

In [None]:
frame_min = []
frame_max = []

for data in non_overlapping:
    ind_min = np.argmin(data.coords['wavelength'].values)
    ind_max = np.argmax(data.coords['wavelength'].values)
    frame_min.append(data[ind_min])
    frame_max.append(data[ind_max])

# Create a source by manually setting neutron birth times and wavelengths
birth_times = sc.concat(
    [f.coords["time"] for f in frame_min] + [f.coords["time"] for f in frame_max],
    dim="event",
)
wavelengths = sc.concat(
    [f.coords["wavelength"] for f in frame_min]
    + [f.coords["wavelength"] for f in frame_max],
    dim="event",
)
source_min_max = tof.Source.from_neutrons(
    birth_times=birth_times, wavelengths=wavelengths
)
source_min_max

We can see that the source has 12 neutrons, which is 2 per frame.
Re-running the model with those yields

In [None]:
model = tof.Model(source=source_min_max, choppers=choppers, detectors=detectors)
model.run().plot()

It is important to note here that the maximum wavelength of frame $i$ should be close to the minimum wavelength of frame $i+1$.

## Stitching the data: computing a new time-of-flight

Using WFM choppers allows us to re-define the burst time of the neutrons, and compute a more accurate wavelength.

In the following, we use the boundaries of the frames to select neutrons in each frame,
and apply a correction to the time-of-flight of those neutrons which corresponds to the time when the WFM choppers are open.

### Computing wavelengths from the naive time-of-flight

We first begin by computing the neutron wavelengths as if there were no WFM choppers.
We take the distance from the source to the detector, and use the neutron arrival time at the detector to compute a speed and hence a wavelength.

In [None]:
# Extract the tof data of the events that make it through to the detector
tofs = full_beamline_results["detector"].tofs.data["visible"]["pulse:0"].copy()
tofs.coords["source_position"] = sc.vector([0.0, 0.0, 0.0], unit="m")
tofs.coords["position"] = sc.vector(
    [0.0, 0.0, detectors[0].distance.value], unit=detectors[0].distance.unit
)
tofs

Converting the time-of-flight to wavelength is done using Scipp's `transform_coords`:

In [None]:
# Make a coordinate transformation graph to compute wavelength from tof
graph = {**beamline(scatter=False), **elastic("tof")}
wav_naive = tofs.transform_coords("wavelength", graph=graph)
wav_naive.hist(wavelength=300).plot()

### Computing time-of-flight from WFM choppers to detector

Instead of using the source as the departure point of the neutrons, we use the WFM choppers.
This means that the distance used for the flight is from the WFM choppers to the detector,
and the flight time is from when the choppers open to the arrival time at detector.

Those were computed above, and stored as `time_correction` in the `frames` data group.

We apply the corrections, which effectively 'stitches' the data back together:

In [None]:
def stitch(
    data: sc.DataArray,
    frames: sc.DataGroup,
    dim: str,
) -> sc.DataArray:
    edges = sc.flatten(
        sc.transpose(
            sc.concat([frames["time_min"], frames["time_max"]], "dummy"),
            dims=["frame", "dummy"],
        ),
        to=dim,
    )

    binned = data.bin({dim: edges})

    for i in range(frames.sizes["frame"]):
        binned[dim, i * 2].bins.coords[dim] -= frames["time_correction"]["frame", i]

    binned.masks["frame_gaps"] = (
        sc.arange(dim, 2 * frames.sizes["frame"] - 1) % 2
    ).astype(bool)
    binned.masks["frame_gaps"].unit = None
    return binned.bins.concat()


wfm_tofs = stitch(data=tofs, frames=frames, dim="tof")

Finally, we change the `source_position` to now be the mid-point between the WFM choppers:

In [None]:
wfm_tofs.coords["source_position"] = sc.vector(
    [0.0, 0.0, 0.5 * (choppers[0].distance.value + choppers[1].distance.value)],
    unit=choppers[0].distance.unit,
)
wfm_tofs

We can now compute wavelengths using the `transform_coords`, as before:

In [None]:
wav_wfm = wfm_tofs.transform_coords("wavelength", graph=graph)

### Comparison between naive and WFM computations

We compare the wavelengths computed using the naive approach, the WFM approach, and also with the true wavelengths of the neutrons:

In [None]:
pp.plot(
    {
        "naive": wav_naive.hist(wavelength=300),
        "wfm": wav_wfm.hist(wavelength=300),
        "truth": full_beamline_results["detector"]
        .wavelengths.data["visible"]["pulse:0"]
        .hist(wavelength=300),
    }
)

As we can see, the WFM approach vastly outperforms the naive approach.

In [None]:
bs = [(float(name[len('detector'):]), d.data['pulse', 0].copy()) for name, d in full_beamline_results.detectors.items()]
for L, b in bs:
    b.coords['Ltot'] = sc.scalar(L, unit='m')

bs = [b for _, b in bs]

b = sc.concat(bs, dim='Ltot')
for m in b.masks.values():
    b = b[~m]

b = b.group('Ltot').bins.concat('event').bin(tof=500)
mu = (b * b.bins.coords['wavelength']).bins.mean() / b.bins.mean()
var = (b * (b.bins.coords['wavelength'] - mu)**2).bins.mean() / b.bins.mean()
(sc.sqrt(var) / mu).plot()

In [None]:
(T - t0) * iL = l

In [None]:
mu.plot()

In [None]:
 mu.plot().save('wavelength.png')

In [None]:
(sc.midpoints(mu.coords['tof']) - ((mu * mu.coords['Ltot']) / sc.constants.h * sc.constants.m_n).to(unit='us')).plot()