# Compiling to GPU w/ xDSL through MLIR

Let's look at a toy example of square matrix multiplication. We will have the following boilerplate for matrix initilization, printing, and a main function to run them. This main expect a sqmm function, which returns the product of two square matrices.

You should probably skip to the next part if you are mainly interested in the GPU kernel example.

In [None]:
common = """
"builtin.module"() ({
	"func.func"() ({}) {"function_type" = (memref<?x?xi32>, memref<?x?xi32>, memref<?x?xi32>, index) -> (), "sym_name" = "sqmm", "sym_visibility" = "private"} : () -> ()
	"llvm.func"() ({}) {"function_type" = !llvm.func<i32 (ptr<i8>, ...)>, "linkage" = #llvm.linkage<external>, "sym_name" = "printf"} : () -> ()
	"memref.global"() {"initial_value" = dense<[37, 56, 100, 32,0]> : tensor<5xi8>, "sym_name" = "format", "type" = memref<5xi8>} : () -> ()
	"memref.global"() {"initial_value" = dense<[10, 0]> : tensor<2xi8>, "sym_name" = "newline", "type" = memref<2xi8>} : () -> ()
	"func.func"() ({
	^body(%A : memref<?x?xi32>, %N : index):
		%zero = "arith.constant"() {"value" = 0 : index} : () -> index
        %one = "arith.constant"() {"value" = 1 : index} : () -> index
        %format = "memref.get_global"() {"name" = @format} : () -> memref<5xi8>
		%f0 = "memref.extract_aligned_pointer_as_index"(%format) : (memref<5xi8>) -> index
		%f1 = "arith.index_cast"(%f0) : (index) -> i64
		%format_ptr = "llvm.inttoptr"(%f1) : (i64) -> !llvm.ptr<i8>
        %newline = "memref.get_global"() {"name" = @newline} : () -> memref<2xi8>
		%n0 = "memref.extract_aligned_pointer_as_index"(%newline) : (memref<2xi8>) -> index
		%n1 = "arith.index_cast"(%n0) : (index) -> i64
		%newline_ptr = "llvm.inttoptr"(%n1) : (i64) -> !llvm.ptr<i8>
		"scf.for"(%zero, %N, %one) ({
		^bb(%i : index):
			"scf.for"(%zero, %N, %one) ({
			^bb(%j : index):
				%Aij = "memref.load"(%A, %i, %j) : (memref<?x?xi32>, index, index) -> i32
                %ignored = "llvm.call"(%format_ptr, %Aij) {"callee" = @printf} : (!llvm.ptr<i8>, i32) -> i32 
				"scf.yield"() : () -> ()
			}) : (index, index, index) -> ()
			"llvm.call"(%newline_ptr) {"callee" = @printf} : (!llvm.ptr<i8>) -> i32 
			"scf.yield"() : () -> ()
		}) : (index, index, index) -> ()
		"func.return"() : () -> ()
	}) {"function_type" = (memref<?x?xi32>, index) -> (), "sym_name" = "print_mat"} : () -> ()
    
	"func.func"() ({
	^body(%A : memref<?x?xi32>, %N : index):
		%zero = "arith.constant"() {"value" = 0 : index} : () -> index
        %one = "arith.constant"() {"value" = 1 : index} : () -> index
        %a = "arith.constant"() {"value" = 3 : index} : () -> index
        
		"scf.for"(%zero, %N, %one) ({
		^bb(%i : index):
			"scf.for"(%zero, %N, %one) ({
			^bb(%j : index):
				%0 = "arith.muli"(%i, %a) : (index, index) -> index
				%v = "arith.addi"(%0, %j) : (index, index) -> index
                %vi = "index.casts"(%v) : (index) -> i32
				"memref.store"(%vi, %A, %i, %j) : (i32, memref<?x?xi32>, index, index) -> ()
				"scf.yield"() : () -> ()
			}) : (index, index, index) -> ()
			"scf.yield"() : () -> ()
		}) : (index, index, index) -> ()
		"func.return"() : () -> ()
	}) {"function_type" = (memref<?x?xi32>, index) -> (), "sym_name" = "init_mat"} : () -> ()
    
	"func.func"() ({
		%zero = "arith.constant"() {"value" = 0 : index} : () -> index
        %one = "arith.constant"() {"value" = 1 : index} : () -> index
        
		%N = "arith.constant"() {"value" = 2048 : index} : () -> index
		%A = "memref.alloc"(%N, %N) {"alignment" = 8 : i64, "operand_segment_sizes" = array<i32:2,0>} : (index, index) -> memref<?x?xi32>
		%B = "memref.alloc"(%N, %N) {"alignment" = 8 : i64, "operand_segment_sizes" = array<i32:2,0>} : (index, index) -> memref<?x?xi32>
		%C = "memref.alloc"(%N, %N) {"alignment" = 8 : i64, "operand_segment_sizes" = array<i32:2,0>} : (index, index) -> memref<?x?xi32>
        
        "func.call"(%A, %N) {"callee" = @init_mat} : (memref<?x?xi32>, index) -> ()
        "func.call"(%B, %N) {"callee" = @init_mat} : (memref<?x?xi32>, index) -> ()
        "func.call"(%C, %A, %B, %N) {"callee" = @sqmm} : (memref<?x?xi32>, memref<?x?xi32>, memref<?x?xi32>, index) -> ()
        "func.call"(%C, %N) {"callee" = @print_mat} : (memref<?x?xi32>, index) -> ()
        
        
		"memref.dealloc"(%A) : (memref<?x?xi32>) -> ()
		"memref.dealloc"(%B) : (memref<?x?xi32>) -> ()
		"memref.dealloc"(%C) : (memref<?x?xi32>) -> ()
		"func.return"(%zero) : (index) -> ()
	}) {"function_type" = () -> (index), "sym_name" = "main"} : () -> ()
}) : () -> ()
"""

