## Kernel Composition

In [20]:
# import packages
import allo
from allo.ir.types import int32, float32

In [21]:
M, K, N = 32, 32, 32

def madd(A: int32[M, N]) -> int32[M, N]:
  B: int32[M, N] = 0
  for i, j in allo.grid(M, N):
    B[i, j] = A[i, j] + 1
  return B

def gemm(A: int32[M, K], B: int32[K, N]) -> int32[M, N]:
  C: int32[M, N] = 0
  for i, j, in allo.grid(M, N):
    for k in allo.reduction(K):
      C[i, j] += A[i, k] * B[k, j]
  return C

# compose kernels in a top-level function
def top(A: int32[M, K], B: int32[K, N]) -> int32[M, N]:
  C = gemm(A, B)
  D = madd(C)
  return D

In [22]:
# create a schedule and optimize matrix add
s1 = allo.customize(madd)
s1.pipeline("j")
print(s1.module)

module {
  func.func @madd(%arg0: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "s", otypes = "s"} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "B"} : memref<32x32xi32>
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg1 = 0 to 32 {
      affine.for %arg2 = 0 to 32 {
        %0 = affine.load %arg0[%arg1, %arg2] {from = "A"} : memref<32x32xi32>
        %1 = arith.extsi %0 : i32 to i33
        %c1_i32 = arith.constant 1 : i32
        %2 = arith.extsi %c1_i32 : i32 to i33
        %3 = arith.addi %1, %2 : i33
        %4 = arith.trunci %3 : i33 to i32
        affine.store %4, %alloc[%arg1, %arg2] {to = "B"} : memref<32x32xi32>
      } {loop_name = "j", pipeline_ii = 1 : ui32}
    } {loop_name = "i", op_name = "S_i_j_0"}
    return %alloc : memref<32x32xi32>
  }
}



In [23]:
# separately optimize gemm
s2 = allo.customize(gemm)
s2.reorder("k", "j")
s2.buffer_at(s2.C, axis="i")
s2.pipeline("j")
print(s2.module)

