# MPI Dialect example

In this example we show how to use the MPI dialect with memref, and how to lower it to MLIR builtins + llvm

In [1]:
from xdsl.dialects import mpi, scf, memref, builtin, arith, func
from xdsl.printer import Printer

printer = Printer(target=Printer.Target.MLIR)

# given this use of the mpi IR
given = builtin.ModuleOp.from_region_or_ops([
    func.FuncOp.from_callable('main', [], [], lambda: [
        mpi.Init.build(),
        rank := mpi.CommRank.get(),
        lit0 := arith.Constant.from_int_and_width(0, 32),
        cond := arith.Cmpi.from_mnemonic(rank, lit0, 'eq'),
        scf.If.get(cond, [], [  # if rank == 0
            ref := memref.Alloc.get(builtin.f64, 32, [100, 14, 14]),
            dest := arith.Constant.from_int_and_width(1, mpi.t_int),
            req := mpi.ISend.get(ref, dest, 1),
            mpi.Wait.get(req),
        ], [  # else
            source := arith.Constant.from_int_and_width(1, mpi.t_int),
            recv := mpi.IRecv.get(source, memref.MemRefType.from_element_type_and_shape(builtin.f64, [100, 14, 14]), 1),
            mpi.Wait.get(recv.request),
        ]),
        mpi.Finalize.build(),
        func.Return.get()
    ])
])

printer.print(given)

"builtin.module"() ({
  "func.func"() ({
    "mpi.init"() : () -> ()
    %0 = "mpi.comm.rank"() : () -> i32
    %1 = "arith.constant"() {"value" = 0 : i32} : () -> i32
    %2 = "arith.cmpi"(%0, %1) {"predicate" = 0 : i64} : (i32, i32) -> i1
    "scf.if"(%2) ({
      %3 = "memref.alloc"() {"alignment" = 32 : i64, "operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<100x14x14xf64>
      %4 = "arith.constant"() {"value" = 1 : i32} : () -> i32
      %5 = "mpi.isend"(%3, %4) {"tag" = 1 : i32} : (memref<100x14x14xf64>, i32) -> !mpi.request
      %6 = "mpi.wait"(%5) : (!mpi.request) -> i32
    }, {
      %7 = "arith.constant"() {"value" = 1 : i32} : () -> i32
      %8, %9 = "mpi.irecv"(%7) {"tag" = 1 : i32} : (i32) -> (memref<100x14x14xf64>, !mpi.request)
      %10 = "mpi.wait"(%9) : (!mpi.request) -> i32
    }) : (i1) -> ()
    "mpi.finalize"() : () -> ()
    "func.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> (), "sym_visibility" = "private"} : () -> ()
}) : 

In [2]:

from xdsl.pattern_rewriter import PatternRewriteWalker, GreedyRewritePatternApplier


# This is not the best place for this, but it's the best we have:
def apply_rewrites(module, *rewriters):
    PatternRewriteWalker(GreedyRewritePatternApplier(rewriters)).rewrite_module(module)

# We can apply the lowerings defined in mpi.MpiLowerings
lowerings = mpi.MpiLowerings(mpi.MpiLibraryInfo())

# the lowerings need some mpi.MpiLibraryInfo() object passed into it which provides info about the MPI library that should be linked

# we apply the lowerings
apply_rewrites(given, lowerings)
# and then we also have to add the function definitions to the module:
lowerings.insert_externals_into_module(given)


In [3]:
printer.print(given)

"builtin.module"() ({
  "func.func"() ({}) {"sym_name" = "MPI_Finalize", "function_type" = () -> i32, "sym_visibility" = "private"} : () -> ()
  "func.func"() ({}) {"sym_name" = "MPI_Irecv", "function_type" = (!llvm.ptr, i64, i32, i32, i32, i32, !llvm.ptr<i32>) -> i32, "sym_visibility" = "private"} : () -> ()
  "func.func"() ({}) {"sym_name" = "MPI_Wait", "function_type" = (!llvm.ptr<i32>, !llvm.ptr<i32>) -> i32, "sym_visibility" = "private"} : () -> ()
  "func.func"() ({}) {"sym_name" = "MPI_Isend", "function_type" = (!llvm.ptr, i64, i32, i32, i32, i32, !llvm.ptr<i32>) -> i32, "sym_visibility" = "private"} : () -> ()
  "func.func"() ({}) {"sym_name" = "MPI_Comm_rank", "function_type" = (i32, !llvm.ptr<i32>) -> i32, "sym_visibility" = "private"} : () -> ()
  "func.func"() ({}) {"sym_name" = "MPI_Init", "function_type" = (i32, i32) -> i32, "sym_visibility" = "private"} : () -> ()
  "func.func"() ({
    %11 = "llvm.mlir.null"() : () -> i32
    %12 = "func.call"(%11, %11) {"callee" = @MPI