# Writing SQW for BIFROST

**Idea for determining position in pixel block:**

- Compute an upper bound size for each (Q,E,irun,idet,ien) bin.
- For each event, compute an index based on that upper bound.
- `group` or `groupby` in that index. This only retains indices where we have data and orders the data by the index. If the index is computed right, that is the order in a pixel block.
- `sum` the bins.
- Write the result into the pixel block.

The extra index event coord should not be too large, we should be able to fit that into memory.

If we can make `groupby` lazy in its groups, we can stream the data to file.

Or we process the events in arbitrary chunks.
For each chunk, compute the index, group, sum.
Write to file by: if index out of bounds, grow file and fill with 0s. Then add the new value onto the existing one. That way, we accumulate events directly in the file.
Might be quite slow because of somewhat random access to file and many read+writes and because it needs tight loop (-> not Python?)
**Maybe not** this would require knowing all index values that can contribute up front top do the grouping.

In [None]:
import scipp as sc
import sciline
from scippneutron.io import sqw
from pathlib import Path
import scippnexus as snx
import numpy as np
import dataclasses

from ess import bifrost
from ess.bifrost.data import (
    simulated_elastic_incoherent_with_phonon,
    tof_lookup_table_simulation
)
from ess.spectroscopy.types import *

In [None]:
n_det = 3
n_angle = 5

bin_sizes = {'u1': 6, 'u2': 7, 'u3': 8, 'u4': 9}

out_file = Path("bifrost-simulated.sqw")

# Q projections
u = sc.vector([1, 0, 0], unit="1/angstrom")
v = sc.vector([0, 1, 0], unit="1/angstrom")
w = sc.cross(u, v)  # must be orthogonal to u and v for now

In [None]:
with snx.File(simulated_elastic_incoherent_with_phonon()) as f:
    detector_names = list(f['entry/instrument'][snx.NXdetector])
detector_names = detector_names[:n_det]

workflow = bifrost.BifrostSimulationWorkflow(detector_names)
workflow[Filename[SampleRun]] = simulated_elastic_incoherent_with_phonon()
workflow[TimeOfFlightLookupTable] = sc.io.load_hdf5(tof_lookup_table_simulation())
workflow[PreopenNeXusFile] = PreopenNeXusFile(True)
scheduler = sciline.scheduler.NaiveScheduler()

data = workflow.compute(EnergyData[SampleRun], scheduler=scheduler)

data = data['a3', :n_angle].flatten(['triplet', 'tube', 'length'], 'detector')

In [None]:
settings = sc.DataArray(
    sc.arange('s', data.sizes['a3'] * data.sizes['a4'], unit=None).fold('s', sizes={'a3': data.sizes['a3'],
                                                                                    'a4': data.sizes['a4']}),
    coords={
        'a3': data.coords['a3'],
        'a4': data.coords['a4'],
    }
).flatten(to='setting')

In [None]:
def project_onto(direction, vec):
    return sc.dot(direction / sc.norm(direction), vec)


aux = (data
       .flatten(['a3', 'a4'], 'setting')
       # Split into two calls because of https://github.com/scipp/scipp/issues/3766
       .transform_coords(
    u1=lambda sample_table_momentum_transfer: project_onto(u, sample_table_momentum_transfer),
    u2=lambda sample_table_momentum_transfer: project_onto(v, sample_table_momentum_transfer),
    u3=lambda sample_table_momentum_transfer: project_onto(w, sample_table_momentum_transfer),
    keep_inputs=False,
).transform_coords(u4="energy_transfer", keep_inputs=False)
       .bins.drop_coords(['incident_energy', 'incident_wavelength', 'lab_momentum_transfer'])
       .drop_coords(['a3', 'a4', 'secondary_flight_time'])
       )
binned = (aux
          .bins.assign_coords({
    'setting': sc.bins_like(aux, settings.data),
    'detector_number': sc.bins_like(aux, aux.coords.pop('detector_number')),
})
          .bins.concat()
          .bin(bin_sizes)
          )

In [None]:
def coord_range(da: sc.DataArray, coord: str) -> sc.Variable:
    assert sc.islinspace(da.coords[coord])
    return sc.array(dims=[coord], values=[da.coords[coord][0].value, da.coords[coord][-1].value],
                    unit=da.coords[coord].unit)

In [None]:
sample = sqw.SqwIXSample(
    name="Vibranium",
    lattice_spacing=sc.vector([2.86, 2.86, 2.86], unit="angstrom"),
    lattice_angle=sc.vector([90.0, 90.0, 90.0], unit="deg"),
)

