This guide provides an end-to-end implementation of a new custom op that is backwards compatible with an existing custom op.
The example in this guide implements a new custom op that handles inputs that are Python lists of tensors, and is backwards compatible with an existing op that only handles inputs that are single tensors.
The existing op is a multiplexer that returns elements chosen from either of two
single input tensors x
or y
depending on a single condition
tensor.
The content on this page assumes familiarity with the high-level process for adding custom ops to TensorFlow. For additional context, read the OSS guide on creating custom ops.
This example demonstrates how you can create a custom multiplexer,
multiplex_4
, to register a new kernel that is backward compatible with an
existing multiplex_2` op.
The new custom op registers a kernel (multiplex_4_kernel.cc) that takes lists of tensors as inputs, and is backwards compatible with the existing kernel (multiplex_2_kernel.cc) that takes only single tensors as inputs.
The multiplex_4
op is similar to
numpy.select,
while the multiplex_2
op is similar to
numpy.where.
The lists of tensors that the new op takes as inputs are of a particular fixed
size. Since the list size is defined in Attr
, it is fixed at graph
construction time when the constructor of the C++ kernel is called. Therefore,
the size of the list cannot be data dependent. See
Ragged tensors for
variable length lists.
This example contains C++ and Python code snippets to illustrate the code flow. These snippets may be missing namespace declarations, imports, and test cases.
This example uses a SavedModel
from an existing multiplex_2
custom op.
The muliplex_2_save.py
file uses save
from model_using_muliplex.py
to
create a SavedModel
named model_using_multiplex
in the current working
directory.
def save(multiplex_op, path):
"""Save a model that contains the given `multiplex_op`.
Args:
multiplex_op: A multiplex Custom Op, e.g. multiplex_4_op.multiplex. This is
parameterized so it can also be used to create an "old" model with an
older version of the op, e.g. multiplex_2_op.multiplex.
path: Directory to save model to.
"""
example_cond, example_a, example_b = _get_example_tensors()
class UseMultiplex(tf.Module):
@tf.function(input_signature=[
tf.TensorSpec.from_tensor(example_cond),
tf.TensorSpec.from_tensor(example_a),
tf.TensorSpec.from_tensor(example_b)
])
def use_multiplex(self, cond, a, b):
return multiplex_op(cond, a, b)
model = UseMultiplex()
tf.saved_model.save(
model,
path,
signatures=model.use_multiplex.get_concrete_function(
tf.TensorSpec.from_tensor(example_cond),
tf.TensorSpec.from_tensor(example_a),
tf.TensorSpec.from_tensor(example_b)))
This SavedModel
has the old version of the custom op (multiplex_2
) that only
supports individual tensors as inputs. The following steps will register a
kernel that accepts lists of tensors as inputs, while maintaining backward
compatability with the previous op.
Define the op interface and register it using the REGISTER_OP
macro.
REGISTER_OP("Examples>MultiplexDense")
.Input("cond: N * bool")
.Input("a_values: N * T")
.Input("b_values: T")
.Output("output_values: T")
.Attr("T: type")
.Attr("N: int = 1")
.SetShapeFn(MultiplexShapeFunction)
.Doc(R"doc(
Return elements chosen from `a_values` or `b_values` depending on `cond`.
When `a_values` and `cond` are tenors (i.e. N=1), this is similar to `np.where`
and `tf.where`. When `a_values` and `cond` are lists of tensors (i.e. N>1),
this is similar to `np.select`. In either case these are simplified to only
handle dense tensors, no optional parameters, no broadcasting, etc..
cond: tf.Tensor or list of tf.Tensor of type bool. If it is a list, `a_values`
must be a list of the same length. Where True, yield the corresponding
element from `a_values` (with priority to the first one encountered in
lists), otherwise yield `b_values`.
a_values: tf.Tensor or list of tf.Tensor. Each tensor has the same type and
shape as `b_values`. If it is a list, `cond` must be a list of the
same length.
b_values: tf.Tensor with the same type and shape as the `a_values` if it is a
tensor or as every element of `a_values` if `a_values` is a list.
output_values: A tf.Tensor with elements from `a_values` where `cond` is True,
and elements from `b` elsewhere.
)doc");
While the multiplex_2
op defined inputs as single tensors, such as cond: bool
and a_values: T
, this op supports lists of tensors by adding N*
, where
N
is the length of the lists.
The default list size (N
) is set to 1 with the following: .Attr("N: int = 1")
. If the inputs are single tensors, then N
is equal to 1, which is
backwards compatible with a previous definition of .Input("x: T")
.
All lists in this example are of equal length (N
). To support lists of
different lengths, define an attribute for each unique length. For example:
.Input("short_list: short_len * float")
.Input("long_list: long_len * float")
.Attr("short_len: int = 1")
.Attr("long_len: int >= 10")
The C++ kernel in multiplex_4_kernel.cc
implements a multiplexer that accepts
lists of tensors as inputs. Register the kernel by calling the
REGISTER_KERNEL_BUILDER
macro.
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("Examples>MultiplexDense") \
.Device(::tensorflow::DEVICE_CPU) \
.TypeConstraint<type>("T"), \
MultiplexDenseOp<type>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
In the multiplex_4_kernel.cc
op kernel, create a class derived from
OpKernel
that implements a Compute
method. This method retrieves and validates input
tensors, performs computation, and creates output tensors.
template <typename T>
class MultiplexDenseOp : public OpKernel {
public:
explicit MultiplexDenseOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("N", &num_cond_a_));
}
MultiplexDenseOp(const MultiplexDenseOp& other) = delete;
MultiplexDenseOp& operator=(const MultiplexDenseOp& other) = delete;
~MultiplexDenseOp() override = default;
void Compute(OpKernelContext* ctx) override {
// Optional error checking: cond and a_values are lists of N, so there are
// a total of 2N+1 inputs. Check that the number of inputs and the
// `N` Attr is consistent.
const int64_t expected_inputs = 2 * num_cond_a_ + 1;
OP_REQUIRES(ctx, expected_inputs == ctx->num_inputs(),
Internal("expected_inputs != num_inputs(): ", expected_inputs,
" != ", ctx->num_inputs()));
VLOG(1) << "N " << num_cond_a_;
const auto& first_cond_tensor = ctx->input(0);
const auto& first_a_values_tensor = ctx->input(num_cond_a_);
const auto& b_values_tensor = ctx->input(2 * num_cond_a_);
// Allow any shape, but require that a_values, b_values, and cond all
// have the same shape.
// Note that ::tensorflow::TensorShapeUtils has some useful functions
// for checking shapes.
for (int64_t i = 0; i < num_cond_a_; i++) {
const auto& cond_tensor_i = ctx->input(i);
const auto& a_values_tensor_i = ctx->input(num_cond_a_ + i);
OP_REQUIRES(
ctx, a_values_tensor_i.shape() == b_values_tensor.shape(),
InvalidArgument(
"a_values[", i,
"] and b_values must have the same shape. "
"a_values[",
i, "] shape: ", a_values_tensor_i.DebugString(),
" b_values shape: ", b_values_tensor.shape().DebugString()));
OP_REQUIRES(
ctx, cond_tensor_i.shape() == b_values_tensor.shape(),
InvalidArgument(
"cond_values[", i,
"] and b_valuesmust have the same shape. "
"cond_values[",
i, "] shape: ", first_a_values_tensor.shape().DebugString(),
" b_values shape: ", first_cond_tensor.shape().DebugString()));
}
// Create an output tensor
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(
ctx, ctx->allocate_output(0, b_values_tensor.shape(), &output_tensor));
auto output = output_tensor->template flat<T>();
const auto b_values = b_values_tensor.template flat<T>();
// np.select style behavior, `cond` and `a_values` are lists of tensors.
// Also works for the np.where style case where there is only one `cond`
// and one `a_values` tensor.
const int64_t N = first_a_values_tensor.NumElements();
for (int64_t i = 0; i < N; i++) {
bool flag = false;
for (int64_t list_index = 0; list_index < num_cond_a_; list_index++) {
const auto& cond_tensor = ctx->input(list_index);
const auto& a_values_tensor = ctx->input(num_cond_a_ + list_index);
const auto cond = cond_tensor.template flat<bool>();
const auto a_values = a_values_tensor.template flat<T>();
if (cond(i)) {
output(i) = a_values(i);
flag = true;
VLOG(1) << "A " << list_index << " for " << i;
break;
}
}
if (!flag) {
output(i) = b_values(i);
VLOG(1) << "B for " << i;
}
}
}
private:
int64_t num_cond_a_; // the number of `cond` and `a` input tensors
};
The kernel uses a private member variable (num_cond_a_
) to hold the length of
cond
and a
. The constructor saves the N
attribute into the variable.
private:
int64_t num_cond_a_; // the number of cond and a input tensors
explicit MultiplexDenseOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("N", &num_cond_a_));
}
The num_cond_a_
variable is used to index the inputs in the following order:
cond
, a
, b
. The op interfaces specify that cond
and a
are tensor lists
of length N
, and b
is a single tensor. The inputs are indexed as follows:
cond
: [0 ... N-1]a
: [N ... 2*N-1]b
: [2*N]
When num_cond_a_
is equal to 1, the kernel implements numpy.where
as it
would in the multiplex_2
op. When num_cond_a_
is greater than 1, the kernel
implements numpy.select
. This is achieved with the following for
loop.
for (int64_t i = 0; i < N; i++) {
bool flag = false;
for (int64_t list_index = 0; list_index < num_cond_a_; list_index++) {
const auto& cond_tensor = ctx->input(list_index);
const auto& a_values_tensor = ctx->input(num_cond_a_ + list_index);
const auto cond = cond_tensor.flat<bool>();
const auto a_values = a_values_tensor.flat<T>();
if (cond(i)) {
output(i) = a_values(i);
flag = true;
break;
}
}
if (!flag) {
output(i) = b_values(i);
}
}
Compile the C++ op to create a kernel library and Python wrapper that enables you to use the op with TensorFlow.
Create a BUILD
file for the op which declares the dependencies and the output
build targets. Refer to
building for OSS.
To create the Python wrapper, import and implement a function that serves as the op's public API and provides a docstring.
If cond
and a
are not already lists, the wrapper in multiplex_4_op.py
puts the variables in lists before the numpy.where
implementation.
Note: The generated Python wrapper automatically sets the N
attribute based on
the length of the input lists.
def multiplex(cond, a, b, name=None):
"""Return elements chosen from `a` or `b` depending on `cond`.
This is similar to `np.where` and `tf.where` if `cond` and `a` are tensors.
This is similar to `np.select` if `cond` and `a` are lists of tensors.
In either case, this is simplified to only handle the case of dense tensors,
no optional parameters, no broadcasting, etc..
>>> multiplex([True, False, False, True], [1,2,3,4], [100,200,300,400])
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 200, 300, 4], ...)>
>>> a1 = tf.constant([1, 2, 3, 4, 5], dtype=tf.int64)
>>> a2 = tf.constant([6, 7, 8, 9, 10], dtype=tf.int64)
>>> a3 = tf.constant([11, 12, 13, 14, 15], dtype=tf.int64)
>>> b = tf.constant([101, 102, 103, 104, 105], dtype=tf.int64)
>>> cond1 = tf.constant([False, False, True, False, False], dtype=bool)
>>> cond2 = tf.constant([False, False, False, False, True], dtype=bool)
>>> cond3 = tf.constant([True, False, True, False, True], dtype=bool)
>>> multiplex_4_op.multiplex([cond1, cond2, cond3], [a1, a2, a3], b)
<tf.Tensor: shape=(5,), ... numpy=array([ 11, 102, 3, 104, 10], ...)>
Args:
cond: tf.Tensor or list of tf.Tensor of type bool. Where True, yield `a`.
When muliple corresponding `cond` elements are true, the first one yield
based on the first one encountered.
a: tf.Tensor or list of tf.Tensor, each with the same type and shape as `b`.
b: tf.Tensor or list of tf.Tensor with the same type and shape as `a`. Yield
`b` if all corresponding `cond` values is False.
name: An optional name for the op.
Returns:
A tf.Tensor with elements from `a` where `cond` is True, and elements
from `b` elsewhere.
"""
if not isinstance(cond, (list, tuple)):
# Support "old" use of multiplex where `cond` and `a` are tensors,
# not lists of tensors.
return gen_multiplex_4_op.examples_multiplex_dense(
cond=[cond], a_values=[a], b_values=b, name=name)
return gen_multiplex_4_op.examples_multiplex_dense(
cond=cond, a_values=a, b_values=b, name=name)
Create op tests using classes derived from
tf.test.TestCase
.
When writing tests to ensure that the op works correctly in both graph and eager
executions, it is important to note that errors in the op code may be detected
in two distinct phases of code execution depending on how it is executed (eager
or graph executions). Errors may be detected early by the shape function or a
bit later from the logic in the Compute
method. This may lead to differing
error types and messages.
@test_util.with_eager_op_as_function
class MultiplexOpTest(tf.test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_multiplex_int(self):
a = tf.constant([1, 2, 3, 4, 5], dtype=tf.int64)
b = tf.constant([10, 20, 30, 40, 50], dtype=tf.int64)
cond = tf.constant([True, False, True, False, True], dtype=bool)
expect = np.where(self.evaluate(cond), self.evaluate(a), self.evaluate(b))
# expected result is [1, 20, 3, 40, 5]
result = multiplex_4_op.multiplex(cond, a, b)
self.assertAllEqual(result, expect)
@test_util.run_in_graph_and_eager_modes
def test_multiplex_select(self):
a1 = tf.constant([1, 2, 3, 4, 5], dtype=tf.int64)
a2 = tf.constant([6, 7, 8, 9, 10], dtype=tf.int64)
a3 = tf.constant([11, 12, 13, 14, 15], dtype=tf.int64)
a = [a1, a2, a3]
b = tf.constant([101, 102, 103, 104, 105], dtype=tf.int64)
cond1 = tf.constant([False, False, True, False, False], dtype=bool)
cond2 = tf.constant([False, False, False, False, True], dtype=bool)
cond3 = tf.constant([True, False, True, False, True], dtype=bool)
cond = [cond1, cond2, cond3]
expect = np.select([self.evaluate(i) for i in cond],
[self.evaluate(i) for i in a], self.evaluate(b))
# expected result is [11, 102, 3, 104, 10]
result = multiplex_4_op.multiplex(cond, a, b)
self.assertAllEqual(result, expect)
def test_multiplex_saved_model(self):
path = os.path.join(self.create_tempdir(), 'model')
model_using_multiplex.save(multiplex_4_op.multiplex, path)
result = model_using_multiplex.load_and_use(path)
self.assertAllEqual(result, tf.constant([1, 20, 3, 40, 5], dtype=tf.int64))
# One tf.function that uses both multiplex with single tensors for `cond`
# and `a` and with lists of tensors for `cond` and `a`, i.e. a graph
# with two example_multiplex_dense kernels that have different numbers
# of inputs.
@tf.function
def _both(self):
a1 = tf.constant([1, 2, 3, 4, 5], dtype=tf.int64)
a2 = tf.constant([6, 7, 8, 9, 10], dtype=tf.int64)
a3 = tf.constant([11, 12, 13, 14, 15], dtype=tf.int64)
a_123 = [a1, a2, a3]
b_123 = tf.constant([101, 102, 103, 104, 105], dtype=tf.int64)
cond1 = tf.constant([False, False, True, False, False], dtype=bool)
cond2 = tf.constant([False, False, False, False, True], dtype=bool)
cond3 = tf.constant([True, False, True, False, True], dtype=bool)
cond_123 = [cond1, cond2, cond3]
mux_123 = multiplex_4_op.multiplex(cond_123, a_123, b_123)
b4 = tf.constant([201, 202, 203, 204, 205], dtype=tf.int64)
cond4 = tf.constant([True, True, True, False, False], dtype=bool)
result = multiplex_4_op.multiplex(cond4, mux_123, b4)
return result
def test_both_single_and_list(self):
result = self._both()
self.assertAllEqual(result,
tf.constant([11, 102, 3, 204, 205], dtype=tf.int64))
@test_util.run_in_graph_and_eager_modes
def test_inconsistent_inputs_error(self):
a1 = tf.constant([1, 2, 3, 4, 5], dtype=tf.int64)
a2 = tf.constant([6, 7, 8, 9, 10], dtype=tf.int64)
a = [a1, a2]
b = tf.constant([101, 102, 103, 104, 105], dtype=tf.int64)
cond = tf.constant([False, False, True, False, False], dtype=bool)
with self.assertRaisesRegex(
(errors_impl.InvalidArgumentError, ValueError),
# Eager mode raises InvalidArgumentError with the following message
r'(a_values\[0\] and b_values must have the same shape'
r')|('
# Graph mode raises ValueError with the following message
r'Shapes must be equal rank, but are 2 and 1)'):
self.evaluate(multiplex_4_op.multiplex(cond, a, b))
The following tf.function
in muliplex_4_test.py has two multiplex custom ops:
one that takes lists for its cond
and a
inputs, and another that takes
single tensors.
@tf.function
def _both(self):
a1 = tf.constant([1, 2, 3, 4, 5], dtype=tf.int64)
a2 = tf.constant([6, 7, 8, 9, 10], dtype=tf.int64)
a3 = tf.constant([11, 12, 13, 14, 15], dtype=tf.int64)
a_123 = [a1, a2, a3]
b_123 = tf.constant([101, 102, 103, 104, 105], dtype=tf.int64)
cond1 = tf.constant([False, False, True, False, False], dtype=bool)
cond2 = tf.constant([False, False, False, False, True], dtype=bool)
cond3 = tf.constant([True, False, True, False, True], dtype=bool)
cond_123 = [cond1, cond2, cond3]
mux_123 = multiplex_4_op.multiplex(cond_123, a_123, b_123)
b4 = tf.constant([201, 202, 203, 204, 205], dtype=tf.int64)
cond4 = tf.constant([True, True, True, False, False], dtype=bool)
result = multiplex_4_op.multiplex(cond4, mux_123, b4)
return result
The model_using_multiplex.py file has functions for creating and using a saved
custom op model SavedModel
. In this test, the multiplex_4
op is used to both
save and use models.
def test_multiplex_saved_model(self):
path = os.path.join(self.create_tempdir(), 'model')
model_using_multiplex.save(multiplex_4_op.multiplex, path)
result = model_using_multiplex.load_and_use(path)
self.assertAllEqual(result, tf.constant([1, 20, 3, 40, 5], dtype=tf.int64))
Test the op with the following:
bazel test //third_party/tensorflow/google/g3doc/example/multiplex_4:multiplex_4_test
Reuse the BUILD
file to add build rules for the Python API wrapper and the op
test.
py_strict_library(
name = "multiplex_4_op",
srcs = ["multiplex_4_op.py"],
data = ["multiplex_4_kernel.so"],
srcs_version = "PY3",
deps = [
"//third_party/py/tensorflow",
],
)
tf_py_test(
name = "multiplex_4_test",
size = "medium", # This test blocks because it writes and reads a file,
timeout = "short", # but it still runs quickly.
srcs = ["multiplex_4_test.py"],
python_version = "PY3",
srcs_version = "PY3",
tags = [
"no_mac",
"no_pip",
],
deps = [
":model_using_multiplex",
":multiplex_4_op",
"//third_party/py/numpy",
"//third_party/py/tensorflow",
"//third_party/tensorflow/python/framework:errors",
"//third_party/tensorflow/python/framework:test_lib",
],
)
Build the op with the following:
bazel build //third_party/tensorflow/examples/custom_ops_doc/multiplex_4:multiplex_4_op
Import the op and call it using the following example:
import tensorflow as tf
from tensorflow.examples.custom_ops_doc.multiplex_4 import multiplex_4_op
a1 = tf.constant([1, 2, 3, 4, 5], dtype=tf.int64)
a2 = tf.constant([6, 7, 8, 9, 10], dtype=tf.int64)
a3 = tf.constant([11, 12, 13, 14, 15], dtype=tf.int64)
a = [a1, a2, a3]
b = tf.constant([101, 102, 103, 104, 105], dtype=tf.int64)
cond1 = tf.constant([False, False, True, False, False], dtype=bool)
cond2 = tf.constant([False, False, False, False, True], dtype=bool)
cond3 = tf.constant([True, False, True, False, True], dtype=bool)
cond = [cond1, cond2, cond3]
# expected result is [11, 102, 3, 104, 10]
result = multiplex_4_op.multiplex(cond, a, b)
The multiplex_4_load_use.py
file uses load_and_use
from
model_using_muliplex.py
to load a saved model from a multiplex_2
op. The
saved model can be executed using the new kernel, (multiplex_4
), which
supports both lists of tensors and single tensors for cond
and a
inputs.
Since Examples>MultiplexDense
can only be defined once in a binary, there must
be two separate binaries. A binary can either depend on multiplex_2_op
or
multiplex_4_op
, but not both. The custom ops are backward compatible, so we
can use save
on multiplex_2
and load_and_use
on multiplex_4
.
In this example, you learned how to implement a new multiplexer kernel that is backwards compatible with an existing multiplexer kernel. This custom op handles inputs that are lists of tensors, while continuing to handle inputs of single tensors.
The tables below summarize the build rules and targets for building and testing
the multiplex_4
op.
Op components | Build rule | Build target | Source |
---|---|---|---|
Kernels (C++) | tf_custom_op_library |
multiplex_4_kernel |
multiplex_4_kernel.cc , multiplex_4_op.cc |
Wrapper (automatically generated) | N/A | gen_multiplex_4_op |
N/A |
Wrapper (with public API and docstring) | py_strict_library |
multiplex_4_op |
multiplex_4_op.py |
Tests | tf_py_test |
multiplex_4_test |
multiplex_4_test.py |
Op components | Build rule | Build target | Source |
---|---|---|---|
Common library | py_strict_library |
model_using_multiplex |
model_using_multiplex.py |
Old op (with SavedModel) | py_strict_binary |
multiplex_2_save |
multiplex_2_save.py |
New op (with SavedModel) | py_strict_binary |
multiplex_4_load_and_use |
multiplex_4_load_and_use.py |