In [1]:
from nn_globals import *

%matplotlib inline

[INFO    ] Using cmssw CMSSW_10_1_7
[INFO    ] Using numpy 1.14.1
  from ._conv import register_converters as _register_converters
[INFO    ] Using tensorflow 1.5.0
Using TensorFlow backend.
[INFO    ] Using keras 2.1.4
[INFO    ] .. list devices: [_DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 268435456)]
[INFO    ] Using scipy 1.0.0
[INFO    ] Using sklearn 0.19.2


In [2]:
superstrip_size = 16

from nn_logging import getLogger
logger = getLogger()

In [3]:
def cnn_data(filename):
  try:
    logger.info('Loading cnn data from {0} ...'.format(filename))
    loaded = np.load(filename)
    the_image_pixels   = loaded['image_pixels']
    the_image_channels = loaded['image_channels']
    the_labels         = loaded['labels']
    the_parameters     = loaded['parameters']
    logger.info('Loaded the images with shape {0},{1}'.format(the_image_pixels.shape, the_image_channels.shape))
    logger.info('Loaded the labels with shape {0}'.format(the_labels.shape))
    logger.info('Loaded the parameters with shape {0}'.format(the_parameters.shape))
  except:
    logger.error('Failed to load data from file: {0}'.format(filename))
    
  assert(the_image_pixels.shape[0] == the_image_channels.shape[0])
  assert(the_image_pixels.shape[0] == the_labels.shape[0])
  assert(the_image_pixels.shape[0] == the_parameters.shape[0])

  return the_image_pixels, the_image_channels, the_labels, the_parameters

In [4]:
image_pixels, image_channels, labels, parameters = cnn_data(infile_images)

from sklearn.model_selection import train_test_split
test_size = 0.4
shuffle = False

(image_pixels_train, image_pixels_test, image_channels_train, image_channels_test, labels_train, labels_test, parameters_train, parameters_test) = train_test_split(image_pixels, image_channels, labels, parameters, test_size=test_size, shuffle=shuffle)

[INFO    ] Loading cnn data from ../test7/histos_tbe.17.npz ...
[INFO    ] Loaded the images with shape (3535956, 50, 2),(3535956, 50, 3)
[INFO    ] Loaded the labels with shape (3535956, 3)
[INFO    ] Loaded the parameters with shape (3535956, 3)


In [5]:
if True:
  nentries = 200000
  (image_pixels_train, image_pixels_test, image_channels_train, image_channels_test, labels_train, labels_test, parameters_train, parameters_test) = train_test_split(image_pixels[:nentries], image_channels[:nentries], labels[:nentries], parameters[:nentries], test_size=test_size, shuffle=shuffle)

In [6]:
def imaging(pixels, channels, superstrip_size):
  zone_size = 7
  m_size = 11
  n_size = 5040 // superstrip_size
  chn_size = 3
  image = np.zeros((m_size*zone_size, n_size, chn_size), dtype=np.float32)
  mask = (pixels[:,0] != -99)
  image[pixels[mask,0], pixels[mask,1]] = channels[mask]
  return image

def labeling(labels):
  pt_size = 21
  phi_size = 128
  eta_size = 7
  image = np.zeros((pt_size, phi_size, eta_size), dtype=np.float32)
  image[labels[0], labels[1], labels[2]] = 1
  return image

def draw(image, label):
  aspect = 'auto'
  extent = (0,image.shape[1],0,image.shape[0])
  #plt.imshow(image[:,:,0], cmap='viridis', interpolation='none', extent=extent, aspect=aspect)
  #plt.show()
  #plt.imshow(image[:,:,1], cmap='viridis', interpolation='none', extent=extent, aspect=aspect)
  #plt.show()
  #plt.imshow(image[:,:,2], cmap='viridis', interpolation='none', extent=extent, aspect=aspect)
  #plt.show()
  image_2d = np.max(image, axis=-1)
  image_2d[np.nonzero(image_2d)] = 1
  plt.imshow(image_2d, cmap='viridis', interpolation='none', origin='lower', extent=extent, aspect=aspect)
  for y in [11,22,33,44,55,66]:
    plt.axhline(y=y,linewidth=1, color='w', alpha=0.4)
  plt.show()
  print np.where(image_2d)
  #label_2d = np.max(label, axis=-1)
  label_2d = np.expand_dims(label, axis=-1)  #FIXME
  print label_2d.shape
  label_2d[np.nonzero(label_2d)] = 1
  plt.imshow(label_2d, cmap='viridis', interpolation='none', origin='lower')
  plt.show()
  print np.where(label_2d)

