## XLA intermediate files

In [1]:
%%writefile mini_net.py

import argparse
import numpy as np
import tensorflow as tf
from tensorflow import keras

NUM_SAMPLES = 1000
DIM = 32
CLASSES = 5
BATCH_SIZE = 128


def get_model():
    inputs = keras.Input(shape=(DIM,))
    x = keras.layers.Dense(128, activation="relu")(inputs)
    x = keras.layers.Dense(64, activation="relu")(x)
    outputs = keras.layers.Dense(CLASSES, activation="softmax")(x)
    return keras.Model(inputs, outputs)


def generate_dataset():
    x = np.random.randn(NUM_SAMPLES, DIM)
    y = np.random.randint(0, CLASSES, size=(NUM_SAMPLES,))
    dataset = tf.data.Dataset.from_tensor_slices((x, y))
    return dataset.batch(BATCH_SIZE)


def train_model(args):
    model = get_model()
    if not args.jit_compile:
        jit_compile = None
    else:
        jit_compile = True
    model.compile(loss="sparse_categorical_crossentropy", jit_compile=jit_compile)

    dataset = generate_dataset()
    model.fit(dataset, epochs=5)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--jit_compile", action="store_true")
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()
    train_model(args)

Overwriting mini_net.py


In [2]:
# From https://www.tensorflow.org/xla#reproducible_bug_reports
!TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    python mini_net.py

2023-03-12 15:00:29.418558: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-12 15:00:30.369614: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia
2023-03-12 15:00:30.369713: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia
2023-03-12 15:00:33.350366: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Over

In [3]:
!TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    python mini_net.py --jit_compile

2023-03-12 15:00:42.415271: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-12 15:00:43.875446: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia
2023-03-12 15:00:43.875543: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia
2023-03-12 15:00:46.740244: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Over

In [4]:
!ls /tmp/generated

before_increase_dynamism_for_auto_jit_pass_10.pbtxt
before_increase_dynamism_for_auto_jit_pass_11.pbtxt
before_increase_dynamism_for_auto_jit_pass_12.pbtxt
before_increase_dynamism_for_auto_jit_pass_13.pbtxt
before_increase_dynamism_for_auto_jit_pass_14.pbtxt
before_increase_dynamism_for_auto_jit_pass_15.pbtxt
before_increase_dynamism_for_auto_jit_pass_16.pbtxt
before_increase_dynamism_for_auto_jit_pass_17.pbtxt
before_increase_dynamism_for_auto_jit_pass_18.pbtxt
before_increase_dynamism_for_auto_jit_pass_19.pbtxt
before_increase_dynamism_for_auto_jit_pass_1.pbtxt
before_increase_dynamism_for_auto_jit_pass_20.pbtxt
before_increase_dynamism_for_auto_jit_pass_21.pbtxt
before_increase_dynamism_for_auto_jit_pass_22.pbtxt
before_increase_dynamism_for_auto_jit_pass_23.pbtxt
before_increase_dynamism_for_auto_jit_pass_24.pbtxt
before_increase_dynamism_for_auto_jit_pass_25.pbtxt
before_increase_dynamism_for_auto_jit_pass_26.pbtxt
before_increase_dynamism_for_auto_jit_pass_27.pbtxt
before_increa

