Skip to content

Commit

Permalink
More carve-outs.
Browse files Browse the repository at this point in the history
  • Loading branch information
psobot committed Jun 9, 2023
1 parent e44732f commit ec171b6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 37 deletions.
34 changes: 33 additions & 1 deletion tests/callbacks/test_spectrogram_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,19 @@

from typing import Any

import platform
import pytest
import numpy as np
import tensorflow as tf

from realbook.callbacks.spectrogram_visualization import SpectrogramVisualizationCallback
try:
from realbook.callbacks.spectrogram_visualization import SpectrogramVisualizationCallback
except ImportError as e:
if "numpy.core.multiarray failed to import" in str(e) and platform.system() == "Windows":
SpectrogramVisualizationCallback = None # type: ignore
else:
raise

from realbook.layers.signal import Spectrogram


Expand Down Expand Up @@ -52,6 +60,10 @@ def flush(self) -> None:
TEST_AUDIO = np.linspace(0, 1, num=DEFAULT_SAMPLE_RATE * 10)


@pytest.mark.skipif(
SpectrogramVisualizationCallback is None,
reason="SpectrogramVisualizationCallback import fails on this platform",
)
def test_spectrogram_visualization_callback() -> None:
fake_data = tf.data.Dataset.zip(
(
Expand Down Expand Up @@ -80,6 +92,10 @@ def test_spectrogram_visualization_callback() -> None:
assert True


@pytest.mark.skipif(
SpectrogramVisualizationCallback is None,
reason="SpectrogramVisualizationCallback import fails on this platform",
)
def test_callback_fails_on_unbatched_input() -> None:
fake_data = tf.data.Dataset.zip(
(
Expand Down Expand Up @@ -110,6 +126,10 @@ def test_callback_fails_on_unbatched_input() -> None:
assert "shape" in str(excinfo.value)


@pytest.mark.skipif(
SpectrogramVisualizationCallback is None,
reason="SpectrogramVisualizationCallback import fails on this platform",
)
def test_callback_logs_but_doesnt_throw_by_default(caplog: pytest.LogCaptureFixture) -> None:
fake_data = tf.data.Dataset.zip(
(
Expand All @@ -133,6 +153,10 @@ def test_callback_logs_but_doesnt_throw_by_default(caplog: pytest.LogCaptureFixt
assert "shape" in caplog.text


@pytest.mark.skipif(
SpectrogramVisualizationCallback is None,
reason="SpectrogramVisualizationCallback import fails on this platform",
)
def test_fails_on_no_image_like_layers() -> None:
fake_data = tf.data.Dataset.zip(
(
Expand Down Expand Up @@ -162,6 +186,10 @@ def test_fails_on_no_image_like_layers() -> None:
assert "spectrogram" in str(excinfo.value)


@pytest.mark.skipif(
SpectrogramVisualizationCallback is None,
reason="SpectrogramVisualizationCallback import fails on this platform",
)
def test_flexible_with_input_shapes() -> None:
fake_data = tf.data.Dataset.zip(
(
Expand Down Expand Up @@ -192,6 +220,10 @@ def test_flexible_with_input_shapes() -> None:
assert True


@pytest.mark.skipif(
SpectrogramVisualizationCallback is None,
reason="SpectrogramVisualizationCallback import fails on this platform",
)
def test_keras_functional_api_with_tfop_lambda() -> None:
fake_data = tf.data.Dataset.zip(
(
Expand Down
43 changes: 7 additions & 36 deletions tests/layers/test_nnaudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,53 +23,21 @@

try:
import librosa
from realbook.layers import nnaudio as our_nnaudio
except ImportError as e:
if "numpy.core.multiarray failed to import" in str(e) and platform.system() == "Windows":
librosa = None
our_nnaudio = None # type: ignore
else:
raise

from typing import List, Tuple, Union
from typing import Tuple, Union

from realbook.layers import nnaudio as our_nnaudio
from nnAudio.Spectrogram import CQT2010v2

TEST_SAMPLE_RATE = 22050


# Test using this model directly, as well as wrapping it in a Lambda layer.
def get_parameterized_model_variants(
match_torch_exactly_values: Tuple[bool, bool] = (True, False)
) -> List[tf.keras.layers.Layer]:
possible_models = [
our_nnaudio.CQT(match_torch_exactly=v, trainable=trainable)
for v in match_torch_exactly_values
for trainable in (True, False)
]

return [
item
for models in [
[tf.keras.Sequential([tf.keras.layers.InputLayer((TEST_SAMPLE_RATE,)), model])]
+ (
[
tf.keras.Sequential(
[
tf.keras.layers.InputLayer((TEST_SAMPLE_RATE,)),
tf.keras.layers.Lambda(lambda x: model(x)),
]
)
]
# Using a layer with trainable weights inside a Lambda layer isn't supported.
if not model.trainable
else []
)
for model in possible_models
]
for item in models
]


@pytest.mark.skipif(librosa is None, reason="Librosa failed to import on this platform.")
@pytest.mark.parametrize(
"match_torch_exactly,threshold,trainable",
Expand All @@ -91,12 +59,14 @@ def test_cqt(match_torch_exactly: bool, threshold: float, trainable: bool) -> No


def build_layer(
layer: tf.keras.layers.Layer, input_shape: Union[Tuple[int], Tuple[int, int]] = (1, TEST_SAMPLE_RATE)
layer: tf.keras.layers.Layer,
input_shape: Union[Tuple[int], Tuple[int, int]] = (1, TEST_SAMPLE_RATE),
) -> tf.keras.layers.Layer:
layer.build(input_shape)
return layer


@pytest.mark.skipif(our_nnaudio is None, reason="nnaudio failed to import on this platform.")
def test_cqt_trainable_weights() -> None:
assert not build_layer(our_nnaudio.CQT(trainable=False)).trainable
assert not build_layer(our_nnaudio.CQT(trainable=False)).trainable_weights
Expand All @@ -107,6 +77,7 @@ def test_cqt_trainable_weights() -> None:


@pytest.mark.skipif(librosa is None, reason="Librosa failed to import on this platform.")
@pytest.mark.skipif(our_nnaudio is None, reason="nnaudio failed to import on this platform.")
@pytest.mark.parametrize("train", (True, False))
def test_cqt_trainable_layers_change_on_training(train: bool) -> None:
# Make a model that's trainable, then train it and ensure the weights change from the default.
Expand Down

0 comments on commit ec171b6

Please sign in to comment.