##### Copyright 2023 The IREE Authors

In [1]:
#@title Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png" height="20px"> PyTorch Ahead-of-time (AOT) export workflows using <img src="https://raw.githubusercontent.com/openxla/iree/main/docs/website/overrides/.icons/iree/ghost.svg" height="20px"> IREE

This notebook shows how to use [SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) for export from a PyTorch session to [IREE](https://github.com/openxla/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.

SHARK-Turbine contains both a "simple" AOT exporter and an underlying advanced
API for complicated models and full feature availability. This notebook only
uses the "simple" exporter.

## Setup

In [2]:
%%capture
#@title Uninstall existing packages
#   This avoids some warnings when installing specific PyTorch packages below.
!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision

In [3]:
#@title Install SHARK-Turbine

# Limit cell height.
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

!python -m pip install shark-turbine

<IPython.core.display.Javascript object>

Collecting shark-turbine
  Downloading shark-turbine-0.9.1.dev3.tar.gz (60 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/60.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━[0m [32m51.2/60.2 kB[0m [31m1.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.2/60.2 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting iree-compiler>=20231004.665 (from shark-turbine)
  Downloading iree_compiler-20231004.665-cp310-cp310-manylinux_2_28_x86_64.whl (57.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.2/57.2 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting iree-runtime>=20231004.665 (from shark-turbine)
  Downloading iree_runtime-20231004.6

In [4]:
#@title Report version information
!echo "Installed SHARK-Turbine, $(python -m pip show shark_turbine | grep Version)"

!echo -e "\nInstalled IREE, compiler version information:"
!iree-compile --version

import torch
print("\nInstalled PyTorch, version:", torch.__version__)

Installed SHARK-Turbine, Version: 0.9.1.dev3

Installed IREE, compiler version information:
IREE (https://iree.dev):
  IREE compiler version 20231004.665 @ bb51f6f1a1b4ee619fb09a7396f449dadb211447
  LLVM version 18.0.0git
  Optimized build

Installed PyTorch, version: 2.1.0+cu121


## Sample AOT workflow

1. Define a program using `torch.nn.Module`
2. Export the program using `aot.export()`
3. Compile to a deployable artifact
  * a: By staying within a Python session
  * b: By outputting MLIR and continuing using native tools

Useful documentation:

* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation
* IREE compiler and runtime [Python bindings](https://www.iree.dev/reference/bindings/python/)

In [5]:
#@title 1. Define a program using `torch.nn.Module`
torch.manual_seed(0)

class LinearModule(torch.nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
    self.bias = torch.nn.Parameter(torch.randn(out_features))

  def forward(self, input):
    return (input @ self.weight) + self.bias

linear_module = LinearModule(4, 3)

In [6]:
#@title 2. Export the program using `aot.export()`
import shark_turbine.aot as aot

example_arg = torch.randn(4)
export_output = aot.export(linear_module, example_arg)

In [7]:
#@title 3a. Compile fully to a deployable artifact, in our existing Python session

# Staying in Python gives the API a chance to reuse memory, improving
# performance when compiling large programs.

compiled_binary = export_output.compile(save_to=None)

# Use the IREE runtime API to test the compiled program.
import numpy as np
import iree.runtime as ireert

config = ireert.Config("local-task")
vm_module = ireert.load_vm_module(
    ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),
    config,
)

input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
result = vm_module.main(input)
print(result.to_host())

[ 1.4178504 -1.2343317 -7.4767947]


In [8]:
#@title 3b. Output MLIR then continue from Python or native tools later

# Leaving Python allows for file system checkpointing and grants access to
# native development workflows.

mlir_file_path = "/tmp/linear_module_pytorch.mlirbc"
vmfb_file_path = "/tmp/linear_module_pytorch_llvmcpu.vmfb"

export_output.print_readable()
export_output.save_mlir(mlir_file_path)

!iree-compile --iree-input-type=torch --iree-hal-target-backends=llvm-cpu {mlir_file_path} -o {vmfb_file_path}
!iree-run-module --module={vmfb_file_path} --device=local-task --input="4xf32=[1.0, 2.0, 3.0, 4.0]"

module @LinearModule {
  util.global private @_params.weight {noinline} = dense<[[1.54099607, -0.293428898, -2.17878938], [0.568431258, -1.08452237, -1.39859545], [0.403346837, 0.838026344, -0.719257593], [-0.403343529, -0.596635341, 0.182036489]]> : tensor<4x3xf32>
  util.global private @_params.bias {noinline} = dense<[-0.856674611, 1.10060418, -1.07118738]> : tensor<3xf32>
  func.func @main(%arg0: tensor<4xf32>) -> tensor<3xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} {
    %0 = torch_c.from_builtin_tensor %arg0 : tensor<4xf32> -> !torch.vtensor<[4],f32>
    %1 = call @forwa