## Simple Allo -> DSLX flow (stateless dataflow)

In [1]:
%load_ext autoreload
%autoreload 2
import allo
from allo.ir.types import int32, uint32

In [2]:
def add(a: uint32, b: uint32) -> uint32:
  return a + b

s = allo.customize(add)
code = s.build(target='xls')
print(code)

pub proc add {
  in0: chan<u32> in;
  in1: chan<u32> in;
  out0: chan<u32> out;

  config(in0: chan<u32> in, in1: chan<u32> in, out0: chan<u32> out) { (in0, in1, out0) }

  init { () }

  next(state: ()) {
    let (tok0, tmp0) = recv(join(), in0);
    let (tok1, tmp1) = recv(join(), in1);
    let tmp2 = (tmp0 as u33);
    let tmp3 = (tmp1 as u33);
    let tmp4 = (tmp2 + tmp3);
    let tmp5 = (tmp4 as u32);
    let tok = join(tok0, tok1);
    send(tok, out0, tmp5);
  }

}


In [3]:
code.interpret()





In [4]:
code.to_ir()

package add

file_number 0 "abax/add.x"

chan add__in0(bits[32], id=0, kind=streaming, ops=receive_only, flow_control=ready_valid, strictness=proven_mutually_exclusive)
chan add__in1(bits[32], id=1, kind=streaming, ops=receive_only, flow_control=ready_valid, strictness=proven_mutually_exclusive)
chan add__out0(bits[32], id=2, kind=streaming, ops=send_only, flow_control=ready_valid, strictness=proven_mutually_exclusive)

