In [None]:
%matplotlib widget
import scipp as sc

## Loading dataset

> Loader is not part of ``essimaging`` since McStas dataset format is not stabilized yet.

In [None]:
import scippnexus as snx
from typing import cast, NewType
from ess.reduce.nexus.types import FilePath


_DataPath = NewType('_DataPath', str)
_DefaultDataPath = _DataPath(
    "entry1/data/transmission_event_signal_dat_list_p_t_x_y_z_vx_vy_vz/events"
)
_FileLock = NewType('_FileLock', bool)
"""Lock the file to prevent concurrent access."""
_DefaultFileLock = _FileLock(True)
OdinSimulationRawData = NewType('OdinSimulationRawData', sc.DataArray)
ProbabilityToCountsScaleFactor = NewType('ProbabilityToCountsScaleFactor', sc.Variable)
"""Translate the probability to counts."""
DefaultProbabilityToCountsScaleFactor = ProbabilityToCountsScaleFactor(
    sc.scalar(1_000, unit='dimensionless')
)
DetectorStartX = NewType('DetectorStartX', sc.Variable)
"""Start of the detector in x direction."""
DefaultDetectorStartX = DetectorStartX(sc.scalar(-0.03, unit='m'))
DetectorStartY = NewType('DetectorStartY', sc.Variable)
"""Start of the detector in y direction."""
DefaultDetectorStartY = DetectorStartY(sc.scalar(-0.03, unit='m'))

DetectorEndX = NewType('DetectorEndX', sc.Variable)
"""End of the detector in x direction."""
DefaultDetectorEndX = DetectorEndX(sc.scalar(0.03, unit='m'))
DetectorEndY = NewType('DetectorEndY', sc.Variable)
"""End of the detector in y direction."""
DefaultDetectorEndY = DetectorEndY(sc.scalar(0.03, unit='m'))

McStasManualResolution = NewType('McStasManualResolution', tuple)
"""Manual resolution for McStas data (how many pixels per axis x, y)"""
DefaultMcStasManualResolution = McStasManualResolution((1024, 1024))


def _nth_col_or_row_lookup(
    start: sc.Variable, stop: sc.Variable, resolution: int, dim: str
) -> sc.Lookup:
    """Lookup the nth column or row."""
    position = sc.linspace(
        dim, start=start, stop=stop, num=resolution + 1, unit=start.unit
    )
    nth_col_or_row = sc.arange(dim=dim, start=0, stop=resolution, unit='dimensionless')
    hist = sc.DataArray(data=nth_col_or_row, coords={dim: position})
    return sc.lookup(hist, dim)


def _position_to_pixel_id(
    *,
    x_pos: sc.Variable,
    y_pos: sc.Variable,
    detector_start_x: DetectorStartX = DefaultDetectorStartX,
    detector_start_y: DetectorStartY = DefaultDetectorStartY,
    detector_end_x: DetectorEndX = DefaultDetectorEndX,
    detector_end_y: DetectorEndY = DefaultDetectorEndY,
    resolution: McStasManualResolution = DefaultMcStasManualResolution,
) -> sc.Variable:
    """Hardcode pixel ids from positions."""
    x_position_lookup = _nth_col_or_row_lookup(
        detector_start_x, detector_end_x, resolution[0], 'x'
    )
    y_position_lookup = _nth_col_or_row_lookup(
        detector_start_y, detector_end_y, resolution[1], 'y'
    )
    n_cols = x_position_lookup[x_pos]
    n_rows = y_position_lookup[y_pos]
    return n_rows * resolution[0] + n_cols