Let's compile the boilerplate to an object, to link with our multiplication kernels later.

In [None]:
!echo '{common}' | mlir-opt - -test-lower-to-llvm | mlir-translate - --mlir-to-llvmir | clang -x ir -c - -o /tmp/common.o

Now, let's look at a naive, CPU-only implementation:

In [None]:
#Here is some non-GPU stuff:
non_gpu = """
"builtin.module"() ({
	"func.func"() ({}) {"function_type" = (memref<?x?xi32>, index) -> (), "sym_name" = "print_mat", "sym_visibility" = "private"} : () -> ()
	"func.func"() ({}) {"function_type" = (memref<?x?xi32>, index) -> (), "sym_name" = "init_mat", "sym_visibility" = "private"} : () -> ()
	"func.func"() ({
	^body(%C : memref<?x?xi32>, %A : memref<?x?xi32>, %B : memref<?x?xi32>, %N : index):
		%zero = "arith.constant"() {"value" = 0 : index} : () -> index
		%zeroi = "arith.constant"() {"value" = 0 : i32} : () -> i32
        %one = "arith.constant"() {"value" = 1 : index} : () -> index
        
		"scf.for"(%zero, %N, %one) ({
		^bb(%i : index):
			"scf.for"(%zero, %N, %one) ({
			^bb(%j : index):
				%Cij = "scf.for"(%zero, %N, %one, %zeroi) ({
				^bb(%k : index, %sum : i32):
					%Aij = "memref.load"(%A, %i, %j) : (memref<?x?xi32>, index, index) -> i32
					%Bij = "memref.load"(%B, %i, %j) : (memref<?x?xi32>, index, index) -> i32
                    %prodk = "arith.muli"(%Aij, %Bij) : (i32, i32) -> i32
                    %new_sum = "arith.addi"(%sum, %prodk) : (i32, i32) -> i32
					"scf.yield"(%new_sum) : (i32) -> ()
				}) : (index, index, index, i32) -> i32
				"memref.store"(%Cij, %C, %i, %j) : (i32, memref<?x?xi32>, index, index) -> ()
				"scf.yield"() : () -> ()
			}) : (index, index, index) -> ()
			"scf.yield"() : () -> ()
		}) : (index, index, index) -> ()
		"func.return"() : () -> ()
	}) {"function_type" = (memref<?x?xi32>, memref<?x?xi32>, memref<?x?xi32>, index) -> (), "sym_name" = "sqmm"} : () -> ()
}) {} : () -> ()
"""

We can compile it to an object file, and compile an executable by linking it to the boilerplate this way:

In [None]:
!echo '{non_gpu}' | xdsl-opt -f mlir -t mlir | mlir-opt --test-lower-to-llvm | mlir-translate --mlir-to-llvmir | clang -x ir -o /tmp/non_gpu.o -c -
!clang /tmp/non_gpu.o /tmp/common.o -o /tmp/non_gpu

Let's see how it runs:

In [None]:
!time /tmp/non_gpu > /tmp/non_gpu_output

That's not great. Let's implement that on a GPU:

