diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 5a17d482598288..7693c51c5732e1 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -658,13 +658,13 @@ cuda_py_strict_test( srcs_version = "PY3", tags = [ "multi_and_single_gpu", - "notap", # TODO(b/341375925): Re-enable this test when flakiness is fixed. ], xla_enabled = True, deps = [ ":multi_process_runner", ":multi_worker_test_base", "//tensorflow/core:protos_all_py", + "//tensorflow/python/compat:v2_compat", "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", diff --git a/tensorflow/python/distribute/mwms_pjrt_gpu_test.py b/tensorflow/python/distribute/mwms_pjrt_gpu_test.py index dc5f85aef2486c..de0de23890ce2f 100644 --- a/tensorflow/python/distribute/mwms_pjrt_gpu_test.py +++ b/tensorflow/python/distribute/mwms_pjrt_gpu_test.py @@ -16,6 +16,7 @@ import copy from tensorflow.core.protobuf import tensorflow_server_pb2 +from tensorflow.python.compat import v2_compat from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib from tensorflow.python.distribute import multi_process_runner from tensorflow.python.distribute import multi_worker_test_base @@ -113,4 +114,5 @@ def f(x): if __name__ == "__main__": + v2_compat.enable_v2_behavior() multi_process_runner.test_main()