##### Copyright 2023 The IREE Authors

In [1]:
#@title Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# <img src="https://upload.wikimedia.org/wikipedia/commons/thumb/1/10/PyTorch_logo_icon.svg/640px-PyTorch_logo_icon.svg.png" height="20px"> PyTorch Ahead-of-time (AOT) export workflows using <img src="https://raw.githubusercontent.com/openxla/iree/main/docs/website/overrides/.icons/iree/ghost.svg" height="20px"> IREE

This notebook shows how to use [SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) for export from a PyTorch session to [IREE](https://github.com/openxla/iree), leveraging [torch-mlir](https://github.com/llvm/torch-mlir) under the covers.

SHARK-Turbine contains both a "simple" AOT exporter and an underlying advanced
API for complicated models and full feature availability. This notebook only
uses the "simple" exporter.

## Setup

In [1]:
%%capture
#@title Uninstall existing packages
#   This avoids some warnings when installing specific PyTorch packages below.
!python -m pip uninstall -y fastai torchaudio torchdata torchtext torchvision

In [2]:
#@title Install SHARK-Turbine

# Limit cell height.
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

!python -m pip install shark-turbine

<IPython.core.display.Javascript object>



In [3]:
#@title Report version information
!echo "Installed SHARK-Turbine, $(python -m pip show shark_turbine | grep Version)"

!echo -e "\nInstalled IREE, compiler version information:"
!iree-compile --version

import torch
print("\nInstalled PyTorch, version:", torch.__version__)

Installed SHARK-Turbine, Version: 0.9.2

Installed IREE, compiler version information:
IREE (https://iree.dev):
  IREE compiler version 20231113.707 @ e8c6432ee14e1d4bd917be8505465e2c96b94e28
  LLVM version 18.0.0git
  Optimized build

Installed PyTorch, version: 2.1.2+cu121


## Sample AOT workflow

1. Define a program using `torch.nn.Module`
2. Export the program using `aot.export()`
3. Compile to a deployable artifact
  * a: By staying within a Python session
  * b: By outputting MLIR and continuing using native tools

Useful documentation:

* [PyTorch Modules](https://pytorch.org/docs/stable/notes/modules.html) (`nn.Module`) as building blocks for stateful computation
* IREE compiler and runtime [Python bindings](https://www.iree.dev/reference/bindings/python/)

In [4]:
#@title 1. Define a program using `torch.nn.Module`
# torch.manual_seed(0)

# class LinearModule(torch.nn.Module):
#   def __init__(self, in_features, out_features):
#     super().__init__()
#     self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
#     self.bias = torch.nn.Parameter(torch.randn(out_features))

#   def forward(self, input):
#     return (input @ self.weight) + self.bias

# linear_module = LinearModule(4, 3)


import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 卷积层 C1: 输入通道=1, 输出通道=6, 卷积核大小=5x5
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        # 池化层 S2: 使用最大池化, 池化大小=2x2, 步长=2
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 卷积层 C3: 输入通道=6, 输出通道=16, 卷积核大小=5x5
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        # 池化层 S4: 使用最大池化, 池化大小=2x2, 步长=2
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 全连接层 F5: 输入特征=16*5*5, 输出特征=120
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        # 全连接层 F6: 输入特征=120, 输出特征=84
        self.fc2 = nn.Linear(120, 84)
        # 输出层: 输入特征=84, 输出特征=10 (假设有10个分类)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # 通过 C1, 激活函数, S2
        x = self.pool1(F.relu(self.conv1(x)))
        # 通过 C3, 激活函数, S4
        x = self.pool2(F.relu(self.conv2(x)))
        # 展平特征图，准备进入全连接层
        x = x.view(-1, 16 * 5 * 5)
        # 通过 F5, 激活函数
        x = F.relu(self.fc1(x))
        # 通过 F6, 激活函数
        x = F.relu(self.fc2(x))
        # 通过输出层
        x = self.fc3(x)
        return x

# 创建 LeNet 模型实例
lenet = LeNet()

# 打印模型结构
print(lenet)


LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [5]:
#@title 2. Export the program using `aot.export()`
import shark_turbine.aot as aot

example_arg = torch.randn([1, 1, 32, 32])
export_output = aot.export(lenet, example_arg)
print(type(export_output))

<class 'shark_turbine.aot.exporter.ExportOutput'>


  return torch._C._cuda_getDeviceCount() > 0


In [6]:
#@title 3a. Compile fully to a deployable artifact, in our existing Python session

# Staying in Python gives the API a chance to reuse memory, improving
# performance when compiling large programs.

compiled_binary = export_output.compile(save_to=None)

# Use the IREE runtime API to test the compiled program.
import numpy as np
import iree.runtime as ireert

config = ireert.Config("local-task")
vm_module = ireert.load_vm_module(
    ireert.VmModule.wrap_buffer(config.vm_instance, compiled_binary.map_memory()),
    config,
)

# input = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)
input = np.ones((1, 1, 32, 32)).astype(np.float32)
result = vm_module.main(input)
print(result.to_host()[...,:10])

[[-0.06582566 -0.08111075  0.10825731 -0.13008909 -0.05546051 -0.06898132
   0.10918593  0.11239541  0.07509246  0.05092049]]


In [8]:
#@title 3b. Output MLIR then continue from Python or native tools later

# Leaving Python allows for file system checkpointing and grants access to
# native development workflows.

mlir_file_path = "./lenet_pytorch.mlir"
vmfb_file_path = "/tmp/linear_module_pytorch_llvmcpu.vmfb"

# export_output.print_readable()
export_output.save_mlir(mlir_file_path)

!iree-compile --iree-input-type=torch --iree-hal-target-backends=llvm-cpu {mlir_file_path} -o {vmfb_file_path}
!iree-run-module --module={vmfb_file_path} --device=local-task --input="1x1x32x32xf32=1.0"

EXEC @main
result[0]: hal.buffer_view
1x10xf32=[-0.0658257 -0.0811108 0.108257 -0.130089 -0.0554605 -0.0689813 0.109186 0.112395 0.0750925 0.0509205]
