In [7]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm

In [11]:
writer1 = tf.python_io.TFRecordWriter('../data/seq_test1.tfrecord')
writer2 = tf.python_io.TFRecordWriter('../data/seq_test2.tfrecord')

# 非序列数据
labels = [1, 2, 3, 4, 5,
          1, 2, 3, 4, 5, 
          1, 2, 3, 4, 5, 
          1, 2, 3, 4, 5, 
          1, 2, 3, 4, 5, 
          1, 2, 3, 4, 5]
# 长度不固定的序列
frames = [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],
          [1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],
          [1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],
          [1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],
          [1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],
          [1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5]]

writer = writer1
for i in tqdm(range(len(labels))):
    if i == len(labels) / 2:
        writer = writer2
        print('There are %d samples written into writer1' % i)
    label = labels[i]
    frame = frames[i]
    # 非序列化
    label_feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
    # 序列化
    frame_feature = [tf.train.Feature(int64_list=tf.train.Int64List(value=[frame_])) for frame_ in frame]
    
    seq_example = tf.train.SequenceExample(context=tf.train.Features(feature={'label': label_feature}),# 放置非序列化
                                          feature_lists=tf.train.FeatureLists(feature_list={'frame': tf.train.FeatureList(feature=frame_feature)})# 放置变长序列
                                          )
    serialized = seq_example.SerializeToString()
    writer.write(serialized)
print('finished')
writer1.close()
writer2.close()

100%|██████████| 30/30 [00:00<00:00, 8509.44it/s]

There are 15 samples written into writer1
finished





In [4]:
import tensorflow as tf
import numpy as np

tfrecord_filename = ['../data/seq_test1.tfrecord', '../data/seq_test2.tfrecord']
filename_queue = tf.train.string_input_producer(tfrecord_filename, shuffle=True, capacity=2)

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)

context_feature = {'label': tf.FixedLenFeature([], dtype=tf.int64)}
sequence_feature = {'frame': tf.FixedLenSequenceFeature([], dtype=tf.int64)}
context_parsed, sequence_parsed = tf.parse_single_sequence_example(serialized=serialized_example,
                                                                   context_features=context_feature,
                                                                   sequence_features=sequence_feature)

labels = context_parsed['label']
frames = sequence_parsed['frame']
print('labels', labels)
print('frames', frames)

labels Tensor("ParseSingleSequenceExample_1/ParseSingleSequenceExample:0", shape=(), dtype=int64)
frames Tensor("ParseSingleSequenceExample_1/ParseSingleSequenceExample:1", shape=(?,), dtype=int64)


In [5]:
label_batch, frame_batch = tf.train.batch(
    [labels, frames],
    batch_size=10,
    num_threads=4,
    capacity=500,
    dynamic_pad=True,
    allow_smaller_final_batch=False)
print(label_batch, frame_batch)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
tf.train.start_queue_runners(sess=sess)
for i in range(2):
    _label_batch, _frame_batch = sess.run([label_batch, frame_batch])
    print('i', i)
    print(_label_batch)
    print(_frame_batch)

Tensor("batch:0", shape=(10,), dtype=int64) Tensor("batch:1", shape=(10, ?), dtype=int64)
i 0
[1 3 2 4 1 5 2 4 3 5]
[[1 0 0 0 0]
 [3 3 3 0 0]
 [2 2 0 0 0]
 [4 4 4 4 0]
 [1 0 0 0 0]
 [5 5 5 5 5]
 [2 2 0 0 0]
 [4 4 4 4 0]
 [3 3 3 0 0]
 [5 5 5 5 5]]
i 1
[1 2 3 4 5 1 2 3 4 5]
[[1 0 0 0 0]
 [2 2 0 0 0]
 [3 3 3 0 0]
 [4 4 4 4 0]
 [5 5 5 5 5]
 [1 0 0 0 0]
 [2 2 0 0 0]
 [3 3 3 0 0]
 [4 4 4 4 0]
 [5 5 5 5 5]]


In [7]:
import tensorflow as tf
import math

QUEUE_CAPACITY = 100
SHUFFLE_MIN_AFTER_DEQUEUE = QUEUE_CAPACITY // 5

def _shuffle_inputs(input_tensors, capacity, min_after_dequeue, num_threads):
    """Shuffles tensors in `input_tensors`, maintaining grouping."""
    shuffle_queue = tf.RandomShuffleQueue(
        capacity, min_after_dequeue, dtypes=[t.dtype for t in input_tensors])
    enqueue_op = shuffle_queue.enqueue(input_tensors)
    runner = tf.train.QueueRunner(shuffle_queue, [enqueue_op] * num_threads)
    tf.train.add_queue_runner(runner)

    output_tensors = shuffle_queue.dequeue()

    for i in range(len(input_tensors)):
        output_tensors[i].set_shape(input_tensors[i].shape)

    return output_tensors

