Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 178185697 #15174

Merged
merged 126 commits into from
Dec 7, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
e69bb08
Modify custom export strategy to account for multidimensional sparse …
tensorflower-gardener Dec 4, 2017
8cb4a0f
Better deprecation message for --input_type[s].
tensorflower-gardener Dec 4, 2017
50efd92
Enable dependency optimizer in Grappler by default.
tensorflower-gardener Dec 4, 2017
8842c73
Make `replicate_model_fn` available via tensorflow.contrib.estimator …
isaprykin Dec 4, 2017
6eed89c
Treat stream removal as an UNAVAILABLE error in Tensorflow instead of…
rjpower Dec 4, 2017
7ea70db
[XLA:CPU] Extract out a common helper to create an llvm::Function
Dec 4, 2017
ff71c27
Cleanup: Remove dummy_disabled_internal target.
miaout17 Dec 4, 2017
3ff72d4
Sanitize dtypes in filenames in normalization_test.
gunan Dec 4, 2017
171e4db
Changed tensorflow::StringPiece default constructor to set data_ to n…
tensorflower-gardener Dec 4, 2017
5ae2168
[XLA:CPU] Disable frame pointer elimination
Dec 4, 2017
2d3d5c2
Fixing wording in truncated_normal as per Github bug request #13686.
tensorflower-gardener Dec 4, 2017
36ac0f3
[tpu:profiler] Modify the fields in host-independent and host-…
tensorflower-gardener Dec 4, 2017
a1c2913
[XLA:GPU] Switch from specifying maxntid to reqntid.
Dec 4, 2017
8f1e63d
Actually use ApiDef when generating Python API.
annarev Dec 4, 2017
5917d48
Apply oss_serial tag to tests that use portpicker to create local clu…
caisq Dec 4, 2017
71cd06c
[TF:XLA] Fix wrong output of FloorDiv op for DT_HALF values.
hawkinsp Dec 4, 2017
e96c364
Sanitize formatting in IdTableWithHashBuckets doc comment.
yacoder Dec 4, 2017
89f4dd0
Getting rid of obsolete function is_variable_registered from LayerCol…
tensorflower-gardener Dec 4, 2017
a2df7bd
Update pin for bazel-toolchains to latest version
tensorflower-gardener Dec 4, 2017
c14ef60
Marks args as runtime consts in XLA EncapsulateSubgraphsPass.
vinuraja Dec 4, 2017
1b2bd86
Internal-only changes
caisq Dec 4, 2017
fbec215
[XLA:CPU] Use an AVX optimized reduction step for row-major matrix-ve…
Dec 4, 2017
c24b377
[XLA:CPU] Avoid over-aligning parameter buffers
Dec 4, 2017
2427923
Internal change.
annarev Dec 4, 2017
6f09a74
Fix TFGAN's `clip_weights_test.py` bugs.
tensorflower-gardener Dec 4, 2017
7b53245
Sort sections in operation semantics alphabetically.
nickdesaulniers Dec 4, 2017
6f7d8b5
Allow test_util.evaluate handle nested tensors.
sguada Dec 4, 2017
46af794
Fix edge case with ImportGraphDefOption.uniquify_names = true.
skye Dec 4, 2017
4a8e27e
[StreamExecutor] Add UnqueryableDeviceParams for all nvidia GPUs.
Dec 4, 2017
f5f7b85
Fix ResourceVariable's docstring example.
reedwm Dec 4, 2017
d87a76d
Enable bfloat16 use from Python:
hawkinsp Dec 4, 2017
e12c032
hsv_in_yiq gpu implementation.
tensorflower-gardener Dec 4, 2017
0946c24
Add a single capacity prefetch to `tf.contrib.data.read_batch_features`.
tensorflower-gardener Dec 4, 2017
c51221e
[XLA] Add a default implementation of Literal::ToString for rank >= 6…
hawkinsp Dec 4, 2017
eaf5124
[XLA:GPU] Use more threads per thread block.
Dec 4, 2017
aa8b75a
Fix minor typos in the doc of SpaceToDepth and DepthToSpace.
Dec 5, 2017
a5b01ff
Correct trivial spelling error in internal_convert_to_tensor
tensorflower-gardener Dec 5, 2017
868e2b3
[XLA] Add --print_result flag to replay_computation tool.
Dec 5, 2017
1298924
Fix bug with uniquified colocation attrs in ImportGraphDef.
skye Dec 5, 2017
c414e26
Add BF16 tests for reduce-window.
ukoxyz Dec 5, 2017
cfc1de5
Fix tf.identity(resource variable) with eager execution and a device
tensorflower-gardener Dec 5, 2017
2091f50
Treat integer default initializers like floating point ones.
alextp Dec 5, 2017
601687d
Modifying _get_examples in graph_io.py to utilize tf.cond.
tensorflower-gardener Dec 5, 2017
4ff0f28
Reproduce an issue with MonitoredSession when saving a variable on a …
isaprykin Dec 5, 2017
b7affde
[TF2XLA] Change the implementation of Diag and MatrixDiag to use arit…
tensorflower-gardener Dec 5, 2017
201d8d8
Enable transferring a tuple literal to a replicated device.
meheffernan Dec 5, 2017
c9038c9
Generates a warning if the global step is not increased.
Dec 5, 2017
c486a91
Fix bugs in neutral element code and add more unit tests to cover mat…
tensorflower-gardener Dec 5, 2017
5234185
[XLA] Mark Rng as side-effecting and add a rematerialization test to …
blakehechtman Dec 5, 2017
009f4f3
[TF:XLA] Add support for NCHW format to SpaceToDepth and DepthToSpace.
hawkinsp Dec 5, 2017
5ee30df
Estimate Placeholder as a no-op.
yacoder Dec 5, 2017
2b2500a
Add a helper to HloSharding to easily create trivial flat tuples with…
tensorflower-gardener Dec 5, 2017
b62a356
Make RevBlock a subclass of Layer
tensorflower-gardener Dec 5, 2017
09e9ca1
Improve handling of operations that are known to TOCO but not to TF L…
tensorflower-gardener Dec 5, 2017
77b60c1
Simplify code in dependency optimizer.
tensorflower-gardener Dec 5, 2017
f88cd91
Adding variant-based serialization and deserialization for sparse ten…
jsimsa Dec 5, 2017
c72bb97
nn_impl.py cleanup: used keepdims instead of deprecated keep_dims.
tensorflower-gardener Dec 5, 2017
e72ecbd
Add ImportGraphDefOptions::uniquify_prefix.
skye Dec 5, 2017
4b0a236
Add the tf2xla_supported_ops tool, which dumps ops supported by tf2xla.
tensorflower-gardener Dec 5, 2017
f38f92e
Only parse known flags in tf.app.run().
yilei Dec 5, 2017
33e3da5
[TF:XLA] Add support for FusedBatchNormGrad where is_training=False.
hawkinsp Dec 5, 2017
21e831d
Automated g4 rollback of changelist 177799252
tensorflower-gardener Dec 5, 2017
6afface
Add android rule helpers and cleanup input loops
angerson Dec 5, 2017
ad30cd2
Internal change.
tensorflower-gardener Dec 5, 2017
248176b
New document for Getting Started section about saving models.
tensorflower-gardener Dec 5, 2017
b352b38
Rather than make potentially complex modifications to the Hlo graph, …
nickdesaulniers Dec 5, 2017
3b930e3
[TF:XLA] Add test with while loop and many parameters.
tensorflower-gardener Dec 5, 2017
557e0ce
Use a macro to determine whether BF16 is supported.
ukoxyz Dec 5, 2017
aa4f491
Adds shards and increases size to dnn_linear_combined_test to prevent…
tensorflower-gardener Dec 5, 2017
7f75687
Replace `FunctionCallFrame` with a pure-virtual `CallFrameInterface`.
mrry Dec 5, 2017
b6ed812
Sets the master to '' for single node cluster.
Dec 5, 2017
d190669
Add SaveRestoreMeasuringCostEstimator to measure the memory and runti…
Dec 6, 2017
41c19af
Implement faster and less memory hungry version of topological sort t…
tensorflower-gardener Dec 6, 2017
cc7482f
Change InputArray.shape from being a repeated int field to being
tensorflower-gardener Dec 6, 2017
1e78249
Improve module docstrings, which show up in Google search.
tensorflower-gardener Dec 6, 2017
44cb388
Support a vector and a 4D tensor as inputs to a binary op.
Dec 6, 2017
af8a550
Always include the function library when exporting a MetaGraphDef.
mrry Dec 6, 2017
fb857dc
Adds an optional dict to hold tensors that are concatenated into the …
tensorflower-gardener Dec 6, 2017
d9a71ad
Fix some build incompatibilities with new versions of Bazel
angerson Dec 6, 2017
3d895eb
[StreamExecutor] When a kernel launch fails, print the kernel's name.
Dec 6, 2017
cefd2c7
[XLA] Humanize some print-outs.
tensorflower-gardener Dec 6, 2017
1a786ab
[XLA:GPU] Don't autotune while other kernels are running.
Dec 6, 2017
22767d5
Allow CrossReplicaSum to take multiple operands internally.
broune Dec 6, 2017
c231d21
Fix broken usage of mutexes in training ops like AdaDelta.
ebrevdo Dec 6, 2017
4985cf2
Add tutorial, model code and dataset conversion tool for the Quick, D…
tensorflower-gardener Dec 6, 2017
fca13de
[XLA] Change default argument to explicit 0 for default address space…
tensorflower-gardener Dec 6, 2017
9c79b0f
Optimize away NoOp in the common case where num_inputs=num_ouputs=2
benoitsteiner Dec 6, 2017
041dc33
Disable arithmetic optimizations that require shape inference, since …
tensorflower-gardener Dec 6, 2017
63dd8ff
Add DataFormatDimMap op.
Dec 6, 2017
5694fb9
Create a new Var-like object, LegacyVar, which allows access to its m…
ebrevdo Dec 6, 2017
707d586
Support non-const axis for split and concat.
Dec 6, 2017
5b3e3c9
Use `CallFrameInterface` instead of `FunctionCallFrame` in the executor.
mrry Dec 6, 2017
45fef92
Fix typo: SyncReplicaOptimizer -> SyncReplicasOptimizer
Dec 6, 2017
e2a2747
Fix session_list_devices_test
saeta Dec 6, 2017
c60e32e
[TF:XLA] Support for DT_INT64 in the VariableShape operation.
asimshankar Dec 6, 2017
2ad791c
Automated g4 rollback of changelist 178054272
tensorflower-gardener Dec 6, 2017
7ac7aa8
Add version to image retraining setup instructions.
MarkDaoust Dec 6, 2017
f79c39e
Use sparse xent to avoid softmax_v2 warning in examples/learn
MarkDaoust Dec 6, 2017
a16b137
Simplified parts of the constant folding code and made it more robust…
benoitsteiner Dec 6, 2017
dfc3d94
Switch tf.data support for sparse tensors from string-based serializa…
jsimsa Dec 6, 2017
78f199d
fix link to quickdraw data
MarkDaoust Dec 6, 2017
60af864
[tf.data] Validate that all elements of a batch have the same shape.
mrry Dec 6, 2017
93aeeba
[tf] Change TensorArray "enable_identical_shapes" to infer_shape by d…
ebrevdo Dec 6, 2017
91c75ec
Allow SparseSegmentReduction ops to have missing segment IDs.
tensorflower-gardener Dec 6, 2017
62ed393
Style cleanups in `tf.contrib.distributions.bijectors.SoftmaxCentered`.
jvdillon Dec 6, 2017
c777889
Uniquify names and prefixes in import_graph_def with C API enabled.
skye Dec 6, 2017
4c5564a
Remove vestigial test modification for TensorArray.
ebrevdo Dec 6, 2017
4567c77
Fixed a few bugs in the dependency optimizer.
benoitsteiner Dec 6, 2017
278d358
Bugfix: Cast dtype of log_det_jacobian to match log_prob in Transform…
tensorflower-gardener Dec 6, 2017
4972095
Clean up bijectors by removing _impl files.
tensorflower-gardener Dec 6, 2017
e7fc1d5
Add a convenient function to client_library_test_base.
yunxing Dec 6, 2017
2463815
Check against passing the same array in --input_arrays and --output_a…
tensorflower-gardener Dec 6, 2017
e70d562
graphviz improvements:
tensorflower-gardener Dec 6, 2017
697fa59
Simplification of propagate_array_data_types:
tensorflower-gardener Dec 6, 2017
c2e6d55
[XLA:CPU] Factor out parallel function call logic into IrFunction (so…
tensorflower-gardener Dec 6, 2017
2d8206b
Add Python checks to prevent mixing ops from different while loops.
skye Dec 6, 2017
bcdcb78
Fix tests in control_flow_ops_test.py to not access Tensor._shape
skye Dec 6, 2017
9244afc
Minor changes to help debugging graph corruptions
benoitsteiner Dec 7, 2017
80375ff
Make note of configure script's WORKSPACE-updating ability
angerson Dec 7, 2017
afdb6c8
Add missing gradient registration for ConjugateTranspose.
tensorflower-gardener Dec 7, 2017
b10855a
Add feature_column.InputLayer, an object-oriented version of input_la…
akshayka Dec 7, 2017
cddf941
[TF:XLA] Add XLA lowering of ResizeBilinear and ResizeBilinearGrad fo…
hawkinsp Dec 7, 2017
a51cc58
Handle repeated applications of functools.partial correctly.
tensorflower-gardener Dec 7, 2017
8ad62af
Split the tests so that it doesn't time out.
Dec 7, 2017
fe84061
Merge changes from github.
caisq Dec 7, 2017
757f3c5
Merge commit for internal changes
caisq Dec 7, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
407 changes: 323 additions & 84 deletions configure.py

Large diffs are not rendered by default.

8 changes: 0 additions & 8 deletions tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -364,14 +364,6 @@ config_setting(
visibility = ["//visibility:public"],
)

# Make a dummy rule that we can change "default" in select statements to.
# to disable dependencies in copybara.
config_setting(
name = "dummy_disabled_internal",
values = {"define": "with_dummy_disabled_internal=true"},
visibility = ["//visibility:public"],
)

package_group(
name = "internal",
packages = [
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/c/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,16 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts,
opts->opts.prefix = prefix;
}

void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts,
unsigned char uniquify_names) {
opts->opts.uniquify_names = uniquify_names;
}

void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts,
unsigned char uniquify_prefix) {
opts->opts.uniquify_prefix = uniquify_prefix;
}

