-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathutils.py
257 lines (208 loc) · 7.19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# -*- mode: python; coding: utf-8 -*-
#
# Copyright (C) 2023 Benjamin Thomas Schwertfeger
# All rights reserved.
# https://github.com/btschwertfeger
#
"""Module providing utility functions"""
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Optional, Union
import numpy as np
from cmethods.types import NPData_t, XRData_t
if TYPE_CHECKING:
from cmethods.types import NPData, XRData
class UnknownMethodError(Exception):
"""
Exception raised for errors if unknown method called in CMethods class.
"""
def __init__(self: UnknownMethodError, method: str, available_methods: list):
super().__init__(
f'Unknown method "{method}"! Available methods: {available_methods}',
)
def check_adjust_called(
function_name: str,
adjust_called: Optional[bool] = None, # noqa: FBT001
) -> None:
"""
Displays a user warning in case a correction function was not called via
`CMethods.adjust`.
:param adjust_called: If the function was called via adjust
:type adjust_called: Optional[bool]
:param function_name: The function that was called
:type function_name: str
"""
if not adjust_called:
warnings.warn(
message=f"Do not call {function_name} directly, use `CMethods.adjust` instead!",
category=UserWarning,
stacklevel=1,
)
def ensure_xr_dataarray(obs: XRData, simh: XRData, simp: XRData) -> None:
"""
Checks if the parameters are in the correct type. **only used internally**
"""
phrase: str = "must be type 'xarray.core.dataarray.DataArray'."
if not isinstance(obs, XRData_t):
raise TypeError(f"'obs' {phrase}")
if not isinstance(simh, XRData_t):
raise TypeError(f"'simh' {phrase}")
if not isinstance(simp, XRData_t):
raise TypeError(f"'simp' {phrase}")
def check_np_types(
obs: NPData,
simh: NPData,
simp: NPData,
) -> None:
"""
Checks if the parameters are in the correct type. **only used internally**
"""
phrase: str = "must be type list, np.ndarray, or np.generic"
if not isinstance(obs, NPData_t):
raise TypeError(f"'obs' {phrase}")
if not isinstance(simh, NPData_t):
raise TypeError(f"'simh' {phrase}")
if not isinstance(simp, NPData_t):
raise TypeError(f"'simp' {phrase}")
def nan_or_equal(value1: float, value2: float) -> bool:
"""
Returns True if the values are equal or at least one is NaN
:param value1: First value to check
:type value1: float
:param value2: Second value to check
:type value2: float
:return: If any value is NaN or values are equal
:rtype: bool
"""
return np.isnan(value1) or np.isnan(value2) or value1 == value2
def ensure_dividable(
numerator: Union[float, np.ndarray],
denominator: Union[float, np.ndarray],
max_scaling_factor: float,
) -> np.ndarray:
"""
Ensures that the arrays can be divided. The numerator will be multiplied by
the maximum scaling factor of the CMethods class if division by zero.
:param numerator: Numerator to use
:type numerator: np.ndarray
:param denominator: Denominator that can be zero
:type denominator: np.ndarray
:return: Zero-ensured division
:rtype: np.ndarray | float
"""
with np.errstate(divide="ignore", invalid="ignore"):
result = numerator / denominator
if isinstance(numerator, np.ndarray):
mask_inf = np.isinf(result)
result[mask_inf] = numerator[mask_inf] * max_scaling_factor # type: ignore[index]
mask_nan = np.isnan(result)
result[mask_nan] = 0 # type: ignore[index]
elif np.isinf(result):
result = numerator * max_scaling_factor
elif np.isnan(result):
result = 0.0
return result
def get_pdf(
x: Union[list, np.ndarray],
xbins: Union[list, np.ndarray],
) -> np.ndarray:
r"""
Compuites and returns the the probability density function :math:`P(x)`
of ``x`` based on ``xbins``.
:param x: The vector to get :math:`P(x)` from
:type x: list | np.ndarray
:param xbins: The boundaries/bins of :math:`P(x)`
:type xbins: list | np.ndarray
:return: The probability densitiy function of ``x``
:rtype: np.ndarray
.. code-block:: python
:linenos:
:caption: Compute the probability density function :math:`P(x)`
>>> from cmethods get_pdf
>>> x = [1, 2, 3, 4, 5, 5, 5, 6, 7, 8, 9, 10]
>>> xbins = [0, 3, 6, 10]
>>> print(get_pdf(x=x, xbins=xbins))
[2, 5, 5]
"""
pdf, _ = np.histogram(x, xbins)
return pdf
def get_cdf(
x: Union[list, np.ndarray],
xbins: Union[list, np.ndarray],
) -> np.ndarray:
r"""
Computes and returns returns the cumulative distribution function :math:`F(x)`
of ``x`` based on ``xbins``.
:param x: Vector to get :math:`F(x)` from
:type x: list | np.ndarray
:param xbins: The boundaries/bins of :math:`F(x)`
:type xbins: list | np.ndarray
:return: The cumulative distribution function of ``x``
:rtype: np.ndarray
.. code-block:: python
:linenos:
:caption: Compute the cumulative distribution function :math:`F(x)`
>>> from cmethods.utils import get_cdf
>>> x = [1, 2, 3, 4, 5, 5, 5, 6, 7, 8, 9, 10]
>>> xbins = [0, 3, 6, 10]
>>> print(get_cdf(x=x, xbins=xbins))
[0.0, 0.16666667, 0.58333333, 1.]
"""
pdf, _ = np.histogram(x, xbins)
cdf = np.insert(np.cumsum(pdf), 0, 0.0)
return cdf / cdf[-1]
def get_inverse_of_cdf(
base_cdf: Union[list, np.ndarray],
insert_cdf: Union[list, np.ndarray],
xbins: Union[list, np.ndarray],
) -> np.ndarray:
r"""
Returns the inverse cumulative distribution function as:
:math:`F^{-1}_{x}\left[y\right]` where :math:`x` represents ``base_cdf`` and
``insert_cdf`` is represented by :math:`y`.
:param base_cdf: The basis
:type base_cdf: list | np.ndarray
:param insert_cdf: The CDF that gets inserted
:type insert_cdf: list | np.ndarray
:param xbins: Probability boundaries
:type xbins: list | np.ndarray
:return: The inverse CDF
:rtype: np.ndarray
"""
return np.interp(insert_cdf, base_cdf, xbins)
def get_adjusted_scaling_factor(
factor: float,
max_scaling_factor: float,
) -> float:
r"""
Returns:
- :math:`x` if :math:`-|y| \le x \le |y|`,
- :math:`|y|` if :math:`x > |y|`, or
- :math:`-|y|` if :math:`x < -|y|`
where:
- :math:`x` is ``factor``
- :math:`y` is ``max_scaling_factor``
:param factor: The value to check for
:type factor: int | float
:param max_scaling_factor: The maximum/minimum allowed value
:type max_scaling_factor: int | float
:return: The correct value
:rtype: float
"""
if factor > 0 and factor > abs(max_scaling_factor):
return abs(max_scaling_factor)
if factor < 0 and factor < -abs(max_scaling_factor):
return -abs(max_scaling_factor)
return factor
__all__ = [
"UnknownMethodError",
"check_adjust_called",
"check_np_types",
"ensure_dividable",
"ensure_xr_dataarray",
"get_adjusted_scaling_factor",
"get_cdf",
"get_inverse_of_cdf",
"get_pdf",
"nan_or_equal",
]