##### Copyright 2023 The IREE Authors

In [1]:
#@title Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png" height="20px"> PyTorch Ahead-of-time (AOT) export workflows using <img src="https://raw.githubusercontent.com/openxla/iree/main/docs/website/overrides/.icons/iree/ghost.svg" height="20px"> IREE

This notebook shows how to use [SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) for export from a PyTorch session to [IREE](https://github.com/openxla/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.

SHARK-Turbine contains both a "simple" AOT exporter and an underlying advanced
API for complicated models and full feature availability. This notebook only
uses the "simple" exporter.

## Setup

In [1]:
%%capture
#@title Uninstall existing packages
#   This avoids some warnings when installing specific PyTorch packages below.
!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision

In [2]:
#@title Install SHARK-Turbine

# Limit cell height.
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

!python -m pip install shark-turbine

<IPython.core.display.Javascript object>



In [2]:
#@title Report version information
!echo "Installed SHARK-Turbine, $(python -m pip show shark_turbine | grep Version)"

!echo -e "\nInstalled IREE, compiler version information:"
!iree-compile --version

import torch
import torchvision
print("\nInstalled PyTorch, version:", torch.__version__)

Installed SHARK-Turbine, Version: 0.9.2

Installed IREE, compiler version information:
IREE (https://iree.dev):
  IREE compiler version 20231113.707 @ e8c6432ee14e1d4bd917be8505465e2c96b94e28
  LLVM version 18.0.0git
  Optimized build

Installed PyTorch, version: 2.1.2+cu121


## Sample AOT workflow

1. Define a program using `torch.nn.Module`
2. Export the program using `aot.export()`
3. Compile to a deployable artifact
  * a: By staying within a Python session
  * b: By outputting MLIR and continuing using native tools

Useful documentation:

* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation
* IREE compiler and runtime [Python bindings](https://www.iree.dev/reference/bindings/python/)

In [3]:
#@title 1. Define a program using `torch.nn.Module`
# torch.manual_seed(0)

# class LinearModule(torch.nn.Module):
#   def __init__(self, in_features, out_features):
#     super().__init__()
#     self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
#     self.bias = torch.nn.Parameter(torch.randn(out_features))

#   def forward(self, input):
#     return (input @ self.weight) + self.bias

# linear_module = LinearModule(4, 3)


model_name = "alexnet"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()



In [4]:
#@title 2. Export the program using `aot.export()`
import shark_turbine.aot as aot

example_arg = torch.randn([1, 3, 224, 224])
export_output = aot.export(model, example_arg)
print(type(export_output))

RuntimeError: Given input size: (192x((s1 - 7)//8)x((s1 - 7)//8)). Calculated output size: (192x((((s1 - 7)//8) - 3)//2) + 1x((((s1 - 7)//8) - 3)//2) + 1). Output size is too small

In [32]:
#@title 3a. Compile fully to a deployable artifact, in our existing Python session

# Staying in Python gives the API a chance to reuse memory, improving
# performance when compiling large programs.

compiled_binary = export_output.compile(save_to=None)

# Use the IREE runtime API to test the compiled program.
import numpy as np
import iree.runtime as ireert

config = ireert.Config("local-task")
vm_module = ireert.load_vm_module(
    ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),
    config,
)

# input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
input = np.ones((1, 3, 224, 224)).astype(np.float32)
result = vm_module.main(input)
print(result.to_host()[...,:10])

[[-0.72738683  0.7084101  -1.8970113  -2.2323952  -1.0149305  -0.04174495
  -3.0588696  -0.30246565 -2.070094    0.08365561]]


In [35]:
#@title 3b. Output MLIR then continue from Python or native tools later

# Leaving Python allows for file system checkpointing and grants access to
# native development workflows.

mlir_file_path = "./linear_module_pytorch.mlir"
vmfb_file_path = "/tmp/linear_module_pytorch_llvmcpu.vmfb"

# export_output.print_readable()
export_output.save_mlir(mlir_file_path)

!iree-compile --iree-input-type=torch --iree-hal-target-backends=llvm-cpu {mlir_file_path} -o {vmfb_file_path}
!iree-run-module --module={vmfb_file_path} --device=local-task --input="1x3x224x224xf32=1.0"

EXEC @main
result[0]: hal.buffer_view
1x1000xf32=[-0.727387 0.70841 -1.89701 -2.2324 -1.01493 -0.041745 -3.05887 -0.302466 -2.07009 0.0836556 1.27432 0.252055 -0.154978 -0.29339 -0.801378 -0.599291 0.703493 -0.997593 -0.551698 -0.320853 -0.0638457 2.43159 0.802407 -0.136691 -0.418823 -1.04191 -0.212887 -0.247344 -0.160357 -0.113097 -1.86739 -0.443473 -1.21912 -0.96456 -1.09505 -2.09367 -0.0159107 -1.72138 1.55908 -1.37779 -1.24785 -0.38931 1.75551 1.64676 -0.706678 -0.670621 -0.947005 0.0929498 -2.29975 -1.75184 -1.81186 -0.586403 0.793219 -0.0113079 -0.955984 -1.51173 -0.706362 -1.67504 -0.443265 0.918081 -1.29618 -1.469 0.143649 1.23974 0.610443 -0.105185 0.455777 -1.77571 0.434015 -1.20014 0.335861 0.61077 -0.0843472 2.11978 -0.553347 1.91601 -1.66834 0.176666 2.20879 0.909475 1.95212 1.18162 -0.609748 0.528536 -1.88456 -0.462378 -0.186048 0.413895 -0.0721481 -0.425943 -0.615624 1.25467 1.09126 -0.852083 -0.44722 0.317173 -0.211035 -2.05171 1.2311 -1.09187 -0.551555 -1.76714 -1.0127