In [2]:
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
import numpy as np

In [3]:
@tf.function
def tf_cube(x):
    return x**3

tf_cube(2)
tf_cube(tf.constant(2.0))

<tf.Tensor: shape=(), dtype=float32, numpy=8.0>

In [17]:
print(tf_cube.get_concrete_function(9)())
print(tf_cube.get_concrete_function(tf.constant(3.0))(tf.constant(1.0)))

tf.Tensor(729, shape=(), dtype=int32)
tf.Tensor(1.0, shape=(), dtype=float32)


In [21]:
concrete_function = tf_cube.get_concrete_function(tf.constant(3.0))
concrete_function.graph

<tensorflow.python.framework.func_graph.FuncGraph at 0x7fe7343551d0>

In [22]:
ops = concrete_function.graph.get_operations()
ops

[<tf.Operation 'x' type=Placeholder>,
 <tf.Operation 'pow/y' type=Const>,
 <tf.Operation 'pow' type=Pow>,
 <tf.Operation 'Identity' type=Identity>]

In [30]:
pow_op = ops[2]
list(pow_op.inputs)

[<tf.Tensor 'x:0' shape=() dtype=float32>,
 <tf.Tensor 'pow/y:0' shape=() dtype=float32>]

In [32]:
concrete_function.graph.get_operation_by_name('x')

<tf.Operation 'x' type=Placeholder>

In [35]:
concrete_function.function_def.signature

name: "__inference_tf_cube_13"
input_arg {
  name: "x"
  type: DT_FLOAT
}
output_arg {
  name: "identity"
  type: DT_FLOAT
}

In [39]:
@tf.function
def tf_cube(x):
    print("x = ", x)
    return x ** 3

result = tf_cube(tf.constant(2.0))

x =  Tensor("x:0", shape=(), dtype=float32)


In [45]:
result = tf_cube(2)

In [51]:
def test_tracing(inputs):
    print(inputs)
    return inputs
lambda_layer = keras.layers.Lambda(test_tracing)
test_model = keras.models.Sequential([lambda_layer])
test_model(tf.constant(2.0))

tf.Tensor(2.0, shape=(), dtype=float32)


True

In [55]:
@tf.function(input_signature=[tf.TensorSpec([None, 28, 28], tf.float32)])
def shrink(images):
    return images[:, ::2, ::2]


img_batch_1 = tf.random.uniform(shape=[100, 28, 28])
img_batch_2 = tf.random.uniform(shape=[50, 28, 28])
print(shrink(img_batch_1).shape)
print(shrink(img_batch_2).shape)

(100, 14, 14)
(50, 14, 14)


In [59]:
@tf.function
def add_10(x):
    for i in range(10):
        x += 1
    return x

add_10.get_concrete_function(tf.constant(0)).graph.get_operations()

[<tf.Operation 'x' type=Placeholder>,
 <tf.Operation 'add/y' type=Const>,
 <tf.Operation 'add' type=AddV2>,
 <tf.Operation 'add_1/y' type=Const>,
 <tf.Operation 'add_1' type=AddV2>,
 <tf.Operation 'add_2/y' type=Const>,
 <tf.Operation 'add_2' type=AddV2>,
 <tf.Operation 'add_3/y' type=Const>,
 <tf.Operation 'add_3' type=AddV2>,
 <tf.Operation 'add_4/y' type=Const>,
 <tf.Operation 'add_4' type=AddV2>,
 <tf.Operation 'add_5/y' type=Const>,
 <tf.Operation 'add_5' type=AddV2>,
 <tf.Operation 'add_6/y' type=Const>,
 <tf.Operation 'add_6' type=AddV2>,
 <tf.Operation 'add_7/y' type=Const>,
 <tf.Operation 'add_7' type=AddV2>,
 <tf.Operation 'add_8/y' type=Const>,
 <tf.Operation 'add_8' type=AddV2>,
 <tf.Operation 'add_9/y' type=Const>,
 <tf.Operation 'add_9' type=AddV2>,
 <tf.Operation 'Identity' type=Identity>]

In [60]:
@tf.function
def add_10(x):
    for i in tf.range(10):
        x += 1
    return x

add_10.get_concrete_function(tf.constant(0)).graph.get_operations()

[<tf.Operation 'x' type=Placeholder>,
 <tf.Operation 'range/start' type=Const>,
 <tf.Operation 'range/limit' type=Const>,
 <tf.Operation 'range/delta' type=Const>,
 <tf.Operation 'range' type=Range>,
 <tf.Operation 'sub' type=Sub>,
 <tf.Operation 'floordiv' type=FloorDiv>,
 <tf.Operation 'mod' type=FloorMod>,
 <tf.Operation 'zeros_like' type=Const>,
 <tf.Operation 'NotEqual' type=NotEqual>,
 <tf.Operation 'Cast' type=Cast>,
 <tf.Operation 'add' type=AddV2>,
 <tf.Operation 'zeros_like_1' type=Const>,
 <tf.Operation 'Maximum' type=Maximum>,
 <tf.Operation 'while/loop_counter' type=Const>,
 <tf.Operation 'while' type=StatelessWhile>,
 <tf.Operation 'Identity' type=Identity>]

In [64]:
counter = tf.Variable(0)

@tf.function
def increment(counter, c=1):
    return counter.assign_add(c)

increment(counter)
increment(counter)


<tf.Tensor: shape=(), dtype=int32, numpy=2>

In [66]:
function_def = increment.get_concrete_function(counter).function_def
function_def.signature

name: "__inference_increment_303"
input_arg {
  name: "counter"
  type: DT_RESOURCE
}
output_arg {
  name: "identity"
  type: DT_INT32
}
is_stateful: true
control_output: "AssignAddVariableOp"