In [3]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import os
np.random.seed(1337)  # for reproducibility

import tensorflow as tf
from tensorflow.contrib.tpu.python.tpu import tpu_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
import pdb

config = tf.ConfigProto()
config.gpu_options.allow_growth=True
sess = tf.Session(config=config)


In [4]:

MODEL_DIR = './MLP/model_2'
DATA_DIR = './MLP_data'

prefetch_buffer_size = 128
num_files_infeed = 16
shuffle_buffer_size = 512
num_parallel_calls = 32

In [5]:
file_pattern = os.path.join(
        DATA_DIR, 'MLP_data_test*')

dataset = tf.data.Dataset.list_files(file_pattern)

In [6]:
dataset

<ShuffleDataset shapes: (), types: tf.string>

In [7]:
!ls ./MLP_data/MLP_data_test*

./MLP_data/MLP_data_test.tfrecords


In [8]:
!ls $file_pattern

./MLP_data/MLP_data_test.tfrecords


In [9]:
class MLP_Input(object):
  """Wrapper class that acts as the input_fn to TPUEstimator."""

  def __init__(self, is_training=True, is_eval=True, data_dir=None):
    self.is_eval = is_eval
    self.is_training = is_training
    self.data_dir = data_dir if data_dir else DATA_DIR

  def dataset_parser(self, value):
    """Parse an Imagenet record from value."""
    keys_to_features = {
        'X': tf.FixedLenFeature([], dtype=tf.string),
        'y': tf.FixedLenFeature(shape=[1], dtype=tf.int64)            
    }
    parsed = tf.parse_single_example(value, keys_to_features)
    X = tf.decode_raw(parsed['X'], tf.float32)
    X = tf.reshape(X, [10000])
    
    y = tf.cast(parsed['y'], tf.int64)
    return X, y

  def __call__(self, params):
    """Input function which provides a single batch for train or eval."""
    # Retrieves the batch size for the current shard. The # of shards is
    # computed according to the input pipeline deployment. See
    # `tf.contrib.tpu.RunConfig` for details.
    batch_size = params['batch_size']

    # Shuffle the filenames to ensure better randomization
    file_pattern = os.path.join(
        self.data_dir, 'MLP_data_train*' if self.is_training 
        else 'MLP_data_test*' )
    dataset = tf.data.Dataset.list_files(file_pattern)
    
    #pdb.set_trace()
    
    if self.is_training:
      dataset = dataset.shuffle(buffer_size=128)  # 1024 files in dataset

    if self.is_training:
      dataset = dataset.repeat()

    def prefetch_dataset(filename):
      buffer_size =  prefetch_buffer_size
      dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size)
      return dataset

    dataset = dataset.apply(
        tf.contrib.data.parallel_interleave(
            prefetch_dataset, cycle_length= num_files_infeed,
            sloppy=True))
    dataset = dataset.shuffle(shuffle_buffer_size)

    dataset = dataset.map(
        self.dataset_parser,
        num_parallel_calls=num_parallel_calls)
    dataset = dataset.prefetch(batch_size)
    dataset = dataset.apply(
        tf.contrib.data.batch_and_drop_remainder(batch_size))

    dataset = dataset.prefetch(2)  # Prefetch overlaps in-feed with training
    images, labels = dataset.make_one_shot_iterator().get_next()
    
    if self.is_training or self.is_eval:
          return images, labels
    else:
          return dataset

In [18]:
my_data = MLP_Input(is_training=False)
X, y = my_data({'batch_size':10})

In [19]:
with tf.Session() as sess:
  valX, valy = sess.run([X, y])
  
    

In [20]:
valX.shape

(10, 10000)

In [21]:
valy

array([[0],
       [1],
       [1],
       [0],
       [0],
       [0],
       [1],
       [1],
       [1],
       [0]])

In [22]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
data_path = './MLP_data/MLP_data_test.tfrecords'  # address to save the hdf5 file
with tf.Session() as sess:
    feature =  {
        'X': tf.FixedLenFeature([], dtype=tf.string),
        'y': tf.FixedLenFeature(shape=[1], dtype=tf.int64)            
    }
    # Create a list of filenames and pass it to a queue
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)
    # Define a reader and read the next record
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    # Decode the record read by the reader
    features = tf.parse_single_example(serialized_example, features=feature)
    # Convert the image data from string back to the numbers
    image = tf.decode_raw(features['X'], tf.float32)
    
    # Cast label data into int32
    label = tf.cast(features['y'], tf.int32)
    # Reshape image data into the original shape
    image = tf.reshape(image, [10000])
    
    # Any preprocessing here ...
    
    # Creates batches by randomly shuffling tensors
    images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10)
    
        # Initialize all global and local variables
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init_op)
    # Create a coordinator and run all QueueRunner objects
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for batch_index in range(5):
        img, lbl = sess.run([images, labels])
        print(img.shape, lbl)
    # Stop the threads
    coord.request_stop()
    
    # Wait for threads to stop
    coord.join(threads)
    sess.close()
    

(10, 10000) [[0]
 [1]
 [1]
 [0]
 [0]
 [1]
 [0]
 [1]
 [1]
 [1]]
(10, 10000) [[1]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]
 [0]]
(10, 10000) [[1]
 [1]
 [1]
 [0]
 [1]
 [1]
 [1]
 [0]
 [0]
 [0]]
(10, 10000) [[1]
 [1]
 [1]
 [0]
 [0]
 [0]
 [1]
 [1]
 [1]
 [1]]
(10, 10000) [[0]
 [0]
 [1]
 [1]
 [1]
 [0]
 [0]
 [1]
 [0]
 [0]]


In [23]:
img

array([[-0.68277508, -1.70158255,  0.1727289 , ...,  0.10917469,
        -2.40945673, -1.14211142],
       [-0.52076739,  1.81326783, -0.51520407, ..., -0.57771862,
         1.21164513,  1.26358306],
       [-0.634489  ,  0.23487876, -1.58474731, ...,  2.48326564,
        -0.64282036,  1.32642758],
       ..., 
       [ 0.47943571, -0.19249211, -1.13439512, ..., -0.14228664,
        -0.03568394,  0.53308493],
       [-0.41281357,  0.07613051, -0.0979723 , ..., -0.198423  ,
         0.7150836 , -0.40859175],
       [ 0.82969844,  1.00123513,  1.00835955, ..., -0.66160816,
         0.07516851,  0.2780641 ]], dtype=float32)