In [1]:
%env WAVE_CACHE_ON=0
%env PYTHONPATH=/home/tim/iree/build/compiler/bindings/python:/home/tim/iree/build/runtime/bindings/python
%env IREE_SAVE_TEMPS=/home/tim/iree-turbine/dump

env: WAVE_CACHE_ON=0
env: PYTHONPATH=/home/tim/iree/build/compiler/bindings/python:/home/tim/iree/build/runtime/bindings/python
env: IREE_SAVE_TEMPS=/home/tim/iree-turbine/dump


In [2]:
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.utils.general_utils import (
    get_default_scheduling_params,
)
from iree.turbine.kernel.wave.scheduling.schedule import SchedulingType
from iree.turbine.kernel.wave.compile import WaveCompileOptions, wave_compile



def test_scanop_cumsum():
    M = tkl.sym.M
    N = tkl.sym.N
    BLOCK_M = 1
    BLOCK_N = 64
    ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE

    constraints: list[tkw.Constraint] = [
        tkw.HardwareConstraint(
            threads_per_wave=64,
            waves_per_block=(1, 1, 1),
            vector_shapes={M: 1, N: 64},
        )
    ]
    constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
    constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 0)]
    constraints += [tkw.WaveConstraint(M, BLOCK_M)]
    constraints += [tkw.WaveConstraint(N, BLOCK_N)]

    @tkw.wave(constraints)
    def scanop_cumsum(
        out: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f16],
        idx: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f16],
        src: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f16],
    ):
        idx_reg = tkw.read(idx)
        tkw.scatter_min(out, idx_reg, src, dim=N)

    options = WaveCompileOptions(
        subs={
            M: 1,
            N: 64,
            BLOCK_M: 1,
            BLOCK_N: 64,
            ADDRESS_SPACE: tkl.AddressSpace.GLOBAL_MEMORY.value,
        },
        canonicalize=True,
        compile_to_mlir=True,
    )
    scanop_cumsum = wave_compile(options, scanop_cumsum)
    print(scanop_cumsum.asm)

test_scanop_cumsum()



#translation = #iree_codegen.translation_info<pipeline = None workgroup_size = [64, 1, 1] subgroup_size = 64>
module attributes {transform.with_named_sequence} {
  stream.executable private @scanop_cumsum {
    stream.executable.export public @scanop_cumsum workgroups() -> (index, index, index) {
      %c1 = arith.constant 1 : index
      stream.return %c1, %c1, %c1 : index, index, index
    }
    builtin.module {
      func.func @scanop_cumsum(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding) attributes {translation_info = #translation} {
        return
      }
    }
  }
  func.func @isolated_benchmark(%arg0: tensor<1x64xf16>, %arg1: tensor<1x64xf16>, %arg2: tensor<1x64xf16>) {
    flow.dispatch @scanop_cumsum::@scanop_cumsum(%arg0, %arg1, %arg2) : (tensor<1x64xf16>, tensor<1x64xf16>, tensor<1x64xf16>) -> ()
    return
  }
}

