Skip to content

Commit

Permalink
Moved some tests out of run_in_graph_and_eager_mode in normalization (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
autoih committed Mar 30, 2020
1 parent f4933de commit c0a6566
Showing 1 changed file with 39 additions and 35 deletions.
74 changes: 39 additions & 35 deletions tensorflow_addons/layers/normalizations_test.py
Expand Up @@ -262,50 +262,54 @@ def test_groupnorm_correctness_1d(self):
self.assertAllClose(out.mean(), 0.0, atol=1e-1)
self.assertAllClose(out.std(), 1.0, atol=1e-1)

def test_groupnorm_2d_different_groups(self):
np.random.seed(0x2020)
groups = [2, 1, 10]
for i in groups:
model = tf.keras.models.Sequential()
norm = GroupNormalization(axis=1, groups=i, input_shape=(10, 3))
model.add(norm)
# centered and variance are 5.0 and 10.0, respectively
model.compile(loss="mse", optimizer="rmsprop")
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 3))
model.fit(x, x, epochs=5, verbose=0)
out = model.predict(x)
out -= np.reshape(self.evaluate(norm.beta), (1, 10, 1))
out /= np.reshape(self.evaluate(norm.gamma), (1, 10, 1))

self.assertAllClose(
out.mean(axis=(0, 1), dtype=np.float32), (0.0, 0.0, 0.0), atol=1e-1
)
self.assertAllClose(
out.std(axis=(0, 1), dtype=np.float32), (1.0, 1.0, 1.0), atol=1e-1
)

def test_groupnorm_convnet(self):
np.random.seed(0x2020)
@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_groupnorm_2d_different_groups():
np.random.seed(0x2020)
groups = [2, 1, 10]
for i in groups:
model = tf.keras.models.Sequential()
norm = GroupNormalization(axis=1, input_shape=(3, 4, 4), groups=3)
norm = GroupNormalization(axis=1, groups=i, input_shape=(10, 3))
model.add(norm)
model.compile(loss="mse", optimizer="sgd")

# centered = 5.0, variance = 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
model.fit(x, x, epochs=4, verbose=0)
# centered and variance are 5.0 and 10.0, respectively
model.compile(loss="mse", optimizer="rmsprop")
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 3))
model.fit(x, x, epochs=5, verbose=0)
out = model.predict(x)
out -= np.reshape(self.evaluate(norm.beta), (1, 3, 1, 1))
out /= np.reshape(self.evaluate(norm.gamma), (1, 3, 1, 1))
out -= np.reshape(norm.beta.numpy(), (1, 10, 1))
out /= np.reshape(norm.gamma.numpy(), (1, 10, 1))

self.assertAllClose(
np.mean(out, axis=(0, 2, 3), dtype=np.float32), (0.0, 0.0, 0.0), atol=1e-1
np.testing.assert_allclose(
out.mean(axis=(0, 1), dtype=np.float32), (0.0, 0.0, 0.0), atol=1e-1
)
self.assertAllClose(
np.std(out, axis=(0, 2, 3), dtype=np.float32), (1.0, 1.0, 1.0), atol=1e-1
np.testing.assert_allclose(
out.std(axis=(0, 1), dtype=np.float32), (1.0, 1.0, 1.0), atol=1e-1
)


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_groupnorm_convnet():
np.random.seed(0x2020)
model = tf.keras.models.Sequential()
norm = GroupNormalization(axis=1, input_shape=(3, 4, 4), groups=3)
model.add(norm)
model.compile(loss="mse", optimizer="sgd")

# centered = 5.0, variance = 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4))
model.fit(x, x, epochs=4, verbose=0)
out = model.predict(x)
out -= np.reshape(norm.beta.numpy(), (1, 3, 1, 1))
out /= np.reshape(norm.gamma.numpy(), (1, 3, 1, 1))

np.testing.assert_allclose(
np.mean(out, axis=(0, 2, 3), dtype=np.float32), (0.0, 0.0, 0.0), atol=1e-1
)
np.testing.assert_allclose(
np.std(out, axis=(0, 2, 3), dtype=np.float32), (1.0, 1.0, 1.0), atol=1e-1
)


@pytest.mark.usefixtures("maybe_run_functions_eagerly")
def test_groupnorm_convnet_no_center_no_scale():
np.random.seed(0x2020)
Expand Down

0 comments on commit c0a6566

Please sign in to comment.