void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts,
const char* src_name,
int src_index, TF_Output dst) {
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/c/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,20 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions(
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix(
TF_ImportGraphDefOptions* opts, const char* prefix);

// Set whether to uniquify imported operation names. If true, imported operation
// names will be modified if their name already exists in the graph. If false,
// conflicting names will be treated as an error. Note that this option has no
// effect if a prefix is set, since the prefix will guarantee all names are
// unique. Defaults to false.
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyNames(
TF_ImportGraphDefOptions* opts, unsigned char uniquify_names);

// If true, the specified prefix will be modified if it already exists as an
// operation name or prefix in the graph. If false, a conflicting prefix will be
// treated as an error. This option has no effect if no prefix is specified.
TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix(
TF_ImportGraphDefOptions* opts, unsigned char uniquify_prefix);

// Set any imported nodes with input `src_name:src_index` to have that input
// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
// `dst` references a node already existing in the graph being imported into.
Expand Down
53 changes: 53 additions & 0 deletions tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
Expand All @@ -48,6 +49,52 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs";

namespace {

bool AreAllParentsConst(const Node& n,
const gtl::FlatSet<const Node*>& runtime_const_nodes) {
if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") {
// If the current node is itself a cast-to-const, no need
// to look at the incoming edges.
return true;
}

bool all_parents_const = true;
bool atleast_one_non_control_edge = false;
for (const Edge* in : n.in_edges()) {
atleast_one_non_control_edge =
atleast_one_non_control_edge || !in->IsControlEdge();
if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) {
all_parents_const = false;
break;
}
}
return all_parents_const && atleast_one_non_control_edge;
}

void MarkGuaranteedConstants(
const Graph& graph,
const std::vector<std::pair<Node*, Node*>>& src_arg_pairs) {
gtl::FlatSet<const Node*> guaranteed_const_nodes;
std::vector<Node*> srcs;
srcs.reserve(src_arg_pairs.size());
for (const auto& src_arg : src_arg_pairs) {
srcs.push_back(src_arg.first);
}
ReverseDFSFrom(graph, srcs, /*enter=*/nullptr,
/*leave=*/[&guaranteed_const_nodes](Node* n) {
// TODO(vinuraja): Doesn't work in the presence of loops.
if (AreAllParentsConst(*n, guaranteed_const_nodes)) {
guaranteed_const_nodes.insert(n);
}
});

for (auto& src_arg : src_arg_pairs) {
if (guaranteed_const_nodes.count(src_arg.first) != 0) {
VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString();
src_arg.second->AddAttr("_is_guaranteed_constant", true);
}
}
}

// A node/slot pair.
// TODO(phawkins): is there a common definition of this?
struct NodeSlot {
Expand Down Expand Up @@ -175,9 +222,11 @@ Status Encapsulator::SplitIntoSubgraphs() {
// Map from input graph nodes to subgraph nodes.
std::unordered_map<Node*, Node*> node_images;

std::vector<std::pair<Node*, Node*>> src_arg_pairs;
// Copy all marked nodes to a subgraph. Do nothing for unmarked nodes.
for (Node* node : graph_in_->op_nodes()) {
string func_id = GetFunctionNameAttr(node);

if (func_id.empty()) continue;

Subgraph& subgraph = subgraphs_[func_id];
Expand Down Expand Up @@ -276,11 +325,13 @@ Status Encapsulator::SplitIntoSubgraphs() {
kArgOp);
builder.Attr("T", dtype);
builder.Attr("index", arg_index);

s = builder.Finalize(&arg_def);
if (!s.ok()) return s;

Node* arg = dst_subgraph.graph->AddNode(arg_def, &s);
if (!s.ok()) return s;
src_arg_pairs.push_back({edge->src(), arg});

dst_subgraph.args.push_back(arg);
}
Expand All @@ -292,6 +343,8 @@ Status Encapsulator::SplitIntoSubgraphs() {
}
}

MarkGuaranteedConstants(*graph_in_, src_arg_pairs);

for (auto& entry : subgraphs_) {
FixupSourceAndSinkEdges(entry.second.graph.get());
}
Expand Down
104 changes: 104 additions & 0 deletions tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,5 +398,109 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) {
EXPECT_EQ(expected_edges, GraphEdges(*graph));
}

const Node* FindNodeByName(const Graph& graph, const string& name) {
for (const Node* node : graph.nodes()) {
if (node->name() == name) return node;
}
return nullptr;
}

bool HasGuaranteeConstAttr(const Node& n) {
bool is_guaranteed_constant = false;
if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant",
&is_guaranteed_constant)
.ok()) {
return false;
}
return is_guaranteed_constant;
}

TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) {
Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
"/job:localhost/replica:0/task:0/cpu:0");
auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f);
auto const_guarantee_x1 =
ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2);
add1.node()->AddAttr("_encapsulate", "encapsulate1");

Graph graph_before(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&graph_before));

std::unique_ptr<Graph> graph_after;
FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", graph_before,
/*rewrite_subgraph_fn=*/
[&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation,
std::vector<int>* output_permutation,
NodeDef* call_def) {
Graph* graph = graph_ptr->get();
for (const Node* n : graph->nodes()) {
if (n->type_string() == "_Arg" &&
StringPiece(n->name()).starts_with("const")) {
++guaranteed_consts;
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
} else {
EXPECT_FALSE(HasGuaranteeConstAttr(*n));
}
}
return Status::OK();
},
/*parallel_checking=*/false,
/*reuse_existing_functions=*/false, &graph_after, &library));
EXPECT_EQ(2, guaranteed_consts);
}

TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) {
Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
"/job:localhost/replica:0/task:0/cpu:0");
auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
auto const_guarantee_x1 =
ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1);
auto const_guarantee_x2 =
ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2);
auto const_guarantee_add1 = ops::Add(root.WithOpName("const_guarantee_add1"),
const_guarantee_x1, const_guarantee_x2);
auto add2 = ops::Add(root.WithOpName("add2"), const_guarantee_x1, x2);
auto mul1 = ops::Mul(root.WithOpName("mul1"), const_guarantee_add1, add2);
mul1.node()->AddAttr("_encapsulate", "encapsulate1");

