Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tensor slice in inductor to stop falling back for aten.index #111015

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4430,6 +4430,50 @@ def fn(a):
if self.device != "cpu":
self.assertTrue(same(arg1, arg2))

def test_tensor_index_slice(self):
def fn(a):
x = torch.tensor([1, 2], device=self.device)
y = torch.tensor([2, 3], device=self.device)
return [
a[x, y],
a[:, x, y],
a[:, x, y, :],
a[x, :, y],
a[:, x, :, y, :],
]

a = torch.arange(3 * 4 * 5 * 6 * 7, device=self.device).view(3, 4, 5, 6, 7)
refs = fn(a)
tests = torch.compile(fn)(a)
for ref, test in zip(refs, tests):
torch.testing.assert_close(ref, test)

def test_tensor_index_put_slice(self):
def fn(a, version):
x = torch.tensor([1, 2], device=self.device, dtype=torch.int32)
oulgen marked this conversation as resolved.
Show resolved Hide resolved
y = torch.tensor([2, 3], device=self.device, dtype=torch.int32)

if version == 0:
a[x, y] = torch.zeros_like(a[x, y])
elif version == 1:
a[:, x, y] = torch.zeros_like(a[:, x, y])
elif version == 2:
a[:, x, y, :] = torch.zeros_like(a[:, x, y, :])
elif version == 3:
a[x, :, y] = torch.zeros_like(a[x, :, y])
elif version == 4:
a[:, x, :, y, :] = torch.zeros_like(a[:, x, :, y, :])

return a

a = torch.arange(3 * 4 * 5 * 6 * 7, device=self.device, dtype=torch.int32).view(
3, 4, 5, 6, 7
)
for i in range(5):
ref = fn(torch.clone(a), i)
test = torch.compile(fn)(torch.clone(a), i)
torch.testing.assert_close(ref, test)

def test_indirect_load_broadcast(self):
def fn(in_ptr0, in_ptr1, in_ptr2):
return torch.gather(in_ptr1, 0, in_ptr2) + in_ptr0
Expand Down
1 change: 0 additions & 1 deletion test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def run(*ex, **kwargs):
"test_empty1_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_empty2_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_empty_strided_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_index3_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_bucketize_dynamic_shapes": TestFailure("cpu"),
"test_bucketize_default_kwargs_dynamic_shapes": TestFailure("cpu"),
"test_bucketize_int_dynamic_shapes": TestFailure("cpu"),
Expand Down
154 changes: 90 additions & 64 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,63 +2571,100 @@ def check_and_broadcast_indices(indices, device):
if x.get_device() != device:
raise NotImplementedError("Fallback when indices is on a different device")
new_indices[i] = x
output_dim = len(x.get_size())
start_offset = 0
# only support None at start or end for now
tmp = list(new_indices)
while tmp and tmp[-1] is None:
tmp.pop()
while tmp and tmp[0] is None:
tmp.pop(0)
start_offset += 1
if any((i is None) for i in tmp):
raise NotImplementedError("Fallback when None is in the middle of indices")

end_offset = output_dim + start_offset
return new_indices, start_offset, end_offset
return new_indices, valid_idxs


def index_output_size_and_inner_fn(
x_size, indices, valid_idxs, tensor_size, indices_loaders, indexed_size, x_loader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think valid_idxs is a bit of a bad name. Perhaps "tensor_indices"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I think the mixup of idxs and indices is a bit confusing.

):
# Note that behavior of indexing differs when there are non consecutive
oulgen marked this conversation as resolved.
Show resolved Hide resolved
# tensors. In this case, the tensor index is pulled to the beginning.
non_consecutive_tensors = False
for previous, current in zip(valid_idxs, valid_idxs[1:]):
if current - previous != 1:
non_consecutive_tensors = True

output_size = [x_size[i] for i, val in enumerate(indices) if val is None]
output_size = [*output_size, *x_size[len(output_size) + len(valid_idxs) :]]

first_tensor_index = valid_idxs[0]
if non_consecutive_tensors:
output_size = tensor_size + output_size
else:
output_size = (
output_size[:first_tensor_index]
+ tensor_size
+ output_size[first_tensor_index:]
)

def fn(idx):
assert len(idx) == len(output_size)
assert len(indices_loaders) == len(indexed_size)

