Skip to content

Commit

Permalink
[TF:XLA] Add partial implementation of tf.FIFOQueue for XLA devices (…
Browse files Browse the repository at this point in the history
…e.g., TPU).

The idea is to have a host-side queue of device tensors.

Operators dequeue_many, enqueue_many, and dequeue_up_to are not yet implemented because they require splitting/concatenating tensors, which will require calling into a compiled XLA compilation.

Refactor queue operator implementations into libraries separate from the kernel registrations.

Add support for ResourceOpKernels that are placed on non-CPU devices. Add support for allocating host-memory tensors during OpKernel construction.

PiperOrigin-RevId: 202590292
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Jun 29, 2018
1 parent f04400f commit 5083915
Show file tree
Hide file tree
Showing 13 changed files with 883 additions and 466 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/jit/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,11 @@ cc_library(
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:fifo_queue",
"//tensorflow/core/kernels:identity_n_op",
"//tensorflow/core/kernels:identity_op",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:queue_op",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:shape_ops",
Expand Down
29 changes: 28 additions & 1 deletion tensorflow/compiler/jit/xla_device_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ limitations under the License.
#include "tensorflow/core/kernels/cast_op.h"
#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h"
#include "tensorflow/core/kernels/fifo_queue.h"
#include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/kernels/identity_op.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/kernels/queue_op.h"
#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/sendrecv_ops.h"
#include "tensorflow/core/kernels/shape_ops.h"
Expand Down Expand Up @@ -145,7 +147,32 @@ class XlaAssignVariableOp : public AsyncOpKernel {
.Device(DEVICE) \
.HostMemory("input") \
.HostMemory("output"), \
LoopCondOp);
LoopCondOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("QueueEnqueueV2").Device(DEVICE).HostMemory("handle"), EnqueueOp); \
REGISTER_KERNEL_BUILDER( \
Name("QueueDequeueV2").Device(DEVICE).HostMemory("handle"), DequeueOp); \
REGISTER_KERNEL_BUILDER( \
Name("QueueCloseV2").Device(DEVICE).HostMemory("handle"), QueueCloseOp); \
REGISTER_KERNEL_BUILDER(Name("QueueSizeV2") \
.Device(DEVICE) \
.HostMemory("size") \
.HostMemory("handle"), \
QueueSizeOp); \
REGISTER_KERNEL_BUILDER( \
Name("QueueIsClosedV2").Device(DEVICE).HostMemory("handle"), \
QueueIsClosedOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);

// TODO(phawkins): currently we do not register the QueueEnqueueMany,
// QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
// and write the tensors they access in order to concatenate them into a batch.
// We would need either to call out to an XLA computation to perform the
// concatenation, or we would need to refactor those kernels so the splitting
// or merging is done in a separate operator that can be compiled.

} // namespace tensorflow

Expand Down
14 changes: 14 additions & 0 deletions tensorflow/compiler/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,20 @@ tf_xla_py_test(
],
)

tf_xla_py_test(
name = "fifo_queue_test",
size = "medium",
srcs = ["fifo_queue_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:extra_py_tests_deps",
"//tensorflow/python:framework",
"//tensorflow/python:platform_test",
],
)

