diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index de791851db..40ec9ac843 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -174,6 +174,7 @@ aten.full, aten.repeat, aten.var_mean, + aten.select_scatter, } torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = { aten._softmax.default, diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index b11d44cc10..9e0606b8b5 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -434,6 +434,7 @@ def forward(self, x, src, dim, index): torch.ops.aten.slice.Tensor, torch.ops.aten.squeeze.dim, torch.ops.aten.cat.default, + torch.ops.aten.reshape.default, } unexpected_ops = {torch.ops.aten.select_scatter.default} @@ -496,6 +497,7 @@ def forward(self, x, src, dim, index): expected_ops = { torch.ops.aten.slice.Tensor, torch.ops.aten.squeeze.dim, + torch.ops.aten.unsqueeze.default, torch.ops.aten.cat.default, } unexpected_ops = {torch.ops.aten.select_scatter.default}