Notebook for generating StableHLO mlir file for simple function. Meant to be run in a colab environment and jax needs to be reinstalled to have the latest version

In [1]:
!pip uninstall jax

Found existing installation: jax 0.4.26
Uninstalling jax-0.4.26:
  Would remove:
    /usr/local/lib/python3.10/dist-packages/jax-0.4.26.dist-info/*
    /usr/local/lib/python3.10/dist-packages/jax/*
Proceed (Y/n)? Y
  Successfully uninstalled jax-0.4.26


In [2]:
!pip install jax

Collecting jax
  Downloading jax-0.4.30-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jaxlib<=0.4.30,>=0.4.27 (from jax)
  Downloading jaxlib-0.4.30-cp310-cp310-manylinux2014_x86_64.whl (79.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.6/79.6 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.26+cuda12.cudnn89
    Uninstalling jaxlib-0.4.26+cuda12.cudnn89:
      Successfully uninstalled jaxlib-0.4.26+cuda12.cudnn89
Successfully installed jax-0.4.30 jaxlib-0.4.30


In [3]:
import jax

In [4]:
def f(x, y): return 2 * x + y + 2

In [5]:
x, y = 3, 4

JAX has functionality for lowering functions and then running the compiled code. If you want to use the lowered function output later however, use the export module because that is guaranteed to have a serialized version (https://jax.readthedocs.io/en/latest/aot.html, https://jax.readthedocs.io/en/latest/export/export.html#support-for-reverse-mode-ad). Examples of both below

In [6]:
lowered = jax.jit(f).lower(x, y)

In [7]:
print(lowered.as_text())

module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %c = stablehlo.constant dense<2> : tensor<i32>
    %0 = stablehlo.multiply %c, %arg0 : tensor<i32>
    %1 = stablehlo.add %0, %arg1 : tensor<i32>
    %c_0 = stablehlo.constant dense<2> : tensor<i32>
    %2 = stablehlo.add %1, %c_0 : tensor<i32>
    return %2 : tensor<i32>
  }
}



In [8]:
import re
import numpy as np
import jax
from jax import export

In [9]:
def f(x): return 2 * x * x

In [10]:
exp = export.export(jax.jit(f))(
   jax.ShapeDtypeStruct((), np.float32))


In [11]:
exp.mlir_module()

'#loc1 = loc("x")\nmodule @jit_f attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n  func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"} loc("x")) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {\n    %cst = stablehlo.constant dense<2.000000e+00> : tensor<f32> loc(#loc)\n    %0 = stablehlo.multiply %cst, %arg0 : tensor<f32> loc(#loc31)\n    %1 = stablehlo.multiply %0, %arg0 : tensor<f32> loc(#loc31)\n    return %1 : tensor<f32> loc(#loc)\n  } loc(#loc)\n} loc(#loc)\n#loc = loc(unknown)\n#loc2 = loc("<ipython-input-9-018d23f7472d>":1:0)\n#loc3 = loc("<ipython-input-10-c708c2053347>":1:0)\n#loc4 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3553:0)\n#loc5 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3473:0)\n#loc6 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3257:0)\n#loc7 = loc