In [None]:
dnd_metadata = sqw.SqwDndMetadata(
    axes=sqw.SqwLineAxes(
        title="My Axes",
        label=["u1", "u2", "u3", "u4"],
        img_scales=[
            sc.scalar(1.0, unit="1/angstrom"),
            sc.scalar(1.0, unit="1/angstrom"),
            sc.scalar(1.0, unit="1/angstrom"),
            sc.scalar(1.0, unit="meV"),
        ],
        img_range=[
            coord_range(binned, name)
            for name in ["u1", "u2", "u3", "u4"]
        ],
        n_bins_all_dims=sc.array(dims=["axis"], values=[binned.sizes[f'u{i}'] for i in range(1, 5)], unit=None),
        single_bin_defines_iax=sc.array(dims=["axis"], values=[True] * 4),
        dax=sc.arange("axis", 4, unit=None),
        offset=[
            sc.scalar(0.0, unit="1/angstrom"),
            sc.scalar(0.0, unit="1/angstrom"),
            sc.scalar(0.0, unit="1/angstrom"),
            sc.scalar(0.0, unit="meV"),
        ],
        changes_aspect_ratio=True,
    ),
    proj=sqw.SqwLineProj(
        title="My Projection",
        lattice_spacing=sample.lattice_spacing,
        lattice_angle=sample.lattice_angle,
        offset=[
            sc.scalar(0.0, unit="1/angstrom"),
            sc.scalar(0.0, unit="1/angstrom"),
            sc.scalar(0.0, unit="1/angstrom"),
            sc.scalar(0.0, unit="meV"),
        ],
        label=["u1", "u2", "u3", "u4"],
        u=u,
        v=v,
        w=None,
        non_orthogonal=False,
        type="aaa",
    ),
)

In [None]:
dnd_metadata.axes.img_range

In [None]:
instrument = sqw.SqwIXNullInstrument(
    name="BIFROST",
    source=sqw.SqwIXSource(
        name="ESS",
        target_name="Tungsten wheel",
        frequency=sc.scalar(14, unit="Hz"),
    ),
)

In [None]:
en = binned.coords['u4'].broadcast(sizes={'detector': data.sizes['detector'], 'u4': len(binned.coords['u4'])}).rename(
    u4='energy_transfer')
experiment_template = sqw.SqwIXExperiment(
    run_id=0,
    efix=data.coords['final_energy'],
    emode=sqw.EnergyMode.indirect,
    en=en,
    psi=sc.scalar(0.0, unit="rad"),
    u=u,
    v=v,
    omega=sc.scalar(0.0, unit="rad"),
    dpsi=sc.scalar(0.0, unit="rad"),
    gl=sc.scalar(0.0, unit="rad"),
    gs=sc.scalar(0.0, unit="rad"),
)
assert np.unique(data.coords['a4'].values).size == 1
experiments = [
    dataclasses.replace(experiment_template, run_id=i, psi=a3)
    for i, a3 in enumerate(data.coords['a3'], 1)
]

In [None]:
binned

In [None]:
# TODO add 0 observations
#  we lose some bins in the above binning processes, e.g., here, idet is only a subset of all detector numbers

In [None]:
observations = binned.copy().drop_coords(['sample_position', 'source_position'])
observations = observations.bins.assign_coords(
    {
        'idet': observations.bins.coords.pop('detector_number').to(dtype='float32') + sc.index(1),
        'irun': observations.bins.coords.pop('setting').to(dtype='float32') + sc.index(1),
        # +1 because of 1-based indexing
    }
)
observations = observations.assign_coords(
    {
        f'u{i}': sc.midpoints(observations.coords.pop(f'u{i}')).to(dtype='float32', copy=False)
        for i in range(1, 5)
    }
)
observations.coords['ien'] = sc.arange('u4', 1, 1 + observations.sizes['u4'], dtype='float32', unit=None)
observations

In [None]:
buffer = None


def get_buffer(n: int):
    global buffer
    if buffer is None or buffer.shape[0] < n:
        buffer = np.empty((n, 9), dtype=np.float32)
        return buffer
    return buffer[:n]


buffers = []
# transpose to match fortran layout. But horace plots look the same, does this actually matter?
flat = observations.transpose(['u4', 'u3', 'u2', 'u1']).flatten(to='u')
for obs_bin in observations.flatten(to='u'):
    h = obs_bin.group('idet', 'irun').hist()
    n = h.sizes['idet'] * h.sizes['irun']
    h = h.flatten(to='obs')
    buf = get_buffer(n)
    buf[:, 0] = h.coords['u1'].values
    buf[:, 1] = h.coords['u2'].values
    buf[:, 2] = h.coords['u3'].values
    buf[:, 3] = h.coords['u4'].values
    buf[:, 4] = h.coords['irun'].values
    buf[:, 5] = h.coords['idet'].values
    buf[:, 6] = h.coords['ien'].values
    buf[:, 7] = h.values
    buf[:, 8] = sc.stddevs(h).values
    buffers.append(buf.copy())

