Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 161686867 #11461

Merged
merged 17 commits into from Jul 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
baba8fc
Further BUILD cleanup in tensorflow/contrib/...
tensorflower-gardener Jul 12, 2017
272dee9
Java: Make maven release script runnable from any directory.
asimshankar Jul 12, 2017
b66ecbe
Add warning to Experiment so that users know to set environment=cloud…
tensorflower-gardener Jul 12, 2017
8d3913a
Mark SessionBundle targets as no longer supported.
sukritiramesh Jul 12, 2017
0720891
Use graph._c_graph instead of global USE_C_API
tensorflower-gardener Jul 12, 2017
1d1f99d
Make ClientLibraryTestBase able to test all layouts for all input arg…
tensorflower-gardener Jul 12, 2017
54137ff
Emit the correct LLVM opcode for kMinimum and kMaximum over unsigned
tensorflower-gardener Jul 12, 2017
a05c3ce
Add scratch allocator option for 1D, 2D, 3D, and batched cufft plan c…
Jul 12, 2017
e8c354f
[XLA] Propagate Status from DeviceMemoryAllocator on failure.
tensorflower-gardener Jul 12, 2017
bf1461a
Add platform bridge for grpc response reader.
Jul 12, 2017
8f66dd2
Add checks to TensorForest that help with debugging when labels are w…
tensorflower-gardener Jul 12, 2017
eb1fe50
[TF:XLA] Add initial implementation of the Stack operators to the TF/…
hawkinsp Jul 12, 2017
786bf6c
Refactor some of TensorForest V4 to make the tree model valid during …
tensorflower-gardener Jul 12, 2017
576c7b1
Automated g4 rollback of changelist 161218103
tensorflower-gardener Jul 12, 2017
de546d0
BUILD cleanup in tensorflow/compiler/...
tensorflower-gardener Jul 12, 2017
2195db6
Remove unused flag: xla_hlo_graph_for_compute_constant
tensorflower-gardener Jul 12, 2017
a891c24
Merge commit for internal changes
Jul 12, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow/compiler/tests/BUILD
Expand Up @@ -42,6 +42,7 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:random_seed",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)

Expand Down Expand Up @@ -366,6 +367,19 @@ tf_xla_py_test(
],
)

tf_xla_py_test(
name = "stack_ops_test",
size = "small",
srcs = ["stack_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)

tf_xla_py_test(
name = "tensor_array_ops_test",
size = "small",
Expand Down
104 changes: 104 additions & 0 deletions tensorflow/compiler/tests/stack_ops_test.py
@@ -0,0 +1,104 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.stack_ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.platform import test


class StackOpTest(XLATestCase):

def testStackPushPop(self):
with self.test_session(), self.test_scope():
size = array_ops.placeholder(dtypes.int32)
v = array_ops.placeholder(dtypes.float32)
h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo")
c = gen_data_flow_ops._stack_push_v2(h, v)
with ops.control_dependencies([c]):
c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32)
self.assertAllClose([[4.0, 5.0]], c1.eval({size: 5, v: [[4.0, 5.0]]}))

def testStackPushPopSwap(self):
with self.test_session(), self.test_scope():
a = np.arange(2000)
x = array_ops.placeholder(dtypes.float32)
h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
c = gen_data_flow_ops._stack_push_v2(h, x, swap_memory=True)
with ops.control_dependencies([c]):
c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32)
self.assertAllClose(a, c1.eval({x: a}))

def testMultiStack(self):
with self.test_session(), self.test_scope():
v = array_ops.placeholder(dtypes.float32)
h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
c1 = gen_data_flow_ops._stack_push_v2(h1, v)
with ops.control_dependencies([c1]):
c1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32)
h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="bar")
c2 = gen_data_flow_ops._stack_push_v2(h2, 5.0)
with ops.control_dependencies([c2]):
c2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32)
r = c1 + c2
self.assertAllClose(9.0, r.eval({v: 4.0}))

def testSameNameStacks(self):
"""Different stacks with the same name do not interfere."""
with self.test_session() as sess, self.test_scope():
v1 = array_ops.placeholder(dtypes.float32)
v2 = array_ops.placeholder(dtypes.float32)
h1 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
h2 = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")

c1 = gen_data_flow_ops._stack_push_v2(h1, v1)
with ops.control_dependencies([c1]):
c2 = gen_data_flow_ops._stack_push_v2(h2, v2)
with ops.control_dependencies([c2]):
pop1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32)
pop2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32)

out1, out2 = sess.run([pop1, pop2], {v1: 4.0, v2: 5.0})
self.assertAllClose(out1, 4.0)
self.assertAllClose(out2, 5.0)

def testCloseStack(self):
with self.test_session() as sess, self.test_scope():
size = array_ops.placeholder(dtypes.int32)
h = gen_data_flow_ops._stack_v2(size, dtypes.float32, stack_name="foo")
c1 = gen_data_flow_ops._stack_close_v2(h)
sess.run(c1, {size: 5})

def testPushCloseStack(self):
with self.test_session() as sess, self.test_scope():
v = array_ops.placeholder(dtypes.float32)
h = gen_data_flow_ops._stack_v2(5, dtypes.float32, stack_name="foo")
c = gen_data_flow_ops._stack_push_v2(h, v)
with ops.control_dependencies([c]):
c1 = gen_data_flow_ops._stack_close_v2(h)
sess.run(c1, {v: [[4.0, 5.0]]})


if __name__ == "__main__":
test.main()
1 change: 1 addition & 0 deletions tensorflow/compiler/tf2xla/const_analysis.cc
Expand Up @@ -81,6 +81,7 @@ Status BackwardsConstAnalysis(const Graph& g,
{"Split", "split_dim"},
{"SplitV", "split_dim"},
{"SplitV", "size_splits"},
{"StackV2", "max_size"},
{"StridedSlice", "begin"},
{"StridedSlice", "end"},
{"StridedSlice", "strides"},
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/tf2xla/kernels/BUILD
Expand Up @@ -54,6 +54,7 @@ tf_kernel_library(
"softmax_op.cc",
"spacetobatch_op.cc",
"split_op.cc",
"stack_ops.cc",
"strided_slice_op.cc",
"tensor_array_ops.cc",
"tile_ops.cc",
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/arg_op.cc
Expand Up @@ -60,6 +60,9 @@ class ArgOp : public XlaOpKernel {
case XlaCompiler::Argument::kTensorArray:
kind = XlaResource::kTensorArray;
break;
case XlaCompiler::Argument::kStack:
kind = XlaResource::kStack;
break;
default:
CHECK(false);
}
Expand Down