module {
  func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg2 = 0 to 32 {
      %alloc_0 = memref.alloc() : memref<32xi32>
      affine.for %arg3 = 0 to 32 {
        affine.store %c0_i32, %alloc_0[%arg3] : memref<32xi32>
      } {buffer, loop_name = "j_init", pipeline_ii = 1 : i32}
      affine.for %arg3 = 0 to 32 {
        affine.for %arg4 = 0 to 32 {
          %0 = affine.load %arg0[%arg2, %arg3] {from = "A"} : memref<32x32xi32>
          %1 = affine.load %arg1[%arg3, %arg4] {from = "B"} : memref<32x32xi32>
          %2 = arith.extsi %0 : i32 to i64
          %3 = arith.extsi %1 : i32 to i64
          %4 = arith.muli %2, %3 : i64
          %5 = affine.load %alloc_0[%arg4] : memref<32xi32>
          %6 = arith.trunci %4 :

In [24]:
# create a schedule for top-level
s = allo.customize(top)
print(s.module) # not optimized yet

module {
  func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
    %c0_i32 = arith.constant 0 : i32
    %c0_i32_0 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
    linalg.fill ins(%c0_i32_0 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg2 = 0 to 32 {
      affine.for %arg3 = 0 to 32 {
        affine.for %arg4 = 0 to 32 {
          %0 = affine.load %arg0[%arg2, %arg4] {from = "A"} : memref<32x32xi32>
          %1 = affine.load %arg1[%arg4, %arg3] {from = "B"} : memref<32x32xi32>
          %2 = arith.extsi %0 : i32 to i64
          %3 = arith.extsi %1 : i32 to i64
          %4 = arith.muli %2, %3 : i64
          %5 = affine.load %alloc[%arg2, %arg3] {from = "C"} : memref<32x32xi32>
          %6 = arith.trunci %4 : i64 to i32
          %7 = arith.addi %5, %6 : i32
          affine.store %7, %alloc[%arg2, %arg3] {to = "C"} : memref<32x32xi32>
        } {loop_n

In [25]:
# compose the optimizations
s.compose([s1, s2])
print(s.module)

module {
  func.func @gemm(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>) -> memref<32x32xi32> attributes {itypes = "ss", otypes = "s"} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "C"} : memref<32x32xi32>
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32x32xi32>)
    affine.for %arg2 = 0 to 32 {
      %alloc_0 = memref.alloc() : memref<32xi32>
      affine.for %arg3 = 0 to 32 {
        affine.store %c0_i32, %alloc_0[%arg3] : memref<32xi32>
      } {buffer, loop_name = "j_init", pipeline_ii = 1 : i32}
      affine.for %arg3 = 0 to 32 {
        affine.for %arg4 = 0 to 32 {
          %0 = affine.load %arg0[%arg2, %arg3] {from = "A"} : memref<32x32xi32>
          %1 = affine.load %arg1[%arg3, %arg4] {from = "B"} : memref<32x32xi32>
          %2 = arith.extsi %0 : i32 to i64
          %3 = arith.extsi %1 : i32 to i64
          %4 = arith.muli %2, %3 : i64
          %5 = affine.load %alloc_0[%arg4] : memref<32xi32>
          %6 = arith.trunci %4 :

#### Template Composition

In [26]:
# define templated kernel
def kernel[T_in, T_out, S](A: "T_in[S]") -> "T_out[S]":
  B: T_out[S] = 0
  for i in range (S):
    with allo.meta_if(T_out == int32):
      B[i] = A[i] + 1
    with allo.meta_else():
      B[i] = A[i] - 1
  return B

def top2(A: int32[M]) -> float32[M]:
  # last argument of template is kernel id
  C = kernel[int32, int32, M, "K1"](A)
  D = kernel[int32, float32, M, "K2"](C)
  return D

In [27]:
# optimize two instances of kernel
s1 = allo.customize(kernel, instantiate=[int32, int32, M])
s1.unroll("i", factor=4)
print(s1.module)

s2 = allo.customize(kernel, instantiate=[int32, float32, M])
s2.pipeline("i")
print(s2.module)

module {
  func.func @kernel(%arg0: memref<32xi32>) -> memref<32xi32> attributes {itypes = "s", otypes = "s"} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "B"} : memref<32xi32>
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32xi32>)
    affine.for %arg1 = 0 to 32 {
      %0 = affine.load %arg0[%arg1] {from = "A"} : memref<32xi32>
      %1 = arith.extsi %0 : i32 to i33
      %c1_i32 = arith.constant 1 : i32
      %2 = arith.extsi %c1_i32 : i32 to i33
      %3 = arith.addi %1, %2 : i33
      %4 = arith.trunci %3 : i33 to i32
      affine.store %4, %alloc[%arg1] {to = "B"} : memref<32xi32>
    } {loop_name = "i", op_name = "S_i_0", unroll = 4 : i32}
    return %alloc : memref<32xi32>
  }
}

module {
  func.func @kernel(%arg0: memref<32xi32>) -> memref<32xf32> attributes {itypes = "s", otypes = "_"} {
    %c0_i32 = arith.constant 0 : i32
    %0 = arith.sitofp %c0_i32 : i32 to f32
    %alloc = memref.alloc() {name = "B"} : memref<32xf32>
    linalg.fil

In [28]:
s = allo.customize(top2)
# compose with id to optimize correct instantiation
s.compose(s1, id="K1")
s.compose(s2, id="K2")
print(s.module)

module {
  func.func @kernel_K1(%arg0: memref<32xi32>) -> memref<32xi32> attributes {itypes = "s", otypes = "s"} {
    %c0_i32 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "B"} : memref<32xi32>
    linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<32xi32>)
    affine.for %arg1 = 0 to 32 {
      %0 = affine.load %arg0[%arg1] {from = "A"} : memref<32xi32>
      %1 = arith.extsi %0 : i32 to i33
      %c1_i32 = arith.constant 1 : i32
      %2 = arith.extsi %c1_i32 : i32 to i33
      %3 = arith.addi %1, %2 : i33
      %4 = arith.trunci %3 : i33 to i32
      affine.store %4, %alloc[%arg1] {to = "B"} : memref<32xi32>
    } {loop_name = "i", op_name = "S_i_0", unroll = 4 : i32}
    return %alloc : memref<32xi32>
  }
  func.func @kernel_K2(%arg0: memref<32xi32>) -> memref<32xf32> attributes {itypes = "s", otypes = "_"} {
    %c0_i32 = arith.constant 0 : i32
    %0 = arith.sitofp %c0_i32 : i32 to f32
    %alloc = memref.alloc() {name = "B"} : memref<32xf32>
    linalg.fill ins(