Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)

In [1]:
import tensorflow as tf
import numpy as np
import os
from tqdm import tqdm

In [2]:
BATCH_SIZE = 3
VOCAB_SIZE = 20000
TF_RECORD_PATH = './imdb_train_var_len.tfrecord'

(X_train, y_train), (_, _) = tf.keras.datasets.imdb.load_data(num_words=VOCAB_SIZE)

In [3]:
if not os.path.isfile(TF_RECORD_PATH):
    writer = tf.python_io.TFRecordWriter(TF_RECORD_PATH)
    for sent, label in tqdm(zip(X_train, y_train), total=len(X_train), ncols=70):
        example = tf.train.Example(
            features = tf.train.Features(
                 feature = {
                   'sent': tf.train.Feature(
                       int64_list=tf.train.Int64List(value=sent)),
                   'label': tf.train.Feature(
                       int64_list=tf.train.Int64List(value=[label])),
                   }))
        serialized = example.SerializeToString()
        writer.write(serialized)
    writer.close()

100%|██████████████████████████| 25000/25000 [00:37<00:00, 670.52it/s]


In [4]:
def _parse_fn(example_proto):
    parsed_feats = tf.parse_single_example(
        example_proto,
        features={
            'sent': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),
            'label': tf.FixedLenFeature([], tf.int64)
        })
    return parsed_feats

dataset = tf.data.TFRecordDataset([TF_RECORD_PATH])
dataset = dataset.map(_parse_fn)
dataset = dataset.padded_batch(BATCH_SIZE, {'sent': [None], 'label': []})
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()
X_batch, y_batch = batch['sent'], batch['label']
print(X_batch.get_shape(), y_batch.get_shape())

(?, ?) (?,)


In [5]:
sess = tf.Session()
x, y = sess.run([X_batch, y_batch])
print([len(x) for x in X_train[:BATCH_SIZE]], '->', x.shape)
print()
print("Original:")
print(np.array(X_train[BATCH_SIZE-1]))
print()
print("Padded:")
print(x[-1])

[218, 189, 141] -> (3, 218)

Original:
[   1   14   47    8   30   31    7    4  249  108    7    4 5974   54
   61  369   13   71  149   14   22  112    4 2401  311   12   16 3711
   33   75   43 1829  296    4   86  320   35  534   19  263 4821 1301
    4 1873   33   89   78   12   66   16    4  360    7    4   58  316
  334   11    4 1716   43  645  662    8  257   85 1200   42 1228 2578
   83   68 3912   15   36  165 1539  278   36   69    2  780    8  106
   14 6905 1338   18    6   22   12  215   28  610   40    6   87  326
   23 2300   21   23   22   12  272   40   57   31   11    4   22   47
    6 2307   51    9  170   23  595  116  595 1352   13  191   79  638
   89    2   14    9    8  106  607  624   35  534    6  227    7  129
  113]

Padded:
[   1   14   47    8   30   31    7    4  249  108    7    4 5974   54
   61  369   13   71  149   14   22  112    4 2401  311   12   16 3711
   33   75   43 1829  296    4   86  320   35  534   19  263 4821 1301
    4 1873   33   89  