def load_odin_simulation_data(
    file_path: FilePath,
    _data_path: _DataPath = _DefaultDataPath,
    _file_lock: _FileLock = _DefaultFileLock,
    detector_start_x: DetectorStartX = DefaultDetectorStartX,
    detector_start_y: DetectorStartY = DefaultDetectorStartY,
    detector_end_x: DetectorEndX = DefaultDetectorEndX,
    detector_end_y: DetectorEndY = DefaultDetectorEndY,
    resolution: McStasManualResolution = DefaultMcStasManualResolution,
    probability_scale_factor: ProbabilityToCountsScaleFactor = DefaultProbabilityToCountsScaleFactor,
) -> OdinSimulationRawData:
    with snx.File(file_path, "r", locking=_file_lock) as f:
        # The name p_t_x_y_z_vx_vy_vz represents
        # probability, time of arrival, position(x, y, z) and velocity(vx, vy, vz).
        # The name also represents the order of each field in the table.
        # For example, probability is the first field, so data['dim_1', 0] is the probability.
        data = f[_data_path][()].rename_dims({'dim_0': 'event'})
        probabilities = cast(sc.Variable, data['dim_1', 0].copy())
        probabilities.unit = 'dimensionless'
        time_of_arrival = cast(sc.Variable, data['dim_1', 1].copy())
        time_of_arrival.unit = 's'  # Hardcoded unit from the data.
        positions = data['dim_1', 2:5]
        counts = (probabilities / probabilities.max()) * probability_scale_factor
        counts.unit = 'counts'
        # Units are hardcoded from the data.
        x_pos = cast(sc.Variable, positions['dim_1', 0].copy())
        x_pos.unit = 'm'
        y_pos = cast(sc.Variable, positions['dim_1', 1].copy())
        y_pos.unit = 'm'
        pixel_id = _position_to_pixel_id(
            x_pos=x_pos,
            y_pos=y_pos,
            detector_start_x=detector_start_x,
            detector_start_y=detector_start_y,
            detector_end_x=detector_end_x,
            detector_end_y=detector_end_y,
            resolution=resolution,
        )
        da = sc.DataArray(
            data=counts.copy().astype(sc.DType.int32),
            coords={
                'time_of_arrival': time_of_arrival.to(unit='us'),
                'sample_position': sc.vector([0.0, 0.0, 60.5], unit='m'),
                # Hardcoded from the data.
                'source_position': sc.vector([0.0, 0.0, 0.0], unit="m"),
                # Hardcoded from the data.
                'pixel_id': pixel_id,
            },
        )

        return OdinSimulationRawData(da)


McStasVelocities = NewType('McStasVelocities', sc.DataGroup)


def load_velocities(
    file_path: FilePath,
    _data_path: _DataPath = _DefaultDataPath,
    _file_lock: _FileLock = _DefaultFileLock,
) -> McStasVelocities:
    with snx.File(file_path, "r", locking=_file_lock) as f:
        data = f[_data_path][()].rename_dims({'dim_0': 'event'})
        velocities = data['dim_1', 5:8]
        vx = cast(sc.Variable, velocities['dim_1', 0].copy())
        vy = cast(sc.Variable, velocities['dim_1', 1].copy())
        vz = cast(sc.Variable, velocities['dim_1', 2].copy())
        for v_component in (vx, vy, vz):
            v_component.unit = 'm/s'
        # Add special tags if you want to use them as coordinates
        # for example, da.coords['vx_MC'] = vx
        # to distinguish them from the measurement
        return McStasVelocities(sc.DataGroup(vx=vx, vy=vy, vz=vz))


In [None]:
from ess.imaging.data import get_mcstas_ob_images_path, get_mcstas_sample_images_path

ob_file_path = FilePath(get_mcstas_ob_images_path())
sample_file_path = FilePath(get_mcstas_sample_images_path())
ob_da = load_odin_simulation_data(ob_file_path)
sample_da = load_odin_simulation_data(sample_file_path)
sample_da

In [None]:
def _pixel_ids_to_x(
    *,
    pixel_id: sc.Variable,
    resolution: McStasManualResolution = DefaultMcStasManualResolution,
    detector_start_x: DetectorStartX = DefaultDetectorStartX,
    detector_end_x: DetectorEndX = DefaultDetectorEndX,
) -> sc.Variable:
    n_col = pixel_id % resolution[0]
    x_interval = (detector_end_x - detector_start_x) / resolution[0]
    return (
        detector_start_x + n_col * x_interval
    ) + x_interval / 2  # Center of the pixel


