Skip to content

Commit a77e4ae

Browse files
ebrevdotensorflower-gardener
authored andcommitted
AddN for variants adds in a tree structure (pairwise summation)
Improves numerical precision (if applicable) using pairwise summation: https://en.wikipedia.org/wiki/Pairwise_summation Thanks to Rasmus Larsen for the succinct binary tree aggregation pseudocode. Also adds AsString for Variant types: this emits the Variant as a string via its DebugString(). PiperOrigin-RevId: 340246073 Change-Id: I009281f46cbea30d6e33ecf79a1723d62e96cc6d
1 parent faac5b2 commit a77e4ae

File tree

7 files changed

+122
-31
lines changed

7 files changed

+122
-31
lines changed

tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ array([b'3.14', b'2.72'], dtype=object)
509509
}];
510510

511511
let arguments = (ins
512-
TensorOf<[TF_Bool, TF_Complex128, TF_Complex64, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$input,
512+
TensorOf<[TF_Bool, TF_Complex128, TF_Complex64, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Variant]>:$input,
513513

514514
DefaultValuedAttr<I64Attr, "-1">:$precision,
515515
DefaultValuedAttr<BoolAttr, "false">:$scientific,
@@ -15226,4 +15226,4 @@ execution the transfer corresponds to.}]>:$dynamic_key,
1522615226
let results = (outs);
1522715227

1522815228
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
15229-
}
15229+
}

tensorflow/core/kernels/aggregate_ops.h

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -370,24 +370,77 @@ class AddNOp<Device, Variant, OpKernelT, OpKernelConstructionT,
370370
i, " has shape: ", ctx->input(i).shape().DebugString(), "."));
371371
}
372372

373-
// Step 2: attempt to add using
373+
// Step 2: Sum input variants in a tree-like structure using
374374
// BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...)
375375
// For the output create a default-constructed variant object.
376-
// TODO(ebrevdo): Perform summation in a tree-structure.
377-
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
378-
Variant* v_out = &(out.scalar<Variant>()());
379-
OP_REQUIRES_OK(ctx, BinaryOpVariants<Device>(
380-
ctx, ADD_VARIANT_BINARY_OP,
381-
ctx->input(0).template scalar<Variant>()(),
382-
ctx->input(1).template scalar<Variant>()(), v_out));
383-
for (int i = 2; i < num; ++i) {
384-
const Variant tmp = std::move(*v_out);
385-
const Variant& inp = ctx->input(i).template scalar<Variant>()();
386-
OP_REQUIRES_OK(ctx, BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP,
387-
inp, tmp, v_out));
376+
//
377+
// Pairwise summation provides better numerical precision by
378+
// reducing round-off error:
379+
//
380+
// https://en.wikipedia.org/wiki/Pairwise_summation
381+
//
382+
// These two vectors are used to store and mark intermediate sums.
383+
gtl::InlinedVector<bool, 4> temp_filled(num, false);
384+
gtl::InlinedVector<Variant, 4> temp(num);
385+
386+
// Tree-based summation.
387+
int skip = 1;
388+
int n = num;
389+
while (skip < n) {
390+
int i = skip;
391+
while (i < n) {
392+
// TODO(ebrevdo, rmlarsen): Parallelize the pairwise summations in the
393+
// inner loop if the variants are "large".
394+
395+
// x[i - skip] += x[i]
396+
OP_REQUIRES_OK(ctx,
397+
AddVariantTo(ctx, i - skip, i, &temp, &temp_filled));
398+
// We won't use this index again, recover its memory.
399+
temp[i].clear();
400+
i += 2 * skip;
401+
}
402+
if (i == n) {
403+
// x[0] += x[i - skip]
404+
OP_REQUIRES_OK(ctx,
405+
AddVariantTo(ctx, 0, i - skip, &temp, &temp_filled));
406+
// We won't use this index again, recover its memory.
407+
temp[i - skip].clear();
408+
n -= skip;
409+
}
410+
skip *= 2;
388411
}
412+
413+
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
414+
out.scalar<Variant>()() = std::move(temp[0]);
389415
ctx->set_output(0, out);
390416
}
417+
418+
private:
419+
// AddVariantTo efficiently performs:
420+
// temp[lhs_ix] <- array(lhs_ix) + array(rhs_ix)
421+
// where array(ix) := (temp_filled[ix]
422+
// ? temp[ix]
423+
// : ctx->input(ix).scalar<Variant>()())
424+
// This reduces (possibly expensive) copying of Variants from
425+
// the inputs into temp at the lowest levels of the summation tree.
426+
static inline Status AddVariantTo(OpKernelContextT* ctx, const int lhs_ix,
427+
const int rhs_ix,
428+
gtl::InlinedVector<Variant, 4>* temp,
429+
gtl::InlinedVector<bool, 4>* temp_filled) {
430+
Variant tmp;
431+
if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix));
432+
const Variant& a = temp_filled->at(lhs_ix)
433+
? tmp
434+
: ctx->input(lhs_ix).template scalar<Variant>()();
435+
const Variant& b = temp_filled->at(rhs_ix)
436+
? temp->at(rhs_ix)
437+
: ctx->input(rhs_ix).template scalar<Variant>()();
438+
Variant* c = &temp->at(lhs_ix);
439+
TF_RETURN_IF_ERROR(
440+
BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP, a, b, c));
441+
temp_filled->at(lhs_ix) = true;
442+
return Status::OK();
443+
}
391444
};
392445