In [None]:
#Here is some non-GPU stuff:
gpu_stuff = """
"builtin.module"() ({
	"func.func"() ({}) {"function_type" = (memref<?x?xi32>, index) -> (), "sym_name" = "print_mat", "sym_visibility" = "private"} : () -> ()
	"func.func"() ({}) {"function_type" = (memref<?x?xi32>, index) -> (), "sym_name" = "init_mat", "sym_visibility" = "private"} : () -> ()
	"func.func"() ({
	^body(%C : memref<?x?xi32>, %A : memref<?x?xi32>, %B : memref<?x?xi32>, %N : index):
		%zero = "arith.constant"() {"value" = 0 : index} : () -> index
        %one = "arith.constant"() {"value" = 1 : index} : () -> index
		%zeroi = "arith.constant"() {"value" = 0 : i32} : () -> i32
        %32 = "arith.constant"() {"value" = 32 : index} : () -> index
        %blocks = "arith.ceildivui"(%N, %32) : (index, index) -> index
        %uA = "memref.cast"(%A) : (memref<?x?xi32>) -> memref<*xi32>
        %uB = "memref.cast"(%B) : (memref<?x?xi32>) -> memref<*xi32>
        %uC = "memref.cast"(%C) : (memref<?x?xi32>) -> memref<*xi32>
        "gpu.host_register"(%uA) : (memref<*xi32>) -> ()
        "gpu.host_register"(%uB) : (memref<*xi32>) -> ()
        "gpu.host_register"(%uC) : (memref<*xi32>) -> ()
        "gpu.launch"(%blocks, %blocks, %one, %32, %32, %one) ({
		^bb0(%bx : index, %by : index, %bz : index,
			 %tx : index, %ty : index, %tz : index,
			 %num_bx : index, %num_by : index, %num_bz : index,
			 %num_tx : index, %num_ty : index, %num_tz : index):
             
             
             %bi = "arith.muli"(%bx, %num_tx) : (index, index) -> index
             %bj = "arith.muli"(%by, %num_ty) : (index, index) -> index
             %i = "arith.addi"(%bi, %tx) : (index, index) -> index
             %j = "arith.addi"(%bj, %ty) : (index, index) -> index
             
        	%one2 = "arith.constant"() {"value" = 1 : index} : () -> index
			%Cij = "scf.for"(%zero, %N, %one2, %zeroi) ({
			^bb(%k : index, %sum : i32):
				%Aij = "memref.load"(%A, %i, %j) : (memref<?x?xi32>, index, index) -> i32
				%Bij = "memref.load"(%B, %i, %j) : (memref<?x?xi32>, index, index) -> i32
				%prodk = "arith.muli"(%Aij, %Bij) : (i32, i32) -> i32
				%new_sum = "arith.addi"(%sum, %prodk) : (i32, i32) -> i32
				"scf.yield"(%new_sum) : (i32) -> ()
			}) : (index, index, index, i32) -> i32
			"memref.store"(%Cij, %C, %i, %j) : (i32, memref<?x?xi32>, index, index) -> ()
			"gpu.terminator"() : () -> ()
		}) {"operand_segment_sizes" = array<i32: 0, 1, 1, 1, 1, 1, 1, 0>} : (index, index, index, index, index, index) -> () 
		"func.return"() : () -> ()
	}) {"function_type" = (memref<?x?xi32>, memref<?x?xi32>, memref<?x?xi32>, index) -> (), "sym_name" = "sqmm"} : () -> ()
}) {} : () -> ()
"""

We can compile it to an object file, and compile an executable by linking it to the boilerplate this way:

In [None]:
!echo '{gpu_stuff}' | xdsl-opt -f mlir -t mlir | mlir-opt --pass-pipeline="builtin.module(gpu-kernel-outlining, convert-scf-to-cf, gpu.module(convert-gpu-to-nvvm, reconcile-unrealized-casts, symbol-dce, gpu-to-cubin), gpu-to-llvm, arith-expand, convert-arith-to-llvm, convert-index-to-llvm, reconcile-unrealized-casts, symbol-dce)" | mlir-translate --mlir-to-llvmir | clang -x ir -c -o /tmp/gpu.o -c -
!clang /tmp/gpu.o /tmp/common.o -o /tmp/gpu -lmlir_cuda_runtime

And run it:

In [None]:
!time /tmp/gpu > /tmp/gpu_output

Check that the results were the same:

In [None]:
!if diff /tmp/gpu_output /tmp/non_gpu_output; then echo "Outputs are the same!"; else echo "Outputs are different!"; fi

Et voil√†, a naive GPU kernel hopefully giving some speedup!