 - hash name
 - get python types from function type
 - add a compile function that takes (name, input types, output types) and returns a callable

In [1]:
n = 1000000

In [2]:
from pathlib import Path
from xdsl.builder import ImplicitBuilder

from xdsl.dialects import func, arith, memref, scf
from xdsl.dialects.builtin import ModuleOp, i64, f64, IndexType
from xdsl.ir.core import Block, Region

index = IndexType()

def get_module() -> ModuleOp:
    memref_t = memref.MemRefType.from_element_type_and_shape(f64, [n])
    module = ModuleOp([])
    with ImplicitBuilder(module.body):
        f = func.FuncOp("hello", ((i64, i64), (i64,)))
        with ImplicitBuilder(f.body) as (lhs, rhs):
            res = arith.Addi(lhs, rhs).result
            func.Return(res)

        g = func.FuncOp("dist", ((memref_t, memref_t), (f64,)))
        with ImplicitBuilder(g.body) as (lhs, rhs):
            zero = arith.Constant.from_int_and_width(0, index).result
            one = arith.Constant.from_int_and_width(1, index).result
            hundo = arith.Constant.from_int_and_width(n, index).result

            initial = arith.Constant.from_float_and_width(0.0, f64).result

            body = Region(Block(arg_types=(index, f64)))
            norm_squared = scf.For.get(zero, hundo, one, (initial,), body)
            with ImplicitBuilder(norm_squared.body) as (i, acc):
                lhs_el = memref.Load.get(lhs, (i,)).res
                rhs_el = memref.Load.get(rhs, (i,)).res
                sq = arith.Mulf(lhs_el, rhs_el).result
                new_acc = arith.Addf(acc, sq)
                scf.Yield.get(new_acc)
            
            func.Return(norm_squared)

    return module

m = get_module()
m.verify()

str(get_module())

with open(Path() / "input.mlir", 'w') as f:
    f.write(str(get_module()))

In [11]:
from xdsl.jit import jit_module
import ctypes

dbl_ptr_type = ctypes.POINTER(ctypes.c_double)

hey = jit_module(get_module(), "hello", types=((int, int), int))
dist_mlir = jit_module(get_module(), "dist", types=((dbl_ptr_type, dbl_ptr_type), float))
hey(5,6)

11

In [12]:
import numpy as np

In [13]:

a = np.random.random(n)
b = np.random.random(n)

In [14]:
a_data_ptr = a.ctypes.data_as(dbl_ptr_type)
b_data_ptr = b.ctypes.data_as(dbl_ptr_type)
%time dist_mlir(a_data_ptr, b_data_ptr)

CPU times: user 8.59 ms, sys: 1.87 ms, total: 10.5 ms
Wall time: 8.96 ms


249970.96343905377

In [7]:
%time hey(5, 6)

CPU times: user 13 µs, sys: 59 µs, total: 72 µs
Wall time: 10 µs


11

In [8]:
%time 1 + 2

CPU times: user 1 µs, sys: 1 µs, total: 2 µs
Wall time: 1.67 µs


3

In [9]:
from typing import Any


def dist_np(lhs: np.ndarray[Any, np.dtype[np.float64]], rhs: np.ndarray[Any, np.dtype[np.float64]]) -> np.float64:
    return lhs.dot(rhs)

%time dist_np(a, b)

CPU times: user 1.03 ms, sys: 2.83 ms, total: 3.86 ms
Wall time: 520 µs


250036.2650088617

In [10]:
def dist_python(lhs: np.ndarray[Any, np.dtype[np.float64]], rhs: np.ndarray[Any, np.dtype[np.float64]]) -> np.float64:
    return sum(l * r for (l, r) in zip(lhs, rhs))

%time dist_python(a, b)

CPU times: user 766 ms, sys: 184 ms, total: 950 ms
Wall time: 107 ms


250036.26500887092