Skip to content

Commit

Permalink
slice_scatter decomposition (#2519)
Browse files Browse the repository at this point in the history
  • Loading branch information
apbose committed May 30, 2024
1 parent 6cc61b4 commit 6152607
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 0 deletions.
39 changes: 39 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch._decomp import register_decomposition
from torch._ops import OpOverload
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim

from ._decomposition_groups import (
ENABLED_TORCH_DECOMPOSITIONS,
Expand Down Expand Up @@ -174,6 +175,44 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)


@register_torch_trt_decomposition(
torch.ops.aten.slice_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def slice_scatter_decomposition(
input_tensor: torch.Tensor,
src_tensor: torch.Tensor,
dim: int,
start: Optional[int] = None,
end: Optional[int] = None,
step: Optional[int] = None,
):
dim_size = input_tensor.shape[dim]
start = get_positive_dim(start, input_tensor.shape[dim])
if end is None:
end = dim_size
end = get_positive_dim(end, input_tensor.shape[dim])
if step is None:
step = 1

src_dim = src_tensor.shape
# step == 0 is not a valid torch case
# also src_dim should be equal to slice dimension

if start == 0 and end == dim_size and step == 1:
return src_tensor

cat_tensors = []
index_tensor_shape = []
for i, src_each_dim in enumerate(list(src_dim)):
if i != dim:
index_tensor_shape.append(src_each_dim)
for index in range(start, end, step):
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.long))
index_tensor = torch.stack(cat_tensors, dim).cuda()
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src_tensor)
return output_tensor


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
195 changes: 195 additions & 0 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,201 @@ def forward(self, x):
f"The optimized model results shape and torch model results shape should be equal in empty_like",
)

def test_lowering_slice_scatter_dimOne_module(self):
class sliceScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, start=None, end=None, step=1):
y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {
torch.ops.aten.scatter.src,
}
unexpected_ops = {torch.ops.aten.select_scatter}

inputs = [torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda(), 1, 6, None, 1]

fx_graph = torch.fx.symbolic_trace(sliceScatter())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Slice_scatter TRT outputs don't match with the original model.",
)

def test_lowering_slice_scatter_dimZero_StepTwo_module(self):
class sliceScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, start, end, step):
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {
torch.ops.aten.scatter.src,
}
unexpected_ops = {torch.ops.aten.slice_scatter}

inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 2, 6, 2]

fx_graph = torch.fx.symbolic_trace(sliceScatter())

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Slice_scatter TRT outputs don't match with the original model.",
)

def test_lowering_slice_scatter_dimOne_3d_module(self):
class sliceScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, start, end, step):
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {
torch.ops.aten.scatter.src,
}
unexpected_ops = {torch.ops.aten.slice_scatter}

inputs = [
torch.zeros(8, 8, 8).cuda(),
torch.ones(8, 2, 8).cuda(),
1,
6,
None,
1,
]

fx_graph = torch.fx.symbolic_trace(sliceScatter())

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Slice_scatter TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()

0 comments on commit 6152607

Please sign in to comment.