Graph graph_before(OpRegistry::Global());
TF_ASSERT_OK(root.ToGraph(&graph_before));

std::unique_ptr<Graph> graph_after;
FunctionLibraryDefinition library(OpRegistry::Global(), {});
int guaranteed_consts = 0;
TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
"_encapsulate", graph_before,
/*rewrite_subgraph_fn=*/
[&guaranteed_consts](std::unique_ptr<Graph>* graph_ptr,
std::vector<int>* input_permutation,
std::vector<int>* output_permutation,
NodeDef* call_def) {
Graph* graph = graph_ptr->get();
for (const Node* n : graph->nodes()) {
if (n->type_string() == "_Arg" &&
StringPiece(n->name()).starts_with("const")) {
++guaranteed_consts;
EXPECT_TRUE(HasGuaranteeConstAttr(*n));
} else {
EXPECT_FALSE(HasGuaranteeConstAttr(*n));
}
}
return Status::OK();
},
/*parallel_checking=*/false,
/*reuse_existing_functions=*/false, &graph_after, &library));
// Only 1 runtime const, which is const_guarantee_add1. Add2 has one const
// and another non-const, so overall non-const.
EXPECT_EQ(1, guaranteed_consts);
}

} // namespace
} // namespace tensorflow
2 changes: 0 additions & 2 deletions tensorflow/compiler/jit/kernels/xla_launch_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,13 @@ xla::StatusOr<gpu::DeviceMemoryBase> XlaAllocator::Allocate(
}
void* data =
reinterpret_cast<void*>(const_cast<char*>(t.tensor_data().data()));
TF_RET_CHECK(data != nullptr);
tensors_[data] = t;
return gpu::DeviceMemoryBase(data, size);
}