In [7]:
def tf_imaging(pixels, channels, superstrip_size=superstrip_size):
  zone_size = 7
  m_size = 11
  n_size = 5040 // superstrip_size
  chn_size = 3
  #image_shape = (m_size*zone_size, n_size, chn_size)
  image_shape = (m_size*zone_size, n_size)
  #image = tf.ones(image_shape, dtype=tf.float32)
  mask = tf.not_equal(pixels[:,0], -99)
  indices = tf.boolean_mask(pixels, mask)
  values = tf.boolean_mask(channels, mask)
  #mask = tf.sparse_to_dense(indices, image_shape, 1, default_value=0, validate_indices=False)
  #image = image * tf.cast(mask, dtype=tf.float32)
  scatter0 = tf.scatter_nd(indices, values[:,0], image_shape)
  scatter1 = tf.scatter_nd(indices, values[:,1], image_shape)
  scatter2 = tf.scatter_nd(indices, values[:,2], image_shape)
  image = tf.stack([scatter0, scatter1, scatter2], axis=-1)
  return image

def tf_labeling(labels):
  pt_size = 21
  phi_size = 128
  eta_size = 7
  image_shape = (pt_size, phi_size, eta_size)
  #image = tf.ones(image_shape, dtype=tf.float32)
  indices = tf.reshape(labels, [-1,3])
  image = tf.sparse_to_dense(indices, image_shape, 1, default_value=0, validate_indices=False)
  #image = image * tf.cast(mask, dtype=tf.float32)
  return image

def tf_labeling_one_hot(labels):
  pt_size = 21
  classes = tf.one_hot(labels[0], depth=pt_size)
  return classes

def tf_labeling_no_one_hot(labels):
  classes = labels[0]
  return classes

In [8]:
sess = K.get_session()

sanity_check= False

In [9]:
if sanity_check:
  image_pixels_ph   = tf.placeholder(image_pixels_train.dtype, image_pixels_train.shape)
  image_channels_ph = tf.placeholder(image_channels_train.dtype, image_channels_train.shape)
  labels_ph         = tf.placeholder(labels_train.dtype, labels_train.shape)
  parameters_ph     = tf.placeholder(parameters_train.dtype, parameters_train.shape)

In [10]:
if sanity_check:
  dataset1 = tf.data.Dataset.from_tensor_slices((image_pixels_ph, image_channels_ph))
  dataset1 = dataset1.map(tf_imaging)
  dataset2 = tf.data.Dataset.from_tensor_slices((labels_ph))
  #dataset2 = dataset2.map(tf_labeling)
  dataset2 = dataset2.map(tf_labeling_one_hot)
  dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
  print(dataset3.output_types)
  print(dataset3.output_shapes)

In [11]:
if sanity_check:
  #batched_dataset = dataset3.batch(4)
  batched_dataset = dataset3.batch(1)
  iterator = batched_dataset.make_initializable_iterator()
  feed_dict = {image_pixels_ph: image_pixels_train, image_channels_ph: image_channels_train, labels_ph: labels_train}
  sess.run(iterator.initializer, feed_dict=feed_dict)

  next_element = iterator.get_next()

  #print(sess.run(next_element))
  #print(sess.run(next_element))
  #print(sess.run(next_element))
  #print(sess.run(next_element))

In [12]:
if sanity_check:
  import matplotlib as mpl
  mpl.rcParams['figure.figsize'] = (10,5)
  #mpl.rcParams['axes.labelpad'] = 0
  #mpl.rcParams['axes.labelsize'] = 0
  #mpl.rcParams['xtick.labelsize'] = 0
  #mpl.rcParams['ytick.labelsize'] = 0

In [13]:
if sanity_check:
  image, label = sess.run(next_element)
  draw(image[0], label[0])

In [14]:
if sanity_check:
  image, label = sess.run(next_element)
  draw(image[0], label[0])

In [15]:
if sanity_check:
  image, label = sess.run(next_element)
  draw(image[0], label[0])

In [16]:
if sanity_check:
  image, label = sess.run(next_element)
  draw(image[0], label[0])

In [17]:
if sanity_check:
  image, label = sess.run(next_element)
  draw(image[0], label[0])

