-
Notifications
You must be signed in to change notification settings - Fork 37
Open
Labels
Description
Hi Helion team,
I was learning Helion and wrote a toy code implementing the following indirect Einsum:
C(m,n) = val(m,k) * B(col(m,k),n)
It seems that Helion doesn’t currently support indirect indexing with more than two dimensions.
For example, I can write something like B[col[tile_m], tile_n]
, which works fine, but B[col[tile_m, tile_k], tile_n]
fails. Supporting this would be really helpful.
As a workaround, I flattened B
and used flat indexing, but the code still fails. Is this a bug or something I misunderstood the language?
Reproducer:
import torch
import helion
import helion.language as hl
@helion.kernel()
def test(
col: torch.Tensor, # [M, K] int64
val: torch.Tensor, # [M, K] fp32
B: torch.Tensor, # [K, N] fp32
) -> torch.Tensor: # [M, N] fp32
M, K = col.shape
_, N = B.shape
out_dtype = torch.promote_types(val.dtype, B.dtype)
C = torch.empty((M, N), dtype=out_dtype, device=B.device)
B_flat = B.reshape(-1) # [K*N]
for tile_m, tile_n in hl.tile([M, N]):
# [tile_m, tile_n]
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
for tile_k in hl.tile(K):
# [tile_m, tile_k]
cols_2d = col[tile_m, tile_k]
# [tile_m, tile_k, tile_n]
B_slice = hl.load(
B_flat,
[(cols_2d * N)[:, :, None] + tile_n.index[None, None, :]]
)
# [tile_m, tile_k]
vals_2d = val[tile_m, tile_k]
# [tile_m, tile_k, tile_n]
contrib = vals_2d[:, :, None] * B_slice
# [tile_m, tile_n]
contrib = contrib.sum(dim=1)
# [tile_m, tile_n]
acc = acc + contrib
C[tile_m, tile_n] = acc.to(out_dtype)
return C
M,K,N = 128,128,128
col = torch.randint(K, (M,K), device="cuda")
val = torch.rand((M,K), device="cuda")
B = torch.rand((K,N), device="cuda")
test(col,val,B)
Error message
Traceback (most recent call last):
File "/home/jaeyeon/test.py", line 48, in <module>
test(col,val,B)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/runtime/kernel.py", line 285, in __call__
return self.bind(args)(*args)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/runtime/kernel.py", line 168, in bind
bound_kernel = BoundKernel(self, args)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/runtime/kernel.py", line 351, in __init__
self.host_function: HostFunction = HostFunction(
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/host_function.py", line 113, in __init__
propagate_types(self)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 2291, in propagate_types
prop.visit(stmt)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 1589, in visit
type_info = visitor(node)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 2128, in visit_For
body = self._loop_body(node.body)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 2092, in _loop_body
self.visit(stmt)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 1589, in visit
type_info = visitor(node)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 2128, in visit_For
body = self._loop_body(node.body)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 2092, in _loop_body
self.visit(stmt)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 1589, in visit
type_info = visitor(node)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 2016, in visit_Assign
type_info = self.visit(node.value)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 1589, in visit
type_info = visitor(node)
File "/home/jaeyeon/miniconda3/envs/helion/lib/python3.10/site-packages/helion/_compiler/type_propagation.py", line 1881, in visit_BinOp
raise exc.TorchOpTracingError(e) from e
helion.exc.TorchOpTracingError: RuntimeError: The size of tensor a (u4) must match the size of tensor b (u6) at non-singleton dimension 1)
While processing:
File "/home/jaeyeon/test.py", line 36, in test
acc = acc + contrib