Notebook for generating the mlir for a forward pass of resnet18

In [1]:
!pip install -U jax jaxlib flax transformers

Collecting jax
  Downloading jax-0.4.30-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
Collecting jaxlib
  Downloading jaxlib-0.4.30-cp310-cp310-manylinux2014_x86_64.whl (79.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.6/79.6 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
Collecting flax
  Downloading flax-0.8.5-py3-none-any.whl (731 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m731.3/731.3 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
Collecting transformers
  Downloading transformers-4.42.0-py3-none-any.whl (9.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.3/9.3 MB[0m [31m18.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxlib, jax, transformers, flax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.26+cuda12.cudnn89
    Uninstalling jaxlib-0.4.26+cuda12.cudnn89:
      S

In [2]:
from jax._src.interpreters import mlir as jax_mlir
from jax._src.lib.mlir import ir

# Returns prettyprint of StableHLO module without large constants
def get_stablehlo_asm(module_str):
  with jax_mlir.make_ir_context():
    stablehlo_module = ir.Module.parse(module_str, context=jax_mlir.make_ir_context())
    return stablehlo_module.operation.get_asm(large_elements_limit=20)

# Disable logging for better tutorial rendering
import logging
logging.disable(logging.WARNING)

In [3]:
import jax
from jax import export
import numpy as np

def plus(x,y,z):
  a = x + y
  return a + z

exp = export.export(jax.jit(plus))(
   jax.ShapeDtypeStruct((), np.float32), jax.ShapeDtypeStruct((), np.float32), jax.ShapeDtypeStruct((), np.float32)).mlir_module()
print(get_stablehlo_asm(exp))
# print(exp)

module @jit_plus attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}, %arg1: tensor<f32> {mhlo.layout_mode = "default"}, %arg2: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    %1 = stablehlo.add %0, %arg2 : tensor<f32>
    return %1 : tensor<f32>
  }
}



In [5]:
from transformers import AutoImageProcessor, FlaxResNetModel
import jax
import numpy as np
from jax.experimental import export

# Construct flax model with sample inputs

resnet18 = FlaxResNetModel.from_pretrained("microsoft/resnet-18", return_dict=False)
sample_input = np.random.randn(1, 3, 224, 224)
input_shape = jax.ShapeDtypeStruct(sample_input.shape, sample_input.dtype)

# Export to StableHLO
stablehlo_resnet18_export = export.export(resnet18)(input_shape)
resnet18_stablehlo = get_stablehlo_asm(stablehlo_resnet18_export.mlir_module())
print()
# print(resnet18_stablehlo[:600], "\n...\n", resnet18_stablehlo[-345:])
print(resnet18_stablehlo)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/69.5k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

  stablehlo_resnet18_export = export.export(resnet18)(input_shape)



module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1x3x224x224xf32> {mhlo.layout_mode = "default"}) -> (tensor<1x512x7x7xf32> {mhlo.layout_mode = "default"}, tensor<1x512x1x1xf32> {mhlo.layout_mode = "default"}) {
    %cst = stablehlo.constant dense_resource<__elided__> : tensor<7x7x3x64xf32>
    %cst_0 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>
    %cst_1 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>
    %cst_2 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>
    %cst_3 = stablehlo.constant dense_resource<__elided__> : tensor<64xf32>
    %cst_4 = stablehlo.constant dense_resource<__elided__> : tensor<3x3x64x64xf32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<64xf32>
    %cst_6 = stablehlo.constant dense<1.000000e+00> : tensor<64xf32>
    %cst_7 = stablehlo.constant dense<1.00