Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/user-guide/amor/amor-reduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
"from ess.reflectometry.tools import scale_reflectivity_curves_to_overlap\n",
"results_scaled = dict(zip(\n",
" results.keys(),\n",
" scale_reflectivity_curves_to_overlap(results.values()),\n",
" scale_reflectivity_curves_to_overlap(results.values())[0],\n",
" strict=True\n",
"))\n",
"sc.plot(results_scaled, norm='log', vmin=1e-5)"
Expand Down
35 changes: 24 additions & 11 deletions src/ess/reflectometry/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,28 +170,43 @@ def _interpolate_on_qgrid(curves, grid):

def scale_reflectivity_curves_to_overlap(
curves: Sequence[sc.DataArray],
return_scaling_factors=False,
) -> list[sc.DataArray] | list[sc.scalar]:
critical_edge_interval: tuple[sc.Variable, sc.Variable] | None = None,
) -> tuple[list[sc.DataArray], list[sc.Variable]]:
'''Make the curves overlap by scaling all except the first by a factor.
The scaling factors are determined by a maximum likelihood estimate
(assuming the errors are normal distributed).

If :code:`critical_edge_interval` is provided then all curves are scaled.

All curves must be have the same unit for data and the Q-coordinate.

Parameters
---------
curves:
the reflectivity curves that should be scaled together
return_scaling_factor:
If True the return value of the function
is a list of the scaling factors that should be applied.
If False (default) the function returns the scaled curves.
critical_edge_interval:
a tuple denoting an interval that is known to belong
to the critical edge, i.e. where the reflectivity is
known to be 1.

Returns
---------
:
A list of scaled reflectivity curves or a list of scaling factors.
A list of scaled reflectivity curves and a list of the scaling factors.
'''
if critical_edge_interval is not None:
q = next(iter(curves)).coords['Q']
N = (
((q >= critical_edge_interval[0]) & (q < critical_edge_interval[1]))
.sum()
.value
)
edge = sc.DataArray(
data=sc.ones(dims=('Q',), shape=(N,), with_variances=True),
coords={'Q': sc.linspace('Q', *critical_edge_interval, N + 1)},
)
curves, factors = scale_reflectivity_curves_to_overlap([edge, *curves])
return curves[1:], factors[1:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to set the scale factor as a coord in each curve? Returning a tuple of lists seems a bit error prone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now I prefer the tuple of lists because it's simple and obvious. For example, setting it as a coord requires the user to know the name of the coord to extract the return value. I don't directly see why it would be more error prone, did you have a particular situation in mind?

if len({c.data.unit for c in curves}) != 1:
raise ValueError('The reflectivity curves must have the same unit')
if len({c.coords['Q'].unit for c in curves}) != 1:
Expand All @@ -214,13 +229,11 @@ def cost(scaling_factors):
return np.nansum((r_scaled - r_avg) ** 2 * inv_v_scaled)

sol = opt.minimize(cost, [1.0] * (len(curves) - 1))
scaling_factors = (1.0, *sol.x)
if return_scaling_factors:
return [sc.scalar(x) for x in scaling_factors]
scaling_factors = (1.0, *map(float, sol.x))
return [
scaling_factor * curve
for scaling_factor, curve in zip(scaling_factors, curves, strict=True)
]
], scaling_factors


def combine_curves(
Expand Down
23 changes: 15 additions & 8 deletions tests/tools_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
import scipp as sc
from numpy.testing import assert_allclose as np_assert_allclose
from scipp.testing import assert_allclose

from ess.reflectometry.tools import combine_curves, scale_reflectivity_curves_to_overlap
Expand All @@ -20,16 +21,17 @@ def test_reflectivity_curve_scaling():
)
data.variances[:] = 0.1

curves = scale_reflectivity_curves_to_overlap(
curves, factors = scale_reflectivity_curves_to_overlap(
(curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)),
)

assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5))
assert_allclose(curves[1].data, 0.5 * data, rtol=sc.scalar(1e-5))
assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5))
np_assert_allclose((1, 0.5 / 0.8, 0.25 / 0.1), factors, 1e-4)


def test_reflectivity_curve_scaling_return_factors():
def test_reflectivity_curve_scaling_with_critical_edge():
data = sc.concat(
(
sc.ones(dims=['Q'], shape=[10], with_variances=True),
Expand All @@ -39,14 +41,19 @@ def test_reflectivity_curve_scaling_return_factors():
)
data.variances[:] = 0.1

factors = scale_reflectivity_curves_to_overlap(
(curve(data, 0, 0.3), curve(0.8 * data, 0.2, 0.7), curve(0.1 * data, 0.6, 1.0)),
return_scaling_factors=True,
curves, factors = scale_reflectivity_curves_to_overlap(
(
2 * curve(data, 0, 0.3),
curve(0.8 * data, 0.2, 0.7),
curve(0.1 * data, 0.6, 1.0),
),
critical_edge_interval=(sc.scalar(0.01), sc.scalar(0.05)),
)

assert_allclose(factors[0], sc.scalar(1.0), rtol=sc.scalar(1e-5))
assert_allclose(factors[1], sc.scalar(0.5 / 0.8), rtol=sc.scalar(1e-5))
assert_allclose(factors[2], sc.scalar(0.25 / 0.1), rtol=sc.scalar(1e-5))
assert_allclose(curves[0].data, data, rtol=sc.scalar(1e-5))
assert_allclose(curves[1].data, 0.5 * data, rtol=sc.scalar(1e-5))
assert_allclose(curves[2].data, 0.25 * data, rtol=sc.scalar(1e-5))
np_assert_allclose((0.5, 0.5 / 0.8, 0.25 / 0.1), factors, 1e-4)


def test_combined_curves():
Expand Down