##### Copyright 2021 The IREE Authors

In [1]:
#@title Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Variables and State

This notebook

1. Creates a TensorFlow program with basic tf.Variable use
2. Imports that program into IREE's compiler
3. Compiles the imported program to an IREE VM bytecode module
4. Tests running the compiled VM module using IREE's runtime
5. Downloads compilation artifacts for use with the native (C API) sample application

In [2]:
#@title General setup

import os
import tempfile

ARTIFACTS_DIR = os.path.join(tempfile.gettempdir(), "iree", "colab_artifacts")
os.makedirs(ARTIFACTS_DIR, exist_ok=True)
print(f"Using artifacts directory '{ARTIFACTS_DIR}'")

Using artifacts directory '/tmp/iree/colab_artifacts'


In [3]:
%%capture
!python -m pip install --upgrade tf-nightly  # Needed for stablehlo export in TF>=2.14

In [4]:
import tensorflow as tf

# Print version information for future notebook users to reference.
print("TensorFlow version: ", tf.__version__)

TensorFlow version:  2.15.0-dev20230831


## Create a program using TensorFlow and import it into IREE

This program uses `tf.Variable` to track state internal to the program then exports functions which can be used to interact with that variable.

Note that each function we want to be callable from our compiled program needs
to use `@tf.function` with an `input_signature` specified.

References:

