From 9919ef0da5173a1b2f2f113bd327df45d768bdf0 Mon Sep 17 00:00:00 2001 From: Pat Notz Date: Mon, 15 May 2023 18:38:22 -0700 Subject: [PATCH] Re-enable Switch, Merge, Enter/Exit ops on TPU PiperOrigin-RevId: 532290710 --- tensorflow/core/kernels/control_flow_ops.cc | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index d4b284759406b1..eb574c981f8127 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -55,6 +55,12 @@ void SwitchNOp::Compute(OpKernelContext* context) { REGISTER_KERNEL_BUILDER( Name("Switch").Device(DEVICE_TPU_SYSTEM).HostMemory("pred"), SwitchOp); +REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE_TPU).HostMemory("pred"), + SwitchOp); + +REGISTER_KERNEL_BUILDER( + Name("_SwitchN").Device(DEVICE_TPU).HostMemory("output_index"), SwitchNOp); + #define REGISTER_CPU_SWITCH(type) \ REGISTER_KERNEL_BUILDER(Name("Switch") \ .Device(DEVICE_CPU) \ @@ -297,6 +303,8 @@ void MergeOp::Compute(OpKernelContext* context) { REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp); REGISTER_KERNEL_BUILDER( Name("Merge").Device(DEVICE_TPU_SYSTEM).HostMemory("value_index"), MergeOp); +REGISTER_KERNEL_BUILDER( + Name("Merge").Device(DEVICE_TPU).HostMemory("value_index"), MergeOp); REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp); #define REGISTER_GPU_KERNEL(type) \ @@ -405,6 +413,7 @@ void EnterOp::Compute(OpKernelContext* context) { } REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_TPU_SYSTEM), EnterOp); +REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_TPU), EnterOp); REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp); #define REGISTER_GPU_KERNEL(type) \ @@ -502,6 +511,7 @@ void ExitOp::Compute(OpKernelContext* context) { } REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_TPU_SYSTEM), ExitOp); +REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_TPU), ExitOp); REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp); #define REGISTER_GPU_KERNEL(type) \ @@ -590,6 +600,8 @@ void NextIterationOp::Compute(OpKernelContext* context) { REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_TPU_SYSTEM), NextIterationOp); +REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_TPU), + NextIterationOp); REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU), NextIterationOp);