diff --git a/src/pals/parameters/MagneticMultipoleParameters.py b/src/pals/parameters/MagneticMultipoleParameters.py index a1b7dd0..7c4b176 100644 --- a/src/pals/parameters/MagneticMultipoleParameters.py +++ b/src/pals/parameters/MagneticMultipoleParameters.py @@ -1,62 +1,67 @@ from pydantic import BaseModel, ConfigDict, model_validator -from typing import Any, Dict +from typing import Any + +# Valid parameter prefixes, their expected format and description +_PARAMETER_PREFIXES = { + "tilt": ("tiltN", "Tilt"), + "Bn": ("BnN", "Normal component"), + "Bs": ("BsN", "Skew component"), + "Kn": ("KnN", "Normalized normal component"), + "Ks": ("KsN", "Normalized skew component"), +} + + +def _validate_order( + key_num: str, parameter_name: str, prefix: str, expected_format: str +) -> None: + """Validate that the order number is a non-negative integer without leading zeros.""" + error_msg = ( + f"Invalid {parameter_name}: '{prefix}{key_num}'. " + f"Parameter must be of the form '{expected_format}', where 'N' is a non-negative integer without leading zeros." + ) + if not key_num.isdigit() or (key_num.startswith("0") and key_num != "0"): + raise ValueError(error_msg) class MagneticMultipoleParameters(BaseModel): - """Magnetic multipole parameters""" + """Magnetic multipole parameters - # Allow arbitrary fields - model_config = ConfigDict(extra="allow") + Valid parameter formats: + - tiltN: Tilt of Nth order multipole + - BnN: Normal component of Nth order multipole + - BsN: Skew component of Nth order multipole + - KnN: Normalized normal component of Nth order multipole + - KsN: Normalized skew component of Nth order multipole + - *NL: Length-integrated versions of components (e.g., Bn3L, KsNL) + + Where N is a positive integer without leading zeros (except "0" itself). + """ - # Custom validation of magnetic multipole order - def _validate_order(key_num, msg): - if key_num.isdigit(): - if key_num.startswith("0") and key_num != "0": - raise ValueError(msg) - else: - raise ValueError(msg) + model_config = ConfigDict(extra="allow") - # Custom validation to be applied before standard validation @model_validator(mode="before") - def validate(cls, values: Dict[str, Any]) -> Dict[str, Any]: - # loop over all attributes + @classmethod + def validate(cls, values: dict[str, Any]) -> dict[str, Any]: + """Validate all parameter names match the expected multipole format.""" for key in values: - # validate tilt parameters 'tiltN' - if key.startswith("tilt"): - key_num = key[4:] - msg = " ".join( - [ - f"Invalid tilt parameter: '{key}'.", - "Tilt parameter must be of the form 'tiltN', where 'N' is an integer.", - ] - ) - cls._validate_order(key_num, msg) - # validate normal component parameters 'BnN' - elif key.startswith("Bn"): - key_num = key[2:] - msg = " ".join( - [ - f"Invalid normal component parameter: '{key}'.", - "Normal component parameter must be of the form 'BnN', where 'N' is an integer.", - ] - ) - cls._validate_order(key_num, msg) - # validate skew component parameters 'BsN' - elif key.startswith("Bs"): - key_num = key[2:] - msg = " ".join( - [ - f"Invalid skew component parameter: '{key}'.", - "Skew component parameter must be of the form 'BsN', where 'N' is an integer.", - ] - ) - cls._validate_order(key_num, msg) + # Check if key ends with 'L' for length-integrated values + is_length_integrated = key.endswith("L") + base_key = key[:-1] if is_length_integrated else key + + # No length-integrated values allowed for tilt parameter + if is_length_integrated and base_key.startswith("tilt"): + raise ValueError(f"Invalid magnetic multipole parameter: '{key}'. ") + + # Find matching prefix + for prefix, (expected_format, description) in _PARAMETER_PREFIXES.items(): + if base_key.startswith(prefix): + key_num = base_key[len(prefix) :] + _validate_order(key_num, description, prefix, expected_format) + break else: - msg = " ".join( - [ - f"Invalid magnetic multipole parameter: '{key}'.", - "Magnetic multipole parameters must be of the form 'tiltN', 'BnN', or 'BsN', where 'N' is an integer.", - ] + raise ValueError( + f"Invalid magnetic multipole parameter: '{key}'. " + f"Parameters must be of the form 'tiltN', 'BnN', 'BsN', 'KnN', or 'KsN' " + f"(with optional 'L' suffix for length-integrated), where 'N' is a non-negative integer." ) - raise ValueError(msg) return values diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 21eb8f9..ba904e9 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -43,10 +43,19 @@ def test_ParameterClasses(): # assert emp.En1 == 1.0 # Test MagneticMultipoleParameters - mmp = MagneticMultipoleParameters(Bn1=1.0, Bs1=0.5) + mmp = MagneticMultipoleParameters(tilt1=1.2, Bn1=1.0, Bs1=0.5) + assert mmp.tilt1 == 1.2 assert mmp.Bn1 == 1.0 assert mmp.Bs1 == 0.5 + mmp2 = MagneticMultipoleParameters(Kn0=1.0, Ks1=0.5) + assert mmp2.Kn0 == 1.0 + assert mmp2.Ks1 == 0.5 + + mmp3 = MagneticMultipoleParameters(Bn1L=1.0, Bs1L=0.5) + assert mmp3.Bn1L == 1.0 + assert mmp3.Bs1L == 0.5 + # catch typos with pytest.raises(ValidationError): _ = MagneticMultipoleParameters(Bm1=1.0, Bs1=0.5) @@ -54,6 +63,10 @@ def test_ParameterClasses(): _ = MagneticMultipoleParameters(Bn1=1.0, Bv1=0.5) with pytest.raises(ValidationError): _ = MagneticMultipoleParameters(Bn01=1.0, Bs01=0.5) + with pytest.raises(ValidationError): + _ = MagneticMultipoleParameters(Bn1v=1.0, Bs1l=0.5) + with pytest.raises(ValidationError): + _ = MagneticMultipoleParameters(tilt1L=1.2) # Test SolenoidParameters sol = SolenoidParameters(Ksol=0.1, Bsol=0.2)