def _pixel_ids_to_y(
    *,
    pixel_id: sc.Variable,
    resolution: McStasManualResolution = DefaultMcStasManualResolution,
    detector_start_y: DetectorStartY = DefaultDetectorStartY,
    detector_end_y: DetectorEndY = DefaultDetectorEndY,
) -> sc.Variable:
    n_row = pixel_id // resolution[0]
    y_interval = (detector_end_y - detector_start_y) / resolution[1]
    return (
        detector_start_y + n_row * y_interval
    ) + y_interval / 2  # Center of the pixel


def _pixel_ids_to_position(
    *, x: sc.Variable, y: sc.Variable, z_pos: sc.Variable
) -> sc.Variable:
    z = sc.zeros_like(x) + z_pos
    var = (
        sc.concat([x, y, z], 'event')
        .fold('event', dims=['pos', 'event'], shape=[3, len(x)])
        .transpose(dims=['event', 'pos'])
        .values
    )
    return sc.vectors(dims=['event'], values=var, unit='m')


In [None]:
import scipp as sc
from scippneutron.conversion import graph


plane_graph = {**graph.beamline.beamline(False), **graph.tof.kinematic("tof")}

# TODO: Replace this with actual WFM stitching method
plane_graph['tof'] = lambda time_of_arrival: time_of_arrival
plane_graph['x'] = lambda pixel_id: _pixel_ids_to_x(pixel_id=pixel_id)
plane_graph['y'] = lambda pixel_id: _pixel_ids_to_y(pixel_id=pixel_id)
plane_graph['position'] = lambda x, y: _pixel_ids_to_position(
    x=x, y=y, z_pos=sc.scalar(0.0, unit='m')
)

sc.show_graph(plane_graph, simplified=True)

In [None]:
# We want to keep all time_of_flight, tof and wavelength

sample_da = sample_da.transform_coords(
    ["wavelength", "position", "x", "y"], graph=plane_graph, keep_intermediate=False
)
ob_da = ob_da.transform_coords(
    ["wavelength", "position", "x", "y"], graph=plane_graph, keep_intermediate=False
)

sample_da

In [None]:
sample_da.hist(time_of_arrival=100).plot(
    grid=True, title="Time of Arrival Spectrum of Fe Sample"
)

In [None]:
all_pixel_ids = sc.arange('pixel_id', 0, 1024 * 1024, unit='dimensionless')
grouped_sample = sample_da.group(all_pixel_ids).bin(time_of_arrival=100)
grouped_ob = ob_da.group(all_pixel_ids).bin(
    time_of_arrival=grouped_sample.coords['time_of_arrival']
)

In [None]:
grouped_sample.coords['x'] = _pixel_ids_to_x(pixel_id=grouped_sample.coords['pixel_id'])
grouped_sample.coords['y'] = _pixel_ids_to_y(pixel_id=grouped_sample.coords['pixel_id'])

grouped_ob.coords['x'] = _pixel_ids_to_x(pixel_id=grouped_ob.coords['pixel_id'])
grouped_ob.coords['y'] = _pixel_ids_to_y(pixel_id=grouped_ob.coords['pixel_id'])

normalized_sample = grouped_sample.bins.sum() / (
    grouped_ob.bins.sum() + sc.scalar(1, unit='counts')
)
normalized_sample

## Region of Interest

In [None]:
import plopp as pp
from plopp.widgets import HBar
from plopp.widgets.tools import ToggleTool
from mpltoolbox import Rectangles


class RectInfo(sc.DataGroup):
    @property
    def min_x(self) -> sc.Variable:
        return self['min_x']

    @property
    def max_x(self) -> sc.Variable:
        return self['max_x']

    @property
    def min_y(self) -> sc.Variable:
        return self['min_y']

    @property
    def max_y(self) -> sc.Variable:
        return self['max_y']

    def __init__(
        self,
        min_x: sc.Variable,
        max_x: sc.Variable,
        min_y: sc.Variable,
        max_y: sc.Variable,
    ):
        super().__init__(
            {'min_x': min_x, 'max_x': max_x, 'min_y': min_y, 'max_y': max_y}
        )


