-
Notifications
You must be signed in to change notification settings - Fork 46
Open
Labels
Description
@helion.kernel
def kernel(As: list[torch.Tensor]) -> torch.Tensor:
out = torch.zeros_like(As[0])
for tile in hl.tile(out.size()):
for i in range(len(As)):
a = As[i]
out[tile] += a[tile]
return out
args = [torch.randn(16, device=DEVICE) for _ in range(4)]
code, result = code_and_output(kernel, (args,))
torch.testing.assert_close(result, sum(args))
doesnt work because type_prop thinks that i is GridIndexType and there's no getitem for tensor with GridIndexType.