def get_padded_batch(file_list, batch_size, num_enqueuing_threads=4, shuffle=False):
    """Reads batches of SequenceExamples from TFRecords and pads them.

    Can deal with variable length SequenceExamples by padding each batch to the
    length of the longest sequence with zeros.

    Args:
      file_list: A list of paths to TFRecord files containing SequenceExamples.
      batch_size: The number of SequenceExamples to include in each batch.
      num_enqueuing_threads: The number of threads to use for enqueuing
          SequenceExamples.
      shuffle: Whether to shuffle the batches.

    Returns:
      labels: A tensor of shape [batch_size] of int64s.
      frames: A tensor of shape [batch_size, num_steps] of floats32s. note that
          num_steps is the max time_step of all the tensors.
    Raises:
      ValueError: If `shuffle` is True and `num_enqueuing_threads` is less than 2.
    """
    file_queue = tf.train.string_input_producer(file_list)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(file_queue)

    context_features = {
        "label": tf.FixedLenFeature([], dtype=tf.int64)
    }
    sequence_features = {
        "frame": tf.FixedLenSequenceFeature([], dtype=tf.int64)
    }

    context_parsed, sequence_parsed = tf.parse_single_sequence_example(
        serialized=serialized_example,
        context_features=context_features,
        sequence_features=sequence_features
    )

    labels = context_parsed['label']
    frames = sequence_parsed['frame']
    input_tensors = [labels, frames]

    if shuffle:
        if num_enqueuing_threads < 2:
            raise ValueError(
                '`num_enqueuing_threads` must be at least 2 when shuffling.')
        shuffle_threads = int(math.ceil(num_enqueuing_threads) / 2.)

        # Since there may be fewer records than SHUFFLE_MIN_AFTER_DEQUEUE, take the
        # minimum of that number and the number of records.
        min_after_dequeue = count_records(
            file_list, stop_at=SHUFFLE_MIN_AFTER_DEQUEUE)
        input_tensors = _shuffle_inputs(
            input_tensors, capacity=QUEUE_CAPACITY,
            min_after_dequeue=min_after_dequeue,
            num_threads=shuffle_threads)

        num_enqueuing_threads -= shuffle_threads

    tf.logging.info(input_tensors)
    return tf.train.batch(
        input_tensors,
        batch_size=batch_size,
        capacity=QUEUE_CAPACITY,
        num_threads=num_enqueuing_threads,
        dynamic_pad=True,
        allow_smaller_final_batch=False)
        
def count_records(file_list, stop_at=None):
    """Counts number of records in files from `file_list` up to `stop_at`.

    Args:
      file_list: List of TFRecord files to count records in.
      stop_at: Optional number of records to stop counting at.

    Returns:
      Integer number of records in files from `file_list` up to `stop_at`.
    """
    num_records = 0
    for tfrecord_file in file_list:
        tf.logging.info('Counting records in %s.', tfrecord_file)
        for _ in tf.python_io.tf_record_iterator(tfrecord_file):
            num_records += 1
            if stop_at and num_records >= stop_at:
                tf.logging.info('Number of records is at least %d.', num_records)
                return num_records
    tf.logging.info('Total records: %d', num_records)
    return num_records        

In [9]:
tfrecord_filename = ['../data/seq_test1.tfrecord', '../data/seq_test2.tfrecord']
label_batch, frame_batch = get_padded_batch(tfrecord_filename, 10, shuffle=True)
print(label_batch, frame_batch)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
tf.train.start_queue_runners(sess=sess)
for i in range(2):
    _label_batch, _frame_batch = sess.run([label_batch, frame_batch])
    print('i', i)
    print(_label_batch)
    print(_frame_batch)

INFO:tensorflow:Counting records in ../data/seq_test1.tfrecord.
INFO:tensorflow:Counting records in ../data/seq_test2.tfrecord.
INFO:tensorflow:Number of records is at least 20.
INFO:tensorflow:[<tf.Tensor 'random_shuffle_queue_Dequeue:0' shape=() dtype=int64>, <tf.Tensor 'random_shuffle_queue_Dequeue:1' shape=(?,) dtype=int64>]
Tensor("batch:0", shape=(10,), dtype=int64) Tensor("batch:1", shape=(10, ?), dtype=int64)
i 0
[1 1 5 4 4 3 3 2 4 2]
[[1 0 0 0 0]
 [1 0 0 0 0]
 [5 5 5 5 5]
 [4 4 4 4 0]
 [4 4 4 4 0]
 [3 3 3 0 0]
 [3 3 3 0 0]
 [2 2 0 0 0]
 [4 4 4 4 0]
 [2 2 0 0 0]]
i 1
[2 4 4 5 1 5 1 3 3 5]
[[2 2 0 0 0]
 [4 4 4 4 0]
 [4 4 4 4 0]
 [5 5 5 5 5]
 [1 0 0 0 0]
 [5 5 5 5 5]
 [1 0 0 0 0]
 [3 3 3 0 0]
 [3 3 3 0 0]
 [5 5 5 5 5]]


In [4]:
tfrecord_file = '../data/seq_test1.tfrecord'
num = 0
for _ in tf.python_io.tf_record_iterator(tfrecord_file):
    num += 1
print(num)

  from ._conv import register_converters as _register_converters


15


In [5]:
tfrecord_file = '../data/seq_test2.tfrecord'
num = 0
for _ in tf.python_io.tf_record_iterator(tfrecord_file):
    num += 1
print(num)

15