def _get_rect_info(artist, figure) -> RectInfo:
    """
    Convert the raw rectangle info to a ``RectInfo`` object.
    each axis, and values with units.
    """
    x_range = sc.array(
        dims=['x'],
        values=[artist.xy[1], artist.xy[1] + artist.height],
        unit=figure.canvas.units['x'],
    )
    y_range = sc.array(
        dims=['y'],
        values=[artist.xy[0], artist.xy[0] + artist.width],
        unit=figure.canvas.units['y'],
    )
    return RectInfo(x_range.min(), x_range.max(), y_range.min(), y_range.max())


class MultiRectangleTool(ToggleTool):
    def __init__(self, figure, destination, value: bool = False, **kwargs):
        super().__init__(callback=self.start_stop, value=value, **kwargs)

        self._figure = figure
        self._tool = Rectangles(ax=self._figure.ax, autostart=kwargs.get('autostart', False))
        self.rectangles = []
        self._draw_node = pp.Node(lambda: self.rectangles)  # Empty rectangle info.
        self._destination = destination
        self._destination.add_parents(self._draw_node)
        self._tool.on_create(self.update_node)
        self._tool.on_change(self.update_node)

    def update_node(self, _):
        self.rectangles = [
            _get_rect_info(artist=artist, figure=self._figure)
            for artist in self._tool.children
        ]
        self._draw_node.func = lambda: self.rectangles
        self._draw_node.notify_children("")  # Empty message.

    def start_stop(self):
        """
        Toggle start or stop of the tool.
        """
        if self.value:
            self._tool.start()
        else:
            self._tool.stop()


In [None]:
# Prepare the data to be visualized
_binned_sample = normalized_sample.hist(x=1024, y=1024)
_binned_sample

In [None]:
# Prepare the merge method
from functools import reduce


def _show_histogram(da: sc.DataArray, rect_infos: list[RectInfo]) -> sc.DataArray:
    if len(rect_infos) > 0:
        masks = [
            (_binned_sample.coords['x'] > rect_info.max_x)
            | (_binned_sample.coords['x'] < rect_info.min_x)
            | (_binned_sample.coords['y'] > rect_info.max_y)
            | (_binned_sample.coords['y'] < rect_info.min_y)
            for rect_info in rect_infos
        ]
        mask = reduce(lambda x, y: x & y, masks)
        selected_sample_region = da.copy(deep=False)
        selected_sample_region.masks['roi'] = mask['x', :-1]['y', :-1]
        return selected_sample_region.sum('x').sum('y')
    else:
        return da.sum('x').sum('y')


original_data_node = pp.Node(_binned_sample)
merge_node = pp.Node(_show_histogram)
merge_node.add_parents(original_data_node)

In [None]:
data_node = pp.Node(_binned_sample.sum('time_of_arrival'))
f2d = pp.imagefigure(
    data_node, norm='log', title="Region of Interest Selection", cbar=True
)

r = MultiRectangleTool(figure=f2d, destination=merge_node, icon='vector-square')
f1d = pp.linefigure(
    merge_node, title="Time of Arrival Spectrum in ROI", grid=True, vmin=0
)

f2d.toolbar['roi'] = r
box = HBar([f2d, f1d])

In [None]:
box

If you want to use QT instead of notebook, you can use the snippet below.

```python
%matplotlib qt  # Should be at the top of the notebook.
import matplotlib.pyplot as plt

fig, ax = plt.subplots(ncols=2)
data_node = pp.Node(_binned_sample.sum('time_of_arrival'))
f2d = pp.imagefigure(
    data_node, norm='log', title="Region of Interest Selection", cbar=True, ax=ax[0]
)

original_data_node = pp.Node(_binned_sample)
merge_node = pp.Node(_show_histogram)
merge_node.add_parents(original_data_node)

r = MultiRectangleTool(figure=f2d, destination=merge_node, icon='vector-square', autostart=True)
f1d = pp.linefigure(
    merge_node, title="Time of Arrival Spectrum in ROI", grid=True, vmin=0, ax=ax[1]
)

fig.show()  # Will start the QT application.
```

