-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from petrux/dev
Added input module
- Loading branch information
Showing
3 changed files
with
319 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
"""Utilities for input pipelines.""" | ||
|
||
import tensorflow as tf | ||
|
||
|
||
def shuffle(tensors, | ||
capacity=32, | ||
min_after_dequeue=16, | ||
num_threads=1, | ||
dtypes=None, | ||
shapes=None, | ||
seed=None, | ||
shared_name=None, | ||
name='shuffle'): | ||
"""Wrapper around a `tf.RandomShuffleQueue` creation. | ||
Return a dequeue op that dequeues elements from `tensors` in a | ||
random order, through a `tf.RandomShuffleQueue` -- see for further | ||
documentation. | ||
Arguments: | ||
tensors: an iterable of tensors. | ||
capacity: (Optional) the capacity of the queue; default value set to 32. | ||
num_threads: (Optional) the number of threads to be used fo the queue runner; | ||
default value set to 1. | ||
min_after_dequeue: (Optional) minimum number of elements to remain in the | ||
queue after a `dequeue` or `dequeu_many` has been performend, | ||
in order to ensure better mixing of elements; default value set to 16. | ||
dtypes: (Optional) list of `DType` objects, one for each tensor in `tensors`; | ||
if not provided, will be inferred from `tensors`. | ||
shapes: (Optional) list of shapes, one for each tensor in `tensors`. | ||
seed: (Optional) seed for random shuffling. | ||
shared_name: (Optional) If non-empty, this queue will be shared under | ||
the given name across multiple sessions. | ||
name: Optional name scope for the ops. | ||
Returns: | ||
The tuple of tensors that was randomly dequeued from `tensors`. | ||
""" | ||
|
||
tensors = list(tensors) | ||
with tf.name_scope(name, tensors): | ||
dtypes = dtypes or list([t.dtype for t in tensors]) | ||
queue = tf.RandomShuffleQueue( | ||
seed=seed, | ||
shared_name=shared_name, | ||
name='random_shuffle_queue', | ||
dtypes=dtypes, | ||
shapes=shapes, | ||
capacity=capacity, | ||
min_after_dequeue=min_after_dequeue) | ||
enqueue = queue.enqueue(tensors) | ||
runner = tf.train.QueueRunner(queue, [enqueue] * num_threads) | ||
tf.train.add_queue_runner(runner) | ||
dequeue = queue.dequeue() | ||
return dequeue | ||
|
||
|
||
def shuffle_batch(tensors, | ||
batch_size, | ||
capacity=32, | ||
num_threads=1, | ||
min_after_dequeue=16, | ||
dtypes=None, | ||
shapes=None, | ||
seed=None, | ||
enqueue_many=False, | ||
dynamic_pad=True, | ||
allow_smaller_final_batch=False, | ||
shared_name=None, | ||
name='shuffle_batch'): | ||
"""Create shuffled and padded batches of tensors in `tensors`. | ||
Dequeue elements from `tensors` shuffling, batching and dynamically | ||
padding them. First a `tf.RandomShuffleQueue` is created and fed with | ||
`tensors` (using the `dket.input.shuffle` function); the dequeued tensors | ||
shapes are then set and fed into a `tf.train.batch` function that provides | ||
batching and dynamic padding. | ||
Arguments: | ||
tensors: an iterable of tensors. | ||
batch_size: an `int` representing th batch size. | ||
capacity: (Optional) the capacity of the queues; default value set to 32. | ||
num_threads: (Optional) the number of threads to be used fo the queue runner; | ||
default value set to 1. | ||
min_after_dequeue: (Optional) minimum number of elements to remain in the | ||
shuffling queue after a `dequeue` or `dequeu_many` has been performend, | ||
in order to ensure better mixing of elements; default value set to 16. | ||
dtypes: (Optional) list of `DType` objects, one for each tensor in `tensors`; | ||
if not provided, will be inferred from `tensors`. | ||
shapes: (Optional) list of shapes, one for each tensor in `tensors`. | ||
seed: (Optional) seed for random shuffling. | ||
enqueue_many: Whether each tensor in tensors is a single example. | ||
dynamic_pad: Boolean. Allow variable dimensions in input shapes. | ||
The given dimensions are padded upon dequeue so that tensors within | ||
a batch have the same shapes. | ||
allow_smaller_final_batch: (Optional) Boolean. If True, allow the final | ||
batch to be smaller if there are insufficient items left in the queue. | ||
shared_name: if set, the queues will be shared under the given name | ||
across different sessions. | ||
name: scope name for the given ops. | ||
Returns: | ||
A batch of tensors from `tensors`, shuffled and padded. | ||
""" | ||
|
||
tensors = list(tensors) | ||
with tf.name_scope(name, tensors): | ||
dtypes = dtypes or list([t.dtype for t in tensors]) | ||
shapes = shapes or list([t.get_shape() for t in tensors]) | ||
inputs = shuffle(tensors, | ||
seed=seed, | ||
dtypes=dtypes, | ||
capacity=capacity, | ||
num_threads=num_threads, | ||
min_after_dequeue=min_after_dequeue, | ||
shared_name=shared_name, | ||
name='shuffle') | ||
|
||
# fix the shapes | ||
for tensor, shape in zip(inputs, shapes): | ||
tensor.set_shape(shape) | ||
|
||
minibatch = tf.train.batch( | ||
tensors=inputs, | ||
batch_size=batch_size, | ||
num_threads=num_threads, | ||
capacity=capacity, | ||
dynamic_pad=dynamic_pad, | ||
allow_smaller_final_batch=allow_smaller_final_batch, | ||
shared_name=shared_name, | ||
enqueue_many=enqueue_many, | ||
name='batch') | ||
return minibatch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
"""Test module for the `dket.input` module.""" | ||
|
||
import datetime | ||
import os | ||
|
||
import tensorflow as tf | ||
|
||
from liteflow import input as linput | ||
|
||
|
||
def _timestamp(): | ||
frmt = "%Y-%m-%d--%H-%M-%S.%f" | ||
stamp = datetime.datetime.now().strftime(frmt) | ||
print 'STAMP: ' + stamp | ||
return stamp | ||
|
||
|
||
def _encode(key, vector): | ||
example = tf.train.Example( | ||
features=tf.train.Features( | ||
feature={ | ||
'key': tf.train.Feature( | ||
int64_list=tf.train.Int64List( | ||
value=[key])), | ||
'vector': tf.train.Feature( | ||
int64_list=tf.train.Int64List( | ||
value=vector))})) | ||
return example | ||
|
||
|
||
def _decode(message): | ||
features = { | ||
'key': tf.FixedLenFeature([], tf.int64), | ||
'vector': tf.VarLenFeature(tf.int64) | ||
} | ||
parsed = tf.parse_single_example( | ||
serialized=message, | ||
features=features) | ||
key = parsed['key'] | ||
vector = tf.sparse_tensor_to_dense(parsed['vector']) | ||
return key, vector | ||
|
||
|
||
def _save_records(fpath, *records): | ||
with tf.python_io.TFRecordWriter(fpath) as fout: | ||
for record in records: | ||
fout.write(record.SerializeToString()) | ||
|
||
|
||
def _read(fpath, num_epochs=None, shuffle=True): | ||
queue = tf.train.string_input_producer( | ||
string_tensor=[fpath], | ||
num_epochs=num_epochs, | ||
shuffle=shuffle) | ||
reader = tf.TFRecordReader() | ||
_, value = reader.read(queue) | ||
key, vector = _decode(value) | ||
return key, vector | ||
|
||
|
||
class ShuffleBatchTest(tf.test.TestCase): | ||
""".""" | ||
|
||
TMP_DIR = '/tmp' | ||
|
||
|
||
def test_base(self): | ||
""".""" | ||
|
||
# NOTA BENE: all the test depends on the value | ||
# used for the random seed, so if you change it | ||
# you HAVE TO re run the generation and check | ||
# manually in order to update the expected results. | ||
# Bottom line: DON'T CHANGE THE RANDOM SEED. | ||
tf.reset_default_graph() | ||
tf.set_random_seed(23) | ||
|
||
filename = os.path.join(self.TMP_DIR, _timestamp() + '.rio') | ||
data = [ | ||
(1, [1]), | ||
(2, [2, 2]), | ||
(3, [3, 3, 3]), | ||
(4, [4, 4, 4, 4]), | ||
(5, [5, 5, 5, 5, 5]), | ||
(6, [6, 6, 6, 6, 6, 6])] | ||
examples = [_encode(k, v) for k, v in data] | ||
_save_records(filename, *examples) | ||
tensors = _read(filename, num_epochs=4, shuffle=False) | ||
|
||
batch_size = 3 | ||
batch = linput.shuffle_batch(tensors, batch_size) | ||
|
||
actual_keys = [] | ||
expected_keys = [2, 5, 6, 1, 3, 6, 3, 4, 5, 1, 4, 1, 5, 6, 3, 2, 2, 4, 6, 1, 4, 5, 3, 2] | ||
|
||
with tf.Session() as sess: | ||
sess.run(tf.local_variables_initializer()) | ||
sess.run(tf.global_variables_initializer()) | ||
coord = tf.train.Coordinator() | ||
threads = tf.train.start_queue_runners(coord=coord) | ||
try: | ||
while True: | ||
bkey, bvector = sess.run(batch) | ||
bkey = bkey.tolist() | ||
length = max(bkey) | ||
self.assertEqual((batch_size, length), bvector.shape) | ||
actual_keys = actual_keys + bkey | ||
|
||
except tf.errors.OutOfRangeError as ex: | ||
coord.request_stop(ex=ex) | ||
finally: | ||
coord.request_stop() | ||
coord.join(threads) | ||
|
||
self.assertEquals(actual_keys, expected_keys) | ||
os.remove(filename) | ||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters