In [1]:
import tensorflow as tf
tf.__version__

'2.0.0-alpha0'

# TensorFlow 2: Concrete Functions

A python function can be decorated to become a "tensorflow function":

In [13]:
@tf.function
def tf_pow(x):
    return x ** 2

This is similar to applying JIT foe example when using `numba.jit`.

A *tensorflow function* mostly behaves as a normal function, but 
with a few subtleties.

First, it always returns a tensor:

In [14]:
tf_pow(32)

<tf.Tensor: id=37, shape=(), dtype=int32, numpy=1024>

Sometimes we need to export such a function, for example to run it in production.
To do so we need to transform it in a concrete function.

A concrete function is just like a compiled C function that
only works with a specific input and output type (`tf.Tensor`):

In [41]:
concrete_func = tf_pow.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.float32))

In [42]:
concrete_func = tf_pow.get_concrete_function(tf.constant(1.0))

In [43]:
concrete_func.graph

<tensorflow.python.framework.func_graph.FuncGraph at 0xb2e0e9048>

In [44]:
concrete_func.graph.get_operations()

[<tf.Operation 'x' type=Placeholder>,
 <tf.Operation 'pow/y' type=Const>,
 <tf.Operation 'pow' type=Pow>,
 <tf.Operation 'Identity' type=Identity>]

In [45]:
concrete_func(tf.constant(3.0))

<tf.Tensor: id=71, shape=(), dtype=float32, numpy=9.0>

In [46]:
pow(3)

<tf.Tensor: id=73, shape=(), dtype=int32, numpy=9>

In [49]:
try:
    concrete_func(3.0)
except ValueError:
    print('-> Raises ValueError because input is not a tf.Tensor.')

-> Raises ValueError because input is not a tf.Tensor.


What's interesting in the concrete function is that we can see the "computational graph",
so we can understand how the computation is carried out by tensorflow.

Now, more examples of concrete functions from:
- https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r2/convert/concrete_function.md

### Power function (with default power 2).

In [74]:
class PowerModel(tf.Module):

    def __init__(self):
        self.const = None

    @tf.function(input_signature=[tf.TensorSpec(shape=[1], dtype=tf.float32)])
    def pow(self, x):
        if self.const is None:
            self.const = tf.Variable(2.)
        return x ** self.const

# Create the tf.Module object.
model = PowerModel()
model

<__main__.PowerModel at 0xb302aee10>

Compute the square:

In [71]:
# Get the concrete function.
concrete_func = model.pow.get_concrete_function()

concrete_func(tf.constant(4.0, shape=[1]))

<tf.Tensor: id=671, shape=(1,), dtype=float32, numpy=array([64.], dtype=float32)>

Compute the cube:

In [75]:
model.const = tf.Variable(3.)
concrete_func(tf.constant(4.0, shape=[1]))

<tf.Tensor: id=683, shape=(1,), dtype=float32, numpy=array([64.], dtype=float32)>

Alternative definition of the same thing:

In [82]:
# Create the tf.Module object.
model = tf.Module()
model.const = tf.Variable(3.)
model.pow = tf.function(lambda x : x ** model.const)

# Get the concrete function.
input_data = tf.TensorSpec(shape=[1], dtype=tf.float32)
concrete_func = model.pow.get_concrete_function(input_data)

In [83]:
concrete_func(tf.constant(4.0, shape=[1]))

<tf.Tensor: id=771, shape=(1,), dtype=float32, numpy=array([64.], dtype=float32)>

Refs

- https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/r2/convert/concrete_function.md
- https://github.com/ageron/tf2_course/issues/8#issuecomment-464645688