tf_xla_py_test(
name = "fft_test",
size = "medium",
Expand Down
201 changes: 201 additions & 0 deletions tensorflow/compiler/tests/fifo_queue_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.data_flow_ops.FIFOQueue."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time

from six.moves import xrange # pylint: disable=redefined-builtin

from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes as dtypes_lib
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.platform import test


class FIFOQueueTest(xla_test.XLATestCase):

def testEnqueue(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
enqueue_op.run()

def testEnqueueWithShape(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32, shapes=(3, 2))
enqueue_correct_op = q.enqueue(([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],))
enqueue_correct_op.run()
with self.assertRaises(ValueError):
q.enqueue(([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],))
self.assertEqual(1, q.size().eval())

def testMultipleDequeues(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
self.evaluate(q.enqueue([1]))
self.evaluate(q.enqueue([2]))
self.evaluate(q.enqueue([3]))
a, b, c = self.evaluate([q.dequeue(), q.dequeue(), q.dequeue()])
self.assertAllEqual(set([1, 2, 3]), set([a, b, c]))

def testQueuesDontShare(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
self.evaluate(q.enqueue(1))
q2 = data_flow_ops.FIFOQueue(10, [dtypes_lib.int32], shapes=[()])
self.evaluate(q2.enqueue(2))
self.assertAllEqual(self.evaluate(q2.dequeue()), 2)
self.assertAllEqual(self.evaluate(q.dequeue()), 1)

def testEnqueueDictWithoutNames(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
with self.assertRaisesRegexp(ValueError, "must have names"):
q.enqueue({"a": 12.0})

def testParallelEnqueue(self):
with self.test_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
dequeued_t = q.dequeue()

# Run one producer thread for each element in elems.
def enqueue(enqueue_op):
sess.run(enqueue_op)

threads = [
self.checkedThread(target=enqueue, args=(e,)) for e in enqueue_ops
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

# Dequeue every element using a single thread.
results = []
for _ in xrange(len(elems)):
results.append(dequeued_t.eval())
self.assertItemsEqual(elems, results)

def testParallelDequeue(self):
with self.test_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
dequeued_t = q.dequeue()

# Enqueue every element using a single thread.
for enqueue_op in enqueue_ops:
enqueue_op.run()

# Run one consumer thread for each element in elems.
results = []

def dequeue():
results.append(sess.run(dequeued_t))

threads = [self.checkedThread(target=dequeue) for _ in enqueue_ops]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
self.assertItemsEqual(elems, results)

def testDequeue(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
dequeued_t = q.dequeue()

for enqueue_op in enqueue_ops:
enqueue_op.run()

for i in xrange(len(elems)):
vals = dequeued_t.eval()
self.assertEqual([elems[i]], vals)

def testEnqueueAndBlockingDequeue(self):
with self.test_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(3, dtypes_lib.float32)
elems = [10.0, 20.0, 30.0]
enqueue_ops = [q.enqueue((x,)) for x in elems]
dequeued_t = q.dequeue()

def enqueue():
# The enqueue_ops should run after the dequeue op has blocked.
# TODO(mrry): Figure out how to do this without sleeping.
time.sleep(0.1)
for enqueue_op in enqueue_ops:
sess.run(enqueue_op)

results = []

def dequeue():
for _ in xrange(len(elems)):
results.append(sess.run(dequeued_t))

enqueue_thread = self.checkedThread(target=enqueue)
dequeue_thread = self.checkedThread(target=dequeue)
enqueue_thread.start()
dequeue_thread.start()
enqueue_thread.join()
dequeue_thread.join()

for elem, result in zip(elems, results):
self.assertEqual([elem], result)

def testMultiEnqueueAndDequeue(self):
with self.test_session() as sess, self.test_scope():
q = data_flow_ops.FIFOQueue(10, (dtypes_lib.int32, dtypes_lib.float32))
elems = [(5, 10.0), (10, 20.0), (15, 30.0)]
enqueue_ops = [q.enqueue((x, y)) for x, y in elems]
dequeued_t = q.dequeue()

for enqueue_op in enqueue_ops:
enqueue_op.run()

for i in xrange(len(elems)):
x_val, y_val = sess.run(dequeued_t)
x, y = elems[i]
self.assertEqual([x], x_val)
self.assertEqual([y], y_val)

def testQueueSizeEmpty(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
self.assertEqual([0], q.size().eval())

def testQueueSizeAfterEnqueueAndDequeue(self):
with self.test_session(), self.test_scope():
q = data_flow_ops.FIFOQueue(10, dtypes_lib.float32)
enqueue_op = q.enqueue((10.0,))
dequeued_t = q.dequeue()
size = q.size()
self.assertEqual([], size.get_shape())

enqueue_op.run()
self.assertEqual(1, size.eval())
dequeued_t.op.run()
self.assertEqual(0, size.eval())


if __name__ == "__main__":
test.main()
1 change: 1 addition & 0 deletions tensorflow/contrib/makefile/tf_op_files.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ tensorflow/core/kernels/reduction_ops_common.cc
tensorflow/core/kernels/reduction_ops_any.cc
tensorflow/core/kernels/reduction_ops_all.cc
tensorflow/core/kernels/roll_op.cc
tensorflow/core/kernels/queue_op.cc
tensorflow/core/kernels/queue_ops.cc
tensorflow/core/kernels/queue_base.cc
tensorflow/core/kernels/pooling_ops_common.cc
Expand Down
25 changes: 18 additions & 7 deletions tensorflow/core/framework/resource_op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,15 @@ template <typename T>
class ResourceOpKernel : public OpKernel {
public:
explicit ResourceOpKernel(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context,
context->allocate_persistent(DT_STRING, TensorShape({2}),
&handle_, nullptr));
has_resource_type_ = (context->output_type(0) == DT_RESOURCE);
if (!has_resource_type_) {
// The resource variant of the op may be placed on non-CPU devices, but
// this allocation is always on the host. Fortunately we don't need it in
// the resource case.
OP_REQUIRES_OK(context,
context->allocate_persistent(DT_STRING, TensorShape({2}),
&handle_, nullptr));
}
}

// The resource is deleted from the resource manager only when it is private
Expand Down Expand Up @@ -89,12 +95,14 @@ class ResourceOpKernel : public OpKernel {
return;
}

auto h = handle_.AccessTensor(context)->template flat<string>();
h(0) = cinfo_.container();
h(1) = cinfo_.name();
if (!has_resource_type_) {
auto h = handle_.AccessTensor(context)->template flat<string>();
h(0) = cinfo_.container();
h(1) = cinfo_.name();
}
resource_ = resource;
}
if (context->expected_output_dtype(0) == DT_RESOURCE) {
if (has_resource_type_) {
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
context, 0, cinfo_.container(), cinfo_.name(),
MakeTypeIndex<T>()));
Expand Down Expand Up @@ -122,6 +130,9 @@ class ResourceOpKernel : public OpKernel {
virtual Status VerifyResource(T* resource) { return Status::OK(); }

PersistentTensor handle_ GUARDED_BY(mu_);

// Is the output of the operator of type DT_RESOURCE?
bool has_resource_type_;
};
} // namespace tensorflow

Expand Down
5 changes: 4 additions & 1 deletion tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ cc_library(

cc_library(
name = "queue_op",
srcs = ["queue_op.cc"],
hdrs = ["queue_op.h"],
deps = [
":queue_base",
Expand Down Expand Up @@ -1885,9 +1886,10 @@ cc_library(
name = "fifo_queue",
srcs = ["fifo_queue.cc"],
hdrs = ["fifo_queue.h"],
visibility = ["//visibility:private"],
visibility = [":friends"],
deps = [
":queue_base",
":queue_op",
":typed_queue",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
Expand Down Expand Up @@ -5076,6 +5078,7 @@ filegroup(
"padding_fifo_queue.cc",
"padding_fifo_queue_op.cc",
"queue_base.cc",
"queue_op.cc",
"queue_ops.cc",
"random_op.cc",
"reduction_ops_all.cc",
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/core/kernels/fifo_queue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,4 +366,19 @@ Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
return Status::OK();
}

// Defines a FIFOQueueOp, which produces a Queue (specifically, one
// backed by FIFOQueue) that persists across different graph
// executions, and sessions. Running this op produces a single-element
// tensor of handles to Queues in the corresponding device.
FIFOQueueOp::FIFOQueueOp(OpKernelConstruction* context)
: TypedQueueOp(context) {
OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_));
}

Status FIFOQueueOp::CreateResource(QueueInterface** ret) {
FIFOQueue* queue = new FIFOQueue(capacity_, component_types_,
component_shapes_, cinfo_.name());
return CreateTypedQueue(queue, ret);
}

} // namespace tensorflow
Loading

0 comments on commit 5083915

Please sign in to comment.