Skip to content
Merged
325 changes: 286 additions & 39 deletions docs/user-guide/zoom.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/ess/polarization/correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def compute_polarizing_element_correction(
:
Correction matrix coefficients.
"""
if isinstance(channel, sc.Variable | sc.DataArray) and channel.bins is not None:
channel = channel.bins

t_plus = transmission.apply(channel, 'plus')
t_minus = transmission.apply(channel, 'minus')
t_minus *= -1
Expand Down
9 changes: 3 additions & 6 deletions src/ess/polarization/he3.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,7 @@ def __call__(self, wavelength: sc.Variable) -> sc.Variable:
scale = broadcast_with_upper_bound_variances(
self.opacity0, prototype=wavelength
)
return sc.DataArray(
(scale * wavelength).to(unit='', copy=False),
coords={'wavelength': wavelength},
)
return (scale * wavelength).to(unit='', copy=False)


def he3_opacity_from_cell_params(
Expand Down Expand Up @@ -192,7 +189,7 @@ def T1(self) -> sc.Variable:
return self._T1

def __call__(self, time: sc.Variable) -> sc.Variable:
return sc.DataArray(self.C * sc.exp(-time / self.T1), coords={'time': time})
return self.C * sc.exp(-time / self.T1)


@dataclass
Expand All @@ -214,7 +211,7 @@ def __call__(
polarization *= -plus_minus
return self.transmission_empty_glass * sc.exp(-opacity * (1.0 + polarization))

def apply(self, data: sc.DataArray, plus_minus: PlusMinus) -> sc.DataArray:
def apply(self, data: sc.DataArray, plus_minus: PlusMinus) -> sc.Variable:
return self(
time=data.coords['time'],
wavelength=data.coords['wavelength'],
Expand Down
12 changes: 6 additions & 6 deletions src/ess/polarization/supermirror.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SupermirrorEfficiencyFunction(Generic[PolarizingElement], ABC):
"""Base class for supermirror efficiency functions"""

@abstractmethod
def __call__(self, *, wavelength: sc.Variable) -> sc.DataArray:
def __call__(self, *, wavelength: sc.Variable) -> sc.Variable:
"""Return the efficiency of a supermirror for a given wavelength"""


Expand All @@ -44,7 +44,7 @@ class SecondDegreePolynomialEfficiency(
b: sc.Variable
c: sc.Variable

def __call__(self, *, wavelength: sc.Variable) -> sc.DataArray:
def __call__(self, *, wavelength: sc.Variable) -> sc.Variable:
"""Return the efficiency of a supermirror for a given wavelength"""
return (
(self.a * wavelength**2).to(unit='', copy=False)
Expand All @@ -71,9 +71,9 @@ def __post_init__(self):
table = self.table if self.table.variances is None else sc.values(self.table)
self._lut = sc.lookup(table, 'wavelength')

def __call__(self, *, wavelength: sc.Variable) -> sc.DataArray:
def __call__(self, *, wavelength: sc.Variable) -> sc.Variable:
"""Return the efficiency of a supermirror for a given wavelength"""
return sc.DataArray(self._lut(wavelength), coords={'wavelength': wavelength})
return self._lut(wavelength)

@classmethod
def from_file(
Expand Down Expand Up @@ -107,15 +107,15 @@ class SupermirrorTransmissionFunction(TransmissionFunction[PolarizingElement]):

def __call__(
self, *, wavelength: sc.Variable, plus_minus: PlusMinus
) -> sc.DataArray:
) -> sc.Variable:
"""Return the transmission fraction for a given wavelength"""
efficiency = self.efficiency_function(wavelength=wavelength)
if plus_minus == 'plus':
return 0.5 * (1 + efficiency)
else:
return 0.5 * (1 - efficiency)

def apply(self, data: sc.DataArray, plus_minus: PlusMinus) -> sc.DataArray:
def apply(self, data: sc.DataArray, plus_minus: PlusMinus) -> sc.Variable:
"""Apply the transmission function to a data array"""
return self(wavelength=data.coords['wavelength'], plus_minus=plus_minus)

Expand Down
8 changes: 7 additions & 1 deletion src/ess/polarization/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TransmissionFunction(Generic[PolarizingElement], ABC):
"""Wavelength- and time-dependent transmission for a given cell."""

@abstractmethod
def apply(self, data: sc.DataArray, plus_minus: PlusMinus) -> sc.DataArray: ...
def apply(self, data: sc.DataArray, plus_minus: PlusMinus) -> sc.Variable: ...


@dataclass
Expand Down Expand Up @@ -85,6 +85,12 @@ class PolarizationCorrectedData(Generic[PolarizerSpin, AnalyzerSpin]):
downup: sc.DataArray
downdown: sc.DataArray

def __post_init__(self):
self.upup.name = '(up, up)'
self.updown.name = '(up, down)'
self.downup.name = '(down, up)'
self.downdown.name = '(down, down)'


"""The sum of polarization corrected data from all flipper state channels."""
TotalPolarizationCorrectedData = NewType(
Expand Down
103 changes: 10 additions & 93 deletions src/ess/polarization/zoom.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
import threading
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Generic

import mantid.api as _mantid_api
import sciline as sl
import scipp as sc
from mantid import simpleapi as _mantid_simpleapi
from scippnexus import NXsource

import ess.isissans as isis
from ess.isissans.io import LoadedFileContents
from ess.isissans.mantidio import DataWorkspace, Period
from ess.reduce.nexus.types import Position
from ess.sans.types import (
EmptyBeamRun,
Filename,
Incident,
MonitorType,
Expand All @@ -26,53 +19,13 @@
SampleRun,
Transmission,
TransmissionRun,
UncertaintyBroadcastMode,
)

# In this case the "sample" is the analyzer cell, of which we want to measure
# the transmission fraction.
sample_run_type = RunType


def load_histogrammed_run(
filename: Filename[sample_run_type], period: Period
) -> DataWorkspace[sample_run_type]:
"""Load a non-event-data ISIS file"""
# Loading many small files with Mantid is, for some reason, very slow when using
# the default number of threads in the Dask threaded scheduler (1 thread worked
# best, 2 is a bit slower but still fast). We can either limit that thread count,
# or add a lock here, which is more specific.
with load_histogrammed_run.lock:
loaded = _mantid_simpleapi.Load(Filename=str(filename), StoreInADS=False)
if isinstance(loaded, _mantid_api.Workspace):
# A single workspace
data_ws = loaded
if isinstance(data_ws, _mantid_api.WorkspaceGroup):
if period is None:
raise ValueError(
f'Needs {Period} to be set to know what '
'section of the event data to load'
)
data_ws = data_ws.getItem(period)
else:
# Separate data and monitor workspaces
data_ws = loaded.OutputWorkspace
if isinstance(data_ws, _mantid_api.WorkspaceGroup):
if period is None:
raise ValueError(
f'Needs {Period} to be set to know what '
'section of the event data to load'
)
data_ws = data_ws.getItem(period)
data_ws.setMonitorWorkspace(loaded.MonitorWorkspace.getItem(period))
else:
data_ws.setMonitorWorkspace(loaded.MonitorWorkspace)
return DataWorkspace[sample_run_type](data_ws)


load_histogrammed_run.lock = threading.Lock()


def _get_time(dg: sc.DataGroup) -> sc.Variable:
start = sc.datetime(dg['run_start'].value)
end = sc.datetime(dg['run_end'].value)
Expand Down Expand Up @@ -100,46 +53,24 @@ def _get_unique_position(*positions: sc.DataArray) -> sc.DataArray:
return unique


@dataclass
class MonitorSpectrumNumber(Generic[MonitorType]):
value: int


def get_monitor_data(
dg: LoadedFileContents[RunType], nexus_name: NeXusMonitorName[MonitorType]
def get_monitor_data_no_variances(
dg: LoadedFileContents[RunType],
nexus_name: NeXusMonitorName[MonitorType],
spectrum_number: isis.MonitorSpectrumNumber[MonitorType],
) -> NeXusComponent[MonitorType, RunType]:
"""
Same as :py:func:`ess.isissans.get_monitor_data` but dropping variances.

Dropping variances is a workaround required since ESSsans does not handle
variance broadcasting when combining monitors. In our case some of the monitors
are time-dependent, so this is required for now.
"""
# See https://github.com/scipp/sciline/issues/52 why copy needed
mon = dg['monitors'][nexus_name]['data'].copy()
return NeXusComponent[MonitorType, RunType](
sc.DataGroup(data=sc.values(mon), position=mon.coords['position'])
monitor = isis.general.get_monitor_data(
dg, nexus_name=nexus_name, spectrum_number=spectrum_number
)


def get_monitor_data_from_empty_beam_run(
dg: LoadedFileContents[EmptyBeamRun],
spectrum_number: MonitorSpectrumNumber[MonitorType],
) -> NeXusComponent[MonitorType, EmptyBeamRun]:
"""
Extract incident or transmission monitor from ZOOM empty beam run

The files in this case do not contain detector data, only monitor data. Mantid
stores this as a Workspace2D, where each spectrum corresponds to a monitor.
"""
# Note we index with a scipp.Variable, i.e., by the spectrum number used at ISIS
monitor = sc.values(dg["data"]["spectrum", sc.index(spectrum_number.value)]).copy()
return sc.DataGroup(data=monitor, position=monitor.coords['position'])
monitor['data'] = sc.values(monitor['data'])
return NeXusComponent[MonitorType, RunType](monitor)


def get_monitor_data_from_transmission_run(
dg: LoadedFileContents[TransmissionRun[RunType]],
spectrum_number: MonitorSpectrumNumber[MonitorType],
spectrum_number: isis.MonitorSpectrumNumber[MonitorType],
) -> NeXusComponent[MonitorType, TransmissionRun[RunType]]:
"""
Extract incident or transmission monitor from ZOOM direct-beam run
Expand Down Expand Up @@ -172,10 +103,8 @@ def ZoomTransmissionFractionWorkflow(runs: Sequence[str]) -> sl.Pipeline:
List of filenames of the runs to use for the transmission fraction.
"""
workflow = isis.zoom.ZoomWorkflow()
workflow.insert(get_monitor_data)
workflow.insert(get_monitor_data_from_empty_beam_run)
workflow.insert(get_monitor_data_no_variances)
workflow.insert(get_monitor_data_from_transmission_run)
workflow.insert(load_histogrammed_run)

mapped = workflow.map({Filename[TransmissionRun[SampleRun]]: runs})
for mon_type in (Incident, Transmission):
Expand All @@ -186,16 +115,4 @@ def ZoomTransmissionFractionWorkflow(runs: Sequence[str]) -> sl.Pipeline:
Position[NXsource, TransmissionRun[SampleRun]]
].reduce(func=_get_unique_position)

# We are dealing with two different types of files, and monitors are identified
# differently in each case, so there is some duplication here.
workflow[MonitorSpectrumNumber[Incident]] = MonitorSpectrumNumber[Incident](3)
workflow[MonitorSpectrumNumber[Transmission]] = MonitorSpectrumNumber[Transmission](
4
)
workflow[NeXusMonitorName[Incident]] = NeXusMonitorName[Incident]("monitor3")
workflow[NeXusMonitorName[Transmission]] = NeXusMonitorName[Transmission](
"monitor4"
)
workflow[UncertaintyBroadcastMode] = UncertaintyBroadcastMode.upper_bound

return workflow
24 changes: 24 additions & 0 deletions tests/correction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,30 @@ def test_compute_polarizing_element_correction() -> None:
assert_allclose(off_diag, -transmission_minus / denom)


def test_compute_polarizing_element_correction_binned_data() -> None:
time = sc.linspace('event', 1, 10, 10, unit='')
wavelength = sc.linspace('event', 0.1, 1, 10, unit='')
events = sc.DataArray(
sc.arange('event', 10),
coords={'time': time, 'wavelength': wavelength},
)
binned = sc.bins(
data=events, dim='event', begin=sc.array(dims=['Q'], values=[0, 3], unit=None)
)
transmission = SimpleTransmissionFunction()

result = compute_polarizing_element_correction(
channel=binned, transmission=transmission
)
diag = result.diag
off_diag = result.off_diag
transmission_plus = transmission(time, wavelength, 'plus')
transmission_minus = transmission(time, wavelength, 'minus')
denom = transmission_plus**2 - transmission_minus**2
assert_allclose(diag.bins.concat().value, transmission_plus / denom)
assert_allclose(off_diag.bins.concat().value, -transmission_minus / denom)


class FakeTransmissionFunction:
def __init__(self, coeffs: np.ndarray) -> None:
self.coeffs = coeffs
Expand Down
6 changes: 3 additions & 3 deletions tests/he3/opacity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_opacity_from_cell_params() -> None:
pressure=pressure, length=length, temperature=temperature
)
opacity_function = he3.he3_opacity_function_from_cell_opacity(opacity0)
opacity = opacity_function(wavelength).data
opacity = opacity_function(wavelength)
assert_identical(2 * opacity['pressure', 0], opacity['pressure', 1])
assert_identical(2 * opacity['cell_length', 0], opacity['cell_length', 1])
assert_identical(2 * opacity['wavelength', 0], opacity['wavelength', 1])
Expand All @@ -38,7 +38,7 @@ def test_opacity_from_cell_params_reproduces_literature_value() -> None:
pressure=pressure, length=length, temperature=temperature
)
opacity_function = he3.he3_opacity_function_from_cell_opacity(opacity0)
opacity = opacity_function(wavelength).data
opacity = opacity_function(wavelength)
assert sc.isclose(opacity, sc.scalar(0.0733, unit=''), rtol=sc.scalar(1e-3))


Expand All @@ -64,7 +64,7 @@ def test_opacity_from_beam_data() -> None:
transmission_fraction=transmission,
opacity0_initial_guess=opacity0 * 1.23, # starting guess imperfect
)
opacity = opacity_function(wavelength).data
opacity = opacity_function(wavelength)
assert sc.isclose(
opacity_function.opacity0, opacity0.to(unit=opacity_function.opacity0.unit)
)
Expand Down
Loading
Loading