Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 200727545 #20064

Merged
merged 101 commits into from
Jun 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
11b3a9f
Documenting capabilities and limitations of AutoGraph
tensorflower-gardener Jun 13, 2018
ee3ecdf
Reversible residual network example with manually built gradient comp…
tensorflower-gardener Jun 13, 2018
5bfc42c
[tf.data] Factor out function argument restructuring into a helper.
mrry Jun 13, 2018
0104d4f
[TF:XLA] Bump open source llvm revision to r334593
Jun 13, 2018
cb2c5be
Add a test that checks memory usage by running a model 100k times.
Jun 13, 2018
6b7a17d
Automated g4 rollback of changelist 199870879
tensorflower-gardener Jun 13, 2018
74655a9
fix md link format
MarkDaoust Jun 13, 2018
106766c
Fix a build failure when cuda version is less than 9000.
tensorflower-gardener Jun 13, 2018
d40ca72
Switch Estimator from using DistributionStrategy.fetch() to .read_var().
tensorflower-gardener Jun 13, 2018
47b1c93
Initial application of runtime shapes to runtime kernels.
tensorflower-gardener Jun 13, 2018
8051c4b
Provide default name_scope in cond_v2.
skye Jun 13, 2018
642a043
[TF:XLA] Replace bespoke NodeSlot class in subgraph encapsulation cod…
hawkinsp Jun 13, 2018
4254b2c
Splits testLargeCase in metric_ops_test into a dedicated file for slo…
tensorflower-gardener Jun 13, 2018
9103442
[tf.data] Factor out a helper for creating flat args to `function.Def…
mrry Jun 13, 2018
b253e6b
support int16-quantized data in TFLite interpreter.
tensorflower-gardener Jun 13, 2018
7b033a1
[XLA] Make --xla_dump_executions_to actually dump the HloSnapshot.
tensorflower-gardener Jun 13, 2018
fbd920a
Split out HloInfeedIndexInstruction and HloOutfeedInstruction as subc…
tensorflower-gardener Jun 13, 2018
8be4327
[XLA:GPU] Move IsProfitableOperand implementation into the MultiOutpu…
thomasjoerg Jun 13, 2018
096b7dc
Pick up estimator docstrings from correct modules when generating API.
annarev Jun 13, 2018
bf920de
[contrib.cloud] Expose GCS config methods
saeta Jun 13, 2018
e1296c1
Fix assumptions that a Shape must be a tuple or an array.
meheffernan Jun 13, 2018
40e4beb
Add return statement to end of ToVlogString(dnn::DataType data_type)
jtkeeling Jun 13, 2018
2f7f04a
[XLA:GPU] Run HloCSE after multi-output fusion
d0k Jun 13, 2018
a3273e0
Variable Tensor API for TF Lite.
miaout17 Jun 13, 2018
e2213af
[XLA] Update the error message for AllReduce.
tensorflower-gardener Jun 13, 2018
88ad994
Make ops.colocate_with work with tower-local variables as well.
Jun 13, 2018
02c74ef
Add xla::ShapeUtil::TryGetSubshape that doesn't CHECK fail on invalid…
tensorflower-gardener Jun 13, 2018
31ea26d
Fix `Input` to allow scalar shape.
MarkDaoust Jun 13, 2018
4d48d1d
Uses a resource variable by default for the global step.
alextp Jun 13, 2018
ec927be
Subgroup CrossReplicaSum and change in TpuOptimizer.
toponado-zz Jun 13, 2018
b74197c
Upgrade the tpu profiler version to 1.7.0.
tensorflower-gardener Jun 13, 2018
11e1a45
Automated g4 rollback of changelist 200309129
tensorflower-gardener Jun 13, 2018
a6cccdc
[XLA] Add missing space in evaluator error message.
mkuperst Jun 13, 2018
d1ff8bc
Documentation style fix.
bileschi Jun 13, 2018
4986168
[XLA] Fix indentation in comment in EmitRowReduction.
Jun 13, 2018
1babacb
Minor fix for lt.map_fn, handling a case where Tensor type inference …
tensorflower-gardener Jun 14, 2018
462a7e0
Add sequential functionality to _SharedEmbeddingColumn.
tensorflower-gardener Jun 14, 2018
dac4634
Fix typo in register.h
tensorflower-gardener Jun 14, 2018
007fc38
Makes cond_v2 pass in device, container, colocation stacks, and colle…
tensorflower-gardener Jun 14, 2018
c62f4a5
Reduce runtime of metric_ops_test by increasing sharding and splitting
tensorflower-gardener Jun 14, 2018
0946c28
fully_connected_feed_test timing out, increase its size.
tensorflower-gardener Jun 14, 2018
2832528
Fix layout assignment CHECK failure on channel constraints.
tensorflower-gardener Jun 14, 2018
e9a7286
Automated g4 rollback of changelist 200495346
tensorflower-gardener Jun 14, 2018
c570211
Re-enable compilation for MacOS. This was unintentionally broken prev…
tensorflower-gardener Jun 14, 2018
0b8c580
Remove hardcoded dtype in tf.layers.xxx() function call to make them …
protoget Jun 14, 2018
8d9787b
Automated g4 rollback of changelist 200467580
tensorflower-gardener Jun 14, 2018
83a48e0
Provide the ability to specify, in tf.train.MonitoredTrainingSession(…
tensorflower-gardener Jun 14, 2018
03dd231
Extract HloExecutionProfiler into its own file.
akuegel Jun 14, 2018
915b138
Internal change.
tensorflower-gardener Jun 14, 2018
15430c5
[TF:XLA] Pass source tensors in original input graph to subgraph rewr…
hawkinsp Jun 14, 2018
ae26e86
Add support for propagating resource shapes via the TPUReplicatedInpu…
hawkinsp Jun 14, 2018
a7c1b03
Standardize the type notation for docstrings that require describing …
Jun 14, 2018
b704ab9
Make deleting HloInstruction safer.
tensorflower-gardener Jun 14, 2018
4ec3fcd
Adds support for explicitly assigning the replica to the VariableDevi…
tensorflower-gardener Jun 14, 2018
b22cfe5
[XLA:GPU] Turn on Loop-Loop sibling multi-output fusion
d0k Jun 14, 2018
3d5fa1f
Disable removing pairs of transposes across chains, while debugging b…
tensorflower-gardener Jun 14, 2018
5001a3f
Add tf.contrib.checkpoint.list_objects for listing all Python depende…
allenlavoie Jun 14, 2018
a4cadda
[tf.data] Add `StructuredFunctionWrapper` to encapsulate tf.data's en…
mrry Jun 14, 2018
e1b0ceb
Amend notes on eager compatibility for Estimator
martinwicke Jun 14, 2018
df9dd22
[XLA:GPU] Make alias analysis emit metadata for subshapes
d0k Jun 14, 2018
eb97901
Propagate the non-resource part of a resource tensor's shape in Enter…
hawkinsp Jun 14, 2018
f596bcc
Remove dead code from bulk_restore() but keep dead function parameter…
tensorflower-gardener Jun 14, 2018
3d7b33f
Make it possible to retrieve the variables used in a defined function.
akshayka Jun 14, 2018
3970b53
Switch "init_from_checkpoint" to use "DEBUG" log level.
MarkDaoust Jun 14, 2018
8f7afe0
Automated g4 rollback of changelist 200500606
tensorflower-gardener Jun 14, 2018
8e4c414
Optimized implementation of transpose conv. Uses an im2col array and …
tensorflower-gardener Jun 14, 2018
91ec6cc
[TF:XLA] Bump open source llvm revision to r334704
Jun 14, 2018
7ccf193
Factor a "capture_dependencies" scope out of Template.
allenlavoie Jun 14, 2018
d943de3
Support non-static shape in `tf.distributions.Categorical`.
csuter Jun 14, 2018
840aeb0
Merged commit includes the following changes:
tensorflower-gardener Jun 14, 2018
f01d254
Add support for TOKEN type to CPU/GPU backends.
meheffernan Jun 14, 2018
c4eafb4
Install Keras dependencies.
yifeif Jun 14, 2018
24b2043
Automated g4 rollback of changelist 200414970
miaout17 Jun 14, 2018
d57e9a6
Clarify reuse documentation in variable_scope and eager.
alextp Jun 14, 2018
f5c9d27
Internal Change.
tensorflower-gardener Jun 14, 2018
929474d
[tf.data] Convert GeneratorDataset to use StructuredFunctionWrapper.
mrry Jun 14, 2018
18b0f66
Export build_toco_convert_protos
Jun 14, 2018
e87b52a
[tf.data] Adding support for tf.data.Dataset.prefetch(buffer_size=0).
jsimsa Jun 14, 2018
261ab05
Automated g4 rollback of changelist 196296096
tensorflower-gardener Jun 14, 2018
9e4cbaf
Convert log(x+1) to log1p(x).
tensorflower-gardener Jun 15, 2018
7e05b8a
[TF:XLA] Account for subcomputations in heap simulator during schedul…
dimvar Jun 15, 2018
5ae938f
Speed up shuffle_dataset_op_test.
saxenasaurabh Jun 15, 2018
99d48bd
Small refactoring of code to check device crossing in dependency opti…
tensorflower-gardener Jun 15, 2018
889833b
Add HWNC and HWCN data format support
tensorflower-gardener Jun 15, 2018
d8adf4b
Correctly build and link in the GCS control ops
saeta Jun 15, 2018
332c4d6
Increase tolerance for depthwise convolution gradient tests.
tensorflower-gardener Jun 15, 2018
271c1a1
Split out HloAllReduceInstruction as a subclass of HloInstruction.
tensorflower-gardener Jun 15, 2018
7ebce39
Increase the numerical tolerance threshold temporarily to make the te…
aaroey Jun 15, 2018
7d5a7ec
[tf.data] Internal refactor of `tf.data.contrib.map_and_batch()`, swi…
jsimsa Jun 15, 2018
7f265d1
Move xla_sharding related code to third_party
toponado-zz Jun 15, 2018
7f3dbd0
Disable collective ops support on Android builds.
timonvo Jun 15, 2018
3cd4eda
Added comment to explain plugging on external sharding normalizers.
tensorflower-gardener Jun 15, 2018
284ad32
Improves the docstring and comments about feature column library.
Jun 15, 2018
9d67a56
Add resource type to Switch op.
tensorflower-gardener Jun 15, 2018
b84506e
Update demo app to use nightly TFLite build instead of latest release…
alanchiao Jun 15, 2018
7bd8cd2
Adds warm start capability to tf.contrib.estimator.DNNEstimator
tensorflower-gardener Jun 15, 2018
4944c27
Broad refactoring (part 1): Introduce a module dedicated to symbols t…
Jun 15, 2018
69e3c1d
Fix Makefile build for benchmarking code.
shashishekhar Jun 15, 2018
8ad3184
Add XLA support for the error function (and complement).
rjpower Jun 15, 2018
a8615a9
Merge commit for internal changes
Jun 15, 2018
eb8ed73
Fix bad manual merge.
Jun 15, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion SECURITY.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,4 +245,4 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc=
### Known Vulnerabilities

For a list of known vulnerabilities and security advisories for TensorFlow,
(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md)[click here].
[click here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/security/index.md).
246 changes: 117 additions & 129 deletions tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ limitations under the License.
namespace tensorflow {

// A rewriting function to apply to each subgraph during encapsulation.
// 'arg_source_tensors' are the tensors corresponding to the arguments in the
// original source graph (*not* 'graph').
//
// 'graph' is the subgraph. The rewriting may renumber the inputs and outputs;
// 'input_permutation' is a mapping from old argument numbers to new argument
// numbers, whereas 'output_permutation' is the same for outputs. Both
Expand All @@ -37,6 +40,7 @@ namespace tensorflow {
// The rewrite may also change the NodeDef's operator name, and that
// name will be used as the name of the generated function.
typedef std::function<Status(
const std::vector<OutputTensor>& arg_source_tensors,
std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
std::vector<int>* output_permutation, NodeDef* node_def)>
RewriteSubgraphFn;
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", "_outside", graph_before,
/*rewrite_subgraph_fn=*/
[&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation,
std::vector<int>* output_permutation,
NodeDef* call_def) {
Expand Down Expand Up @@ -801,7 +802,8 @@ TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", "_outside", graph_before,
/*rewrite_subgraph_fn=*/
[&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
[&guaranteed_consts](const std::vector<OutputTensor>& arg_source_tensors,
std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation,
std::vector<int>* output_permutation,
NodeDef* call_def) {
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ py_library(
],
)

py_test(
name = "xla_test_test",
size = "small",
srcs = ["xla_test_test.py"],
deps = [
":xla_test",
],
)

tf_xla_py_test(
name = "adagrad_test",
size = "small",
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/compiler/tests/unary_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,16 @@ def testFloatOps(self):
expected=np.array([1.54308063, 3.76219569, 10.067662, 27.30823284],
dtype=dtype))

# Disable float16 testing for now
if dtype != np.float16:
x = np.arange(-10, 10, 1).astype(dtype)
with self.test_session() as session:
erf_x = session.run(math_ops.erf(x))
erfc_x = session.run(math_ops.erfc(x))

self._assertOpOutputMatchesExpected(math_ops.erf, x, expected=erf_x)
self._assertOpOutputMatchesExpected(math_ops.erfc, x, expected=erfc_x)

self._assertOpOutputMatchesExpected(
math_ops.exp,
np.array([[-1, 1]], dtype=dtype),
Expand Down
57 changes: 33 additions & 24 deletions tensorflow/compiler/tests/xla_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,32 @@
'Value to set the TF_XLA_FLAGS environment variable to')


def parse_disabled_manifest(manifest_content):
comments_re = re.compile('#.*$')
disabled_tests = []
disabled_method_types = []
for l in manifest_content.splitlines():
stripped = comments_re.sub('', l).strip()
if not stripped:
continue
entry = stripped.split(' ')
if len(entry) == 1:
disabled_tests.append(entry[0])
elif len(entry) == 2:
disabled_method_types.append((entry[0], entry[1].strip().split(',')))
else:
raise ValueError('Bad entry in manifest file.')

disabled_regex = '|'.join(disabled_tests)
method_types_filter = dict()
for method, types in disabled_method_types:
method_types_filter[method] = set([
dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
for name in types
])
return disabled_regex, method_types_filter


class XLATestCase(test.TestCase):
"""XLA test cases are parameterized test cases."""

Expand Down Expand Up @@ -85,38 +111,21 @@ def __init__(self, method_name='runTest'):

# Parse the manifest file, if any, into a regex identifying tests to
# disable
self.disabled_regex = None
self._method_types_filter = dict()
# TODO(xpan): Make it text proto if it doesn't scale.
# Each line of the manifest file specifies an entry. The entry can be
# 1) TestNameRegex // E.g. CumprodTest.* Or
# 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16
# The 1) disables the entire test. While 2) only filter some numeric types
# so that they are not used in those tests.
self.disabled_regex = None
self._method_types_filter = {}

if FLAGS.disabled_manifest is not None:
comments_re = re.compile('#.*$')
manifest_file = open(FLAGS.disabled_manifest, 'r')
disabled_tests = []
disabled_method_types = []
for l in manifest_file.read().splitlines():
if not l:
continue
entry = comments_re.sub('', l).strip().split(' ')
if len(entry) == 1:
disabled_tests.append(entry[0])
elif len(entry) == 2:
disabled_method_types.append(
(entry[0], entry[1].strip().split(',')))
else:
raise ValueError('Bad entry in manifest file.')

self.disabled_regex = re.compile('|'.join(disabled_tests))
for method, types in disabled_method_types:
self._method_types_filter[method] = set([
dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype
for name in types])
manifest_file.close()
with open(FLAGS.disabled_manifest, 'r') as manifest_file:
disabled_regex, self._method_types_filter = (
parse_disabled_manifest(manifest_file.read()))
if disabled_regex:
self.disabled_regex = re.compile(disabled_regex)

if FLAGS.tf_xla_flags is not None:
os.environ['TF_XLA_FLAGS'] = FLAGS.tf_xla_flags
Expand Down
44 changes: 44 additions & 0 deletions tensorflow/compiler/tests/xla_test_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the XLATestCase test fixture base class."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.compiler.tests import xla_test
from tensorflow.python.platform import test


class XlaTestCaseTestCase(test.TestCase):

def testManifestEmptyLineDoesNotCatchAll(self):
manifest = """
testCaseOne
"""
disabled_regex, _ = xla_test.parse_disabled_manifest(manifest)
self.assertEqual(disabled_regex, "testCaseOne")

def testManifestWholeLineCommentDoesNotCatchAll(self):
manifest = """# I am a comment
testCaseOne
testCaseTwo
"""
disabled_regex, _ = xla_test.parse_disabled_manifest(manifest)
self.assertEqual(disabled_regex, "testCaseOne|testCaseTwo")


if __name__ == "__main__":
test.main()
46 changes: 46 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/unary_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ limitations under the License.
// Native XLA implementations of simple unary Ops

#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"

Expand Down Expand Up @@ -185,5 +187,49 @@ XLAJIT_MAKE_UNARY(Imag, b->Imag(x));

#undef XLAJIT_MAKE_UNARY

// Erf/Erfc. For x in (-1, 1), the erf approximation is used; erfc polynomial
// is used outside of this range.
class ErfOp : public XlaOpKernel {
public:
explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
xla::PrimitiveType primitive_type;
xla::XlaOp one = XlaHelpers::One(b, input_type(0));
xla::XlaOp x = ctx->Input(0);
xla::XlaOp abs_x = b->Abs(x);

OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(input_type(0), &primitive_type));

auto y = b->Select(b->Gt(abs_x, one),
b->Sub(one, ComputeErfc(b, x, primitive_type)),
ComputeErf(b, x, primitive_type));
ctx->SetOutput(0, y);
}
};
REGISTER_XLA_OP(Name("Erf"), ErfOp);

class ErfcOp : public XlaOpKernel {
public:
explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaBuilder* b = ctx->builder();
xla::XlaOp one = XlaHelpers::One(b, input_type(0));
xla::XlaOp x = ctx->Input(0);
xla::XlaOp abs_x = b->Abs(x);

xla::PrimitiveType primitive_type;
OP_REQUIRES_OK(ctx,
DataTypeToPrimitiveType(input_type(0), &primitive_type));

auto y = b->Select(b->Lt(abs_x, one),
b->Sub(one, ComputeErf(b, x, primitive_type)),
ComputeErfc(b, x, primitive_type));
ctx->SetOutput(0, y);
}
};
REGISTER_XLA_OP(Name("Erfc"), ErfcOp);

} // namespace
} // namespace tensorflow
4 changes: 2 additions & 2 deletions tensorflow/compiler/tf2xla/lib/batch_dot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
}

// Check for zero lhs/rhs dim size.
if (xla::ShapeUtil::HasZeroElements(x_shape) ||
xla::ShapeUtil::HasZeroElements(y_shape)) {
if (xla::ShapeUtil::IsZeroElementArray(x_shape) ||
xla::ShapeUtil::IsZeroElementArray(y_shape)) {
std::vector<int64> dimensions(batch_dimension_numbers.size());
for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]);
Expand Down
1 change: 0 additions & 1 deletion tensorflow/compiler/xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ cc_library(
":types",
":util",
":xla_data_proto",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
Expand Down
84 changes: 84 additions & 0 deletions tensorflow/compiler/xla/client/lib/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,88 @@ StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder) {
return builder->Reduce(predicates, f, logical_or, all_dimensions);
}

namespace {
xla::XlaOp FloatLiteral(xla::XlaBuilder* b, PrimitiveType data_type,
float value) {
return b->ConvertElementType(b->ConstantR0(value), data_type);
}

// Polynomials for computing erf/erfc. Originally from cephes.
// Note we use float for compatibility across devices, at the cost of some
// precision for 64 bit computations.
//
// Coefficients are in descending order.
std::array<float, 9> kErfcPCoefficient = {
2.46196981473530512524E-10, 5.64189564831068821977E-1,
7.46321056442269912687E0, 4.86371970985681366614E1,
1.96520832956077098242E2, 5.26445194995477358631E2,
9.34528527171957607540E2, 1.02755188689515710272E3,
5.57535335369399327526E2};
std::array<float, 9> kErfcQCoefficient = {
1.00000000000000000000E0, 1.32281951154744992508E1,
8.67072140885989742329E1, 3.54937778887819891062E2,
9.75708501743205489753E2, 1.82390916687909736289E3,
2.24633760818710981792E3, 1.65666309194161350182E3,
5.57535340817727675546E2};
std::array<float, 6> kErfcRCoefficient = {
5.64189583547755073984E-1, 1.27536670759978104416E0,
5.01905042251180477414E0, 6.16021097993053585195E0,
7.40974269950448939160E0, 2.97886665372100240670E0};
std::array<float, 7> kErfcSCoefficient = {
1.00000000000000000000E0, 2.26052863220117276590E0,
9.39603524938001434673E0, 1.20489539808096656605E1,
1.70814450747565897222E1, 9.60896809063285878198E0,
3.36907645100081516050E0};
std::array<float, 5> kErfTCoefficient = {
9.60497373987051638749E0, 9.00260197203842689217E1,
2.23200534594684319226E3, 7.00332514112805075473E3,
5.55923013010394962768E4};
std::array<float, 6> kErfUCoefficient = {
1.00000000000000000000E0, 3.35617141647503099647E1,
5.21357949780152679795E2, 4.59432382970980127987E3,
2.26290000613890934246E4, 4.92673942608635921086E4};
} // namespace

// Evaluate the polynomial given coefficients and `x`.
// N.B. Coefficients should be supplied in decreasing order.
xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x,
tensorflow::gtl::ArraySlice<float> coefficients,
PrimitiveType data_type) {
xla::XlaOp poly = FloatLiteral(b, data_type, 0.0);
for (float c : coefficients) {
poly = b->Add(b->Mul(poly, x), FloatLiteral(b, data_type, c));
}
return poly;
}

// Compute an approximation of the error function complement (1 - erf(x)).
xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x,
PrimitiveType data_type) {
xla::XlaOp zero = FloatLiteral(b, data_type, 0.0);
xla::XlaOp two = FloatLiteral(b, data_type, 2.0);
xla::XlaOp eight = FloatLiteral(b, data_type, 8.0);

xla::XlaOp abs_x = b->Abs(x);
xla::XlaOp z = b->Exp(b->Mul(b->Neg(x), x));

xla::XlaOp pp = EvaluatePolynomial(b, abs_x, kErfcPCoefficient, data_type);
xla::XlaOp pq = EvaluatePolynomial(b, abs_x, kErfcQCoefficient, data_type);
xla::XlaOp pr = EvaluatePolynomial(b, abs_x, kErfcRCoefficient, data_type);
xla::XlaOp ps = EvaluatePolynomial(b, abs_x, kErfcSCoefficient, data_type);

xla::XlaOp y = b->Select(b->Lt(abs_x, eight), b->Div(b->Mul(z, pp), pq),
b->Div(b->Mul(z, pr), ps));

return b->Select(b->Lt(x, zero), b->Sub(two, y), y);
}

// Compute a polynomial approximation of the error function.
xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x,
PrimitiveType data_type) {
xla::XlaOp z = b->Mul(x, x);
xla::XlaOp pt = EvaluatePolynomial(b, z, kErfTCoefficient, data_type);
xla::XlaOp pu = EvaluatePolynomial(b, z, kErfUCoefficient, data_type);
return b->Div(b->Mul(x, pt), pu);
}

} // namespace xla
14 changes: 14 additions & 0 deletions tensorflow/compiler/xla/client/lib/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ XlaComputation CreateScalarOrComputation(XlaBuilder* builder);
// Note: if predicates is zero-sized, Any() vacuously returns false.
StatusOr<XlaOp> Any(const XlaOp& predicates, XlaBuilder* builder);

// Evaluate the polynomial given coefficients and `x`.
// N.B. Coefficients should be supplied in decreasing order.
xla::XlaOp EvaluatePolynomial(xla::XlaBuilder* b, const xla::XlaOp& x,
tensorflow::gtl::ArraySlice<double> coefficients,
PrimitiveType data_type);

// Compute an approximation of the error function complement (1 - erf(x)).
xla::XlaOp ComputeErfc(xla::XlaBuilder* b, const xla::XlaOp& x,
PrimitiveType data_type);

// Compute an approximation of the error function.
xla::XlaOp ComputeErf(xla::XlaBuilder* b, const xla::XlaOp& x,
PrimitiveType data_type);

} // namespace xla

#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_