In [18]:
if sanity_check:
  image, label = sess.run(next_element)
  draw(image[0], label[0])

### MNIST in Tensorflow

In [19]:
def create_model(data_format, n_rows=28, n_columns=28, n_channels=1, n_classes=10, dropout=0.4):
  """Model to recognize digits in the MNIST dataset.
  Network structure is equivalent to:
  https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
  and
  https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
  But uses the tf.keras API.
  Args:
    data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
      typically faster on GPUs while 'channels_last' is typically faster on
      CPUs. See
      https://www.tensorflow.org/performance/performance_guide#data_formats
  Returns:
    A tf.keras.Model.
  """
  if data_format == 'channels_first':
    input_shape = [n_channels, n_rows, n_columns]
  else:
    assert data_format == 'channels_last'
    input_shape = [n_rows, n_columns, n_channels]

  l = tf.keras.layers
      
  # The model consists of a sequential chain of layers, so tf.keras.Sequential
  # (a subclass of tf.keras.Model) makes for a compact description.
  return tf.keras.Sequential(
      [
          l.Reshape(
              target_shape=input_shape,
              input_shape=(n_rows * n_columns,)),
          l.Conv2D(
              32,
              5,
              padding='same',
              data_format=data_format,
              activation=tf.nn.relu),
          l.MaxPooling2D(
              (2, 2), 
              (2, 2), 
              padding='same', 
              data_format=data_format),
          l.Conv2D(
              64,
              5,
              padding='same',
              data_format=data_format,
              activation=tf.nn.relu),
          l.MaxPooling2D(
              (2, 2), 
              (2, 2), 
              padding='same', 
              data_format=data_format),
          l.Flatten(),
          l.Dense(1024, activation=tf.nn.relu),
          #l.Dropout(dropout),
          l.Dense(n_classes),
      ])

In [20]:
def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
  model = create_model(params['data_format'], params['n_rows'], params['n_columns'], params['n_channels'], params['n_classes'], params['dropout'])
  learning_rate = params['learning_rate']
  
  image = features
  if isinstance(image, dict):
    image = features['image']

  if mode == tf.estimator.ModeKeys.PREDICT:
    #logits = model(image, training=False)
    logits = model(image)  # no keyword argument 'training' in tensorflow 1.5
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits),
    }
    # For mode == ModeKeys.PREDICT: required fields are predictions.
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })
  
  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.AdamOptimizer(learning_rate=params['learning_rate'])

    # If we are running multi-GPU, we need to wrap the optimizer.
    if params.get('multi_gpu'):
      optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

    #logits = model(image, training=True)
    logits = model(image)  # no keyword argument 'training' in tensorflow 1.5
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    accuracy = tf.metrics.accuracy(
        labels=labels, predictions=tf.argmax(logits, axis=1))

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(learning_rate, 'learning_rate')
    tf.identity(loss, 'cross_entropy')
    tf.identity(accuracy[1], name='train_accuracy')

    # Save accuracy scalar to Tensorboard output.
    tf.summary.scalar('train_accuracy', accuracy[1])

    # For mode == ModeKeys.TRAIN: required fields are loss and train_op
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.TRAIN,
        loss=loss,
        train_op=optimizer.minimize(loss, global_step=tf.train.get_or_create_global_step()))
  
  if mode == tf.estimator.ModeKeys.EVAL:
    #logits = model(image, training=False)
    logits = model(image)  # no keyword argument 'training' in tensorflow 1.5
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    # For mode == ModeKeys.EVAL: required field is loss.
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=loss,
        eval_metric_ops={
            'accuracy':
                tf.metrics.accuracy(
                    labels=labels, predictions=tf.argmax(logits, axis=1)),
        })

In [21]:
from tensorflow.python.training import training
from tensorflow.python.data.ops import dataset_ops

class IteratorInitializerHook(training.SessionRunHook):
  """Hook to initialise data iterator after Session is created."""

  def __init__(self):
      super(IteratorInitializerHook, self).__init__()
      self.iterator_initializer_func = None
      self.feed_fn = None

  def after_create_session(self, session, coord):
      """Initialise the iterator after the session has been created."""
      self.iterator_initializer_func(session)

class _DatasetInitializerHook(training.SessionRunHook):

  def __init__(self, iterator, feed_fn):
    self._iterator = iterator
    self.feed_fn = feed_fn

  def begin(self):
    self._initializer = self._iterator.initializer

  def after_create_session(self, session, coord):
    del coord
    session.run(self._initializer, feed_dict=self.feed_fn())

