# Nest Tutorial
tf.contrib.framework.nest is a utility library that packs/flattens arbitary nested structures. A nested structure is a Python sequence, tuple (including `namedtuple`), or dict that can contain
further sequences, tuples, and dicts. The utilities here assume (and do not check) that the nested structures form a 'tree'.

For example, if you have a robot with two types of sensor inputs, you might want to structure the sensor inputs in a named tuple like collections.namedtuple('SensorInputs', 'input_1', 'input_2'). Then feed it into some machine learning system as input. But cetrain machine learning libraries (TensorFlow) takes only array as inputs, and therefore requires the user to pack/flatten the data themselves. Nest abstracts away this kind of tedious work, so user can focus on the application logic.

Below we provide three examples to illustarte the usages of nest:


1. Using nest pack/flatten nested structure
2. Using nest to augment py_func to handle structured input/output. 





# nest takes array, tuple, dict as input

In [0]:
import numpy as np
import tensorflow as tf
nest = tf.contrib.framework.nest
import unittest
import collections

In [0]:
class NestInputTest(unittest.TestCase):

  def testNestFlattensArray(self):
    self.assertEquals(nest.flatten([[1, 2, 3], [4, 5]]), [1, 2, 3, 4, 5])

  def testNestFlattensDict(self):
    self.assertEquals(nest.flatten({'key_1': 1, 'key_2': 2}), [1, 2])

  def testNestFlattensTuple(self):
    self.assertEquals(nest.flatten((1, 2)), [1, 2])

  def testNestFlattensNamedTuple(self):
    NamedTuple = collections.namedtuple('NamedTuple', 'str_attr int_attr')
    self.assertEquals(
        nest.flatten(NamedTuple(str_attr='1', int_attr=2)), ['1', 2])

  def testNestPacksSequences(self):
    packed_sequence = nest.pack_sequence_as(([None, None], (None,)), [1, 2, 3])
    self.assertEquals(packed_sequence, ([1, 2], (3,)))


suite = unittest.TestLoader().loadTestsFromTestCase(NestInputTest)
unittest.TextTestRunner().run(suite)

Below we look at a more complex example where we use nest.map_structure to transform inputs.


In [0]:
class NestMapStructureTest(tf.test.TestCase):

  def testAddTensors(self):
    add_op = nest.map_structure(lambda x, y: x + y, tf.constant(1),
                                tf.constant(1))
    with self.test_session() as sess:
      self.assertEquals(add_op.eval(), 2)

  def testGroupTuplesByElement(self):
    grouped = nest.map_structure(lambda *arr: arr, ('a', 'b', ('c', 'd')),
                                 (3, 4, (5, 6)), (False, True, (True, True)))
    np.testing.assert_equal(grouped, (('a', 3, False), ('b', 4, True),
                                      (('c', 5, True), ('d', 6, True))))


suite = unittest.TestLoader().loadTestsFromTestCase(NestMapStructureTest)
unittest.TextTestRunner().run(suite)

# Using nest to work with structured input/output with py_func
py_func only takes arrays as input and output, so nest is very useful when we want to process structured input and output with py_func.
Below is an example that adds three numbers together and returns the sum and list of original operands.


In [0]:
# define input/output structures
input_structure = {'input1': (None, None), 'input2': {'num3': None}}
output_structure = {'num1': None, 'num2': None, 'num3': None, 'sum': None}


def adds(*arg):
  # input is a flat list; pack it back to structured data
  packed_input = nest.pack_sequence_as(input_structure, arg)
  num1, num2 = packed_input['input1']
  num3 = packed_input['input2']['num3']

  # return a flat list
  return [num1, num2, num3, num1 + num2 + num3]


with tf.Graph().as_default():
  inputs = {
      'input1': (tf.constant(1), tf.constant(2)),
      'input2': {
          'num3': tf.constant(3)
      }
  }

  # call py_func with structured inputs
  flat_result = tf.py_func(adds, nest.flatten(inputs),
                           [tf.int32, tf.int32, tf.int32, tf.int32])

  # pack the result into structured tensors
  nest_result = nest.pack_sequence_as(output_structure, flat_result)

  tf_num1 = nest_result['num1']
  tf_num2 = nest_result['num2']
  tf_num3 = nest_result['num3']
  tf_sum = nest_result['sum']

  with tf.Session() as sess:
    # verify it all works
    np.testing.assert_equal(
        sess.run([tf_num1, tf_num2, tf_num3, tf_sum]), [1, 2, 3, 6])