diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 605a8bb0321..9e604ae42aa 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -162,6 +162,10 @@ "quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" ) +lib.define("where_Scalar(Tensor condition, float self, float other) -> (Tensor Z)") +lib.define( + "where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)" +) # ------------------------------------ # # Migrated from custom_ops.yaml # @@ -935,3 +939,12 @@ def transposed_im2row_meta( output_size = torch.Size((batch_size, output_length, n_output_plane)) return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::where_Scalar") +def where_Scalar_meta( + condition: torch.Tensor, + self: float, + other: float, +) -> torch.Tensor: + return condition.new_empty(condition.size(), dtype=torch.float32) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index bda565a5904..050b27818d4 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2062,6 +2062,54 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: return PassResult(ret.graph_module, modified) +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceWhereWithFullArgsWithWhereScalar(ExportPass): + """Replaces where ops using two full ops as tensors with a scalar + version. + """ + + def call_operator( + self, + op, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op not in { + exir_ops.edge.aten.where.self, + }: + return super().call_operator(op, args, kwargs, meta) + + # If the args are not full ops, bail + # pyre-ignore[16]: `ProxyValue` has no attribute `node`. + if (args[1].node.target != exir_ops.edge.aten.full.default) or ( + args[2].node.target != exir_ops.edge.aten.full.default + ): + return super().call_operator(op, args, kwargs, meta) + + # If one of the full ops is a different size than than the cond tensor, we need to broadcast. Bail. + if ( + # pyre-ignore[16]: `ProxyValue` has no attribute `node`. + list(args[0].to_tensor().shape) != args[1].node.args[0] + or list(args[0].to_tensor().shape) != args[2].node.args[0] + ): + return super().call_operator(op, args, kwargs, meta) + + # Get the scalar values from the full ops + scalar_value_1 = args[1].node.args[1] + scalar_value_2 = args[2].node.args[1] + + # Replace the where op with a scalar where op + return super().call_operator( + exir_ops.edge.cadence.where_Scalar.default, + (args[0], scalar_value_1, scalar_value_2), + kwargs, + meta, + ) + + return super().call_operator(op, args, kwargs, meta) + + # This class encapsulates all the functions that replace/switch one op in the # graph with another. class CadenceReplaceOpsInGraph: @@ -2100,4 +2148,5 @@ class CadenceReplaceOpsInGraph: ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, ReplaceAtenAvgPoolWithJarvisAvgPoolPass, ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass, + ReplaceWhereWithFullArgsWithWhereScalar, ] diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 69a4d552a18..e40c26c0f4e 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -44,6 +44,7 @@ ReplaceTCopyWithTransposePass, ReplaceTransposedConvWithLinearPass, ReplaceTrivialConvWithLinear, + ReplaceWhereWithFullArgsWithWhereScalar, ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -1217,6 +1218,89 @@ def forward(self, x: torch.Tensor): 1, ) + def test_replace_aten_where_with_cadence_where_Scalar(self): + class WhereScalarModel(torch.nn.Module): + def forward(self, cond: torch.Tensor): + a = torch.ops.aten.full.default(a_shape, val1) + b = torch.ops.aten.full.default(b_shape, val2) + return torch.where(cond > 0, a, b) + + cond_shape, a_shape, b_shape, val1, val2 = [(4, 8), (4, 8), (4, 8), 0.0, 1.0] + cond = torch.randn(cond_shape) + + graph_module = ( + export_to_edge(WhereScalarModel(), (cond,)).exported_program().graph_module + ) + + p = ReplaceWhereWithFullArgsWithWhereScalar() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + # Assert that aten.where op was replaced by a + # cadence.where_Scalar op + self.assertEqual( + count_node( + graph_after_passes, + exir_ops.edge.aten.where.self, + ), + 0, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.where_Scalar.default), + 1, + ) + + class WhereBroadcastModel(torch.nn.Module): + def forward(self, cond: torch.Tensor): + a = torch.ops.aten.full.default(a_shape, val1) + b = torch.ops.aten.full.default(b_shape, val2) + return torch.where(cond > 0, a, b) + + # a tensor bigger than cond and b + cond_shape, a_shape, b_shape, val1, val2 = [(8,), (4, 8), (8,), 0.0, 1.0] + cond = torch.randn(cond_shape) + + graph_module = ( + export_to_edge(WhereBroadcastModel(), (cond,)) + .exported_program() + .graph_module + ) + + p = ReplaceWhereWithFullArgsWithWhereScalar() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + # Assert that aten.where op is still in the graph since where_Scalar does not + # support broadcast + self.assertEqual( + count_node( + graph_after_passes, + exir_ops.edge.aten.where.self, + ), + 1, + ) + + # cond tensor bigger than a and b + cond_shape, a_shape, b_shape, val1, val2 = [(4, 8), (8,), (8,), 0.0, 1.0] + cond = torch.randn(cond_shape) + + graph_module = ( + export_to_edge(WhereBroadcastModel(), (cond,)) + .exported_program() + .graph_module + ) + + p = ReplaceWhereWithFullArgsWithWhereScalar() + graph_after_passes = cast(PassResult, p(graph_module)).graph_module + + # Assert that aten.where op is still in the graph since where_Scalar does not + # support broadcast + self.assertEqual( + count_node( + graph_after_passes, + exir_ops.edge.aten.where.self, + ), + 1, + ) + class TestReplaceIm2rowWithViewPass(unittest.TestCase): def test_no_replacement_for_conv(self):