Skip to content

Commit

Permalink
Do not perform colocation checks for IdentityN since it just forwards…
Browse files Browse the repository at this point in the history
… its inputs.

PiperOrigin-RevId: 257230757
  • Loading branch information
saxenasaurabh authored and tensorflower-gardener committed Jul 9, 2019
1 parent 348fde8 commit ca57a9d
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tensorflow/core/kernels/BUILD
Expand Up @@ -1097,7 +1097,9 @@ tf_kernel_library(
tf_kernel_library(
name = "identity_n_op",
prefix = "identity_n_op",
deps = ARRAY_DEPS,
deps = ARRAY_DEPS + [
"//tensorflow/core:core_cpu_internal",
],
)

tf_kernel_library(
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/kernels/identity_n_op.cc
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
#include "tensorflow/core/kernels/identity_n_op.h"

#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
Expand All @@ -24,5 +25,10 @@ limitations under the License.
namespace tensorflow {

REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE_DEFAULT), IdentityNOp);
// Do not worry about colocating IdentityN op with its resource inputs since
// it just forwards it's inputs anyway. This is needed because we create
// IdentityN nodes to club "all" outputs of functional ops while lowering to
// make the original functional op fetchable.
REGISTER_INPUT_COLOCATION_EXEMPTION("IdentityN");

} // namespace tensorflow
1 change: 1 addition & 0 deletions tensorflow/python/distribute/BUILD
Expand Up @@ -963,6 +963,7 @@ distribute_py_test(
deps = [
":single_loss_example",
"//tensorflow/contrib/tpu:tpu_lib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/python/distribute/step_fn_test.py
Expand Up @@ -25,9 +25,11 @@
from tensorflow.python.distribute.single_loss_example import single_loss_example
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables


@test_util.with_control_flow_v2
class SingleLossStepTest(test.TestCase, parameterized.TestCase):

@combinations.generate(
Expand Down

0 comments on commit ca57a9d

Please sign in to comment.