In [None]:
r._tool.click(-0.028, -0.028)
r._tool.click(0.0, 0.01)
r._tool.click(-0.005, 0.028)
r._tool.click(0.028, -0.005)
f1d.focus()
rois = sc.DataGroup(
    {f'rectangle_{i}': rectangle for i, rectangle in enumerate(r.rectangles)}
)
rois

## Choppers

> Choppers can be retrieved from nexus file automatically. <br>
> We are hardcoding them for simulation data reduction. <br>
> We may automate this once McStas nexus format is stabilized.

In [None]:
from scippneutron.chopper import DiskChopper

# Collect choppers
## WFM choppers
wfm_frequency = sc.scalar(value=56.0, unit='Hz')
beam_angle = sc.scalar(value=0.0, unit='deg')

WFMC_1 = DiskChopper(
    axle_position=sc.vector(value=(0.026000, 0.000000, 6.850000), unit='m'),
    frequency=wfm_frequency,
    beam_angle=beam_angle,
    phase=sc.scalar(value=93.244, unit='deg'),
    slit_begin=sc.array(
        dims=['slit'],
        values=[-1.9419, 49.5756, 98.9315, 146.2165, 191.5176, 234.9179],
        unit='deg',
    ),
    slit_end=sc.array(
        dims=['slit'],
        values=[1.9419, 55.7157, 107.2332, 156.5891, 203.8741, 249.1752],
        unit='deg',
    ),
)

WFMC_2 = DiskChopper(
    axle_position=sc.vector(value=(0.026000, 0.000000, 7.150000), unit='m'),
    frequency=wfm_frequency,
    beam_angle=beam_angle,
    phase=sc.scalar(value=152.029879, unit='deg'),
    slit_begin=sc.array(
        dims=['slit'],
        values=[-1.9419, 51.8318, 103.3493, 152.7052, 199.9903, 245.2914],
        unit='deg',
    ),
    slit_end=sc.array(
        dims=['slit'],
        values=[1.9419, 57.9719, 111.6510, 163.0778, 212.3468, 259.5486],
        unit='deg',
    ),
)

WFMC_2  # Display one of the WFM choppers

In [None]:
## FOC choppers
foc_frequency = sc.scalar(value=42.0, unit='Hz')

F01 = DiskChopper(
    axle_position=sc.vector(value=(0.026000, 0.000000, 8.400000), unit='m'),
    frequency=foc_frequency,
    beam_angle=beam_angle,
    phase=sc.scalar(value=81.303297, unit='deg'),
    slit_begin=sc.array(
        dims=['slit'],
        values=[-5.1362, 42.5536, 88.2425, 132.0144, 173.9497, 216.7867],
        unit='deg',
    ),
    slit_end=sc.array(
        dims=['slit'],
        values=[5.1362, 54.2095, 101.2237, 146.2653, 189.417, 230.7582],
        unit='deg',
    ),
)

F02 = DiskChopper(
    axle_position=sc.vector(value=(0.026000, 0.000000, 12.200000), unit='m'),
    frequency=foc_frequency,
    beam_angle=beam_angle,
    phase=sc.scalar(value=107.013442, unit='deg'),
    slit_begin=sc.array(
        dims=['slit'],
        values=[-16.3227, 53.7401, 120.8633, 185.1701, 246.7787, 307.0165],
        unit='deg',
    ),
    slit_end=sc.array(
        dims=['slit'],
        values=[16.3227, 86.8303, 154.3794, 218.7551, 280.7508, 340.3188],
        unit='deg',
    ),
)

F03 = DiskChopper(
    axle_position=sc.vector(value=(0.026000, 0.000000, 17.000000), unit='m'),
    frequency=foc_frequency,
    beam_angle=beam_angle,
    phase=sc.scalar(value=158.294923, unit='deg'),
    slit_begin=sc.array(
        dims=['slit'],
        values=[-20.302, 45.247, 108.0457, 168.2095, 225.8489, 282.2199],
        unit='deg',
    ),
    slit_end=sc.array(
        dims=['slit'],
        values=[20.302, 85.357, 147.6824, 207.3927, 264.5977, 319.4024],
        unit='deg',
    ),
)

