Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@
"quantized_fully_connected.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
)
lib.define("where_Scalar(Tensor condition, float self, float other) -> (Tensor Z)")
lib.define(
"where_Scalar.out(Tensor condition, float self, float other, *, Tensor(a!) out) -> Tensor(a!)"
)

# ------------------------------------ #
# Migrated from custom_ops.yaml #
Expand Down Expand Up @@ -935,3 +939,12 @@ def transposed_im2row_meta(
output_size = torch.Size((batch_size, output_length, n_output_plane))

return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::where_Scalar")
def where_Scalar_meta(
condition: torch.Tensor,
self: float,
other: float,
) -> torch.Tensor:
return condition.new_empty(condition.size(), dtype=torch.float32)
49 changes: 49 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,6 +2062,54 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
return PassResult(ret.graph_module, modified)


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class ReplaceWhereWithFullArgsWithWhereScalar(ExportPass):
"""Replaces where ops using two full ops as tensors with a scalar
version.
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
exir_ops.edge.aten.where.self,
}:
return super().call_operator(op, args, kwargs, meta)

# If the args are not full ops, bail
# pyre-ignore[16]: `ProxyValue` has no attribute `node`.
if (args[1].node.target != exir_ops.edge.aten.full.default) or (
args[2].node.target != exir_ops.edge.aten.full.default
):
return super().call_operator(op, args, kwargs, meta)

# If one of the full ops is a different size than than the cond tensor, we need to broadcast. Bail.
if (
# pyre-ignore[16]: `ProxyValue` has no attribute `node`.
list(args[0].to_tensor().shape) != args[1].node.args[0]
or list(args[0].to_tensor().shape) != args[2].node.args[0]
):
return super().call_operator(op, args, kwargs, meta)

# Get the scalar values from the full ops
scalar_value_1 = args[1].node.args[1]
scalar_value_2 = args[2].node.args[1]

# Replace the where op with a scalar where op
return super().call_operator(
exir_ops.edge.cadence.where_Scalar.default,
(args[0], scalar_value_1, scalar_value_2),
kwargs,
meta,
)

return super().call_operator(op, args, kwargs, meta)


# This class encapsulates all the functions that replace/switch one op in the
# graph with another.
class CadenceReplaceOpsInGraph:
Expand Down Expand Up @@ -2100,4 +2148,5 @@ class CadenceReplaceOpsInGraph:
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
ReplaceWhereWithFullArgsWithWhereScalar,
]
84 changes: 84 additions & 0 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
ReplaceTCopyWithTransposePass,
ReplaceTransposedConvWithLinearPass,
ReplaceTrivialConvWithLinear,
ReplaceWhereWithFullArgsWithWhereScalar,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
Expand Down Expand Up @@ -1217,6 +1218,89 @@ def forward(self, x: torch.Tensor):
1,
)

def test_replace_aten_where_with_cadence_where_Scalar(self):
class WhereScalarModel(torch.nn.Module):
def forward(self, cond: torch.Tensor):
a = torch.ops.aten.full.default(a_shape, val1)
b = torch.ops.aten.full.default(b_shape, val2)
return torch.where(cond > 0, a, b)

cond_shape, a_shape, b_shape, val1, val2 = [(4, 8), (4, 8), (4, 8), 0.0, 1.0]
cond = torch.randn(cond_shape)

graph_module = (
export_to_edge(WhereScalarModel(), (cond,)).exported_program().graph_module
)

p = ReplaceWhereWithFullArgsWithWhereScalar()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module

# Assert that aten.where op was replaced by a
# cadence.where_Scalar op
self.assertEqual(
count_node(
graph_after_passes,
exir_ops.edge.aten.where.self,
),
0,
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.cadence.where_Scalar.default),
1,
)

class WhereBroadcastModel(torch.nn.Module):
def forward(self, cond: torch.Tensor):
a = torch.ops.aten.full.default(a_shape, val1)
b = torch.ops.aten.full.default(b_shape, val2)
return torch.where(cond > 0, a, b)

# a tensor bigger than cond and b
cond_shape, a_shape, b_shape, val1, val2 = [(8,), (4, 8), (8,), 0.0, 1.0]
cond = torch.randn(cond_shape)

graph_module = (
export_to_edge(WhereBroadcastModel(), (cond,))
.exported_program()
.graph_module
)

p = ReplaceWhereWithFullArgsWithWhereScalar()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module

# Assert that aten.where op is still in the graph since where_Scalar does not
# support broadcast
self.assertEqual(
count_node(
graph_after_passes,
exir_ops.edge.aten.where.self,
),
1,
)

# cond tensor bigger than a and b
cond_shape, a_shape, b_shape, val1, val2 = [(4, 8), (8,), (8,), 0.0, 1.0]
cond = torch.randn(cond_shape)

graph_module = (
export_to_edge(WhereBroadcastModel(), (cond,))
.exported_program()
.graph_module
)

p = ReplaceWhereWithFullArgsWithWhereScalar()
graph_after_passes = cast(PassResult, p(graph_module)).graph_module

# Assert that aten.where op is still in the graph since where_Scalar does not
# support broadcast
self.assertEqual(
count_node(
graph_after_passes,
exir_ops.edge.aten.where.self,
),
1,
)


class TestReplaceIm2rowWithViewPass(unittest.TestCase):
def test_no_replacement_for_conv(self):
Expand Down
Loading