Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions src/ess/imaging/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def calculate_scale_factor(
def apply_threshold_to_sample_images(
samples: CleansedSampleImages, sample_threshold: SamplePixelThreshold
) -> SampleImageStacks:
"""Apply the threshold to the sample image stack.
"""Apply a mask based on the threshold to the sample image stack.

Parameters
----------
Expand All @@ -190,21 +190,18 @@ def apply_threshold_to_sample_images(

sample_threshold:
Threshold for the sample pixel values.
Any pixel values less than ``sample_threshold``
are replaced with ``sample_threshold``.
Any pixel values less than ``sample_threshold`` will be masked.

"""
samples = CleansedSampleImages(samples.copy(deep=False))
samples.data = sc.where(
samples.data < sample_threshold, sample_threshold, samples.data
return SampleImageStacks(
samples.assign_masks(counts=samples.data < sample_threshold)
)
return SampleImageStacks(samples)


def apply_threshold_to_background_image(
background: CleansedOpenBeamImage, background_threshold: BackgroundPixelThreshold
) -> BackgroundImage:
"""Apply the threshold to the background image.
"""Apply a mask based on the threshold to the background image.

Parameters
----------
Expand All @@ -213,15 +210,12 @@ def apply_threshold_to_background_image(

background_threshold:
Threshold for the background pixel values.
Any pixel values less than ``background_threshold``
are replaced with ``background_threshold``.
Any pixel values less than ``background_threshold`` will be masked.

"""
background = CleansedOpenBeamImage(background.copy(deep=False))
background.data = sc.where(
background.data < background_threshold, background_threshold, background.data
return BackgroundImage(
background.assign_masks(counts=background.data < background_threshold)
)
return BackgroundImage(background)


def normalize_sample_images(
Expand Down
42 changes: 34 additions & 8 deletions tests/image_normalize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,32 +188,43 @@ def test_apply_threshold_to_sample_images() -> None:
},
)
threshold = sc.scalar(1.0, unit="counts")
mask = sc.array(
dims=["time", "dim_1", "dim_2"],
values=[[[False, False], [False, True]], [[False, False], [False, True]]],
)
thresholded_sample_images = apply_threshold_to_sample_images(
CleansedSampleImages(sample_images_with_negative_values),
SamplePixelThreshold(threshold),
)
with pytest.raises(AssertionError, match="Arrays are not equal"):
assert_identical(sample_images_with_negative_values.data.min(), threshold)
assert_identical(thresholded_sample_images.data.min(), threshold)
assert_identical(
thresholded_sample_images,
sample_images_with_negative_values.assign_masks(counts=mask),
)


def test_apply_threshold_to_background_image() -> None:
background_image_with_negative_values = sc.DataArray(
background_image_with_zeros = sc.DataArray(
data=sc.array(
dims=["dim_1", "dim_2"],
values=[[3.0, 3.0], [3.0, -1.0]],
values=[[3.0, 3.0], [3.0, 0.0]],
unit="counts",
),
coords={},
)
threshold = sc.scalar(1.0, unit="counts")
mask = sc.array(dims=["dim_1", "dim_2"], values=[[False, False], [False, True]])
thresholded_background_image = apply_threshold_to_background_image(
CleansedOpenBeamImage(background_image_with_negative_values),
CleansedOpenBeamImage(background_image_with_zeros),
BackgroundPixelThreshold(threshold),
)
with pytest.raises(AssertionError, match="Arrays are not equal"):
assert_identical(background_image_with_negative_values.data.min(), threshold)
assert_identical(thresholded_background_image.data.min(), threshold)
assert_identical(background_image_with_zeros.data.min(), threshold)
assert_identical(
thresholded_background_image,
background_image_with_zeros.assign_masks(counts=mask),
)


def test_normalize_negative_scale_factor_raises(
Expand Down Expand Up @@ -242,14 +253,29 @@ def test_normalize_workflow(
data=sc.array(
dims=["time", "dim_1", "dim_2"],
values=[
[[1 / 3 * (5 / 3), 1 / 3 * (5 / 3)], [1 / 3 * (5 / 3), 0.0]],
[[3 / 3 * (5 / 3), 3 / 3 * (5 / 3)], [3 / 3 * (5 / 3), 0.0]],
[
[(2 - 1) / ((4 - 1) / 1.6), (2 - 1) / ((4 - 1) / 1.6)],
[(2 - 1) / ((4 - 1) / 1.6), (0 - 1) / ((0 - 1) / 1.6)],
],
[
[(4 - 1) / ((4 - 1) / 1.6), (4 - 1) / ((4 - 1) / 1.6)],
[(4 - 1) / ((4 - 1) / 1.6), (0 - 1) / ((0 - 1) / 1.6)],
],
],
unit="counts",
),
coords={
"time": sc.array(dims=["time"], values=[1, 2], unit="s"),
},
masks={
"counts": sc.array(
dims=["time", "dim_1", "dim_2"],
values=[
[[False, False], [False, True]],
[[False, False], [False, True]],
],
)
},
)

wf = YmirImageNormalizationWorkflow()
Expand Down
Loading