Skip to content

Commit

Permalink
select_scatter decomp
Browse files Browse the repository at this point in the history
Changing lowering of select_scatter

select_scatter changes

select_scatter changes

Test case for select_scatter

removing assertion

adding select_scatter decomp lowering ops in test

implement select_scatter using slice_scatter

adding test case

linting commit fix
  • Loading branch information
apbose committed May 31, 2024
1 parent 6152607 commit ee8330a
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 3 deletions.
13 changes: 13 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,19 @@ def slice_scatter_decomposition(
return output_tensor


@register_torch_trt_decomposition(
torch.ops.aten.slice_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
)
def select_scatter_decomposition(
input_tensor: torch.Tensor,
src_tensor: torch.Tensor,
dim: int,
index: int,
) -> torch.Tensor:
src_tensor = torch.unsqueeze(src_tensor, dim)
return torch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1)


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
195 changes: 192 additions & 3 deletions tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def forward(self, x, src, dim, start=None, end=None, step=1):
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand Down Expand Up @@ -593,7 +593,7 @@ def forward(self, x, src, dim, start, end, step):
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand Down Expand Up @@ -663,7 +663,7 @@ def forward(self, x, src, dim, start, end, step):
"torch_compile",
inputs,
min_block_size=1,
truncate_long_and_double=True,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand All @@ -679,6 +679,195 @@ def forward(self, x, src, dim, start, end, step):
f"Slice_scatter TRT outputs don't match with the original model.",
)

def test_lowering_select_scatter_dimZero_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, index):
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
unexpected_ops = {
torch.ops.aten.select_scatter.default,
torch.ops.aten.slice_scatter.default,
}

inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 0, 0]

fx_graph = torch.fx.symbolic_trace(selectScatter())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_and_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)

def test_lowering_select_scatter_dimOne_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, index):
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
unexpected_ops = {
torch.ops.aten.select_scatter.default,
torch.ops.aten.slice_scatter.default,
}

inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 1, 0]

fx_graph = torch.fx.symbolic_trace(selectScatter())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)

def test_lowering_select_scatter_multidimension_module(self):
class selectScatter(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, x, src, dim, index):
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
return y

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
unexpected_ops = {
torch.ops.aten.select_scatter.default,
torch.ops.aten.slice_scatter.default,
}

inputs = [torch.zeros(2, 3, 4).cuda(), torch.ones(2, 4).cuda(), 1, 0]

fx_graph = torch.fx.symbolic_trace(selectScatter())
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
truncate_double=True,
pass_through_build_failures=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
torch_model_results = fx_graph(*inputs).detach().cpu()

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
f"Select_scatter TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()

0 comments on commit ee8330a

Please sign in to comment.