In [5]:
!ls /tmp/generated/*.txt

/tmp/generated/module_0031.a_inference_run_step_633__.413.before_optimizations.txt
/tmp/generated/module_0031.a_inference_run_step_633__.413.sm_7.5_gpu_after_optimizations-buffer-assignment.txt
/tmp/generated/module_0031.a_inference_run_step_633__.413.sm_7.5_gpu_after_optimizations.txt
/tmp/generated/module_0031.cluster_1__XlaCompiledKernel_true__XlaHasReferenceVars_false__XlaNumConstantArgs_0__XlaNumResourceArgs_7_.188.before_optimizations.txt
/tmp/generated/module_0031.cluster_1__XlaCompiledKernel_true__XlaHasReferenceVars_false__XlaNumConstantArgs_0__XlaNumResourceArgs_7_.188.sm_7.5_gpu_after_optimizations-buffer-assignment.txt
/tmp/generated/module_0031.cluster_1__XlaCompiledKernel_true__XlaHasReferenceVars_false__XlaNumConstantArgs_0__XlaNumResourceArgs_7_.188.sm_7.5_gpu_after_optimizations.txt
/tmp/generated/module_0045.a_inference__update_step_xla_376__.31.before_optimizations.txt
/tmp/generated/module_0045.a_inference__update_step_xla_376__.31.sm_7.5_gpu_after_optimizations-buf

The operation semantics are explained in [this guide](https://www.tensorflow.org/xla/operation_semantics). 

In [6]:
!cat /tmp/generated/module_0049.a_inference__update_step_xla_376__.31.sm_7.5_gpu_after_optimizations.txt

HloModule a_inference__update_step_xla_376__.31, input_output_alias={ {0}: (1, {}, may-alias), {1}: (2, {}, may-alias) }, alias_passthrough_params=true, entry_computation_layout={(f32[128,64]{1,0},f32[128,64]{1,0},f32[128,64]{1,0},f32[])->(f32[128,64]{1,0}, f32[128,64]{1,0})}

fused_computation {
  param_0 = f32[128,64]{1,0} parameter(0)
  param_2.10 = f32[] parameter(2)
  broadcast.1 = f32[128,64]{1,0} broadcast(param_2.10), dimensions={}, metadata={op_type="Mul" op_name="mul_2" source_file="/usr/local/lib/python3.9/dist-packages/keras/optimizers/optimizer_experimental/rmsprop.py" source_line=191}
  param_1.9 = f32[128,64]{1,0} parameter(1)
  multiply.1 = f32[128,64]{1,0} multiply(broadcast.1, param_1.9), metadata={op_type="Mul" op_name="mul_2" source_file="/usr/local/lib/python3.9/dist-packages/keras/optimizers/optimizer_experimental/rmsprop.py" source_line=191}
  constant_2_clone_1 = f32[] constant(0.9), metadata={op_type="Mul" op_name="mul" source_file="/usr/local/lib/python3.9/dis

## Interacting with `tf.function`

Taken largely from [this guide](https://www.tensorflow.org/guide/function).

In [1]:
import tensorflow as tf
import timeit

In [2]:
def add(a, b):
    return a + b


def dense_layer(x, w, b):
    return add(tf.matmul(x, w), b)


@tf.function
def func_dense_layer(x, w, b):
    return dense_layer(x, w, b)


@tf.function(jit_compile=True)
def jit_dense_layer(x, w, b):
    return dense_layer(x, w, b)

In [3]:
print(
    "Eager dense:",
    timeit.timeit(
        lambda: dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2])), number=10
    ),
)
print(
    "Function dense:",
    timeit.timeit(
        lambda: func_dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2])),
        number=10,
    ),
)

# Warm-up
_ = jit_dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
print(
    "JIT dense:",
    timeit.timeit(
        lambda: jit_dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2])),
        number=10,
    ),
)

Eager dense: 2.2915861180008505
Function dense: 0.1730724110002484
JIT dense: 0.01554165000015928


### Actual models

In [4]:
mobilenet = tf.keras.applications.MobileNetV3Large()
random_inputs = tf.random.normal((10, 224, 224, 3))

# Warm-up
_ = mobilenet.predict(random_inputs)
print(
    "Eager:",
    timeit.timeit(lambda: mobilenet.predict(random_inputs, verbose=0), number=10),
)



Eager: 1.0998468459983997


In [5]:
tf.keras.backend.clear_session()

mobilenet = tf.keras.applications.MobileNetV3Large()
mobilenet.compile(jit_compile=True)

# Warm-up
_ = mobilenet.predict(random_inputs)
print("XLA:", timeit.timeit(
    lambda: mobilenet.predict(random_inputs, verbose=0),
    number=10)
)



XLA: 0.6547937049999746


Or, 

In [6]:
tf.keras.backend.clear_session()

mobilenet = tf.keras.applications.MobileNetV3Large()
model_call_fn = tf.function(mobilenet, jit_compile=True)

# Warm-up
_ = model_call_fn(random_inputs, training=False)
print(
    "XLA:",
    timeit.timeit(lambda: model_call_fn(random_inputs, training=False), number=10),
)



XLA: 0.051761404998615035