F04 = DiskChopper(
    axle_position=sc.vector(value=(0.026000, 0.000000, 23.690000), unit='m'),
    frequency=foc_frequency,
    beam_angle=beam_angle,
    phase=sc.scalar(value=61.584, unit='deg'),
    slit_begin=sc.array(
        dims=['slit'],
        values=[-16.7157, 29.1882, 73.1661, 115.2988, 155.6636, 195.5254],
        unit='deg',
    ),
    slit_end=sc.array(
        dims=['slit'],
        values=[16.7157, 61.8217, 105.0352, 146.4355, 186.0987, 224.0978],
        unit='deg',
    ),
)

F05 = DiskChopper(
    axle_position=sc.vector(value=(0.026000, 0.000000, 33.000000), unit='m'),
    frequency=foc_frequency,
    beam_angle=beam_angle,
    phase=sc.scalar(value=145.973844, unit='deg'),
    slit_begin=sc.array(
        dims=['slit'],
        values=[-25.8514, 38.3239, 99.8064, 160.1254, 217.4321, 272.5426],
        unit='deg',
    ),
    slit_end=sc.array(
        dims=['slit'],
        values=[25.8514, 88.4621, 147.4729, 204.0245, 257.7603, 313.7139],
        unit='deg',
    ),
)

F05  # Display one of the FOC choppers

In [None]:
## BP choppers
bp_frequency = sc.scalar(value=7.0, unit='Hz')
BP01 = DiskChopper(
    axle_position=sc.vector(value=(0.026000, 0.000000, 8.450000), unit='m'),
    frequency=bp_frequency,
    beam_angle=beam_angle,
    phase=sc.scalar(value=31.079597, unit='deg'),
    slit_begin=sc.array(dims=['slit'], values=[-23.6029], unit='deg'),
    slit_end=sc.array(dims=['slit'], values=[23.6029], unit='deg'),
    radius=sc.scalar(value=0.5, unit='m'),
    slit_height=sc.scalar(value=0.075000, unit='m'),
)

BP02 = DiskChopper(
    axle_position=sc.vector(value=(0.026000, 0.000000, 12.250000), unit='m'),
    frequency=bp_frequency,
    beam_angle=beam_angle,
    phase=sc.scalar(value=44.223912, unit='deg'),
    slit_begin=sc.array(dims=['slit'], values=[-34.4663], unit='deg'),
    slit_end=sc.array(dims=['slit'], values=[34.4663], unit='deg'),
    radius=sc.scalar(value=0.5, unit='m'),
    slit_height=sc.scalar(value=0.080000, unit='m'),
)

BP02  # Display one of the BP choppers

In [None]:
# T0 chopppers
t0_frequency = sc.scalar(value=14.0, unit='Hz')

TALPHA = DiskChopper(
    axle_position=sc.vector(value=(0.026000, 0.000000, 13.500000), unit='m'),
    frequency=t0_frequency,
    beam_angle=beam_angle,
    phase=sc.scalar(value=179.672400, unit='deg'),
    slit_begin=sc.array(dims=['slit'], values=[-167.8986], unit='deg'),
    slit_end=sc.array(dims=['slit'], values=[167.8986], unit='deg'),
    radius=sc.scalar(value=0.3, unit='m'),
    slit_height=sc.scalar(value=0.075000, unit='m'),
)

TBETA = DiskChopper(
    axle_position=sc.vector(value=(0.000000, 0.000000, 0.200000), unit='m'),
    frequency=t0_frequency,
    beam_angle=beam_angle,
    phase=sc.scalar(value=179.672, unit='deg'),
    slit_begin=sc.array(dims=['slit'], values=[-167.8986], unit='deg'),
    slit_end=sc.array(dims=['slit'], values=[167.8986], unit='deg'),
    radius=sc.scalar(value=0.3, unit='m'),
    slit_height=sc.scalar(value=0.075000, unit='m'),
)

TBETA  # Display one of the T0 choppers

## WFM Stitching
