diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index 749a37097f9645..1f535cbb846a4d 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -1880,5 +1880,37 @@ def test_get_session_different_graphs(self): self.assertIsNot(session, keras.backend.get_session()) +@test_util.run_all_in_graph_and_eager_modes +class ControlOpsTests(test.TestCase): + + def test_function_switch_basics(self): + x = array_ops.constant(2.0) + y = array_ops.constant(3.0) + + def xpowy(): + return keras.backend.pow(x, y) + def ypowx(): + return keras.backend.pow(y, x) + + tensor = keras.backend.switch(keras.backend.less(x, y), xpowy, ypowx) + self.assertEqual(keras.backend.eval(tensor), [8.0]) + + tensor = keras.backend.switch(keras.backend.greater(x, y), xpowy, ypowx) + self.assertEqual(keras.backend.eval(tensor), [9.0]) + + def test_unequal_rank(self): + x = ops.convert_to_tensor(np.array([[1, 2, 3], [4, 5, 6]]), + dtype='float32') + y = ops.convert_to_tensor(np.array([1, 2, 3]), dtype='float32') + def true_func(): + return x + + def false_func(): + return y + + with self.assertRaisesRegexp(ValueError, "Rank of `condition` should be less than"): + keras.backend.switch(keras.backend.equal(x, x), false_func, true_func) + + if __name__ == '__main__': test.main()