# DREAM WFM choppers with tof

Stitching WFM data based on the DREAM instrument chopper system.

In [None]:
import numpy as np
import scipp as sc
import plopp as pp
import tof
%matplotlib widget

Hz = sc.Unit('Hz')
deg = sc.Unit('deg')
meter = sc.Unit('m')
AA = sc.Unit('angstrom')

## Chopper and detector setup

In [None]:
choppers = [
    tof.Chopper(
        frequency=14*Hz,
        direction=tof.AntiClockwise,
        centers=sc.array(
            dims=['cutout'],
            values=[0, 72, 86.4, 115.2, 172.8, 273.6, 288.0, 302.4],
            unit='deg',
        ),
        widths=sc.array(
            dims=['cutout'],
            values=[2.46, 3.02, 3.27, 3.27, 5.02, 3.93, 3.93, 2.46],
            unit='deg',
        ),
        phase=(286-180)*deg,
        distance= 6.145*meter, 
        name="PSC1",
    ),
    
    tof.Chopper(
        frequency=14*Hz,
        direction=tof.Clockwise,
        centers=sc.array(
            dims=['cutout'],
            values=[0, 28.8, 57.6, 144, 158.4, 216, 259.2, 316.8],
            unit='deg',
        ),
        widths=sc.array(
            dims=['cutout'],
            values=[2.46, 3.60, 3.60, 3.23, 3.27, 3.77, 3.94, 2.62],
            unit='deg',
        ),
        phase=236*deg,
        distance=6.155*meter,
        name="PSC2",
    ),
    
    tof.Chopper(
        frequency=14*Hz,
        direction=tof.AntiClockwise,
        centers=sc.array(
            dims=['cutout'],
            values=[0.],
            unit='deg',
        ),
        widths=sc.array(
            dims=['cutout'],
            values = [27.6],
            unit='deg',
        ),
        phase=(297-180-90)*deg,
        distance=6.174*meter,   
        name="OC",
    ),
    
    tof.Chopper(
        frequency=112*Hz,
        direction=tof.AntiClockwise,
        centers=sc.array(
            dims=['cutout'],
            values=[0., 180.],
            unit='deg',
        ),
        widths=sc.array(
            dims=['cutout'],
            values=[73.75, 73.75],
            unit='deg',
        ),
        phase=(215-180)*deg,
        distance=9.78*meter,
        name="BC",
    ),
    
    tof.Chopper(
        frequency=28*Hz,
        direction=tof.AntiClockwise,
        centers=sc.array(
            dims=['cutout'],
            values=[0.],
            unit='deg',
        ),
        widths=sc.array(
            dims=['cutout'],
            values=[314.9],
            unit='deg',
        ),
        phase=(280-180)*deg,
        distance=13.05*meter,
        name="T0",
    ),  
]

detectors = [tof.Detector(distance=(76.55+1.125)*meter, name='detector')]

## Create an ESS source and run the model

In [None]:
source = tof.Source(facility='ess', neutrons=1_000_000)
source.plot()

In [None]:
model = tof.Model(source=source, choppers=choppers, detectors=detectors)
model

In [None]:
results = model.run()
results

In [None]:
results.plot(blocked_rays=5000)

In [None]:
# Squeeze the pulse dimension since we only have one pulse
events = results['detector'].data.squeeze()
# Remove the events that don't make it to the detector
events = events[~events.masks['blocked_by_others']]
# Histogram and plot
events.hist(wavelength=500, tof=500).plot(norm='log', grid=True)

In [None]:
binned = events.bin(tof=500)

# Weighted mean of wavelength inside each bin
mu = (
    binned.bins.data * binned.bins.coords['wavelength']
).bins.sum() / binned.bins.sum()

# Variance of wavelengths inside each bin
var = (
    binned.bins.data * (binned.bins.coords['wavelength'] - mu) ** 2
) / binned.bins.sum()

In [None]:
mu.plot(grid=True)

In [None]:
stddev = sc.sqrt(var.hist())
stddev.plot(grid=True)

In [None]:
mu.data = sc.where(stddev.data > 0.02 * AA, np.nan * AA, mu.data)

In [None]:
# a = stddev.copy()
a = sc.where(stddev.data < 0.02 * AA, np.nan * sc.units.us, sc.midpoints(stddev.coords['tof']))
a

In [None]:
fig = events.hist(wavelength=500, tof=500).plot(norm='log', grid=True)
fig.ax.axvspan(a.nanmin().value, a.nanmax().value, alpha=0.3)
fig

In [None]:
from scipp.scipy.interpolate import interp1d

# Set up interpolator
y = mu.copy()
y.coords['tof'] = sc.midpoints(y.coords['tof'])
f = interp1d(y, 'tof', bounds_error=False)

# Compute wavelengths
wavs = f(events.coords['tof'].rename_dims(event='tof'))
wavelengths = sc.DataArray(
    data=sc.ones(sizes=wavs.sizes, unit='counts'), coords={'wavelength': wavs.data}
).rename_dims(tof='event')
wavelengths

In [None]:
naive = events.copy()
speed = detectors[0].distance / naive.coords['tof']
naive.coords['wavelength'] = sc.reciprocal(
    speed * sc.constants.m_n / sc.constants.h
).to(unit='angstrom')

In [None]:
pp.plot(
    {
        'naive': naive.hist(wavelength=300),
        'wfm': wavelengths[~sc.isnan(wavelengths.coords['wavelength'])].hist(wavelength=300),
        'original': events.hist(wavelength=300),
    }
)