# torch.mlir.export

This notebook shows an potential API for `torch_mlir`, modeled after `torch.onnx.export` and serving similar use cases. It is implemented as a pure-Python dependency that mostly provides a nice wrapper around `torch_mlir`. We probably want a function like this in `torch_mlir` anyway, but the real point under discussion here is how to get some "official" blessing of `torch_mlir` exported via the PyTorch `torch.mlir` Python module (or determining that it isn't even desirable to do that).

Anush has a [PR](https://github.com/pytorch/pytorch/pull/65880) to add Torch-MLIR as a CMake `ExternalProject_Add`, but I didn't use that here because it feels like that is more for if we needed PyTorch to depend at the C++ level on Torch-MLIR. But it doesn't seem like we need that. It suffices to have a `torch_mlir` Python package built against the right PyTorch.

## PYTHONPATH setup

For this notebook, I've hardcoded PYTHONPATH to point at a `torch_mlir` that I built locally. So one of the big questions is how to deal with the Torch dependency on `torch_mlir`.

This is one of the things that we could use the most guidance on -- how to best integrate with PT at the Python level. It seems like one good outcome would be that PyTorch's officially provided Torch/MLIR interop story (i.e. `torch.mlir` python package) would light up when the `torch_mlir` Python package is installed, but otherwise report a "feature unavailable" sort of error. **Would that be acceptable to PT devs?**

Some topics adjacent to that:
- What support matrix of {Python version}x{PyTorch version}x{OS}x{...} of packages do we need? (Presumably torch/xla and other PT-depending projects have crossed this path before... any pointers would be very welcome)
- How to hook into the appropriate PT release channels? (if desirable)

In [16]:
import sys
sys.path += ["/usr/local/google/home/silvasean/pg/torch-mlir/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir"]

## Common setup

This and the ONNX export is taken almost verbatim from https://pytorch.org/docs/master/onnx.html

In [9]:
import torch
import torchvision

dummy_input = torch.randn(10, 3, 224, 224)
model = torchvision.models.resnet18(pretrained=True)
model.eval();

# torch.onnx.export

In [17]:
torch.onnx.export(model, dummy_input, "resnet.onnx", verbose=True, export_params=True)

graph(%input.1 : Float(10, 3, 224, 224, strides=[150528, 50176, 224, 1], requires_grad=0, device=cpu),
      %fc.weight : Float(1000, 512, strides=[512, 1], requires_grad=1, device=cpu),
      %fc.bias : Float(1000, strides=[1], requires_grad=1, device=cpu),
      %193 : Float(64, 3, 7, 7, strides=[147, 49, 7, 1], requires_grad=0, device=cpu),
      %194 : Float(64, strides=[1], requires_grad=0, device=cpu),
      %196 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
      %197 : Float(64, strides=[1], requires_grad=0, device=cpu),
      %199 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
      %200 : Float(64, strides=[1], requires_grad=0, device=cpu),
      %202 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
      %203 : Float(64, strides=[1], requires_grad=0, device=cpu),
      %205 : Float(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=0, device=cpu),
      %206 : Float(64, strides=[1], requir

## torch.mlir.export

This is the same almost-one-liner from the ONNX exporter (we are missing a bit of the sugar that `torch.onnx.export` provides, but it is easy to add that later).

In [18]:
import torch.mlir
mlir_module = torch.mlir.export(torch.jit.script(model), [dummy_input])
print(mlir_module.operation.get_asm(large_elements_limit=10))

module attributes {torch.debug_module_name = "ResNet"}  {
  func @forward(%arg0: !torch.vtensor<[10,3,224,224],f32>) -> !torch.vtensor<[?,?],f32> {
    %int-1 = torch.constant.int -1
    %int3 = torch.constant.int 3
    %true = torch.constant.bool true
    %int0 = torch.constant.int 0
    %float1.000000e-05 = torch.constant.float 1.000000e-05
    %float1.000000e-01 = torch.constant.float 1.000000e-01
    %int2 = torch.constant.int 2
    %int1 = torch.constant.int 1
    %0 = torch.vtensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<64x3x7x7xf32>) : !torch.vtensor<[64,3,7,7],f32>
    %1 = torch.vtensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<64xf32>) : !torch.vtensor<[64],f32>
    %2 = torch.vtensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<64xf32>) : !torch.vtensor<[64],f32>
    %3 = torch.vtensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<64xf32>) : !torch.vtensor<[64],f32>
    %4 = torch.vtensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<64xf32>) : !torch.vtensor<[64],f32>
    %5 =