In [None]:
from typing import NewType
from dataclasses import dataclass
from collections.abc import Callable

import scipp as sc
import sciline
import matplotlib.pyplot as plt
import time
from ess import bifrost
from ess.bifrost.data import (
    simulated_elastic_incoherent_with_phonon,
    tof_lookup_table_simulation
)
from ess.spectroscopy.types import *
import scippnexus as snx
from ess.reduce.streaming import StreamProcessor,Accumulator,EternalAccumulator

In [None]:
%matplotlib widget

In [None]:
@dataclass(frozen=True, kw_only=True, slots=True)
class CutAxis:
    output: str
    fn: Callable[[...], sc.Variable]
    bins: sc.Variable

    @classmethod
    def from_q_vector(cls, output: str, vec: sc.Variable, bins: sc.Variable):
        vec = vec / sc.norm(vec)
        return cls(
            output=output,
            fn=lambda sample_table_momentum_transfer: sc.dot(vec, sample_table_momentum_transfer),
            bins=bins
        )

CutAxis1 = NewType('CutAxis1', CutAxis)
CutAxis2 = NewType('CutAxis2', CutAxis)

class CutData(sciline.Scope[RunType, sc.DataArray], sc.DataArray): ...

def cut(data: EnergyData[RunType], *, axis_1: CutAxis1, axis_2: CutAxis2) -> CutData[RunType]:
    new_coords = {axis_1.output, axis_2.output}
    projected = data.bins.concat().transform_coords(new_coords, graph={axis_1.output: axis_1.fn, axis_2.output: axis_2.fn},keep_inputs=False)
    projected = projected.drop_coords(list(set(projected.coords.keys()) - new_coords))
    return CutData[RunType](projected.hist({axis_2.output: axis_2.bins, axis_1.output: axis_1.bins}))

In [None]:
fname = simulated_elastic_incoherent_with_phonon()
with snx.File(fname) as f:
    detector_names = list(f['entry/instrument'][snx.NXdetector])
detector_names = detector_names[:2]

In [None]:
workflow = bifrost.BifrostWorkflow(detector_names)
workflow[Filename[SampleRun]] = simulated_elastic_incoherent_with_phonon()
workflow[TimeOfFlightLookupTableFilename] = tof_lookup_table_simulation()
workflow[PreopenNeXusFile] = PreopenNeXusFile(True)

workflow.insert(cut)

# workflow[CutAxis1] = CutAxis.from_q_vector(
#     output="Qx",
#     vec=sc.vector([1, 0, 0]),
#     bins=sc.linspace('Qx', -3.0, 3.0, 300, unit='1/Å')
# )
# workflow[CutAxis2] = CutAxis.from_q_vector(
#     output="Qz",
#     vec=sc.vector([0, 0, 1]),
#     bins=sc.linspace('Qz', -3.0, 3.0, 300, unit='1/Å')
# )

workflow[CutAxis1] = CutAxis(
    output="|Q|",
    fn=lambda sample_table_momentum_transfer: sc.norm(sample_table_momentum_transfer),
    bins=sc.linspace('|Q|', 0.9, 3.0, 300, unit='1/Å')
)
workflow[CutAxis2] = CutAxis(
    output="E",
    fn=lambda energy_transfer: energy_transfer,
    bins=sc.linspace('E', -0.1, 0.1, 300, unit='meV')
)

In [None]:
workflow.visualize(CutData[SampleRun], graph_attr={"rankdir": "LR"}, compact=True)

In [None]:
scheduler = sciline.scheduler.NaiveScheduler()
results = workflow.compute([NeXusData[snx.NXdetector, SampleRun], InstrumentAngles[SampleRun]],
                           scheduler=scheduler)
base_data = results[NeXusData[snx.NXdetector, SampleRun]]
angles = results[InstrumentAngles[SampleRun]]

In [None]:
angles['a3']

In [None]:
# This is similar to `group_by_rotation` but preserves the event_time_zero coord and dim.
# The elements of `angle_groups` look like NeXusData.
# For simplicity, it assumes that there is only one a4 value.
a3 = sc.lookup(angles['a3'], 'time')
a4 = sc.lookup(angles['a4'], 'time')
graph = {
    'a3': lambda event_time_zero: a3[event_time_zero],
    'a4': lambda event_time_zero: a4[event_time_zero],
}
d = base_data.bins.assign_coords({'event_time_zero': sc.bins_like(base_data.data, base_data.coords['event_time_zero'])})
grouped = d.transform_coords(('a3', 'a4'), graph=graph).group('a3', 'a4')
angle_groups = [grouped['a3', a3]['a4', 0].group('event_time_zero') for a3 in grouped.coords['a3']]

In [None]:
from copy import deepcopy


class BinAccumulator(Accumulator[sc.DataArray]):
    def __init__(self, **kwargs: Any) -> None:
        super().__init__(preprocess=None, **kwargs)
        self._value = None

    @property
    def is_empty(self) -> bool:
        return self._value is None

    def _get_value(self):
        return deepcopy(self._value)

    def _do_push(self, value) -> None:
        if self._value is None:
            self._value = deepcopy(value)
        else:
            self._value.bins.concatenate(value, out=self._value)

    def clear(self) -> None:
        """Clear the accumulated value."""
        self._value = None


sp = StreamProcessor(
    workflow,
    dynamic_keys=(NeXusData[snx.NXdetector, SampleRun],),
    context_keys=(InstrumentAngles[SampleRun],),
    target_keys=(CutData[SampleRun],),
    accumulators={
        # DataGroupedByRotation[SampleRun]: BinAccumulator(),
        CutData[SampleRun]: EternalAccumulator(),
    },
)

In [None]:
times = []
for group in angle_groups:
    step_angles = sc.DataGroup(a3=sc.DataArray(group.coords['a3']), a4=sc.DataArray(group.coords['a4']))
    events = group.drop_coords(['a3', 'a4'])  # NeXusData does not have these coords

    start = time.time()
    sp.set_context({InstrumentAngles[SampleRun]: step_angles})
    sp.accumulate({NeXusData[snx.NXdetector, SampleRun]: events})
    end = time.time()
    times.append(end-start)

print(f"Sum: {sum(times):.3f}s  Mean: {sum(times)/len(times):.3f}s [{min(times):.3f}s, {max(times):.3f}s]")

In [None]:
plt.plot(times)

In [None]:
results = sp.finalize()
data = results[CutData[SampleRun]]

In [None]:
data

In [None]:
data.plot(norm='log')