diff --git a/kwave/enums.py b/kwave/enums.py index d17fdb3d6..931cbea55 100644 --- a/kwave/enums.py +++ b/kwave/enums.py @@ -1,5 +1,17 @@ from enum import Enum + +class AlphaMode(str, Enum): + """Controls which absorption/dispersion terms are included in the equation of state.""" + + NO_ABSORPTION = "no_absorption" + NO_DISPERSION = "no_dispersion" + STOKES = "stokes" + + def __str__(self): + return self.value + + ################################################################ # literals that link the discrete cosine and sine transform types with # their type definitions in the functions dtt1D, dtt2D, and dtt3D diff --git a/kwave/kmedium.py b/kwave/kmedium.py index 2886299d3..8f554db70 100644 --- a/kwave/kmedium.py +++ b/kwave/kmedium.py @@ -1,10 +1,23 @@ import logging from dataclasses import dataclass -from typing import List +from typing import List, Optional, Union import numpy as np import kwave.utils.checks +from kwave.enums import AlphaMode + + +def _to_alpha_mode(value): + """Normalize a value to AlphaMode. Accepts None, AlphaMode, or a valid string.""" + if value is None or isinstance(value, AlphaMode): + return value + try: + return AlphaMode(value) + except (ValueError, TypeError): + raise ValueError( + f"medium.alpha_mode must be an AlphaMode enum value or one of " f"'no_absorption', 'no_dispersion', 'stokes', got {value!r}" + ) from None @dataclass @@ -20,8 +33,8 @@ class kWaveMedium(object): # power law absorption exponent alpha_power: np.array = None # optional input to force either the absorption or dispersion terms in the equation of state to be excluded; - # valid inputs are 'no_absorption' or 'no_dispersion' - alpha_mode: np.array = None + # valid inputs are AlphaMode.NO_ABSORPTION, AlphaMode.NO_DISPERSION, or the equivalent strings + alpha_mode: Optional[Union[AlphaMode, str]] = None # frequency domain filter applied to the absorption and dispersion terms in the equation of state alpha_filter: np.array = None # two element array used to control the sign of absorption and dispersion terms in the equation of state @@ -43,6 +56,7 @@ class kWaveMedium(object): def __post_init__(self): self.sound_speed = np.atleast_1d(self.sound_speed) + self.alpha_mode = _to_alpha_mode(self.alpha_mode) def check_fields(self, kgrid_shape: np.ndarray) -> None: """ @@ -54,13 +68,8 @@ def check_fields(self, kgrid_shape: np.ndarray) -> None: Returns: None """ - # check the absorption mode input is valid - if self.alpha_mode is not None: - assert self.alpha_mode in [ - "no_absorption", - "no_dispersion", - "stokes", - ], "medium.alpha_mode must be set to 'no_absorption', 'no_dispersion', or 'stokes'." + # re-normalize alpha_mode in case it was reassigned as a plain string post-construction + self.alpha_mode = _to_alpha_mode(self.alpha_mode) # check the absorption filter input is valid if self.alpha_filter is not None and not (self.alpha_filter.shape == kgrid_shape).all(): diff --git a/tests/test_kmedium.py b/tests/test_kmedium.py new file mode 100644 index 000000000..cbfdff60d --- /dev/null +++ b/tests/test_kmedium.py @@ -0,0 +1,46 @@ +"""Tests for kWaveMedium alpha_mode normalization and validation.""" +import numpy as np +import pytest + +from kwave.enums import AlphaMode +from kwave.kmedium import kWaveMedium + + +class TestAlphaModeNormalization: + def test_default_is_none(self): + m = kWaveMedium(sound_speed=1500) + assert m.alpha_mode is None + + def test_enum_passes_through(self): + m = kWaveMedium(sound_speed=1500, alpha_mode=AlphaMode.NO_DISPERSION) + assert m.alpha_mode is AlphaMode.NO_DISPERSION + + @pytest.mark.parametrize("value", ["no_absorption", "no_dispersion", "stokes"]) + def test_valid_string_normalized_at_construction(self, value): + m = kWaveMedium(sound_speed=1500, alpha_mode=value) + assert isinstance(m.alpha_mode, AlphaMode) + assert m.alpha_mode == value + + def test_invalid_string_at_construction_raises_friendly_error(self): + with pytest.raises(ValueError, match="must be an AlphaMode"): + kWaveMedium(sound_speed=1500, alpha_mode="garbage") + + def test_post_construction_string_assignment_accepted_by_check_fields(self): + m = kWaveMedium(sound_speed=1500, alpha_coeff=np.array(0.5), alpha_power=1.5) + m.alpha_mode = "no_dispersion" # plain string per type hint + m.check_fields(np.array([64, 64])) + # check_fields normalizes for downstream consumers + assert isinstance(m.alpha_mode, AlphaMode) + assert m.alpha_mode == "no_dispersion" + + def test_post_construction_invalid_string_rejected_by_check_fields(self): + m = kWaveMedium(sound_speed=1500, alpha_coeff=np.array(0.5), alpha_power=1.5) + m.alpha_mode = "garbage" + with pytest.raises(ValueError, match="must be an AlphaMode"): + m.check_fields(np.array([64, 64])) + + def test_string_comparison_still_works(self): + # AlphaMode inherits from str, so == against raw strings must keep working + m = kWaveMedium(sound_speed=1500, alpha_mode="no_dispersion") + assert m.alpha_mode == "no_dispersion" + assert m.alpha_mode in ["no_absorption", "no_dispersion"]