Skip to content

Commit

Permalink
Re-enable Switch, Merge, Enter/Exit ops on TPU
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 532290710
  • Loading branch information
patnotz authored and tensorflower-gardener committed May 16, 2023
1 parent b89217e commit 9919ef0
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tensorflow/core/kernels/control_flow_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ void SwitchNOp::Compute(OpKernelContext* context) {
REGISTER_KERNEL_BUILDER(
Name("Switch").Device(DEVICE_TPU_SYSTEM).HostMemory("pred"), SwitchOp);

REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE_TPU).HostMemory("pred"),
SwitchOp);

REGISTER_KERNEL_BUILDER(
Name("_SwitchN").Device(DEVICE_TPU).HostMemory("output_index"), SwitchNOp);

#define REGISTER_CPU_SWITCH(type) \
REGISTER_KERNEL_BUILDER(Name("Switch") \
.Device(DEVICE_CPU) \
Expand Down Expand Up @@ -297,6 +303,8 @@ void MergeOp::Compute(OpKernelContext* context) {
REGISTER_KERNEL_BUILDER(Name("Merge").Device(DEVICE_CPU), MergeOp);
REGISTER_KERNEL_BUILDER(
Name("Merge").Device(DEVICE_TPU_SYSTEM).HostMemory("value_index"), MergeOp);
REGISTER_KERNEL_BUILDER(
Name("Merge").Device(DEVICE_TPU).HostMemory("value_index"), MergeOp);
REGISTER_KERNEL_BUILDER(Name("RefMerge").Device(DEVICE_CPU), MergeOp);

#define REGISTER_GPU_KERNEL(type) \
Expand Down Expand Up @@ -405,6 +413,7 @@ void EnterOp::Compute(OpKernelContext* context) {
}

REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_TPU_SYSTEM), EnterOp);
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE_TPU), EnterOp);
REGISTER_KERNEL_BUILDER(Name("RefEnter").Device(DEVICE_CPU), EnterOp);

#define REGISTER_GPU_KERNEL(type) \
Expand Down Expand Up @@ -502,6 +511,7 @@ void ExitOp::Compute(OpKernelContext* context) {
}

REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_TPU_SYSTEM), ExitOp);
REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE_TPU), ExitOp);
REGISTER_KERNEL_BUILDER(Name("RefExit").Device(DEVICE_CPU), ExitOp);

#define REGISTER_GPU_KERNEL(type) \
Expand Down Expand Up @@ -590,6 +600,8 @@ void NextIterationOp::Compute(OpKernelContext* context) {

REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_TPU_SYSTEM),
NextIterationOp);
REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE_TPU),
NextIterationOp);
REGISTER_KERNEL_BUILDER(Name("RefNextIteration").Device(DEVICE_CPU),
NextIterationOp);

Expand Down

0 comments on commit 9919ef0

Please sign in to comment.