Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax arithmetic rules in case only one audio object is involved #606

Merged
merged 7 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 20 additions & 10 deletions pyfar/classes/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,8 @@ def add(data: tuple, domain='freq'):

The `fft_norm` of the result is as follows

* If only one signal is involved in the operation, the result gets the same
normalization.
* If one signal has the FFT normalization ``'none'``, the results gets
the normalization of the other signal.
* If both signals have the same FFT normalization, the results gets the
Expand Down Expand Up @@ -951,6 +953,8 @@ def subtract(data: tuple, domain='freq'):

The `fft_norm` of the result is as follows

* If only one signal is involved in the operation, the result gets the same
normalization.
* If one signal has the FFT normalization ``'none'``, the results gets
the normalization of the other signal.
* If both signals have the same FFT normalization, the results gets the
Expand Down Expand Up @@ -996,6 +1000,8 @@ def multiply(data: tuple, domain='freq'):

The `fft_norm` of the result is as follows

* If only one signal is involved in the operation, the result gets the same
normalization.
* If one signal has the FFT normalization ``'none'``, the results gets
the normalization of the other signal.
* If both signals have the same FFT normalization, the results gets the
Expand Down Expand Up @@ -1040,6 +1046,8 @@ def divide(data: tuple, domain='freq'):

The `fft_norm` of the result is as follows

* If only one signal is involved in the operation, the result gets the same
normalization.
* If the denominator signal has the FFT normalization ``'none'``, the
result gets the normalization of the numerator signal.
* If both signals have the same FFT normalization, the results gets the
Expand Down Expand Up @@ -1084,6 +1092,8 @@ def power(data: tuple, domain='freq'):

The `fft_norm` of the result is as follows

* If only one signal is involved in the operation, the result gets the same
normalization.
* If one signal has the FFT normalization ``'none'``, the results gets
the normalization of the other signal.
* If both signals have the same FFT normalization, the results gets the
Expand Down Expand Up @@ -1169,6 +1179,8 @@ def matrix_multiplication(

The `fft_norm` of the result is as follows

* If only one signal is involved in the operation, the result gets the same
normalization.
* If one signal has the FFT normalization ``'none'``, the results gets
the normalization of the other signal.
* If both signals have the same FFT normalization, the results gets the
Expand Down Expand Up @@ -1348,26 +1360,25 @@ def _assert_match_for_arithmetic(data: tuple, domain: str, division: bool,
# properties that must match
sampling_rate = None
n_samples = None
fft_norm = 'none'
# None indicates that no audio object is yet involved in the operation
# it will change upon detection of the first audio object
fft_norm = None
times = None
frequencies = None
audio_type = type(None)
cshape = ()

# check input types and meta data
found_audio_data = False
for n, d in enumerate(data):
n_audio_objects = 0
for d in data:
if isinstance(d, (Signal, TimeData, FrequencyData)):
n_audio_objects += 1
# store meta data upon first appearance
if not found_audio_data:
if n_audio_objects == 1:
if isinstance(d, Signal):
sampling_rate = d.sampling_rate
n_samples = d.n_samples
# if a signal comes first (n==0) its fft_norm is taken
# directly. If a signal does not come first, (n>0, e.g.
# 1/signal), the fft norm is matched
fft_norm = d.fft_norm if n == 0 else \
_match_fft_norm(fft_norm, d.fft_norm, division)
fft_norm = d.fft_norm
elif isinstance(d, TimeData):
if domain != "time":
raise ValueError("The domain must be 'time'.")
Expand All @@ -1378,7 +1389,6 @@ def _assert_match_for_arithmetic(data: tuple, domain: str, division: bool,
frequencies = d.frequencies
if not matmul:
cshape = d.cshape
found_audio_data = True
audio_type = type(d)

# check if type and meta data matches after first appearance
Expand Down
33 changes: 25 additions & 8 deletions tests/test_audio_signal_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,20 +214,15 @@ def test_add_arrays():
z, x + y, atol=1e-15)


def test_signal_inversion():
@pytest.mark.parametrize('fft_norm', ['none', 'rms'])
def test_signal_inversion(fft_norm):
"""Test signal inversion with different FFT norms"""

# 'none' norm
signal = pf.Signal([2, 0, 0], 44100, fft_norm='none')
signal = pf.Signal([2, 0, 0], 44100, fft_norm=fft_norm)
signal_inv = 1 / signal
npt.assert_allclose(signal.time.flatten(), [2, 0, 0])
npt.assert_allclose(signal_inv.time.flatten(), [.5, 0, 0])

# 'rms' norm
signal.fft_norm = 'rms'
with raises(ValueError, match="Either fft_norm_2"):
1 / signal


def test_subtraction():
# only test one case - everything else is tested below
Expand Down Expand Up @@ -756,3 +751,25 @@ def test_matrix_multiplication_undocumented():
y = np.ones((3, 2, 10)) * np.array([[1, 2], [3, 4], [5, 6]])[..., None]
pf.matrix_multiplication(
(x, y), domain='time', axes=[(-2, -1), (-3, -2), (-2, -1)])


@pytest.mark.parametrize('audio_object', [
pf.Signal([1, -1, 1], 1, fft_norm='none'),
pf.Signal([1, -1, 1], 1, fft_norm='rms'),
pf.FrequencyData([1, -1, 1], [0, 1, 3]),
pf.TimeData([1, -1, 1], [1, 2, 3])])
@pytest.mark.parametrize('operation', [
pf.add, pf.subtract, pf.multiply, pf.divide, pf.power])
def test_audio_object_and_number(audio_object, operation):
"""
Test if arithmetic operations work regardless of the fft norm and
audio object type if only one audio object is involved.
"""

domain = 'time' if type(audio_object) is pf.TimeData else 'freq'

result = operation((1, audio_object), domain=domain)
assert type(result) is type(audio_object)

result = operation((audio_object, 1), domain=domain)
assert type(result) is type(audio_object)