# Writing SQW for BIFROST

In [None]:
import scipp as sc
import sciline
from scippneutron.io import sqw
from pathlib import Path
import scippnexus as snx
import numpy as np
import scipp.constants
import dataclasses

from ess import bifrost
from ess.bifrost.data import (
    simulated_elastic_incoherent_with_phonon,
    tof_lookup_table_simulation
)
from ess.spectroscopy.types import *

In [None]:
n_det = 3
n_angle = 5

bin_sizes = {'u1': 6, 'u2': 7, 'u3': 8, 'u4': 9}

out_file = Path("bifrost-simulated.sqw")

# Q projections
u = sc.vector([1, 0, 0], unit="1/angstrom")
v = sc.vector([0, 1, 0], unit="1/angstrom")
w = sc.cross(u, v)  # must be orthogonal to u and v for now

In [None]:
with snx.File(simulated_elastic_incoherent_with_phonon()) as f:
    detector_names = list(f['entry/instrument'][snx.NXdetector])
detector_names = detector_names[:n_det]

workflow = bifrost.BifrostSimulationWorkflow(detector_names)
workflow[Filename[SampleRun]] = simulated_elastic_incoherent_with_phonon()
workflow[TimeOfFlightLookupTable] = sc.io.load_hdf5(tof_lookup_table_simulation())
workflow[PreopenNeXusFile] = PreopenNeXusFile(True)
scheduler = sciline.scheduler.NaiveScheduler()

data = workflow.compute(EnergyData[SampleRun], scheduler=scheduler)

data = data['a3', :n_angle].flatten(['triplet', 'tube', 'length'], 'detector')

# Flatten angles and define a setting index where a3 is the fastest running index.
n_a3 = data.sizes['a3']
data.coords['i_a3'] = sc.arange('a3', n_a3, unit=None)
data.coords['i_a4'] = sc.arange('a4', data.sizes['a4'], unit=None)
data = data.flatten(['a3', 'a4'], 'setting')
data.coords['setting'] = data.coords.pop('i_a3') + data.coords.pop('i_a4') * sc.index(n_a3)

In [None]:
data

In [None]:
index_binned = data.hist(energy_transfer=sc.linspace('energy_transfer', -0.05, 0.05, bin_sizes['u4'] + 1, unit='meV'))
index_binned.coords['energy_transfer'] = sc.midpoints(index_binned.coords['energy_transfer'])
index_binned.coords['ien'] = sc.arange('energy_transfer', bin_sizes['u4'], unit=None, dtype='float32')
index_binned

In [None]:
a3 = index_binned.coords['a3']
# idet = hist.coords['detector_number']
en = index_binned.coords['energy_transfer']
ef = index_binned.coords['final_energy']
kf = index_binned.coords['final_wavevector']

# should get this form the data
bi = sc.vector([0, 0, 1])
Ra3 = sc.spatial.rotations_from_rotvecs(-a3 * sc.vector([0, 1, 0]))

ki = 2 * np.pi / sc.constants.h * sc.sqrt(2 * sc.constants.m_n * (en + ef)) * bi
Q = Ra3 * (ki.to(unit=kf.unit) - kf)

In [None]:
Q

In [None]:
da = index_binned.copy()
da.coords['Qx'] = Q.fields.x
da.coords['Qy'] = Q.fields.y
da.coords['Qz'] = Q.fields.z
da = da.flatten(to='pix')

In [None]:
u_binned = da.bin(
    {'energy_transfer': bin_sizes['u4'], 'Qz': bin_sizes["u3"], 'Qy': bin_sizes['u2'], 'Qx': bin_sizes['u1']})
u_binned

In [None]:
dnd = u_binned.data.rename_dims(energy_transfer='u4', Qz='u3', Qy='u2', Qx='u1')
dnd

In [None]:
def coord_range(da: sc.DataArray, coord: str) -> sc.Variable:
    assert sc.islinspace(da.coords[coord])
    return sc.array(dims=[coord], values=[da.coords[coord][0].value, da.coords[coord][-1].value],
                    unit=da.coords[coord].unit)


img_range = [
    coord_range(u_binned, name)
    for name in ["Qx", "Qy", "Qz", "energy_transfer"]
]

n_bins_all_dims = sc.array(dims=["axis"], values=[
    u_binned.sizes['Qx'], u_binned.sizes['Qy'], u_binned.sizes['Qz'], u_binned.sizes['energy_transfer'],
], unit=None)

In [None]:
# now in the correct order (?)
in_pixel_order = u_binned.bins.concat().value
in_pixel_order

In [None]:
# check that Q computation makes sense
# needs input like index_binned but binned, not histogrammed

# for i, pix in enumerate(da[:1000]):
#     e_q = pix.bins.coords['sample_table_momentum_transfer'].copy().bins.concat().value
#     out = []
#     for dim in 'xyz':
#         q = getattr(e_q.fields, dim)
#         lo = q.min().value
#         hi = q.max().value
#         n = len(q)
#         Q = pix.coords[f'Q{dim}'].value
#         if n > 0:
#             out.append(f'  {dim}: {Q: .4f} | [{lo: .4f}, {hi: .4f}]')
#     if out:
#         print(i)
#         print('\n'.join(out))