393446
} // namespace tensorflow

tensorflow/core/kernels/as_string_op.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ limitations under the License.
2020
#include "tensorflow/core/framework/kernel_def_builder.h"
2121
#include "tensorflow/core/framework/op_kernel.h"
2222
#include "tensorflow/core/framework/tensor.h"
23+
#include "tensorflow/core/framework/variant.h"
24+
#include "tensorflow/core/framework/variant_encode_decode.h"
25+
#include "tensorflow/core/framework/variant_tensor_data.h"
2326
#include "tensorflow/core/lib/core/errors.h"
2427
#include "tensorflow/core/lib/core/status.h"
2528
#include "tensorflow/core/lib/strings/stringprintf.h"
@@ -112,6 +115,8 @@ class AsStringOp : public OpKernel {
112115
break;
113116
case DT_BOOL:
114117
break;
118+
case DT_VARIANT:
119+
break;
115120
default:
116121
bool type_not_supported = true;
117122
OP_REQUIRES(ctx, !type_not_supported,
@@ -156,6 +161,12 @@ class AsStringOp : public OpKernel {
156161
output_flat(i) = (input_flat(i)) ? "true" : "false";
157162
}
158163
} break;
164+
case (DT_VARIANT): {
165+
const auto& input_flat = input_tensor->flat<Variant>();
166+
for (int i = 0; i < input_flat.size(); ++i) {
167+
output_flat(i) = input_flat(i).DebugString();
168+
}
169+
} break;
159170
case (DT_COMPLEX64): {
160171
const auto& input_flat = input_tensor->flat<complex64>();
161172
for (int i = 0; i < input_flat.size(); ++i) {

tensorflow/core/kernels/as_string_op_test.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ limitations under the License.
1818
#include "tensorflow/core/framework/tensor.h"
1919
#include "tensorflow/core/framework/tensor_testutil.h"
2020
#include "tensorflow/core/framework/types.h"
21+
#include "tensorflow/core/framework/variant.h"
22+
#include "tensorflow/core/framework/variant_encode_decode.h"
23+
#include "tensorflow/core/framework/variant_tensor_data.h"
2124
#include "tensorflow/core/kernels/ops_testutil.h"
2225
#include "tensorflow/core/kernels/ops_util.h"
2326
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -148,6 +151,25 @@ TEST_F(AsStringGraphTest, Bool) {
148151
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
149152
}
150153

154+
TEST_F(AsStringGraphTest, Variant) {
155+
TF_ASSERT_OK(Init(DT_VARIANT));
156+
157+
AddInput(DT_VARIANT, TensorShape({4}));
158+
auto inputs = mutable_input(0)->flat<Variant>();
159+
inputs(0) = 2;
160+
inputs(1) = 3;
161+
inputs(2) = true;
162+
inputs(3) = Tensor("hi");
163+
TF_ASSERT_OK(RunOpKernel());
164+
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
165+
test::FillValues<tstring>(
166+
&expected, {"Variant<type: int value: 2>", "Variant<type: int value: 3>",
167+
"Variant<type: bool value: 1>",
168+
("Variant<type: tensorflow::Tensor value: Tensor<type: string"
169+
" shape: [] values: hi>>")});
170+
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
171+
}
172+
151173
TEST_F(AsStringGraphTest, String) {
152174
Status s = Init(DT_STRING);
153175
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());

tensorflow/core/ops/string_ops.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ REGISTER_OP("AsString")
116116
.Output("output: string")
117117
.Attr(
118118
"T: {int8, int16, int32, int64, complex64, complex128, float, double, "
119-
"bool}")
119+
"bool, variant}")
120120
.Attr("precision: int = -1")
121121
.Attr("scientific: bool = false")
122122
.Attr("shortest: bool = false")

tensorflow/python/kernel_tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,7 @@ cuda_py_test(
16061606
"//tensorflow/python:client_testlib",
16071607
"//tensorflow/python:framework_for_generated_wrappers",
16081608
"//tensorflow/python:math_ops",
1609+
"//tensorflow/python:string_ops",
16091610
"//third_party/py/numpy",
16101611
],
16111612
)

tensorflow/python/kernel_tests/aggregate_ops_test.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
from tensorflow.python.framework import tensor_shape
2727
from tensorflow.python.framework import test_util
2828
from tensorflow.python.ops import array_ops
29-
from tensorflow.python.ops import logging_ops
3029
from tensorflow.python.ops import math_ops
30+
from tensorflow.python.ops import string_ops
3131
from tensorflow.python.platform import test
3232

3333

@@ -100,24 +100,28 @@ def create_constant_variant(value):
100100
# TODO(ebrevdo): Re-enable use_gpu=True once non-DMA Variant
101101
# copying between CPU and GPU is supported.
102102
with self.session(use_gpu=False):
103-
variant_const_3 = create_constant_variant(3)
104-
variant_const_4 = create_constant_variant(4)
105-
variant_const_5 = create_constant_variant(5)
106-
# 3 + 3 + 5 + 4 = 15.
107-
result = math_ops.add_n((variant_const_3, variant_const_3,
108-
variant_const_5, variant_const_4))
103+
num_tests = 127
104+
values = list(range(100))
105+
variant_consts = [create_constant_variant(x) for x in values]
106+
sum_count_indices = np.random.randint(1, 29, size=num_tests)
107+
sum_indices = [
108+
np.random.randint(100, size=count) for count in sum_count_indices]
109+
expected_sums = [np.sum(x) for x in sum_indices]
110+
variant_sums = [math_ops.add_n([variant_consts[i] for i in x])
111+
for x in sum_indices]
109112

110-
# Smoke test -- ensure this executes without trouble.
113+
# We use as_string() to get the Variant DebugString for the
114+
# variant_sums; we know its value so we can check via string equality
115+
# here.
116+
#
111117
# Right now, non-numpy-compatible objects cannot be returned from a
112118
# session.run call; similarly, objects that can't be converted to
113119
# native numpy types cannot be passed to ops.convert_to_tensor.
114-
# For now, run the test and examine the output to see that the result is
115-
# equal to 15.
116-
result_op = logging_ops.Print(
117-
result, [variant_const_3, variant_const_4, variant_const_5, result],
118-
message=("Variants stored an int: c(3), c(4), c(5), "
119-
"add_n(c(3), c(3), c(5), c(4)): ")).op
120-
result_op.run()
120+
variant_sums_string = string_ops.as_string(variant_sums)
121+
self.assertAllEqual(
122+
variant_sums_string,
123+
["Variant<type: int value: {}>".format(s).encode("utf-8")
124+
for s in expected_sums])
121125

122126

123127
if __name__ == "__main__":

0 commit comments

Comments
 (0)