# from tensorflow/python/estimator/estimator.py
def _get_features_and_labels_from_input_fn(self, input_fn, mode):
  """Extracts the `features` and labels from return values of `input_fn`."""
  result = self._call_input_fn(input_fn, mode)
  input_hooks = []
  print self.train_input_hook, self.eval_input_hook
  if isinstance(result, dataset_ops.Dataset):
    iterator = result.make_initializable_iterator()
    #input_hooks.append(_DatasetInitializerHook(iterator))
    if mode == tf.estimator.ModeKeys.TRAIN:
      input_hooks.append(_DatasetInitializerHook(iterator, self.train_input_hook.feed_fn))
    else:  # mode == tf.estimator.ModeKeys.EVAL
      input_hooks.append(_DatasetInitializerHook(iterator, self.eval_input_hook.feed_fn))
    result = iterator.get_next()
  if isinstance(result, (list, tuple)):
    if len(result) != 2:
      raise ValueError(
          'input_fn should return (feautures, labels) as a len 2 tuple.')
    return result[0], result[1], input_hooks
  return result, None, input_hooks

In [22]:
n_rows = 7 * 11
n_columns = 5040 // superstrip_size
n_channels = 3
n_classes = 21
dropout = 0.2
learning_rate = 1e-4


def run_mnist(flags_obj):
  """Run MNIST training and eval loop.
  Args:
    flags_obj: An object containing parsed flag values.
  """
  #model_helpers.apply_clean(flags_obj)
  model_function = model_fn
  
  # Get number of GPUs as defined by the --num_gpus flags and the number of
  # GPUs available on the machine.
  num_gpus = flags_obj.num_gpus
  multi_gpu = num_gpus > 1

  if multi_gpu:
    # Validate that the batch size can be split into devices.
    distribution_utils.per_device_batch_size(flags_obj.batch_size, num_gpus)

    # There are two steps required if using multi-GPU: (1) wrap the model_fn,
    # and (2) wrap the optimizer. The first happens here, and (2) happens
    # in the model_fn itself when the optimizer is defined.
    model_function = tf.contrib.estimator.replicate_model_fn(
        model_fn, loss_reduction=tf.losses.Reduction.MEAN,
        devices=["/device:GPU:%d" % d for d in range(num_gpus)])
  
  data_format = flags_obj.data_format
  mnist_classifier = tf.estimator.Estimator(
      model_fn=model_function,
      model_dir=flags_obj.model_dir,
      params={
          'data_format': data_format,
          'multi_gpu': multi_gpu,
          'n_rows': n_rows,
          'n_columns': n_columns,
          'n_channels': n_channels,
          'n_classes': n_classes,
          'dropout': dropout,
          'learning_rate': learning_rate,
      })
  
  
  # Set up training and evaluation input functions.
  def get_train_input_fn_and_hook():
    iterator_initializer_hook = IteratorInitializerHook()
    
    def train_input_fn():
      with tf.name_scope('train_data'):
        image_pixels_ph   = tf.placeholder(image_pixels_train.dtype, [None]+list(image_pixels_train.shape[1:]))
        image_channels_ph = tf.placeholder(image_channels_train.dtype, [None]+list(image_channels_train.shape[1:]))
        labels_ph         = tf.placeholder(labels_train.dtype, [None]+list(labels_train.shape[1:]))
        parameters_ph     = tf.placeholder(parameters_train.dtype, [None]+list(parameters_train.shape[1:]))
        feed_dict_train   = {image_pixels_ph: image_pixels_train, image_channels_ph: image_channels_train, labels_ph: labels_train}

        dataset1 = tf.data.Dataset.from_tensor_slices((image_pixels_ph, image_channels_ph))
        dataset1 = dataset1.map(tf_imaging)
        dataset2 = tf.data.Dataset.from_tensor_slices((labels_ph))
        #dataset2 = dataset2.map(tf_labeling)
        dataset2 = dataset2.map(tf_labeling_no_one_hot)
        dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
        
        # When choosing shuffle buffer sizes, larger sizes result in better
        # randomness, while smaller sizes use less memory. MNIST is a small
        # enough dataset that we can easily shuffle the full epoch.
        ds = dataset3.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)

        # Iterate through the dataset a set number (`epochs_between_evals`) of times
        # during each training session.
        ds = ds.repeat(flags_obj.epochs_between_evals)

        iterator = ds.make_initializable_iterator()
        def _init(sess):
          sess.run(iterator.initializer, feed_dict=feed_dict_train)
        iterator_initializer_hook.iterator_initializer_func = _init
        iterator_initializer_hook.feed_fn = lambda: feed_dict_train
        return ds
    return train_input_fn, iterator_initializer_hook
  
  def get_eval_input_fn_and_hook():
    iterator_initializer_hook = IteratorInitializerHook()
    
    def eval_input_fn():
      with tf.name_scope('test_data'):
        image_pixels_ph   = tf.placeholder(image_pixels_test.dtype, [None]+list(image_pixels_test.shape[1:]))
        image_channels_ph = tf.placeholder(image_channels_test.dtype, [None]+list(image_channels_test.shape[1:]))
        labels_ph         = tf.placeholder(labels_test.dtype, [None]+list(labels_test.shape[1:]))
        parameters_ph     = tf.placeholder(parameters_test.dtype, [None]+list(parameters_test.shape[1:]))
        feed_dict_test    = {image_pixels_ph: image_pixels_test, image_channels_ph: image_channels_test, labels_ph: labels_test}

        dataset1 = tf.data.Dataset.from_tensor_slices((image_pixels_ph, image_channels_ph))
        dataset1 = dataset1.map(tf_imaging)
        dataset2 = tf.data.Dataset.from_tensor_slices((labels_ph))
        #dataset2 = dataset2.map(tf_labeling)
        dataset2 = dataset2.map(tf_labeling_no_one_hot)
        dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
        
        #ds = dataset3.batch(flags_obj.batch_size).make_one_shot_iterator().get_next()
        ds = dataset3.batch(flags_obj.batch_size)

        iterator = ds.make_initializable_iterator()
        def _init(sess):
          sess.run(iterator.initializer, feed_dict=feed_dict_test)
        iterator_initializer_hook.iterator_initializer_func = _init
        iterator_initializer_hook.feed_fn = lambda: feed_dict_test
        return ds
    return eval_input_fn, iterator_initializer_hook
  
  train_input_fn, train_input_hook = get_train_input_fn_and_hook()
  
  eval_input_fn, eval_input_hook = get_eval_input_fn_and_hook()

  # Set up hook that outputs training logs every 100 steps.
  train_hooks = get_train_hooks(
      flags_obj.hooks, model_dir=flags_obj.model_dir,
      batch_size=flags_obj.batch_size)
  
  eval_hooks = []
  
  # Patch the function _get_features_and_labels_from_input_fn()
  import types
  mnist_classifier.train_input_hook = train_input_hook
  mnist_classifier.eval_input_hook = eval_input_hook
  mnist_classifier._get_features_and_labels_from_input_fn = types.MethodType(_get_features_and_labels_from_input_fn, mnist_classifier)

  # Train and evaluate model.
  for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
    mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
    #eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn, hooks=eval_hooks)
    #print('\nEvaluation results:\n\t%s\n' % eval_results)

    #if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
    #                                     eval_results['accuracy']):
    #  break

  # Export the model
  if flags_obj.export_dir is not None:
    image = tf.placeholder(tf.float32, [None, n_rows, n_columns])
    input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
        'image': image,
    })
    mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn)