In [None]:
pix = sc.DataArray(
    in_pixel_order.data,
    coords={
        'u1': in_pixel_order.coords['Qx'].to(unit='1/Å', dtype='float32', copy=False),
        'u2': in_pixel_order.coords['Qy'].to(unit='1/Å', dtype='float32', copy=False),
        'u3': in_pixel_order.coords['Qz'].to(unit='1/Å', dtype='float32', copy=False),
        'u4': in_pixel_order.coords['energy_transfer'].to(unit='meV', dtype='float32', copy=False),
        'idet': in_pixel_order.coords['detector_number'].to(dtype='float32', copy=False),
        'irun': in_pixel_order.coords['setting'].to(dtype='float32', copy=False),
        'ien': in_pixel_order.coords['ien'].to(dtype='float32', copy=False),
    }
)
pix

In [None]:
sample = sqw.SqwIXSample(
    name="Vibranium",
    lattice_spacing=sc.vector([2.86, 2.86, 2.86], unit="angstrom"),
    lattice_angle=sc.vector([90.0, 90.0, 90.0], unit="deg"),
)

In [None]:
dnd_metadata = sqw.SqwDndMetadata(
    axes=sqw.SqwLineAxes(
        title="My Axes",
        label=["u1", "u2", "u3", "u4"],
        img_scales=[
            sc.scalar(1.0, unit="1/angstrom"),
            sc.scalar(1.0, unit="1/angstrom"),
            sc.scalar(1.0, unit="1/angstrom"),
            sc.scalar(1.0, unit="meV"),
        ],
        img_range=img_range,
        n_bins_all_dims=n_bins_all_dims,
        single_bin_defines_iax=sc.array(dims=["axis"], values=[True] * 4),
        dax=sc.arange("axis", 4, unit=None),
        offset=[
            sc.scalar(0.0, unit="1/angstrom"),
            sc.scalar(0.0, unit="1/angstrom"),
            sc.scalar(0.0, unit="1/angstrom"),
            sc.scalar(0.0, unit="meV"),
        ],
        changes_aspect_ratio=True,
    ),
    proj=sqw.SqwLineProj(
        title="My Projection",
        lattice_spacing=sample.lattice_spacing,
        lattice_angle=sample.lattice_angle,
        offset=[
            sc.scalar(0.0, unit="1/angstrom"),
            sc.scalar(0.0, unit="1/angstrom"),
            sc.scalar(0.0, unit="1/angstrom"),
            sc.scalar(0.0, unit="meV"),
        ],
        label=["u1", "u2", "u3", "u4"],
        u=u,
        v=v,
        w=None,
        non_orthogonal=False,
        type="aaa",
    ),
)

In [None]:
instrument = sqw.SqwIXNullInstrument(
    name="BIFROST",
    source=sqw.SqwIXSource(
        name="ESS",
        target_name="Tungsten wheel",
        frequency=sc.scalar(14, unit="Hz"),
    ),
)

In [None]:
multi_en = en.broadcast(sizes={'detector': data.sizes['detector'], 'energy_transfer': len(en)})
experiment_template = sqw.SqwIXExperiment(
    run_id=0,
    efix=data.coords['final_energy'],
    emode=sqw.EnergyMode.indirect,
    en=multi_en,
    psi=sc.scalar(0.0, unit="rad"),
    u=u,
    v=v,
    omega=sc.scalar(0.0, unit="rad"),
    dpsi=sc.scalar(0.0, unit="rad"),
    gl=sc.scalar(0.0, unit="rad"),
    gs=sc.scalar(0.0, unit="rad"),
)
assert np.unique(data.coords['a4'].values).size == 1
experiments = [
    dataclasses.replace(experiment_template, run_id=i, psi=a3)
    for i, a3 in enumerate(data.coords['a3'], 1)
]

In [None]:
pix_buffer = np.c_[
    *(pix.coords[name].values for name in ('u1', 'u2', 'u3', 'u4', 'irun', 'idet', 'ien')),
    pix.values,
    sc.stddevs(pix).values
]

In [None]:
builder = sqw.Sqw.build(out_file, title="Simulated data with phonon (index method").add_default_instrument(instrument)
builder = builder.add_default_sample(sample)
builder = builder.add_dnd_data(dnd_metadata, data=dnd.bins.sum(), counts=dnd.bins.size())
builder = builder.add_pixel_data(pix_buffer, experiments=experiments)
builder.create()

Tested so far:

- can load
- shape matches
- ranges of all u_i match

## Test load

In [None]:
with sqw.Sqw.open(out_file) as sqw:
    m = sqw.read_data_block("data", "metadata")
    d = sqw.read_data_block("data", "nd_data")
    l_pix = sqw.read_data_block("pix", "data_wrap")

In [None]:
d[0].shape

In [None]:
pix

In [None]:
pix.coords['u4'].min()