Skip to content

Commit

Permalink
Moved test out run_in_graph_and_eager_mode in wrapper (#1513)
Browse files Browse the repository at this point in the history
  • Loading branch information
autoih committed Mar 31, 2020
1 parent 30282a5 commit 8749ffd
Showing 1 changed file with 55 additions and 40 deletions.
95 changes: 55 additions & 40 deletions tensorflow_addons/layers/wrappers_test.py
Expand Up @@ -110,48 +110,63 @@ def test_serialization(self, base_layer, rnn):
# After serialization: tensorflow.python.keras.layers.recurrent.LSTM
self.assertTrue(isinstance(new_wn_layer.layer, base_layer.__class__))

@parameterized.named_parameters(
["Dense", lambda: tf.keras.layers.Dense(1), [1]],
["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), [None, 10]],
["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]],
["LSTM", lambda: tf.keras.layers.LSTM(1), [10, 10]],
)
def test_model_build(self, base_layer_fn, input_shape):
inputs = tf.keras.layers.Input(shape=input_shape)
for data_init in [True, False]:
base_layer = base_layer_fn()
wt_layer = wrappers.WeightNormalization(base_layer, data_init)
model = tf.keras.models.Sequential(layers=[inputs, wt_layer])
model.build()

@parameterized.named_parameters(
["Dense", lambda: tf.keras.layers.Dense(1), [1]],
["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), [10, 10]],
["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]],
["LSTM", lambda: tf.keras.layers.LSTM(1), [10, 10]],
)
def test_save_file_h5(self, base_layer, input_shape):
base_layer = base_layer()
wn_conv = wrappers.WeightNormalization(base_layer)
model = tf.keras.Sequential(layers=[wn_conv])
model.build([None] + input_shape)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_weights(os.path.join(tmp_dir, "wrapper_test_model.h5"))
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize("data_init", [True, False])
@pytest.mark.parametrize(
"base_layer_fn, input_shape",
[
(lambda: tf.keras.layers.Dense(1), [1]),
(lambda: tf.keras.layers.SimpleRNN(1), [None, 10]),
(lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]),
(lambda: tf.keras.layers.LSTM(1), [10, 10]),
],
)
def test_model_build(base_layer_fn, input_shape, data_init):
inputs = tf.keras.layers.Input(shape=input_shape)
base_layer = base_layer_fn()
wt_layer = wrappers.WeightNormalization(base_layer, data_init)
model = tf.keras.models.Sequential(layers=[inputs, wt_layer])
model.build()

@parameterized.named_parameters(
["Dense", lambda: tf.keras.layers.Dense(1), [1]],
["SimpleRNN", lambda: tf.keras.layers.SimpleRNN(1), [10, 10]],
["Conv2D", lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]],
["LSTM", lambda: tf.keras.layers.LSTM(1), [10, 10]],
)
def test_forward_pass(self, base_layer, input_shape):
sample_data = np.ones([1] + input_shape, dtype=np.float32)
base_layer = base_layer()
base_output = base_layer(sample_data)
wn_layer = wrappers.WeightNormalization(base_layer, False)
wn_output = wn_layer(sample_data)
self.evaluate(tf.compat.v1.global_variables_initializer())
self.assertAllClose(self.evaluate(base_output), self.evaluate(wn_output))

@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize(
"base_layer, input_shape",
[
(lambda: tf.keras.layers.Dense(1), [1]),
(lambda: tf.keras.layers.SimpleRNN(1), [10, 10]),
(lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]),
(lambda: tf.keras.layers.LSTM(1), [10, 10]),
],
)
def test_save_file_h5(base_layer, input_shape):
base_layer = base_layer()
wn_conv = wrappers.WeightNormalization(base_layer)
model = tf.keras.Sequential(layers=[wn_conv])
model.build([None] + input_shape)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_weights(os.path.join(tmp_dir, "wrapper_test_model.h5"))


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
@pytest.mark.parametrize(
"base_layer, input_shape",
[
(lambda: tf.keras.layers.Dense(1), [1]),
(lambda: tf.keras.layers.SimpleRNN(1), [10, 10]),
(lambda: tf.keras.layers.Conv2D(3, 1), [3, 3, 1]),
(lambda: tf.keras.layers.LSTM(1), [10, 10]),
],
)
def test_forward_pass(base_layer, input_shape):
sample_data = np.ones([1] + input_shape, dtype=np.float32)
base_layer = base_layer()
base_output = base_layer(sample_data)
wn_layer = wrappers.WeightNormalization(base_layer, False)
wn_output = wn_layer(sample_data)
tf.compat.v1.global_variables_initializer()
np.testing.assert_allclose(base_output, wn_output)


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
Expand Down

0 comments on commit 8749ffd

Please sign in to comment.