diff --git a/sfs/util.py b/sfs/util.py index 006bf47f..f91c4a6a 100644 --- a/sfs/util.py +++ b/sfs/util.py @@ -361,7 +361,7 @@ def db(x, power=False): """ with np.errstate(divide='ignore'): - return 10 if power else 20 * np.log10(np.abs(x)) + return (10 if power else 20) * np.log10(np.abs(x)) class XyzComponents(np.ndarray): diff --git a/tests/test_util.py b/tests/test_util.py index 89ff89d1..d46fd76c 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,14 +4,15 @@ import sfs cart_sph_data = [ - ((1, 1, 1), (np.pi/4, np.arccos(1/np.sqrt(3)), np.sqrt(3))), - ((-1, 1, 1), (np.arctan2(1, -1), np.arccos(1/np.sqrt(3)), np.sqrt(3))), - ((1, -1, 1), (-np.pi/4, np.arccos(1/np.sqrt(3)), np.sqrt(3))), - ((-1, -1, 1), (np.arctan2(-1, -1), np.arccos(1/np.sqrt(3)), np.sqrt(3))), - ((1, 1, -1), (np.pi/4, np.arccos(-1/np.sqrt(3)), np.sqrt(3))), - ((-1, 1, -1), (np.arctan2(1, -1), np.arccos(-1/np.sqrt(3)), np.sqrt(3))), - ((1, -1, -1), (-np.pi/4, np.arccos(-1/np.sqrt(3)), np.sqrt(3))), - ((-1, -1, -1), (np.arctan2(-1, -1), np.arccos(-1/np.sqrt(3)), np.sqrt(3))), + ((1, 1, 1), (np.pi / 4, np.arccos(1 / np.sqrt(3)), np.sqrt(3))), + ((-1, 1, 1), (np.arctan2(1, -1), np.arccos(1 / np.sqrt(3)), np.sqrt(3))), + ((1, -1, 1), (-np.pi / 4, np.arccos(1 / np.sqrt(3)), np.sqrt(3))), + ((-1, -1, 1), (np.arctan2(-1, -1), np.arccos(1 / np.sqrt(3)), np.sqrt(3))), + ((1, 1, -1), (np.pi / 4, np.arccos(-1 / np.sqrt(3)), np.sqrt(3))), + ((-1, 1, -1), (np.arctan2(1, -1), np.arccos(-1 / np.sqrt(3)), np.sqrt(3))), + ((1, -1, -1), (-np.pi / 4, np.arccos(-1 / np.sqrt(3)), np.sqrt(3))), + ((-1, -1, -1), (np.arctan2(-1, -1), + np.arccos(-1 / np.sqrt(3)), np.sqrt(3))), ] @@ -27,3 +28,43 @@ def test_sph2cart(coord, polar): alpha, beta, r = polar b = sfs.util.sph2cart(alpha, beta, r) assert_allclose(b, coord) + + +direction_vector_data = [ + ((np.pi / 4, np.pi / 4), (0.5, 0.5, np.sqrt(2) / 2)), + ((3 * np.pi / 4, 3 * np.pi / 4), (-1 / 2, 1 / 2, -np.sqrt(2) / 2)), + ((3 * np.pi / 4, -3 * np.pi / 4), (1 / 2, -1 / 2, -np.sqrt(2) / 2)), + ((np.pi / 4, -np.pi / 4), (-1 / 2, -1 / 2, np.sqrt(2) / 2)), + ((-np.pi / 4, np.pi / 4), (1 / 2, -1 / 2, np.sqrt(2) / 2)), + ((-3 * np.pi / 4, 3 * np.pi / 4), (-1 / 2, -1 / 2, -np.sqrt(2) / 2)), + ((-3 * np.pi / 4, -3 * np.pi / 4), (1 / 2, 1 / 2, -np.sqrt(2) / 2)), + ((-np.pi / 4, -np.pi / 4), (-1 / 2, 1 / 2, np.sqrt(2) / 2)), +] + + +@pytest.mark.parametrize('input, vector', direction_vector_data) +def test_direction_vector(input, vector): + alpha, beta = input + c = sfs.util.direction_vector(alpha, beta) + assert_allclose(c, vector) + + +db_data = [ + (0, -np.inf), + (0.5, -3.01029995663981), + (1, 0), + (2, 3.01029995663981), + (10, 10), +] + + +@pytest.mark.parametrize('linear, power_db', db_data) +def test_db_amplitude(linear, power_db): + d = sfs.util.db(linear) + assert_allclose(d, power_db * 2) + + +@pytest.mark.parametrize('linear, power_db', db_data) +def test_db_power(linear, power_db): + d = sfs.util.db(linear, power=True) + assert_allclose(d, power_db)