@@ -370,24 +370,77 @@ class AddNOp<Device, Variant, OpKernelT, OpKernelConstructionT,
370370 i, " has shape: " , ctx->input (i).shape ().DebugString (), " ." ));
371371 }
372372
373- // Step 2: attempt to add using
373+ // Step 2: Sum input variants in a tree-like structure using
374374 // BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...)
375375 // For the output create a default-constructed variant object.
376- // TODO(ebrevdo): Perform summation in a tree-structure.
377- Tensor out (cpu_allocator (), DT_VARIANT, TensorShape ({}));
378- Variant* v_out = &(out.scalar <Variant>()());
379- OP_REQUIRES_OK (ctx, BinaryOpVariants<Device>(
380- ctx, ADD_VARIANT_BINARY_OP,
381- ctx->input (0 ).template scalar <Variant>()(),
382- ctx->input (1 ).template scalar <Variant>()(), v_out));
383- for (int i = 2 ; i < num; ++i) {
384- const Variant tmp = std::move (*v_out);
385- const Variant& inp = ctx->input (i).template scalar <Variant>()();
386- OP_REQUIRES_OK (ctx, BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP,
387- inp, tmp, v_out));
376+ //
377+ // Pairwise summation provides better numerical precision by
378+ // reducing round-off error:
379+ //
380+ // https://en.wikipedia.org/wiki/Pairwise_summation
381+ //
382+ // These two vectors are used to store and mark intermediate sums.
383+ gtl::InlinedVector<bool , 4 > temp_filled (num, false );
384+ gtl::InlinedVector<Variant, 4 > temp (num);
385+
386+ // Tree-based summation.
387+ int skip = 1 ;
388+ int n = num;
389+ while (skip < n) {
390+ int i = skip;
391+ while (i < n) {
392+ // TODO(ebrevdo, rmlarsen): Parallelize the pairwise summations in the
393+ // inner loop if the variants are "large".
394+
395+ // x[i - skip] += x[i]
396+ OP_REQUIRES_OK (ctx,
397+ AddVariantTo (ctx, i - skip, i, &temp, &temp_filled));
398+ // We won't use this index again, recover its memory.
399+ temp[i].clear ();
400+ i += 2 * skip;
401+ }
402+ if (i == n) {
403+ // x[0] += x[i - skip]
404+ OP_REQUIRES_OK (ctx,
405+ AddVariantTo (ctx, 0 , i - skip, &temp, &temp_filled));
406+ // We won't use this index again, recover its memory.
407+ temp[i - skip].clear ();
408+ n -= skip;
409+ }
410+ skip *= 2 ;
388411 }
412+
413+ Tensor out (cpu_allocator (), DT_VARIANT, TensorShape ({}));
414+ out.scalar <Variant>()() = std::move (temp[0 ]);
389415 ctx->set_output (0 , out);
390416 }
417+
418+ private:
419+ // AddVariantTo efficiently performs:
420+ // temp[lhs_ix] <- array(lhs_ix) + array(rhs_ix)
421+ // where array(ix) := (temp_filled[ix]
422+ // ? temp[ix]
423+ // : ctx->input(ix).scalar<Variant>()())
424+ // This reduces (possibly expensive) copying of Variants from
425+ // the inputs into temp at the lowest levels of the summation tree.
426+ static inline Status AddVariantTo (OpKernelContextT* ctx, const int lhs_ix,
427+ const int rhs_ix,
428+ gtl::InlinedVector<Variant, 4 >* temp,
429+ gtl::InlinedVector<bool , 4 >* temp_filled) {
430+ Variant tmp;
431+ if (temp_filled->at (lhs_ix)) tmp = std::move (temp->at (lhs_ix));
432+ const Variant& a = temp_filled->at (lhs_ix)
433+ ? tmp
434+ : ctx->input (lhs_ix).template scalar <Variant>()();
435+ const Variant& b = temp_filled->at (rhs_ix)
436+ ? temp->at (rhs_ix)
437+ : ctx->input (rhs_ix).template scalar <Variant>()();
438+ Variant* c = &temp->at (lhs_ix);
439+ TF_RETURN_IF_ERROR (
440+ BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP, a, b, c));
441+ temp_filled->at (lhs_ix) = true ;
442+ return Status::OK ();
443+ }
391444};
392445
393446} // namespace tensorflow
0 commit comments