In [23]:
_TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate',
                                        'cross_entropy',
                                        'train_accuracy'])


def get_train_hooks(name_list, use_tpu=False, **kwargs):
  """Factory for getting a list of TensorFlow hooks for training by name.
  Args:
    name_list: a list of strings to name desired hook classes. Allowed:
      LoggingTensorHook, ProfilerHook, ExamplesPerSecondHook, which are defined
      as keys in HOOKS
    use_tpu: Boolean of whether computation occurs on a TPU. This will disable
      hooks altogether.
    **kwargs: a dictionary of arguments to the hooks.
  Returns:
    list of instantiated hooks, ready to be used in a classifier.train call.
  Raises:
    ValueError: if an unrecognized name is passed.
  """

  if not name_list:
    return []

  if use_tpu:
    tf.logging.warning("hooks_helper received name_list `{}`, but a TPU is "
                       "specified. No hooks will be used.".format(name_list))
    return []

  train_hooks = []
  for name in name_list:
    hook_name = HOOKS.get(name.strip().lower())
    if hook_name is None:
      raise ValueError('Unrecognized training hook requested: {}'.format(name))
    else:
      train_hooks.append(hook_name(**kwargs))

  return train_hooks


def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs):  # pylint: disable=unused-argument
  """Function to get LoggingTensorHook.
  Args:
    every_n_iter: `int`, print the values of `tensors` once every N local
      steps taken on the current worker.
    tensors_to_log: List of tensor names or dictionary mapping labels to tensor
      names. If not set, log _TENSORS_TO_LOG by default.
    **kwargs: a dictionary of arguments to LoggingTensorHook.
  Returns:
    Returns a LoggingTensorHook with a standard set of tensors that will be
    printed to stdout.
  """
  if tensors_to_log is None:
    tensors_to_log = _TENSORS_TO_LOG

  return tf.train.LoggingTensorHook(
      tensors=tensors_to_log,
      every_n_iter=every_n_iter)


