# MPI Dialect example

In this example we show how to use the MPI dialect with memref, and how to lower it to MLIR builtins + llvm

In [1]:
from xdsl.dialects import mpi, scf, memref, builtin, arith, func
from xdsl.printer import Printer

f64 = builtin.f64
printer = Printer(target=Printer.Target.MLIR)

# given this use of the mpi IR
given = builtin.ModuleOp.from_region_or_ops([
    func.FuncOp.from_callable('main', [], [], lambda: [
        mpi.Init.build(),
        rank := mpi.CommRank.get(),
        lit0 := arith.Constant.from_int_and_width(0, 32),
        cond := arith.Cmpi.from_mnemonic(rank, lit0, 'eq'),
        buff := memref.Alloc.get(f64, 32, [100, 14, 14]),
        scf.If.get(cond, [], [  # if rank == 0
            dest := arith.Constant.from_int_and_width(1, mpi.t_int),
            mpi.Send.get(buff, dest, 1),
            # mpi.Wait.get(req, ignore_status=False),
            scf.Yield.get(),
        ], [  # else
            source := arith.Constant.from_int_and_width(0, mpi.t_int),
            recv := mpi.Recv.get(source, buff, 1),
            # mpi.Wait.get(recv.request),
            scf.Yield.get(),
        ]),
        mpi.Finalize.build(),
        func.Return.get()
    ])
])  # yapf: disable

printer.print(given)

"builtin.module"() ({
  "func.func"() ({
    "mpi.init"() : () -> ()
    %0 = "mpi.comm.rank"() : () -> i32
    %1 = "arith.constant"() {"value" = 0 : i32} : () -> i32
    %2 = "arith.cmpi"(%0, %1) {"predicate" = 0 : i64} : (i32, i32) -> i1
    %3 = "memref.alloc"() {"alignment" = 32 : i64, "operand_segment_sizes" = array<i32: 0, 0>} : () -> memref<100x14x14xf64>
    "scf.if"(%2) ({
      %4 = "arith.constant"() {"value" = 1 : i32} : () -> i32
      "mpi.send"(%3, %4) {"tag" = 1 : i32} : (memref<100x14x14xf64>, i32) -> ()
      "scf.yield"() : () -> ()
    }, {
      %5 = "arith.constant"() {"value" = 0 : i32} : () -> i32
      "mpi.recv"(%5, %3) {"tag" = 1 : i32} : (i32, memref<100x14x14xf64>) -> ()
      "scf.yield"() : () -> ()
    }) : (i1) -> ()
    "mpi.finalize"() : () -> ()
    "func.return"() : () -> ()
  }) {"sym_name" = "main", "function_type" = () -> (), "sym_visibility" = "private"} : () -> ()
}) : () -> ()


In [2]:

from xdsl.pattern_rewriter import PatternRewriteWalker, GreedyRewritePatternApplier


# This is not the best place for this, but it's the best we have:
def apply_rewrites(module, *rewriters):
    PatternRewriteWalker(GreedyRewritePatternApplier(rewriters)).rewrite_module(module)

# We can apply the lowerings defined in mpi.MpiLowerings
lowerings = mpi.MpiLowerings(mpi.MpiLibraryInfo())

# the lowerings need some mpi.MpiLibraryInfo() object passed into it which provides info about the MPI library that should be linked

# we apply the lowerings
apply_rewrites(given, lowerings)
# and then we also have to add the function definitions to the module:
lowerings.insert_externals_into_module(given)

printer.print(given)

"builtin.module"() ({
  "func.func"() ({}) {"sym_name" = "MPI_Finalize", "function_type" = () -> i32, "sym_visibility" = "private"} : () -> ()
  "func.func"() ({}) {"sym_name" = "MPI_Recv", "function_type" = (!llvm.ptr, i32, i32, i32, i32, i32, !llvm.ptr) -> i32, "sym_visibility" = "private"} : () -> ()
  "func.func"() ({}) {"sym_name" = "MPI_Send", "function_type" = (!llvm.ptr, i32, i32, i32, i32, i32) -> i32, "sym_visibility" = "private"} : () -> ()
  "func.func"() ({}) {"sym_name" = "MPI_Comm_rank", "function_type" = (i32, !llvm.ptr<i32>) -> i32, "sym_visibility" = "private"} : () -> ()
  "func.func"() ({}) {"sym_name" = "MPI_Init", "function_type" = (!llvm.ptr, !llvm.ptr) -> i32, "sym_visibility" = "private"} : () -> ()
  "func.func"() ({
    %6 = "llvm.mlir.null"() : () -> !llvm.ptr
    %7 = "func.call"(%6, %6) {"callee" = @MPI_Init} : (!llvm.ptr, !llvm.ptr) -> i32
    %8 = "arith.constant"() {"value" = 1140850688 : i32} : () -> i32
    %9 = "arith.constant"() {"value" = 1 : i64} 