Status XlaAllocator::RegisterArgument(const Tensor* t) {
void* data =
reinterpret_cast<void*>(const_cast<char*>(t->tensor_data().data()));
TF_RET_CHECK(data != nullptr);
tensors_[data] = *t;
return Status::OK();
}
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,19 @@ tf_xla_py_test(
],
)

tf_xla_py_test(
name = "image_ops_test",
size = "small",
srcs = ["image_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:image_ops",
"//tensorflow/python:platform_test",
],
)

tf_xla_py_test(
name = "lrn_ops_test",
size = "medium",
Expand Down
51 changes: 49 additions & 2 deletions tensorflow/compiler/tests/fused_batchnorm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def testLearning(self):
def testLearningWithGradientChecker(self):
self._testLearning(True)

def testGradient(self):
def testGradientTraining(self):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
channel = 3
Expand All @@ -175,7 +175,7 @@ def testGradient(self):
var = array_ops.placeholder(np.float32, shape=scale_shape, name="var")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad(
grad, x, scale, mean, var, data_format="NHWC")
grad, x, scale, mean, var, data_format="NHWC", is_training=True)

grad_x_val, grad_scale_val, grad_offset_val = sess.run(
[grad_x, grad_scale, grad_offset], {
Expand All @@ -193,6 +193,53 @@ def testGradient(self):
self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)

def testGradientInference(self):
# TODO(b/64270657): Use gradient_checker here in addition to comparing with
# this reference implementation.
channel = 3
x_shape = [2, 2, 6, channel]
scale_shape = [channel]
grad_val = np.random.random_sample(x_shape).astype(np.float32)
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
mean_val = np.random.random_sample(scale_shape).astype(np.float32)
var_val = np.random.random_sample(scale_shape).astype(np.float32)

with self.test_session() as sess, self.test_scope():
grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad")
x = array_ops.placeholder(np.float32, shape=x_shape, name="x")
mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean")
var = array_ops.placeholder(np.float32, shape=scale_shape, name="var")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
with self.test_scope():
out = gen_nn_ops.fused_batch_norm_grad(
grad, x, scale, mean, var, data_format="NHWC", is_training=False)
grad_x, grad_scale, grad_offset, _, _ = out

ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad(
grad, x, scale, mean, var, data_format="NHWC", is_training=False)

grad_x_val, grad_scale_val, grad_offset_val, = sess.run(
[grad_x, grad_scale, grad_offset], {
grad: grad_val,
x: x_val,
mean: mean_val,
var: var_val,
scale: scale_val
})
grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run(
[ref_x, ref_scale, ref_offset], {
grad: grad_val,
x: x_val,
mean: mean_val,
var: var_val,
scale: scale_val
})

self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2)
self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)


if __name__ == "__main__":
test.main()