top proc __add__add_0_next(__state: (), init={()}) {
  after_all.4: token = after_all(id=4)
  literal.3: bits[1] = literal(value=1, id=3)
  after_all.9: token = after_all(id=9)
  receive.5: (token, bits[32]) = receive(after_all.4, predicate=literal.3, channel=add__in0, id=5)
  receive.10: (token, bits[32]) = receive(after_all.9, predicate=literal.3, channel=add__in1, id=10)
  tmp0: bits[32] = tuple_index(receive.5, index=1, id=8, pos=[(0,10,15)])
  tmp1: bits[32] = tuple_index(receive.10, index=1, id=13, pos=[(0,11,15)])
  tmp2: bits[33] = zero_ext(tmp0, new_bit_count=3

In [5]:
code.opt()

package add

file_number 0 "abax/add.x"

chan add__in0(bits[32], id=0, kind=streaming, ops=receive_only, flow_control=ready_valid, strictness=proven_mutually_exclusive)
chan add__in1(bits[32], id=1, kind=streaming, ops=receive_only, flow_control=ready_valid, strictness=proven_mutually_exclusive)
chan add__out0(bits[32], id=2, kind=streaming, ops=send_only, flow_control=ready_valid, strictness=proven_mutually_exclusive)

top proc __add__add_0_next() {
  after_all.4: token = after_all(id=4)
  receive.37: (token, bits[32]) = receive(after_all.4, channel=add__in0, id=37)
  receive.38: (token, bits[32]) = receive(after_all.4, channel=add__in1, id=38)
  tok0: token = tuple_index(receive.37, index=0, id=7, pos=[(0,10,9)])
  tok1: token = tuple_index(receive.38, index=0, id=12, pos=[(0,11,9)])
  tmp0: bits[32] = tuple_index(receive.37, index=1, id=8, pos=[(0,10,15)])
  tmp1: bits[32] = tuple_index(receive.38, index=1, id=13, pos=[(0,11,15)])
  tok: token = after_all(tok0, tok1, id=18)
  tmp4__

In [6]:
code.to_vlog()

module __add__add_0_next(
  input wire clk,
  input wire rst,
  input wire [31:0] add__in0,
  input wire add__in0_vld,
  input wire [31:0] add__in1,
  input wire add__in1_vld,
  input wire add__out0_rdy,
  output wire add__in0_rdy,
  output wire add__in1_rdy,
  output wire [31:0] add__out0,
  output wire add__out0_vld
);
  reg [31:0] __add__in0_reg;
  reg __add__in0_valid_reg;
  reg [31:0] __add__in1_reg;
  reg __add__in1_valid_reg;
  reg [31:0] __add__out0_reg;
  reg __add__out0_valid_reg;
  wire add__out0_valid_inv;
  wire p0_all_active_inputs_valid;
  wire add__out0_valid_load_en;
  wire add__out0_load_en;
  wire p0_stage_done;
  wire add__in0_valid_inv;
  wire add__in1_valid_inv;
  wire add__in0_valid_load_en;
  wire add__in1_valid_load_en;
  wire add__in0_load_en;
  wire add__in1_load_en;
  wire [31:0] tmp4__1;
  assign add__out0_valid_inv = ~__add__out0_valid_reg;
  assign p0_all_active_inputs_valid = __add__in0_valid_reg & __add__in1_valid_reg;
  assign add__out0_valid_load_en =

In [7]:
code.flow()

## Some other examples

In [8]:
# supports both unsigned and signed integers
def mac(a: int32, b: int32, c: int32) -> int32:
  return (a * b) + c

s = allo.customize(mac)
code = s.build(target='xls')
print(code)
code.flow()

pub proc mac {
  in0: chan<s32> in;
  in1: chan<s32> in;
  in2: chan<s32> in;
  out0: chan<s32> out;

  config(in0: chan<s32> in, in1: chan<s32> in, in2: chan<s32> in, out0: chan<s32> out) { (in0, in1, in2, out0) }

  init { () }

  next(state: ()) {
    let (tok0, tmp0) = recv(join(), in0);
    let (tok1, tmp1) = recv(join(), in1);
    let (tok2, tmp2) = recv(join(), in2);
    let tmp3 = (tmp0 as s64);
    let tmp4 = (tmp1 as s64);
    let tmp5 = (tmp3 * tmp4);
    let tmp6 = (tmp5 as sN[65]);
    let tmp7 = (tmp2 as sN[65]);
    let tmp8 = (tmp6 + tmp7);
    let tmp9 = (tmp8 as s32);
    let tok = join(tok0, tok1, tok2);
    send(tok, out0, tmp9);
  }

}


In [9]:
# supports multiple outputs
def wsa(a: int32, b: int32) -> (int32, int32, int32):
  return a | b, a & b, a ^ b

s = allo.customize(wsa)
code = s.build(target='xls')
print(code)
code.flow()

pub proc wsa {
  in0: chan<s32> in;
  in1: chan<s32> in;
  out0: chan<s32> out;
  out1: chan<s32> out;
  out2: chan<s32> out;

  config(in0: chan<s32> in, in1: chan<s32> in, out0: chan<s32> out, out1: chan<s32> out, out2: chan<s32> out) { (in0, in1, out0, out1, out2) }

  init { () }

  next(state: ()) {
    let (tok0, tmp0) = recv(join(), in0);
    let (tok1, tmp1) = recv(join(), in1);
    let tmp2 = (tmp0 | tmp1);
    let tmp3 = (tmp0 & tmp1);
    let tmp4 = (tmp0 ^ tmp1);
    let tok = join(tok0, tok1);
    send(tok, out0, tmp2);
    send(tok, out1, tmp3);
    send(tok, out2, tmp4);
  }

}


In [10]:
# supports (basic) conditional statements
def max(a: int32, b: int32) -> int32:
  return a if (a > b) else b

s = allo.customize(max)
code = s.build(target='xls')
print(code)
# code.flow()

pub proc max {
  in0: chan<s32> in;
  in1: chan<s32> in;
  out0: chan<s32> out;

  config(in0: chan<s32> in, in1: chan<s32> in, out0: chan<s32> out) { (in0, in1, out0) }

  init { () }

  next(state: ()) {
    let (tok0, tmp0) = recv(join(), in0);
    let (tok1, tmp1) = recv(join(), in1);
    let tmp2 = (tmp0 > tmp1);
    let tmp3 = if (tmp2) { tmp0 } else { tmp1 }
    let tok = join(tok0, tok1);
    send(tok, out0, tmp3);
  }

}


In [11]:
# supports (basic) conditional statements
def incr(a: int32) -> int32:
  return a + 1

s = allo.customize(incr)
code = s.build(target='xls')
print(code)
# code.flow()

pub proc incr {
  in0: chan<s32> in;
  out0: chan<s32> out;

  config(in0: chan<s32> in, out0: chan<s32> out) { (in0, out0) }

  init { () }

  next(state: ()) {
    let (tok0, tmp0) = recv(join(), in0);
    let tmp1 = (tmp0 as s33);
    let tmp2 = (tmp1 + s33:1);
    let tmp3 = (tmp2 as s32);
    send(tok0, out0, tmp3);
  }

}


In [12]:
# Generate XLS/DSLX code for fact function
def fact(a: int32) -> int32:
  acc: int32 = 1
  for i in range(a):
    acc *= (i + 1)
  return acc

s = allo.customize(fact)
print(s.module)
code = s.build(target='xls')
print(code)
code.interpret()


module {
  func.func @fact(%arg0: i32) -> i32 attributes {itypes = "s", otypes = "s"} {
    %c1_i32 = arith.constant 1 : i32
    %c1_i32_0 = arith.constant 1 : i32
    %c1_i32_1 = arith.constant 1 : i32
    %c1_i32_2 = arith.constant 1 : i32
    %alloc = memref.alloc() {name = "acc"} : memref<i32>
    affine.store %c1_i32_2, %alloc[] {to = "acc"} : memref<i32>
    %c0_i32 = arith.constant 0 : i32
    %c0_i32_3 = arith.constant 0 : i32
    %c0_i32_4 = arith.constant 0 : i32
    %c0_i32_5 = arith.constant 0 : i32
    %0 = arith.index_cast %c0_i32_5 : i32 to index
    %1 = arith.index_cast %arg0 : i32 to index
    %c1_i32_6 = arith.constant 1 : i32
    %c1_i32_7 = arith.constant 1 : i32
    %c1_i32_8 = arith.constant 1 : i32
    %c1_i32_9 = arith.constant 1 : i32
    %2 = arith.index_cast %c1_i32_9 : i32 to index
    scf.for %arg1 = %0 to %1 step %2 {
      %4 = arith.index_cast %arg1 : index to i34
      %c1_i32_10 = arith.constant 1 : i32
      %c1_i32_11 = arith.constant 1 : i32
      

In [13]:
# Generate XLS/DSLX code for fibonacci function
# This tests multiple accumulators (prev, curr)
def fib(n: int32) -> int32:
  prev: int32 = 0
  curr: int32 = 1
  for i in range(n):
    next_val: int32 = prev + curr
    prev = curr
    curr = next_val
  return curr

s = allo.customize(fib)
print(s.module)
code = s.build(target='xls')
print(code)
code.interpret()


module {
  func.func @fib(%arg0: i32) -> i32 attributes {itypes = "s", otypes = "s"} {
    %c0_i32 = arith.constant 0 : i32
    %c0_i32_0 = arith.constant 0 : i32
    %c0_i32_1 = arith.constant 0 : i32
    %c0_i32_2 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "prev"} : memref<i32>
    affine.store %c0_i32_2, %alloc[] {to = "prev"} : memref<i32>
    %c1_i32 = arith.constant 1 : i32
    %c1_i32_3 = arith.constant 1 : i32
    %c1_i32_4 = arith.constant 1 : i32
    %c1_i32_5 = arith.constant 1 : i32
    %alloc_6 = memref.alloc() {name = "curr"} : memref<i32>
    affine.store %c1_i32_5, %alloc_6[] {to = "curr"} : memref<i32>
    %c0_i32_7 = arith.constant 0 : i32
    %c0_i32_8 = arith.constant 0 : i32
    %c0_i32_9 = arith.constant 0 : i32
    %c0_i32_10 = arith.constant 0 : i32
    %0 = arith.index_cast %c0_i32_10 : i32 to index
    %1 = arith.index_cast %arg0 : i32 to index
    %c1_i32_11 = arith.constant 1 : i32
    %c1_i32_12 = arith.constant 1 : i32
    %c1_i32_13 = ar

In [14]:
# WHILE loop example: Count steps until n becomes 1 (Collatz-like)
# Simplified: divide by 2 if even, subtract 1 if odd, count steps
def count_steps(n: int32) -> int32:
    steps: int32 = 0
    val: int32 = n
    while val > 1:
        if val % 2 == 0:
            val = val // 2
        else:
            val = val - 1
        steps = steps + 1
    return steps

s = allo.customize(count_steps)
print("=== MLIR for count_steps (WHILE loop) ===")
print(s.module)
print("\n=== Building DSLX ===")
code = s.build(target='xls')
print(code)
code.interpret()


=== MLIR for count_steps (WHILE loop) ===
module {
  func.func @count_steps(%arg0: i32) -> i32 attributes {itypes = "s", otypes = "s"} {
    %c0_i32 = arith.constant 0 : i32
    %c0_i32_0 = arith.constant 0 : i32
    %c0_i32_1 = arith.constant 0 : i32
    %c0_i32_2 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "steps"} : memref<i32>
    affine.store %c0_i32_2, %alloc[] {to = "steps"} : memref<i32>
    %alloc_3 = memref.alloc() {name = "val"} : memref<i32>
    affine.store %arg0, %alloc_3[] {to = "val"} : memref<i32>
    scf.while : () -> () {
      %1 = affine.load %alloc_3[] {from = "val"} : memref<i32>
      %c1_i32 = arith.constant 1 : i32
      %c1_i32_4 = arith.constant 1 : i32
      %2 = arith.cmpi sgt, %1, %c1_i32_4 : i32
      scf.condition(%2)
    } do {
      %1 = affine.load %alloc_3[] {from = "val"} : memref<i32>
      %c2_i32 = arith.constant 2 : i32
      %c2_i32_4 = arith.constant 2 : i32
      %2 = arith.remsi %1, %c2_i32_4 : i32
      %c0_i32_5 = arith.c

In [15]:
# Nested FOR loop example: Sum of products (simulates 2D iteration)
# Computes sum of i*j for i in range(n), j in range(m)
def nested_sum(n: int32, m: int32) -> int32:
    total: int32 = 0
    for i in range(n):
        for j in range(m):
            total = total + (i * j)
    return total

s = allo.customize(nested_sum)
print("=== MLIR for nested_sum (nested FOR loops) ===")
print(s.module)
print("\n=== Building DSLX ===")
code = s.build(target='xls')
print(code)
code.interpret()


=== MLIR for nested_sum (nested FOR loops) ===
module {
  func.func @nested_sum(%arg0: i32, %arg1: i32) -> i32 attributes {itypes = "ss", otypes = "s"} {
    %c0_i32 = arith.constant 0 : i32
    %c0_i32_0 = arith.constant 0 : i32
    %c0_i32_1 = arith.constant 0 : i32
    %c0_i32_2 = arith.constant 0 : i32
    %alloc = memref.alloc() {name = "total"} : memref<i32>
    affine.store %c0_i32_2, %alloc[] {to = "total"} : memref<i32>
    %c0_i32_3 = arith.constant 0 : i32
    %c0_i32_4 = arith.constant 0 : i32
    %c0_i32_5 = arith.constant 0 : i32
    %c0_i32_6 = arith.constant 0 : i32
    %0 = arith.index_cast %c0_i32_6 : i32 to index
    %1 = arith.index_cast %arg0 : i32 to index
    %c1_i32 = arith.constant 1 : i32
    %c1_i32_7 = arith.constant 1 : i32
    %c1_i32_8 = arith.constant 1 : i32
    %c1_i32_9 = arith.constant 1 : i32
    %2 = arith.index_cast %c1_i32_9 : i32 to index
    scf.for %arg2 = %0 to %1 step %2 {
      %c0_i32_10 = arith.constant 0 : i32
      %c0_i32_11 = arith.co

In [16]:
# GCD using WHILE loop (Euclidean algorithm)
def gcd(a: int32, b: int32) -> int32:
    x: int32 = a
    y: int32 = b
    while y > 0:
        temp: int32 = y
        y = x % y
        x = temp
    return x

s = allo.customize(gcd)
print("=== MLIR for gcd (WHILE loop with two state vars) ===")
print(s.module)
print("\n=== Building DSLX ===")
code = s.build(target='xls')
print(code)
code.interpret()


=== MLIR for gcd (WHILE loop with two state vars) ===
module {
  func.func @gcd(%arg0: i32, %arg1: i32) -> i32 attributes {itypes = "ss", otypes = "s"} {
    %alloc = memref.alloc() {name = "x"} : memref<i32>
    affine.store %arg0, %alloc[] {to = "x"} : memref<i32>
    %alloc_0 = memref.alloc() {name = "y"} : memref<i32>
    affine.store %arg1, %alloc_0[] {to = "y"} : memref<i32>
    scf.while : () -> () {
      %1 = affine.load %alloc_0[] {from = "y"} : memref<i32>
      %c0_i32 = arith.constant 0 : i32
      %c0_i32_1 = arith.constant 0 : i32
      %2 = arith.cmpi sgt, %1, %c0_i32_1 : i32
      scf.condition(%2)
    } do {
      %1 = affine.load %alloc_0[] {from = "y"} : memref<i32>
      %alloc_1 = memref.alloc() {name = "temp"} : memref<i32>
      affine.store %1, %alloc_1[] {to = "temp"} : memref<i32>
      %2 = affine.load %alloc[] {from = "x"} : memref<i32>
      %3 = affine.load %alloc_0[] {from = "y"} : memref<i32>
      %4 = arith.remsi %2, %3 : i32
      affine.store %4, %a