def get_profiler_hook(model_dir, save_steps=1000, **kwargs):  # pylint: disable=unused-argument
  """Function to get ProfilerHook.
  Args:
    model_dir: The directory to save the profile traces to.
    save_steps: `int`, print profile traces every N steps.
    **kwargs: a dictionary of arguments to ProfilerHook.
  Returns:
    Returns a ProfilerHook that writes out timelines that can be loaded into
    profiling tools like chrome://tracing.
  """
  return tf.train.ProfilerHook(save_steps=save_steps, output_dir=model_dir)


def get_examples_per_second_hook(every_n_steps=100,
                                 batch_size=128,
                                 warm_steps=5,
                                 **kwargs):  # pylint: disable=unused-argument
  """Function to get ExamplesPerSecondHook.
  Args:
    every_n_steps: `int`, print current and average examples per second every
      N steps.
    batch_size: `int`, total batch size used to calculate examples/second from
      global time.
    warm_steps: skip this number of steps before logging and running average.
    **kwargs: a dictionary of arguments to ExamplesPerSecondHook.
  Returns:
    Returns a ProfilerHook that writes out timelines that can be loaded into
    profiling tools like chrome://tracing.
  """
  return hooks.ExamplesPerSecondHook(
      batch_size=batch_size, every_n_steps=every_n_steps,
      warm_steps=warm_steps, metric_logger=logger.get_benchmark_logger())


def get_logging_metric_hook(tensors_to_log=None,
                            every_n_secs=600,
                            **kwargs):  # pylint: disable=unused-argument
  """Function to get LoggingMetricHook.
  Args:
    tensors_to_log: List of tensor names or dictionary mapping labels to tensor
      names. If not set, log _TENSORS_TO_LOG by default.
    every_n_secs: `int`, the frequency for logging the metric. Default to every
      10 mins.
  Returns:
    Returns a LoggingMetricHook that saves tensor values in a JSON format.
  """
  if tensors_to_log is None:
    tensors_to_log = _TENSORS_TO_LOG
  return metric_hook.LoggingMetricHook(
      tensors=tensors_to_log,
      metric_logger=logger.get_benchmark_logger(),
      every_n_secs=every_n_secs)


# A dictionary to map one hook name and its corresponding function
HOOKS = {
    'loggingtensorhook': get_logging_tensor_hook,
    'profilerhook': get_profiler_hook,
    'examplespersecondhook': get_examples_per_second_hook,
    'loggingmetrichook': get_logging_metric_hook,
}

In [24]:
from absl import flags

