Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)

In [1]:
import tensorflow as tf
import os
from tqdm import tqdm

In [2]:
MAX_LEN = 300
BATCH_SIZE = 32
VOCAB_SIZE = 20000
TF_RECORD_PATH = './imdb_train_fixed300.tfrecord'

if not os.path.isfile(TF_RECORD_PATH):
    (X_train, y_train), (_, _) = tf.keras.datasets.imdb.load_data(num_words=VOCAB_SIZE)
    X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train,
                                                            MAX_LEN,
                                                            padding='post',
                                                            truncating='post')

In [3]:
if not os.path.isfile(TF_RECORD_PATH):
    writer = tf.python_io.TFRecordWriter(TF_RECORD_PATH)
    for sent, label in tqdm(zip(X_train, y_train), total=len(X_train), ncols=70):
        example = tf.train.Example(
            features = tf.train.Features(
                 feature = {
                   'sent': tf.train.Feature(
                       int64_list=tf.train.Int64List(value=sent)),
                   'label': tf.train.Feature(
                       int64_list=tf.train.Int64List(value=[label])),
                   }))
        serialized = example.SerializeToString()
        writer.write(serialized)
    writer.close()

100%|██████████████████████████| 25000/25000 [00:45<00:00, 552.72it/s]


In [4]:
def _parse_fn(example_proto):
    parsed_feats = tf.parse_single_example(
        example_proto,
        features={
            'sent': tf.FixedLenFeature([MAX_LEN], tf.int64),
            'label': tf.FixedLenFeature([], tf.int64)
        })
    return parsed_feats['sent'], parsed_feats['label']

dataset = tf.data.TFRecordDataset([TF_RECORD_PATH])
dataset = dataset.map(_parse_fn)
dataset = dataset.batch(BATCH_SIZE)
iterator = dataset.make_one_shot_iterator()
X_batch, y_batch = iterator.get_next()
print(X_batch.get_shape(), y_batch.get_shape())

(?, 300) (?,)


In [5]:
sess = tf.Session()
x, y = sess.run([X_batch, y_batch])
print(x.shape, y.shape)

(32, 300) (32,)