* ["Introduction to Variables" Guide](https://www.tensorflow.org/guide/variable)
* [`tf.Variable` reference](https://www.tensorflow.org/api_docs/python/tf/Variable)
* [`tf.function` reference](https://www.tensorflow.org/api_docs/python/tf/function)

In [5]:
#@title Define a simple "counter" TensorFlow module

class CounterModule(tf.Module):
  def __init__(self):
    super().__init__()
    self.counter = tf.Variable(0)

  @tf.function(input_signature=[])
  def get_value(self):
    return self.counter

  @tf.function(input_signature=[tf.TensorSpec([], tf.int32)])
  def set_value(self, new_value):
    self.counter.assign(new_value)

  @tf.function(input_signature=[tf.TensorSpec([], tf.int32)])
  def add_to_value(self, x):
    self.counter.assign(self.counter + x)

  @tf.function(input_signature=[])
  def reset_value(self):
    self.counter.assign(0)

In [6]:
%%capture
!python -m pip install iree-compiler iree-runtime iree-tools-tf -f https://iree.dev/pip-release-links.html

In [7]:
# Print version information for future notebook users to reference.
!iree-compile --version

IREE (https://iree.dev):
  IREE compiler version 20230831.630 @ 9ed3dab7ac4fcda959f5b8ebbcd7732aeb4b0c8d
  LLVM version 18.0.0git
  Optimized build


In [8]:
#@title Import the TensorFlow program into IREE as MLIR

from IPython.display import clear_output

from iree.compiler import tf as tfc

compiler_module = tfc.compile_module(
    CounterModule(), import_only=True,
    output_mlir_debuginfo=False)
clear_output()  # Skip over TensorFlow's output.

# Save the imported MLIR to disk.
imported_mlirbc_path = os.path.join(ARTIFACTS_DIR, "counter.mlirbc")
with open(imported_mlirbc_path, "wb") as output_file:
  output_file.write(compiler_module)
print(f"Wrote MLIR to path '{imported_mlirbc_path}'")

# Copy MLIR bytecode to MLIR text and see how the compiler views this program.
# Note the 'stablehlo' and 'ml_program' ops and the public (exported) functions.
imported_mlir_path = os.path.join(ARTIFACTS_DIR, "counter.mlir")
!iree-ir-tool copy {imported_mlirbc_path} -o {imported_mlir_path}
print("Counter MLIR:")
!cat {imported_mlir_path}

Wrote MLIR to path '/tmp/iree/colab_artifacts/counter.mlirbc'
Counter MLIR:
module {
  ml_program.global public mutable @vars.__sm_node1__counter(dense<0> : tensor<i32>) : tensor<i32>
  func.func @add_to_value(%arg0: tensor<i32>) {
    %0 = ml_program.global_load @vars.__sm_node1__counter : tensor<i32>
    %1 = stablehlo.add %0, %arg0 : tensor<i32>
    ml_program.global_store @vars.__sm_node1__counter = %1 : tensor<i32>
    return
  }
  func.func @get_value() -> tensor<i32> {
    %0 = ml_program.global_load @vars.__sm_node1__counter : tensor<i32>
    return %0 : tensor<i32>
  }
  func.func @reset_value() {
    %0 = stablehlo.constant dense<0> : tensor<i32>
    ml_program.global_store @vars.__sm_node1__counter = %0 : tensor<i32>
    return
  }
  func.func @set_value(%arg0: tensor<i32>) {
    ml_program.global_store @vars.__sm_node1__counter = %arg0 : tensor<i32>
    return
  }
}

## Test the imported program

_Note: you can stop after each step and use intermediate outputs with other tools outside of Colab._

_See the [README](https://github.com/openxla/iree/tree/main/samples/variables_and_state#changing-compilation-options) for more details and example command line instructions._

* _The "imported MLIR" can be used by IREE's generic compiler tools_
* _The "flatbuffer blob" can be saved and used by runtime applications_

_The specific point at which you switch from Python to native tools will depend on your project._

In [9]:
#@title Compile the imported MLIR further into an IREE VM bytecode module

from iree.compiler import compile_str

flatbuffer_blob = compile_str(compiler_module, target_backends=["vmvx"], input_type="stablehlo")

# Save the compiled program to disk.
flatbuffer_path = os.path.join(ARTIFACTS_DIR, "counter_vmvx.vmfb")
with open(flatbuffer_path, "wb") as output_file:
  output_file.write(flatbuffer_blob)
print(f"Wrote .vmfb to path '{flatbuffer_path}'")

Wrote .vmfb to path '/tmp/iree/colab_artifacts/counter_vmvx.vmfb'


In [10]:
#@title Test running the compiled VM module using IREE's runtime

from iree import runtime as ireert

config = ireert.Config("local-task")
ctx = ireert.SystemContext(config=config)
vm_module = ireert.VmModule.from_flatbuffer(ctx.instance, flatbuffer_blob)
ctx.add_vm_module(vm_module)

  vm_module = ireert.VmModule.from_flatbuffer(ctx.instance, flatbuffer_blob)


In [14]:
# Our @tf.functions are accessible by name on the module named 'module'
counter = ctx.modules.module

# These are buggy in Python but should still work from C
# TODO(scotttodd): figure out why and fix

# print(counter.get_value().to_host())
# counter.set_value(101)
# print(counter.get_value().to_host())

# counter.add_to_value(20)
# print(counter.get_value().to_host())
# counter.add_to_value(-50)
# print(counter.get_value().to_host())

# counter.reset_value()
# print(counter.get_value().to_host())

## Download compilation artifacts

In [12]:
ARTIFACTS_ZIP = "/tmp/variables_and_state_colab_artifacts.zip"

print(f"Zipping '{ARTIFACTS_DIR}' to '{ARTIFACTS_ZIP}' for download...")
!cd {ARTIFACTS_DIR} && zip -r {ARTIFACTS_ZIP} .

# Note: you can also download files using Colab's file explorer
try:
  from google.colab import files
  print("Downloading the artifacts zip file...")
  files.download(ARTIFACTS_ZIP)
except ImportError:
  print("Missing google_colab Python package, can't download files")

Zipping '/tmp/iree/colab_artifacts' to '/tmp/variables_and_state_colab_artifacts.zip' for download...
  adding: counter.mlir (deflated 71%)
  adding: counter_vmvx.vmfb (deflated 65%)
  adding: counter.mlirbc (deflated 29%)
Downloading the artifacts zip file...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>