def define_mnist_flags():
  import functools
  help_wrap = functools.partial(flags.text_wrap, length=80, indent="",
                                firstline_indent="\n")
  
  key_flags = []
  
  flags.DEFINE_string(
        name="data_dir", short_name="dd", default="/tmp",
        help=help_wrap("The location of the input data."))
  key_flags.append("data_dir")
  
  flags.DEFINE_string(
        name="model_dir", short_name="md", default="/tmp",
        help=help_wrap("The location of the model checkpoint files."))
  key_flags.append("model_dir")
  
  flags.DEFINE_integer(
        name="train_epochs", short_name="te", default=1,
        help=help_wrap("The number of epochs used to train."))
  key_flags.append("train_epochs")
  
  flags.DEFINE_integer(
        name="epochs_between_evals", short_name="ebe", default=1,
        help=help_wrap("The number of training epochs to run between "
                       "evaluations."))
  key_flags.append("epochs_between_evals")
  
  flags.DEFINE_float(
        name="stop_threshold", short_name="st",
        default=None,
        help=help_wrap("If passed, training will stop at the earlier of "
                       "train_epochs and when the evaluation metric is  "
                       "greater than or equal to stop_threshold."))
  #key_flags.append("stop_threshold")
  
  flags.DEFINE_integer(
        name="batch_size", short_name="bs", default=32,
        help=help_wrap("Batch size for training and evaluation. When using "
                       "multiple gpus, this is the global batch size for "
                       "all devices. For example, if the batch size is 32 "
                       "and there are 4 GPUs, each GPU will get 8 examples on "
                       "each step."))
  key_flags.append("batch_size")
  
  flags.DEFINE_integer(
        name="num_gpus", short_name="ng",
        default=1 if tf.test.is_gpu_available() else 0,
        help=help_wrap(
            "How many GPUs to use with the DistributionStrategies API. The "
            "default is 1 if TensorFlow can detect a GPU, and 0 otherwise."))
  #key_flags.append("num_gpus")
  
  # Construct a pretty summary of hooks.
  hook_list_str = (
      u"\ufeff  Hook:\n" + u"\n".join([u"\ufeff    {}".format(key) for key
                                       in HOOKS]))
  flags.DEFINE_list(
      name="hooks", short_name="hk", default="LoggingTensorHook",
      help=help_wrap(
          u"A list of (case insensitive) strings to specify the names of "
          u"training hooks.\n{}\n\ufeff  Example: `--hooks ProfilerHook,"
          u"ExamplesPerSecondHook`\n See official.utils.logs.hooks_helper "
          u"for details.".format(hook_list_str))
  )
  key_flags.append("hooks")
  
  flags.DEFINE_string(
      name="export_dir", short_name="ed", default=None,
      help=help_wrap("If set, a SavedModel serialization of the model will "
                     "be exported to this directory at the end of training. "
                     "See the README for more details and relevant links.")
  )
  key_flags.append("export_dir")
  
  flags.DEFINE_enum(
      name="data_format", short_name="df", default="channels_last",
      enum_values=["channels_first", "channels_last"],
      help=help_wrap(
            "A flag to override the data format used in the model. "
            "channels_first provides a performance boost on GPU but is not "
            "always compatible with CPU. If left unspecified, the data format "
            "will be chosen automatically based on whether TensorFlow was "
            "built for CPU or GPU."))
  key_flags.append("data_format")
  return key_flags

def clear_flags():
  for name in list(flags.FLAGS):
    delattr(flags.FLAGS, name)

def set_defaults(**kwargs):
  #flags.FLAGS.remove_flag_values(kwargs.keys())
  for key, value in kwargs.items():
    flags.FLAGS.set_default(name=key, value=value)
    
clear_flags()
define_mnist_flags()

['data_dir',
 'model_dir',
 'train_epochs',
 'epochs_between_evals',
 'batch_size',
 'hooks',
 'export_dir',
 'data_format']

In [None]:
set_defaults(data_dir='./mnist_data',
             model_dir='./mnist_model',
             batch_size=50,
             train_epochs=1)

flags.FLAGS(['lol'])

run_mnist(flags.FLAGS)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_task_type': 'worker', '_is_chief': True, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f4e28185750>, '_save_checkpoints_steps': None, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_num_ps_replicas': 0, '_tf_random_seed': None, '_master': '', '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': 100, '_model_dir': './mnist_model', '_save_summary_steps': 100}
<__main__.IteratorInitializerHook object at 0x7f4e28185590> <__main__.IteratorInitializerHook object at 0x7f4e281850d0>
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Restoring parameters from ./mnist_model/model.ckpt-1
INFO:tensorflow:Saving checkpoints for 2 into ./mnist_model/model.ckpt.
INFO:tensorflow:learning_rate = 1e-04, cross_entropy = 3.0426984, train_accuracy = 0.05
INFO:tensorflow:loss = 3.042698

In [None]:
print image_pixels_train.dtype, image_pixels_train.shape
print image_channels_train.dtype, image_channels_train.shape
print labels_train.dtype, labels_train.shape
print parameters_train.dtype, parameters_train.shape

In [None]:
print image_pixels_test.dtype, image_pixels_test.shape
print image_channels_test.dtype, image_channels_test.shape
print labels_test.dtype, labels_test.shape
print parameters_test.dtype, parameters_test.shape