## Allo Examples

In [55]:
import allo
from allo.ir.types import int32

### Addition

In [56]:
def add(a: int32, b: int32) -> int32:
  return a + b

### Vector-Vector Add

In [57]:
def vvadd[N](A: int32[N], B: int32[N]) -> int32[N]:
  C: int32[N] = 0
  for i in range(N):
    C[i] = A[i] + B[i]
  return C

### Matrix-Vector Multiply

In [58]:
def mv[N](A: int32[N, N], x: int32[N]) -> int32[N]:
  C: int32[N] = 0
  for i, j in allo.grid(N, N):
    C[i] += A[i, j] * x[j]
  return C

### Matrix-Matrix Multiply (GEMM)

In [59]:
def mm[N](A: int32[N, N], B: int32[N, N]) -> int32[N, N]:
  C: int32[N, N] = 0
  for i, j, k in allo.grid(N, N, N):
    C[i, j] += A[i, k] * B[k, j]
  return C

### Testing

In [60]:
import numpy as np

# test add
s = allo.customize(add)
mod = s.build(target="llvm")
for _ in range(10):
  a = np.random.randint(0, 100)
  b = np.random.randint(0, 100)
  c = mod(a, b)
  assert c == (a + b), "add incorrect"
print("add correct")

# test vvadd
s = allo.customize(vvadd, instantiate=[20])
mod = s.build(target="llvm")
for _ in range(10):
  a = np.random.randint(0, 100, 20).astype(np.int32)
  b = np.random.randint(0, 100, 20).astype(np.int32)
  c = mod(a, b)
  assert (c == (a + b)).all(), "vvadd incorrect"
print("vvadd correct")

# test mv
s = allo.customize(mv, instantiate=[20])
mod = s.build(target="llvm")
for _ in range(10):
  a = np.random.randint(0, 100, (20, 20)).astype(np.int32)
  x = np.random.randint(0, 100, 20).astype(np.int32)
  c = mod(a, x)
  assert (c == (a @ x)).all(), "mv incorrect"
print("mv correct")

# test mm
s = allo.customize(mm, instantiate=[20])
mod = s.build(target="llvm")
for _ in range(10):
  a = np.random.randint(0, 100, (20, 20)).astype(np.int32)
  b = np.random.randint(0, 100, (20, 20)).astype(np.int32)
  c = mod(a, b)
  assert (c == (a @ b)).all(), "mm incorrect"
print("mm correct")

add correct
vvadd correct
mv correct
mm correct
