# ONNX to Snitch

This notebook uses Marimo, a Jupyter-like notebook with interactive UI elements and reactive state.

In [None]:
rank = mo.ui.slider(1, 4, value=2, label="Rank")

mo.md(
    f"""
    For example, here is a slider, which can take on values from 1 to 4.

    {rank}
    """
)

In [None]:
shape = tuple(range(2, 2 + rank.value))

mo.md(
    f"""
    We use the slider to determine the shape of our inputs and outputs:

    ```
    A: {'x'.join(str(dim) for dim in shape)}xf64
    B: {'x'.join(str(dim) for dim in shape)}xf64
    C: {'x'.join(str(dim) for dim in shape)}xf64
    ```
    """
)

In [None]:
mo.md(
    f"""
    ### The ONNX model

    We use the ONNX API to build a simple function, one that returns the elementwise sum of two arrays of shape {shape}
    """
)

In [None]:
import onnx
from onnx import AttributeProto, GraphProto, TensorProto, ValueInfoProto, helper

<span class="codehilite"><div class="highlight"><pre><span></span><span class="gt">Traceback (most recent call last):</span>
  File <span class="nb">&quot;/Users/sasha/Developer/xdslproject/xdsl/.venv/lib/python3.12/site-packages/marimo/_runtime/executor.py&quot;</span>, line <span class="m">141</span>, in <span class="n">execute_cell</span>
<span class="w">    </span><span class="n">exec</span><span class="p">(</span><span class="n">cell</span><span class="o">.</span><span class="n">body</span><span class="p">,</span> <span class="n">glbls</span><span class="p">)</span>
  File <span class="nb">&quot;/var/folders/84/ql679qw90tdc6pkg78v59jl40000gn/T/marimo_84608/__marimo__cell_lEQa_.py&quot;</span>, line <span class="m">1</span>, in <span class="n">&lt;module&gt;</span>
<span class="w">    </span><span class="kn">import</span><span class="w"> </span><span class="nn">onnx</span>
<span class="gr">ModuleNotFoundError</span>: <span class="n">No module named &#39;onnx&#39;</span>
</pre></div

In [None]:
# Create one input (ValueInfoProto)
X1 = helper.make_tensor_value_info("X1", TensorProto.DOUBLE, shape)
X2 = helper.make_tensor_value_info("X2", TensorProto.DOUBLE, shape)

# Create one output (ValueInfoProto)
Y = helper.make_tensor_value_info("Y", TensorProto.DOUBLE, shape)

# Create a node (NodeProto) - This is based on Pad-11
node_def = helper.make_node(
    "Sub",  # node name
    ["X1", "X2"],  # inputs
    ["Y"],  # outputs
)

# Create the graph (GraphProto)
graph_def = helper.make_graph(
    [node_def],
    "main_graph",
    [X1, X2],
    [Y],
)

# Set opset version to 18
opset_import = [helper.make_operatorsetid("", 18)]

# Create the model (ModelProto) without using helper.make_model
model_def = helper.make_model(
    graph_def, producer_name="onnx-example", opset_imports=opset_import
)

onnx.checker.check_model(model_def)

ONNX uses a serialized binary format for neural networks, but can also print a string format, which can be useful for debugging.
Here is the textual format of our model:

In [None]:
mo.accordion(
    {
        "ONNX Graph": mo.plain_text(f"{model_def}"),
    }
)

In [None]:
mo.md(f"""
### Converting to `linalg`

Here is the xDSL representation of the function, it takes two `tensor` values of our chosen shape, passes them as operands to the `onnx.Add` operation, and returns it:

{xmo.module_html(init_module)}
"""
)

In [None]:
init_module = build_module(model_def.graph)

In [None]:
ctx = MLContext()

for dialect_name, dialect_factory in get_all_dialects().items():
    ctx.register_dialect(dialect_name, dialect_factory)

xDSL seamlessly interoperates with MLIR, we the `mlir-opt` tool to compile the input to a form that we want to process:

In [None]:
bufferized_ctx, bufferized_module, linalg_html = xmo.pipeline_html(
    ctx,
    init_module,
    (
        (
            mo.md(
                """\
We can use a pass implemented in xDSL to convert the ONNX operations to builtin operations, here we can use the `tensor.empty` op to create our output buffer, and `linalg.add` to represent the addition in destination-passing style:
"""
            ),
            ConvertOnnxToLinalgPass()
        ),
        (
            mo.md(
                """
We can also call into MLIR, here to convert `linalg.add` to `linalg.generic`, a representation of Einstein summation:
"""
            ),
            MLIROptPass(
                generic=False,
                arguments=["--linalg-generalize-named-ops"]
            )
        ),
        (
            mo.md(
                """We prepare the result tensors for bufferization:"""
            ),
            EmptyTensorToAllocTensorPass()
        ),
        (
            mo.md(
                """We then use MLIR to bufferize our function:"""
            ),
            MLIROptPass(
                arguments=[
                    "--one-shot-bufferize=bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map",
                ]
            )
        )
    )
)

linalg_html

From here we can use a number of backends to generate executable code, like LLVM, or RISC-V assembly directly.
Please see other notebooks for details

In [None]:
from xdsl.context import MLContext
from xdsl.frontend.onnx.ir_builder import build_module
from xdsl.ir import Attribute, SSAValue
from xdsl.passes import PipelinePass
from xdsl.tools.command_line_tool import get_all_dialects
from xdsl.transforms.convert_onnx_to_linalg import ConvertOnnxToLinalgPass
from xdsl.transforms.empty_tensor_to_alloc_tensor import EmptyTensorToAllocTensorPass
from xdsl.transforms.mlir_opt import MLIROptPass

<span class="codehilite"><div class="highlight"><pre><span></span><span class="gt">Traceback (most recent call last):</span>
  File <span class="nb">&quot;/Users/sasha/Developer/xdslproject/xdsl/.venv/lib/python3.12/site-packages/marimo/_runtime/executor.py&quot;</span>, line <span class="m">141</span>, in <span class="n">execute_cell</span>
<span class="w">    </span><span class="n">exec</span><span class="p">(</span><span class="n">cell</span><span class="o">.</span><span class="n">body</span><span class="p">,</span> <span class="n">glbls</span><span class="p">)</span>
  File <span class="nb">&quot;/var/folders/84/ql679qw90tdc6pkg78v59jl40000gn/T/marimo_84608/__marimo__cell_iLit_.py&quot;</span>, line <span class="m">2</span>, in <span class="n">&lt;module&gt;</span>
<span class="w">    </span><span class="kn">from</span><span class="w"> </span><span class="nn">xdsl.frontend.onnx.ir_builder</span><span class="w"> </span><span class="kn">import</span> <span class="n">build_module</spa

In [None]:
import marimo as mo

In [None]:
import xdsl.utils.marimo as xmo