diff --git a/docs/user-guide/amor/amor-reduction.ipynb b/docs/user-guide/amor/amor-reduction.ipynb index 01ffaca2..04115d19 100644 --- a/docs/user-guide/amor/amor-reduction.ipynb +++ b/docs/user-guide/amor/amor-reduction.ipynb @@ -172,7 +172,7 @@ "from ess.reflectometry.tools import scale_reflectivity_curves_to_overlap\n", "results_scaled = dict(zip(\n", " results.keys(),\n", - " scale_reflectivity_curves_to_overlap(results.values()),\n", + " scale_reflectivity_curves_to_overlap(results.values())[0],\n", " strict=True\n", "))\n", "sc.plot(results_scaled, norm='log', vmin=1e-5)" diff --git a/src/ess/reflectometry/tools.py b/src/ess/reflectometry/tools.py index 341fa4f0..86823efc 100644 --- a/src/ess/reflectometry/tools.py +++ b/src/ess/reflectometry/tools.py @@ -170,28 +170,43 @@ def _interpolate_on_qgrid(curves, grid): def scale_reflectivity_curves_to_overlap( curves: Sequence[sc.DataArray], - return_scaling_factors=False, -) -> list[sc.DataArray] | list[sc.scalar]: + critical_edge_interval: tuple[sc.Variable, sc.Variable] | None = None, +) -> tuple[list[sc.DataArray], list[sc.Variable]]: '''Make the curves overlap by scaling all except the first by a factor. The scaling factors are determined by a maximum likelihood estimate (assuming the errors are normal distributed). + If :code:`critical_edge_interval` is provided then all curves are scaled. + All curves must be have the same unit for data and the Q-coordinate. Parameters --------- curves: the reflectivity curves that should be scaled together - return_scaling_factor: - If True the return value of the function - is a list of the scaling factors that should be applied. - If False (default) the function returns the scaled curves. + critical_edge_interval: + a tuple denoting an interval that is known to belong + to the critical edge, i.e. where the reflectivity is + known to be 1. Returns --------- : - A list of scaled reflectivity curves or a list of scaling factors. + A list of scaled reflectivity curves and a list of the scaling factors. ''' + if critical_edge_interval is not None: + q = next(iter(curves)).coords['Q'] + N = ( + ((q >= critical_edge_interval[0]) & (q < critical_edge_interval[1])) + .sum() + .value + ) + edge = sc.DataArray( + data=sc.ones(dims=('Q',), shape=(N,), with_variances=True), + coords={'Q': sc.linspace('Q', *critical_edge_interval, N + 1)}, + ) + curves, factors = scale_reflectivity_curves_to_overlap([edge, *curves]) + return curves[1:], factors[1:] if len({c.data.unit for c in curves}) != 1: raise ValueError('The reflectivity curves must have the same unit') if len({c.coords['Q'].unit for c in curves}) != 1: @@ -214,13 +229,11 @@ def cost(scaling_factors): return np.nansum((r_scaled - r_avg) ** 2 * inv_v_scaled) sol = opt.minimize(cost, [1.0] * (len(curves) - 1)) - scaling_factors = (1.0, *sol.x) - if return_scaling_factors: - return [sc.scalar(x) for x in scaling_factors] + scaling_factors = (1.0, *map(float, sol.x)) return [ scaling_factor * curve for scaling_factor, curve in zip(scaling_factors, curves, strict=True) - ] + ], scaling_factors def combine_curves( diff --git a/tests/tools_test.py b/tests/tools_test.py index b447fcf4..193e9245 100644 --- a/tests/tools_test.py +++ b/tests/tools_test.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) import scipp as sc +from numpy.testing import assert_allclose as np_assert_allclose from scipp.testing import assert_allclose from ess.reflectometry.tools import combine_curves, scale_reflectivity_curves_to_overlap @@ -20,16 +21,17 @@ def test_reflectivity_curve_scaling(): ) data.variances[:] = 0.1 - curves = scale_reflectivity_curves_to_overlap( + curves, factors = scale_reflectivity_curves_to_overlap( (curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)), ) assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5)) assert_allclose(curves[1].data, 0.5 * data, rtol=sc.scalar(1e-5)) assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5)) + np_assert_allclose((1, 0.5 / 0.8, 0.25 / 0.1), factors, 1e-4) -def test_reflectivity_curve_scaling_return_factors(): +def test_reflectivity_curve_scaling_with_critical_edge(): data = sc.concat( ( sc.ones(dims=['Q'], shape=[10], with_variances=True), @@ -39,14 +41,19 @@ def test_reflectivity_curve_scaling_return_factors(): ) data.variances[:] = 0.1 - factors = scale_reflectivity_curves_to_overlap( - (curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)), - return_scaling_factors=True, + curves, factors = scale_reflectivity_curves_to_overlap( + ( + 2 * curve(data, 0, 0.3), + curve(0.8 * data, 0.2, 0.7), + curve(0.1 * data, 0.6, 1.0), + ), + critical_edge_interval=(sc.scalar(0.01), sc.scalar(0.05)), ) - assert_allclose(factors[0], sc.scalar(1.0), rtol=sc.scalar(1e-5)) - assert_allclose(factors[1], sc.scalar(0.5 / 0.8), rtol=sc.scalar(1e-5)) - assert_allclose(factors[2], sc.scalar(0.25 / 0.1), rtol=sc.scalar(1e-5)) + assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5)) + assert_allclose(curves[1].data, 0.5 * data, rtol=sc.scalar(1e-5)) + assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5)) + np_assert_allclose((0.5, 0.5 / 0.8, 0.25 / 0.1), factors, 1e-4) def test_combined_curves():