Skip to content

[torchlib] Make index_put dynamic #2263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 36 additions & 84 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4218,23 +4218,9 @@
# ]
#
# Need to transpose the result of GatherND to match this axes ordering.
first_not_none_position = reordered_positions[0] # x_None_front_m + 1
starting_position_of_none_in_back = (
advanced_indexing_rank + first_not_none_position
) # x_None_back_1
result_rank = self_rank - len(not_none_indices) + advanced_indexing_rank
perm = [
*range(
advanced_indexing_rank, starting_position_of_none_in_back
), # None_front_1...x_None_back_1
*range(advanced_indexing_rank), # 0...len(broadcasted_shape)
*range(
starting_position_of_none_in_back,
result_rank,
), # None_back_1...None_back_m
]
inverse_positions = np.argsort(reordered_positions).tolist()

Check warning on line 4221 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4221

Added line #L4221 was not covered by tests

return op.Transpose(self, perm=perm)
return op.Transpose(self, perm=inverse_positions)

Check warning on line 4223 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4223

Added line #L4223 was not covered by tests


@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
Expand Down Expand Up @@ -4324,91 +4310,57 @@
@torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True)
def aten_index_put(
self: TReal,
indices: Sequence[INT64],
indices: Sequence[Optional[INT64]],
values: TReal,
accumulate: bool = False,
) -> TReal:
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor

See implementation of `torch.onnx.symbolic_opset11.index_put
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""

def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Remove ones until the rank of reshape_list matches values_shape.
while len(reshape_list) > len(values_shape) and 1 in reshape_list:
reshape_list.remove(1)

# Now ensure each dimension is broadcastable:
# This is mandatory when mixing basic and advanced indexing
# Example: data((10, 3, 4)), indices([[0, 1], :, [0, 1]]) values(2, 3)
# the reshape list should be : [[2, 1], [1, 3], [2, 1]]
for i, r in enumerate(reshape_list):
if r not in (1, values_shape[i]):
value_index = values_shape.index(r)
# Swap elements
# For the example above the current reshape list is [1, 2] for last dim,
# to make it broadcastable, we swap the elements
reshape_list[value_index], reshape_list[i] = r, 1

return reshape_list
"""index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"""

# Ensure the number of indices matches the tensor rank.
self_rank = len(self.shape)
if len(indices) < self_rank:
indices = list(indices) + [None] * (self_rank - len(indices))

# Get values shape
values_shape = tuple(values.shape)

index_vectors = []
for i in range(self_rank):
if indices[i] is None:
# For a full slice along dim i, create a range index [0, self.shape[i]).
idx = op.Range(0, self.shape[i], 1)
reshape_update = self.shape[i]
else:
idx = indices[i]
reshape_update = math.prod(idx.shape)
# when Index is more than 1D, flatten it and also the values shape
# Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
# Indices -> (2*4,) and values shape (2*4, 32)
if len(idx.shape) > 1:
values_shape = (reshape_update, *values_shape[len(idx.shape) :])

# Flatten index (always working with 1D index in each dim)
idx = op.Reshape(idx, [-1])

# Create a reshape pattern: one value per index dimension,
# with the current dimension set to the update size.
reshape_list = [1] * len(indices)
reshape_list[i] = reshape_update
# 1. Reorder input tensor so that None-indexed axes are last
# This logic is identical to the aten.index implementation.
reordered_positions = sorted(range(len(indices)), key=lambda i: (indices[i] is None, i))
remaining_dims = [i for i in range(self_rank) if i not in reordered_positions]
reordered_positions.extend(remaining_dims)

Check warning on line 4326 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4324-L4326

Added lines #L4324 - L4326 were not covered by tests

# Adjust the reshape list to match the values shape.
reshape_list = _make_reshape_list_broadcastable(reshape_list, values_shape)
# Transpose the input data to group the indexed dimensions first
transposed_self = op.Transpose(self, perm=reordered_positions)

Check warning on line 4329 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4329

Added line #L4329 was not covered by tests

# Reshape and expand the index.
idx = op.Reshape(idx, reshape_list, allowzero=True)
idx = op.Expand(idx, values_shape)
# 2. Prepare indices for ScatterND
# This logic is also identical.
not_none_indices = [idx for idx in indices if idx is not None]
broadcast_shape = _shape_of_broadcast_tensors(*not_none_indices)

Check warning on line 4334 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4333-L4334

Added lines #L4333 - L4334 were not covered by tests

# Flatten the index to 1D and unsqueeze to form a column vector.
idx = op.Reshape(idx, [-1])
idx = op.Unsqueeze(idx, axes=[1])
index_vectors.append(idx)
final_index_parts = []

Check warning on line 4336 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4336

Added line #L4336 was not covered by tests
for idx in not_none_indices:
# Unsqueeze is needed to make indices broadcastable to the common shape
expanded_idx = op.Expand(idx, broadcast_shape)
final_index_parts.append(op.Unsqueeze(expanded_idx, [-1]))

Check warning on line 4340 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4339-L4340

Added lines #L4339 - L4340 were not covered by tests

# Concatenate the index vectors along axis=1 to form the final indices.
new_index = op.Concat(*index_vectors, axis=1)
final_index = op.Concat(*final_index_parts, axis=-1)

Check warning on line 4342 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4342

Added line #L4342 was not covered by tests

# Flatten values to match the indices
flat_values = op.Reshape(values, [-1])
# 3. Prepare the 'updates' tensor (values)
# The 'values' tensor must be broadcast to match the shape of the
# broadcasted indices.
expanded_values = op.Expand(values, broadcast_shape)

Check warning on line 4347 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4347

Added line #L4347 was not covered by tests
# TODO: Handle None
expanded_values = op.Transpose(expanded_values, perm=reordered_positions)

Check warning on line 4349 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4349

Added line #L4349 was not covered by tests

# 4. Perform the scatter operation
if accumulate:
result = op.ScatterND(self, new_index, flat_values, reduction="add")
scattered_data = op.ScatterND(transposed_self, final_index, expanded_values, reduction="add")

Check warning on line 4353 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4353

Added line #L4353 was not covered by tests
else:
result = op.ScatterND(self, new_index, flat_values)
scattered_data = op.ScatterND(transposed_self, final_index, expanded_values)

Check warning on line 4355 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4355

Added line #L4355 was not covered by tests

return result
# 5. Restore original dimension order
# The output of ScatterND has the same shape as the transposed input.
# We must apply an "inverse" transpose to get the final result.
inverse_positions = np.argsort(reordered_positions).tolist()
final_output = op.Transpose(scattered_data, perm=inverse_positions)

Check warning on line 4361 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4360-L4361

Added lines #L4360 - L4361 were not covered by tests

return final_output

Check warning on line 4363 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L4363

Added line #L4363 was not covered by tests

@torch_op("aten::index_put", trace_only=True)
def aten_index_put_bool(
Expand Down
Loading