In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [3]:
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
    try:
        yield
    except error_class as e:
        print('Caught expected exception \n  {}: {}'.format(error_class, e))
    except Exception as e:
        print('Got unexpected exception \n  {}: {}'.format(type(e), e))
    else:
        raise Exception('Expected {} to be raised but no error was raised!'.format(
            error_class))

In [5]:
@tf.function
def add(a, b):
    return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))

<tf.Tensor: id=24, shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>

In [10]:
# Functions have gradients

@tf.function
def add(a, b):
    return a + b

v = tf.Variable(1.0)
with tf.GradientTape() as tape:
    result = add(v, 1.0)
tape.gradient(result, v)

<tf.Tensor: id=107, shape=(), dtype=float32, numpy=1.0>

In [11]:
@tf.function
def dense_layer(x, w, b):
    return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

<tf.Tensor: id=133, shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

In [12]:
# Functions are polymorphic

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)



In [None]:
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("E")

In [14]:
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("Using a concrete trace with incompatible types will throw an error")
with assert_raises(tf.errors.InvalidArgumentError):
    double_strings(tf.constant(1))

Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
Using a concrete trace with incompatible types will throw an error
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute __inference_double_160 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_160]


In [15]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
    print("Tracing with", x)
    return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
with assert_raises(ValueError):
    next_collatz(tf.constant([[1, 2], [3, 4]]))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))


In [16]:
external_list = []

def side_effect(x):
    print('Python side effect')
    external_list.append(x)
    
@tf.function
def f(x):
    tf.py_function(side_effect, inp=[x], Tout=[])
    
f(1)
f(1)
f(1)

assert len(external_list) == 3
assert external_list[0].numpy() == 1

Python side effect
Python side effect
Python side effect


In [19]:
external_val = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
    external_val.assign_add(next(iterator))
    tf.print("Value of external_var:", external_var)
    
iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)

buggy_consume_next(iterator)
buggy_consume_next(iterator)

Value of external_var: 0
Value of external_var: 0
Value of external_var: 0


In [20]:
def measure_graph_size(f, *args):
    g = f.get_concrete_function(*args).graph
    print("{}({}) contains {} nodes in its graph".format(
        f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))
          
@tf.function
def train(dataset):
    loss = tf.constant(0)
    for x, y in dataset:
        loss += tf.abs(y - x)
    return loss

small_data = [(1, 1)] * 2
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))

train([(1, 1), (1, 1)]) contains 8 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 5 nodes in its graph
train(<DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 5 nodes in its graph


In [23]:
@tf.function
def f(x):
    while tf.reduce_sum(x) > 1:
        tf.print(x)
        x = tf.tanh(x)
    return x

f(tf.random.uniform([5]))

[0.860282779 0.0538820028 0.673981428 0.715483546 0.0757763386]
[0.696403384 0.0538299121 0.587592781 0.614104 0.075631626]
[0.602079928 0.0537779741 0.528162122 0.547009587 0.0754877329]
[0.538527966 0.0537261814 0.483974934 0.498275608 0.0753446594]
[0.491872907 0.0536745377 0.449421316 0.460759938 0.0752024055]
[0.455701709 0.0536230505 0.421423167 0.430703372 0.0750609487]
[0.426574469 0.0535717048 0.398128718 0.405908972 0.0749202892]
[0.402454555 0.0535205081 0.378346711 0.384993494 0.0747804195]
[0.382047206 0.0534694567 0.361270815 0.367036164 0.0746413246]
[0.364484 0.0534185581 0.346332908 0.351396561 0.0745030046]
[0.349157542 0.0533677973 0.333119363 0.337613493 0.0743654519]
[0.335628182 0.0533171892 0.321320832 0.325345129 0.0742286593]
[0.323568821 0.053266719 0.310700715 0.314331979 0.0740926191]
[0.312730283 0.0532163903 0.301074386 0.304372907 0.0739573315]
[0.302918881 0.0531662069 0.292295486 0.295309305 0.073822774]


<tf.Tensor: id=451, shape=(5,), dtype=float32, numpy=
array([0.2939815 , 0.05311617, 0.28424618, 0.28701413, 0.07368895],
      dtype=float32)>

In [24]:
def f(x):
    while tf.reduce_sum(x) > 1:
        tf.print(x)
        x = tf.tanh(x)
    return x

print(tf.autograph.to_code(f))

def tf__f(x):
  do_return = False
  retval_ = ag__.UndefinedReturnValue()
  with ag__.FunctionScope('f', 'f_scope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as f_scope:

    def get_state():
      return ()

    def set_state(_):
      pass

    def loop_body(x):
      ag__.converted_call(tf.print, f_scope.callopts, (x,), None, f_scope)
      x = ag__.converted_call(tf.tanh, f_scope.callopts, (x,), None, f_scope)
      return x,

    def loop_test(x):
      return ag__.converted_call(tf.reduce_sum, f_scope.callopts, (x,), None, f_scope) > 1
    x, = ag__.while_stmt(loop_test, loop_body, get_state, set_state, (x,), ('x',), ())
    do_return = True
    retval_ = f_scope.mark_return_value(x)
  do_return,
  return ag__.retval(retval_)



In [25]:
def test_tf_cond(f, *args):
    g = f.get_concrete_function(*args).graph
    if any(node.name == 'cond' for node in g.as_graph_def().node):
        print("{}({}) uses tf.cond.".format(
            f.__name__, ', '.join(map(str, args))))
    else:
        print("{}({}) executes normally.".format(
            f.__name__, ', '.join(map(str, args))))

In [27]:
@tf.function
def hyperparam_cond(x, training=True):
    if training:
        x = tf.nn.dropout(x, rate=0.5)
    return x

@tf.function
def maybe_tensor_cond(x):
    if x < 0:
        x -= x
    return x

test_tf_cond(hyperparam_cond, tf.ones([1], dtype=tf.float32))
test_tf_cond(maybe_tensor_cond, tf.constant(-1))
test_tf_cond(maybe_tensor_cond, -1)

hyperparam_cond(tf.Tensor([1.], shape=(1,), dtype=float32)) executes normally.
maybe_tensor_cond(tf.Tensor(-1, shape=(), dtype=int32)) uses tf.cond.
maybe_tensor_cond(-1) executes normally.
