Skip to content

Commit

Permalink
Use TF2 for mwms_pjrt_gpu_test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 635898879
  • Loading branch information
SeeForTwo authored and tensorflower-gardener committed May 21, 2024
1 parent 68dd380 commit 6af67b7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensorflow/python/distribute/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -658,13 +658,13 @@ cuda_py_strict_test(
srcs_version = "PY3",
tags = [
"multi_and_single_gpu",
"notap", # TODO(b/341375925): Re-enable this test when flakiness is fixed.
],
xla_enabled = True,
deps = [
":multi_process_runner",
":multi_worker_test_base",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/python/distribute/mwms_pjrt_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy

from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.compat import v2_compat
from tensorflow.python.distribute import cluster_resolver as cluster_resolver_lib
from tensorflow.python.distribute import multi_process_runner
from tensorflow.python.distribute import multi_worker_test_base
Expand Down Expand Up @@ -113,4 +114,5 @@ def f(x):


if __name__ == "__main__":
v2_compat.enable_v2_behavior()
multi_process_runner.test_main()

0 comments on commit 6af67b7

Please sign in to comment.