In [None]:
from typing import NewType

import scipp as sc
import sciline
from ess import bifrost
from ess.bifrost.data import (
    simulated_elastic_incoherent_with_phonon,
    tof_lookup_table_simulation
)
from ess.spectroscopy.types import *
import scippnexus as snx
from ess.reduce.streaming import StreamProcessor,Accumulator,EternalAccumulator

In [None]:
CutAxis1 = NewType('CutAxis1', sc.Variable)
CutAxis2 = NewType('CutAxis2', sc.Variable)
CutBins1 = NewType('CutBins1', sc.Variable)
CutBins2 = NewType('CutBins2', sc.Variable)

class CutData(sciline.Scope[RunType, sc.DataArray], sc.DataArray): ...

def cut(data: EnergyData[RunType], *, axis_1: CutAxis1, axis_2: CutAxis2, bins_1: CutBins1, bins_2: CutBins2) -> CutData[RunType]:
    flat = data.bins.concat()
    q = flat.bins.coords['sample_table_momentum_transfer']
    return CutData[RunType](flat.bins.assign_coords({"Q1":sc.dot(axis_1, q), "Q2":sc.dot(axis_2, q)}).hist(Q2=bins_2, Q1=bins_1))

In [None]:
fname = simulated_elastic_incoherent_with_phonon()
with snx.File(fname) as f:
    detector_names = list(f['entry/instrument'][snx.NXdetector])
detector_names = detector_names[:2]

In [None]:
workflow = bifrost.BifrostWorkflow(detector_names)
workflow[Filename[SampleRun]] = simulated_elastic_incoherent_with_phonon()
workflow[TimeOfFlightLookupTableFilename] = tof_lookup_table_simulation()
workflow[PreopenNeXusFile] = PreopenNeXusFile(True)

workflow.insert(cut)
workflow[CutAxis1] = sc.vector([1, 0, 0])
workflow[CutAxis2] = sc.vector([0, 0, 1])
workflow[CutBins1] = sc.linspace('Q1', -1.25, 1.25, 200, unit='1/Å')
workflow[CutBins2] = sc.linspace('Q2', -1.25, 1.25, 200, unit='1/Å')

In [None]:
workflow.visualize(CutData[SampleRun], graph_attr={"rankdir": "LR"}, compact=True)

In [None]:
scheduler = sciline.scheduler.NaiveScheduler()
results = workflow.compute([NeXusData[snx.NXdetector, SampleRun], InstrumentAngles[SampleRun]],
                           scheduler=scheduler)
base_data = results[NeXusData[snx.NXdetector, SampleRun]]
angles = results[InstrumentAngles[SampleRun]]

In [None]:
# This is similar to `group_by_rotation` but preserves the event_time_zero coord and dim.
# The elements of `angle_groups` look like NeXusData.
# For simplicity, it assumes that there is only one a4 value.
a3 = sc.lookup(angles['a3'], 'time')
a4 = sc.lookup(angles['a4'], 'time')
graph = {
    'a3': lambda event_time_zero: a3[event_time_zero],
    'a4': lambda event_time_zero: a4[event_time_zero],
}
d = base_data.bins.assign_coords({'event_time_zero': sc.bins_like(base_data.data, base_data.coords['event_time_zero'])})
grouped = d.transform_coords(('a3', 'a4'), graph=graph).group('a3', 'a4')
angle_groups = [grouped['a3', a3]['a4', 0].group('event_time_zero') for a3 in grouped.coords['a3']]

In [None]:
from copy import deepcopy


class BinAccumulator(Accumulator[sc.DataArray]):
    def __init__(self, **kwargs: Any) -> None:
        super().__init__(preprocess=None, **kwargs)
        self._value = None

    @property
    def is_empty(self) -> bool:
        return self._value is None

    def _get_value(self):
        return deepcopy(self._value)

    def _do_push(self, value) -> None:
        if self._value is None:
            self._value = deepcopy(value)
        else:
            self._value.bins.concatenate(value, out=self._value)

    def clear(self) -> None:
        """Clear the accumulated value."""
        self._value = None


sp = StreamProcessor(
    workflow,
    dynamic_keys=(NeXusData[snx.NXdetector, SampleRun],),
    context_keys=(InstrumentAngles[SampleRun],),
    target_keys=(CutData[SampleRun],),
    accumulators={
        # DataGroupedByRotation[SampleRun]: BinAccumulator(),
        CutData[SampleRun]: EternalAccumulator(),
    },
)

In [None]:
for group in angle_groups:
    angles = sc.DataGroup(a3=sc.DataArray(group.coords['a3']), a4=sc.DataArray(group.coords['a4']))
    events = group.drop_coords(['a3', 'a4'])  # NeXusData does not have these coords

    sp.set_context({InstrumentAngles[SampleRun]: angles})
    sp.accumulate({NeXusData[snx.NXdetector, SampleRun]: events})

In [None]:
results = sp.finalize()
data = results[CutData[SampleRun]]

In [None]:
data

In [None]:
data.plot(norm='log')