pix_buffer = np.concat(buffers, axis=0)

In [None]:
bin_sizes = sc.array(dims=['u'], values=[buf.shape[0] for buf in buffers], unit=None)
bin_sizes = bin_sizes.fold(dim='u', sizes=observations.sizes)

In [None]:
builder = sqw.Sqw.build(out_file, title="Simulated data with phonon").add_default_instrument(
    instrument).add_default_sample(sample)
builder = builder.add_dnd_data(dnd_metadata, data=observations.hist().data, counts=bin_sizes)
builder = builder.add_pixel_data(pix_buffer, experiments=experiments)
builder.create()

In [None]:
with sqw.Sqw.open(out_file) as f:
    print(f.data_block_names())
    d = f.read_data_block('data', "nd_data")

In [None]:
d[0].shape

In [None]:
binned.hist().data['u1', 3]['u2', 6]['u3', 1]['u4', 0]

In [None]:
# TODO The dim order is reversed.
# Seems wrong, I thought we handled transposition properly now?!
d[0][0][1][6][3]

In [None]:
d[0][3][6][1][0]

In [None]:
np.nonzero(d[0] == 127.0)

In [None]:
binned.hist().sum('u2').sum('u4').transpose(['u3', 'u1']).plot()

### Build pix data using group

In [None]:
# Revert all indices to 0-based for faster index arithmetic
binned_indices = observations.copy().bins.assign_coords(
    **{f'u{i}': sc.bins_like(observations, sc.arange(f'u{i}', observations.sizes[f'u{i}'], unit=None, dtype='float32'))
       for i in range(1, 5)},
    ien=sc.bins_like(observations, observations.coords['ien'] - sc.index(1)),
    irun=observations.bins.coords['irun'] - sc.index(1),
    idet=observations.bins.coords['idet'] - sc.index(1),
)
binned_indices

In [None]:
n_u1 = sc.index(binned_indices.sizes['u1'])
n_u2 = sc.index(binned_indices.sizes['u2'])
n_u3 = sc.index(binned_indices.sizes['u3'])
n_u4 = sc.index(binned_indices.sizes['u4'])
n_idet = binned_indices.bins.coords['idet'].max() + sc.index(1)
n_irun = binned_indices.bins.coords['irun'].max() + sc.index(1)
n_ien = binned_indices.bins.coords['ien'].max() + sc.index(1)

i_u1 = binned_indices.bins.coords['u1']
i_u2 = binned_indices.bins.coords['u2']
i_u3 = binned_indices.bins.coords['u3']
i_u4 = binned_indices.bins.coords['u4']
idet = binned_indices.bins.coords['idet']
irun = binned_indices.bins.coords['irun']
ien = binned_indices.bins.coords['ien']

out_index = i_u4.to(dtype='int64')
for i, n in ((i_u3, n_u3), (i_u2, n_u2), (i_u1, n_u1), (idet, n_idet), (irun, n_irun), (ien, n_ien)):
    out_index *= n.to(dtype='int64')
    out_index += i.to(dtype='int64')

In [None]:
# TODO the data is shorter than when using the custom loop! (A bunch of u4 might be missing)

for_grouping = observations.bins.assign_coords(pix_index=out_index)
for_grouping = for_grouping.bins.assign_coords({f'u{i}': sc.bins_like(for_grouping, for_grouping.coords.pop(f'u{i}')) for i in range(1, 5)})
for_grouping.bins.coords['ien'] = sc.bins_like(for_grouping, for_grouping.coords.pop('ien'))

grouped = for_grouping.bins.concat().group('pix_index').drop_coords('pix_index')
hist = grouped.bins.sum()
# indices in `observations` are 1-based
hist = hist.assign_coords({name: grouped.bins.coords.pop(name).bins.min() for name in list(grouped.bins.coords.keys())})
hist = hist.assign_coords({})
hist

In [None]:
hist.group('u3', 'u1').bins.sum().plot()

In [None]:
buffer = np.c_[*(hist.coords[name].values for name in ('u1', 'u2', 'u3', 'u4', 'irun', 'idet', 'ien')), hist.values, sc.stddevs(hist).values]

In [None]:
dnd = hist.group('u4', 'u3', 'u2', 'u1')

In [None]:
builder = sqw.Sqw.build(out_file, title="Simulated data with phonon (index method").add_default_instrument(instrument).add_default_sample(sample)
builder = builder.add_dnd_data(dnd_metadata, data=dnd.bins.sum().data, counts=dnd.bins.size())
builder = builder.add_pixel_data(buffer, experiments=experiments)
builder.create()

In [None]:
dnd.bins.sum().data