## Template Kernels

In [None]:
# import packages
import allo
from allo.ir.types import int32, float32

In [None]:
# need to wrap type in strings for declaration
def kernel[T](A: "T[10]") -> "T[10]":
  B: T[10]
  for i in range(10):
    B[i] = A[i] + 1
  return B

In [None]:
# instantiate kernel with int32
s = allo.customize(kernel, instantiate=[int32])
print(s.module)

In [None]:
# instantiate kernel with float32
s = allo.customize(kernel, instantiate=[float32])
print(s.module)

In [None]:
# also template the array size
def kernel2[T, M: int32](A: "T[M]") -> "T[M]":
  B: T[M]
  for i in range(M):
    B[i] = A[i] + 1
  return B

# instantiate with size 20, type int
s = allo.customize(kernel2, instantiate=[int32, 20])
print(s.module)

In [None]:
# instantiate kernel with metaprogramming (compile time)
def kernel3[T, M: int32](A: "T[M]") -> "T[M]":
  B: T[M]
  for i in range(M):
    with allo.meta_if(T == int32):
      B[i] = A[i] + 1
    with allo.meta_else():
      B[i] = A[i] - 1
  return B

# instantiate with int (should +1)
s = allo.customize(kernel3, instantiate=[int32, 20])
print(s.module)
# instantiate with float (should -1)
s = allo.customize(kernel3, instantiate=[float32, 20])
print(s.module)