rank = len(tensor_size)
new_index = []
first_tensor_index = valid_idxs[0]
start_offset = 0 if non_consecutive_tensors else first_tensor_index
next_idx = 0
for i in range(valid_idxs[-1] + 1):
if i == start_offset:
next_idx += rank
if indices[i] is None:
assert next_idx < len(idx)
new_index.append(idx[next_idx])
next_idx += 1
else:
loader = indices_loaders[i]
assert loader is not None
size = indexed_size[i]
new_index.append(
ops.indirect_indexing(
loader(idx[start_offset : start_offset + rank]),
size,
check=check,
)
)
new_index = [
*new_index,
*idx[next_idx:],
]
return new_index if x_loader is None else x_loader(new_index)

return output_size, fn


def index_impl(x, indices, check):
assert isinstance(indices, (list, tuple))
x_loader = x.make_loader()
indices, start_offset, end_offset = check_and_broadcast_indices(
indices, x.get_device()
)
indices, valid_idxs = check_and_broadcast_indices(indices, x.get_device())
assert len(valid_idxs) > 0, "Must have at least one valid idx"

indices_sizes = [i.get_size() for i in indices if i is not None]
indices_loaders = [i.make_loader() for i in indices if i is not None]
indices_loaders = [i.make_loader() if i is not None else None for i in indices]
# no guards on output size, all the guards are set in broadcast_tensors

output_size = list(indices_sizes[0])
# We can use the first one since they are all required to be the same size
tensor_size = list(indices[valid_idxs[0]].get_size())

x_size = x.get_size()

indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None]
if 0 in indexed_size and 0 not in output_size:
indexed_size = [x_size[i] for i in range(len(indices))]
if 0 in indexed_size and 0 not in tensor_size:
raise IndexError("index is out of bounds for dimension with size 0")

output_size = [
*x_size[:start_offset],
*output_size,
*x_size[start_offset + len(indices_loaders) :],
]

def fn(idx):
assert len(idx) == len(output_size)
assert len(indices_loaders) == len(indexed_size)
new_index = [
ops.indirect_indexing(
loader(idx[start_offset:end_offset]), size, check=check
)
for loader, size in zip(indices_loaders, indexed_size)
]
new_index = [*idx[:start_offset], *new_index, *idx[end_offset:]]
return x_loader(new_index)
output_size, inner_fn = index_output_size_and_inner_fn(
x_size,
indices,
valid_idxs,
tensor_size,
indices_loaders,
indexed_size,
x_loader,
)

return Pointwise.create(
device=x.get_device(),
dtype=x.get_dtype(),
inner_fn=fn,
inner_fn=inner_fn,
ranges=output_size,
)

Expand Down Expand Up @@ -2722,14 +2759,14 @@ def index_put_impl_(self, indices, values, accumulate, check):
return self

values = to_dtype(values, self.get_dtype())

try:
indices, start_offset, end_offset = check_and_broadcast_indices(
indices, self.get_device()
)
# Note that code will only get here when dtype is uint32
indices, valid_idxs = check_and_broadcast_indices(indices, self.get_device())
except NotImplementedError:
return index_put_fallback(self, indices, values, accumulate)
indices_sizes = [i.get_size() for i in indices if i is not None]
indices_loaders = [i.make_loader() for i in indices if i is not None]

indices_loaders = [i.make_loader() if i is not None else None for i in indices]

assert isinstance(self, TensorBox)
self.realize()
Expand All @@ -2738,34 +2775,23 @@ def index_put_impl_(self, indices, values, accumulate, check):
if x_ndim == 0:
self = view(self, [1])

output_size = list(indices_sizes[0])
expected_vals_size = [
*x_size[:start_offset],
*output_size,
*x_size[start_offset + len(indices_sizes) :],
]
indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None]
# We can use the first one since they are all required to be the same size
tensor_size = list(indices[valid_idxs[0]].get_size())
indexed_size = [x_size[i] for i in range(len(indices))]

expected_vals_size, inner_fn = index_output_size_and_inner_fn(
x_size, indices, valid_idxs, tensor_size, indices_loaders, indexed_size, None
)

values = expand(values, expected_vals_size)
# all guards are set above during broadcast_tensors and expand

def output_indexer(index):
assert len(index) == len(expected_vals_size)
new_index = [
ops.indirect_indexing(
loader(index[start_offset:end_offset]), size, check=check
)
for loader, size in zip(indices_loaders, indexed_size)
]
new_index = [*index[:start_offset], *new_index, *index[end_offset:]]
return new_index

scatter = ir.Scatter(
device=self.get_device(),
dtype=self.get_dtype(),
inner_fn=values.make_loader(),
ranges=expected_vals_size, # iter_ranges,
output_indexer=output_indexer,
output_indexer=inner_fn,
scatter_mode="atomic_add" if accumulate else None,
)
buffer = ir.ComputedBuffer(
Expand Down
Loading