Skip to content

device non grid forloops cannot subscript #598

@oulgen

Description

@oulgen
        @helion.kernel
        def kernel(As: list[torch.Tensor]) -> torch.Tensor:
            out = torch.zeros_like(As[0])
            for tile in hl.tile(out.size()):
                for i in range(len(As)):
                    a = As[i]
                    out[tile] += a[tile]
            return out

        args = [torch.randn(16, device=DEVICE) for _ in range(4)]
        code, result = code_and_output(kernel, (args,))
        torch.testing.assert_close(result, sum(args))

doesnt work because type_prop thinks that i is GridIndexType and there's no getitem for tensor with GridIndexType.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions