Skip to content

Commit

Permalink
implement select_scatter using slice_scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed Apr 5, 2024
1 parent c4ff602 commit cce29b7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
6 changes: 2 additions & 4 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,8 @@ def select_scatter_decomposition(
dim: int,
index: int,
) -> torch.Tensor:
unbind_tensors = torch.unbind(input_tensor, dim)
unbind_tensors_list = list(unbind_tensors)
unbind_tensors_list[index] = src_tensor
return torch.stack(tuple(unbind_tensors_list), dim)
src_tensor = torch.unsqueeze(src_tensor, dim)
return torch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1)


def get_decompositions(
Expand Down
22 changes: 10 additions & 12 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,13 +430,11 @@ def forward(self, x, src, dim, index):
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {
torch.ops.aten.slice.Tensor,
torch.ops.aten.squeeze.dim,
torch.ops.aten.cat.default,
torch.ops.aten.reshape.default,
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
unexpected_ops = {
torch.ops.aten.select_scatter.default,
torch.ops.aten.slice_scatter.default,
}
unexpected_ops = {torch.ops.aten.select_scatter.default}

inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 0, 0]

Expand Down Expand Up @@ -469,6 +467,7 @@ def forward(self, x, src, dim, index):
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand All @@ -494,13 +493,11 @@ def forward(self, x, src, dim, index):
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {
torch.ops.aten.slice.Tensor,
torch.ops.aten.squeeze.dim,
torch.ops.aten.unsqueeze.default,
torch.ops.aten.cat.default,
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
unexpected_ops = {
torch.ops.aten.select_scatter.default,
torch.ops.aten.slice_scatter.default,
}
unexpected_ops = {torch.ops.aten.select_scatter.default}

inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 1, 0]

Expand Down Expand Up @@ -533,6 +530,7 @@ def forward(self, x, src, dim, index):
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand Down

0 comments on commit cce29b7

Please sign in to comment.