Skip to content

Helion indirect indexing with higher dimensions #546

@nullplay

Description

@nullplay

Hi Helion team,

I was learning Helion and wrote a toy code implementing the following indirect Einsum:

C(m,n) = val(m,k) * B(col(m,k),n)

It seems that Helion doesn’t currently support indirect indexing with more than two dimensions.
For example, I can write something like B[col[tile_m], tile_n], which works fine, but B[col[tile_m, tile_k], tile_n] fails. Supporting this would be really helpful.

As a workaround, I flattened B and used flat indexing, but the code still fails. Is this a bug or something I misunderstood the language?

Reproducer:

import torch
import helion
import helion.language as hl

@helion.kernel()
def test(
    col: torch.Tensor,   # [M, K] int64
    val: torch.Tensor,   # [M, K] fp32
    B: torch.Tensor,     # [K, N] fp32
) -> torch.Tensor:       # [M, N] fp32
    M, K = col.shape
    _, N = B.shape
    out_dtype = torch.promote_types(val.dtype, B.dtype)
    C = torch.empty((M, N), dtype=out_dtype, device=B.device)
    B_flat = B.reshape(-1)  # [K*N]

    for tile_m, tile_n in hl.tile([M, N]):
        # [tile_m, tile_n]
        acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)

        for tile_k in hl.tile(K):
            # [tile_m, tile_k]
            cols_2d = col[tile_m, tile_k]
            # [tile_m, tile_k, tile_n]
            B_slice = hl.load(
                B_flat,
                [(cols_2d * N)[:, :, None] + tile_n.index[None, None, :]]
            )
            # [tile_m, tile_k]
            vals_2d = val[tile_m, tile_k]
            # [tile_m, tile_k, tile_n]
            contrib = vals_2d[:, :, None] * B_slice
            # [tile_m, tile_n]
            contrib = contrib.sum(dim=1)
            # [tile_m, tile_n]
            acc = acc + contrib

        C[tile_m, tile_n] = acc.to(out_dtype)

    return C


M,K,N = 128,128,128
col = torch.randint(K, (M,K), device="cuda")
val = torch.rand((M,K), device="cuda")
B = torch.rand((K,N), device="cuda")

test(col,val,B)

Error message

Traceback (most recent call last):
  File "/home/jaeyeon/test.py", line 48, in <module>
    test(col,val,B)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/runtime/kernel.py", line 285, in __call__
    return self.bind(args)(*args)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/runtime/kernel.py", line 168, in bind
    bound_kernel = BoundKernel(self, args)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/runtime/kernel.py", line 351, in __init__
    self.host_function: HostFunction = HostFunction(
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/host_function.py", line 113, in __init__
    propagate_types(self)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 2291, in propagate_types
    prop.visit(stmt)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 1589, in visit
    type_info = visitor(node)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 2128, in visit_For
    body = self._loop_body(node.body)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 2092, in _loop_body
    self.visit(stmt)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 1589, in visit
    type_info = visitor(node)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 2128, in visit_For
    body = self._loop_body(node.body)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 2092, in _loop_body
    self.visit(stmt)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 1589, in visit
    type_info = visitor(node)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 2016, in visit_Assign
    type_info = self.visit(node.value)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 1589, in visit
    type_info = visitor(node)
  File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 1881, in visit_BinOp
    raise exc.TorchOpTracingError(e) from e
helion.exc.TorchOpTracingError: RuntimeError: The size of tensor a (u4) must match the size of tensor b (u6) at non-singleton dimension 1)
While processing:
  File "/home/jaeyeon/test.py", line 36, in test
    acc = acc + contrib

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions