## Getting Started

In [None]:
# import packages and int32
import allo
from allo.ir.types import int32

In [None]:

# create gemm function
def gemm(A: int32[32, 32], B: int32[32, 32]) -> int32[32, 32]:
  C: int32[32, 32] = 0
  # allo.grid is a shorthand for loop iterator
  for i, j, k in allo.grid(32, 32, 32):
    C[i, j] += A[i, k] * B[k, j]
  return C

In [4]:
# create schedule
s = allo.customize(gemm)

In [None]:
# inspect intermediate representation
print(s.module)

In [None]:
# apply a split primitive
s.split("i", factor=8)
print(s.module)

In [None]:
# split j as well (for fun)
s.split("j", factor=8)

# reorder loops (essentially tile)
s.reorder("i.outer", "j.outer", "i.inner", "j.inner")
print(s.module)

In [None]:
# generate an LLVM application
mod = s.build(target="llvm")

In [11]:
import numpy as np

# prepare inputs and outputs for app
np_A = np.random.randint(0, 100, (32, 32)).astype(np.int32)
np_B = np.random.randint(0, 100, (32, 32)).astype(np.int32)

# get outputs from application
allo_C = mod(np_A, np_B)

# check results against numpy baseline
golden_C = np.matmul(np_A, np_B)
np.testing.assert_allclose(allo_C, golden_C, rtol=1e-5, atol=1e-5)
print("Results are correct!")

Results are correct!
