diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 7a3a3c90ede..9f60581b985 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -259,63 +259,85 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceSelectWithViewOpPass(ExportPass): +class ReplaceSelectWithViewOpPass(RemoveOrReplacePassInterface): """ If the size along the select dim is 1, then the select op can be replaced by view op. """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.select_copy.int: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.select_copy.int] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Get the input tensor and shapes + in_tensor_node = node.args[0] + assert isinstance(in_tensor_node, torch.fx.Node) + in_shape = in_tensor_node.meta["val"].shape + out_shape = node.meta["val"].shape - # Glean the shape of input and output tensor - in_tensor = args[0].to_tensor() - in_shape = in_tensor.shape - out_shape = meta["val"].shape # Get the select dimension - select_dim = args[1] if args[1] >= 0 else args[1] + len(in_shape) + select_dim = node.args[1] + assert isinstance(select_dim, int) + select_dim = select_dim if select_dim >= 0 else select_dim + len(in_shape) if in_shape[select_dim] == 1: - # Return a view op with the new shape - view_args = (args[0], list(out_shape)) - return super().call_operator( - exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta - ) - return super().call_operator(op, args, kwargs, meta) + # Replace with view op with the new shape + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(node.args[0], list(out_shape)), + ) + # Important to copy metadata + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True + + return False @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceMMWithAddMMPass(ExportPass): +class ReplaceMMWithAddMMPass(RemoveOrReplacePassInterface): """ This pass replaces mm with addmm by introducing a zero bias. mm is not supported, so this is an opt_level=0 pass. """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.mm.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mm.default] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # The mm op has two args: input, mat2 - assert len(args) == 2 - X, mat2 = args + assert len(node.args) == 2 + X, mat2 = node.args + assert isinstance(X, torch.fx.Node) + assert isinstance(mat2, torch.fx.Node) # Create a zero bias tensor, and insert it as a graph buffer before the # current node - mat2_tensor = mat2.to_tensor() + mat2_tensor = mat2.meta["val"] bias_size = mat2_tensor.size(1) - zero_bias = super().call_operator( - exir_ops.edge.aten.full.default, - ([bias_size], 0.0), - {"dtype": torch.float32}, - meta, - ) + + with node.graph.inserting_before(node): + zero_bias = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=([bias_size], 0.0), + kwargs={"dtype": torch.float32}, + ) + zero_bias.meta = node.meta # Replace mm with addmm new_args = (zero_bias, X, mat2) - return super().call_operator( - exir_ops.edge.aten.addmm.default, new_args, kwargs, meta - ) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.addmm.default, + args=new_args, + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -578,28 +600,33 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplacePadWithCatPass(ExportPass): +class ReplacePadWithCatPass(RemoveOrReplacePassInterface): """ Replace constant pad nd op that does padding on outer-most dimension with Cat(left_padding_constant_tensor, X, right_padding_constant_tensor) """ - def call_operator(self, op, args, kwargs, meta): - if op != exir_ops.edge.aten.constant_pad_nd.default: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.constant_pad_nd.default] - assert len(args) >= 2 - input_node, orig_padding = args[:2] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + assert len(node.args) >= 2 + input_node, orig_padding = node.args[:2] + assert isinstance(input_node, torch.fx.Node) # if there is no padding, this op will be treated in removal pass. if not orig_padding: - return super().call_operator(op, args, kwargs, meta) + return False - value = 0 if len(args) == 2 else args[2] + value = 0 if len(node.args) == 2 else node.args[2] - arg_shape = input_node.to_tensor().shape + arg_shape = input_node.meta["val"].shape - padding = orig_padding + ([0] * (len(orig_padding) % 2 != 0)) + # Convert orig_padding to a list for manipulation + # pyre-ignore[6]: Argument type + padding_list = list(orig_padding) + padding = padding_list + ([0] * (len(padding_list) % 2 != 0)) assert len(padding) >= 2 (left_padding_size, right_padding_size) = padding[-2:] # Replace only if constant_pad_nd is along the innermost padding dimension. @@ -608,41 +635,47 @@ def call_operator(self, op, args, kwargs, meta): or left_padding_size < 0 or right_padding_size < 0 ): - return super().call_operator(op, args, kwargs, meta) + return False cat_tensors = [] dim = len(arg_shape) - len(padding) // 2 + graph = node.graph + # add left_padding if left_padding_size > 0: left_padding_shape = ( arg_shape[:dim] + (left_padding_size,) + arg_shape[dim + 1 :] ) - left_padding_node = super().call_operator( - exir_ops.edge.aten.full.default, - ( - left_padding_shape, - value, - ), - {"dtype": torch.float32}, - meta, - ) + with graph.inserting_before(node): + left_padding_node = graph.call_function( + exir_ops.edge.aten.full.default, + args=( + left_padding_shape, + value, + ), + kwargs={"dtype": torch.float32}, + ) + left_padding_node.meta = node.meta cat_tensors.append(left_padding_node) + # input_node cat_tensors.append(input_node) + # right_padding if right_padding_size > 0: right_padding_shape = ( arg_shape[:dim] + (right_padding_size,) + arg_shape[dim + 1 :] ) - right_padding_node = super().call_operator( - exir_ops.edge.aten.full.default, - ( - right_padding_shape, - value, - ), - {"dtype": torch.float32}, - meta, - ) + with graph.inserting_before(node): + right_padding_node = graph.call_function( + exir_ops.edge.aten.full.default, + args=( + right_padding_shape, + value, + ), + kwargs={"dtype": torch.float32}, + ) + right_padding_node.meta = node.meta cat_tensors.append(right_padding_node) assert len(cat_tensors) == 1 + (left_padding_size > 0) + ( @@ -650,12 +683,15 @@ def call_operator(self, op, args, kwargs, meta): ) new_args = (cat_tensors, dim) - return super().call_operator( - exir_ops.edge.aten.cat.default, - new_args, - kwargs, - meta, - ) + with graph.inserting_before(node): + new_node = graph.call_function( + exir_ops.edge.aten.cat.default, + args=new_args, + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -818,7 +854,7 @@ def call_operator(self, op, args, kwargs, meta): @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class ReplaceTrivialConvWithLinear(ExportPass): +class ReplaceTrivialConvWithLinear(RemoveOrReplacePassInterface): """ In nn.Conv1d, the operand shapes are: input - [batch, in_channels, in_length] @@ -840,30 +876,41 @@ class ReplaceTrivialConvWithLinear(ExportPass): exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, } - def call_operator(self, op, args, kwargs, meta): - if op not in self.trivial_conv_op_to_linear_op: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return list(self.trivial_conv_op_to_linear_op.keys()) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Parse the necessary args of the convolution node. Both convolution # and quantized_conv have the same first 8 args. The quantized op has # extra args holding at least the zero point and scale of input, weight, bias, # and output tensor. + assert isinstance(node.target, EdgeOpOverload) quantized_op = ( - op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor - or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor + node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor + or node.target == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor ) - assert (len(args) == 7 and not quantized_op) or ( - len(args) >= 12 and quantized_op + assert (len(node.args) == 7 and not quantized_op) or ( + len(node.args) >= 12 and quantized_op ), "Inconsistent args for convolution" - (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7] + (in_tensor, weight, bias, stride, padding, dilation, groups) = node.args[0:7] - # Glean the shapes of input, weight, and output - in_shape = in_tensor.to_tensor().shape + assert isinstance(in_tensor, torch.fx.Node) + assert isinstance(weight, torch.fx.Node) - weight_shape = weight.to_tensor().shape - out_shape = meta["val"].shape + # Glean the shapes of input, weight, and output + in_shape = in_tensor.meta["val"].shape + weight_shape = weight.meta["val"].shape + out_shape = node.meta["val"].shape assert None not in {in_shape, weight_shape, out_shape} + # pyre-ignore[6]: Argument type for iteration + stride_list = list(stride) + # pyre-ignore[6]: Argument type for iteration + padding_list = list(padding) + # pyre-ignore[6]: Argument type for iteration + dilation_list = list(dilation) + # Check the condition under which conv can be replaced by linear: (1) this # should not be a depthwise convolution; (2) the padding, stride, and dilation # should be standard; (3) The [channels, height, width] of input must match the @@ -872,37 +919,40 @@ def call_operator(self, op, args, kwargs, meta): # by linear. if ( groups != 1 - or any(x != 0 for x in padding) - or any(x != 1 for x in stride) - or any(x != 1 for x in dilation) + or any(x != 0 for x in padding_list) + or any(x != 1 for x in stride_list) + or any(x != 1 for x in dilation_list) or (list(in_shape[1:]) != list(weight_shape[1:])) ): - return super().call_operator(op, args, kwargs, meta) + return False # Reshape the weight to [out_channels, in_channels * X] K = math.prod(weight_shape[1:]) - # Weight is always a ProxyValue, so we need a view_copy operation - linear_weight = super().call_operator( - exir_ops.edge.aten.view_copy.default, - ( - weight, - [weight_shape[0], K], - ), - kwargs, - meta, - ) + graph = node.graph + + # Weight is always a Node, so we need a view_copy operation + with graph.inserting_before(node): + linear_weight = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=( + weight, + [weight_shape[0], K], + ), + ) + linear_weight.meta = node.meta # Reshape the input from 3d to 2d tensor - in_view = super().call_operator( - exir_ops.edge.aten.view_copy.default, - ( - in_tensor, - [in_shape[0], K], - ), - kwargs, - meta, - ) + with graph.inserting_before(node): + in_view = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=( + in_tensor, + [in_shape[0], K], + ), + ) + in_view.meta = node.meta + # Create the linear node, which multiplies the 2d input and weight # tensors, and adds the 1d bias to produce a 2d output. if quantized_op: @@ -912,13 +962,14 @@ def call_operator(self, op, args, kwargs, meta): bias_scale, out_scale, out_zero_point, - ) = args[7:12] + ) = node.args[7:12] # If the multiplier and shift tensors are provided, use them. - if len(args) >= 14: - out_multiplier = args[12] - out_shift = args[13] + if len(node.args) >= 14: + out_multiplier = node.args[12] + out_shift = node.args[13] # If not, compute them. else: + # pyre-ignore[58]: Division operands requantize_scale = bias_scale / out_scale (out_multiplier, out_shift) = quantize_tensor_multiplier( requantize_scale @@ -936,21 +987,23 @@ def call_operator(self, op, args, kwargs, meta): ) else: linear_args = (in_view, linear_weight, bias) + with graph.inserting_before(node): + linear_res = graph.call_function( + self.trivial_conv_op_to_linear_op[cast(EdgeOpOverload, node.target)], + args=linear_args, + ) + linear_res.meta = node.meta - linear_res = super().call_operator( - self.trivial_conv_op_to_linear_op[op], - linear_args, - kwargs, - meta, - ) # Reshape the output of linear from 2d to 3d tensor - out_res = super().call_operator( - exir_ops.edge.aten.view_copy.default, - (linear_res, list(out_shape)), - kwargs, - meta, - ) - return out_res + with graph.inserting_before(node): + out_res = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(linear_res, list(out_shape)), + ) + out_res.meta = node.meta + + node.replace_all_uses_with(out_res) + return True def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int: @@ -1089,7 +1142,7 @@ def call_operator( @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class ReplaceConvWithIm2RowAndLinear(ExportPass): +class ReplaceConvWithIm2RowAndLinear(RemoveOrReplacePassInterface): """ Replace convolution where groups=1 with im2row followed by a linear op. """ @@ -1104,44 +1157,57 @@ class ReplaceConvWithIm2RowAndLinear(ExportPass): exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, } - def call_operator(self, op, args, kwargs, meta): - if op not in self.conv_op_to_linear_op: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return list(self.conv_op_to_linear_op.keys()) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Get the relevant args from convolution node. + assert isinstance(node.target, EdgeOpOverload) quantized_op = ( - op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor - or op == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor + node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor + or node.target == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor ) - assert (len(args) == 7 and not quantized_op) or ( - len(args) >= 12 and quantized_op + assert (len(node.args) == 7 and not quantized_op) or ( + len(node.args) >= 12 and quantized_op ), "Inconsistent args for convolution" - (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7] + (in_tensor, weight, bias, stride, padding, dilation, groups) = node.args[0:7] + + assert isinstance(in_tensor, torch.fx.Node) + assert isinstance(weight, torch.fx.Node) # We do not replace depthwise convolution with gemm yet. if groups != 1: - return super().call_operator(op, args, kwargs, meta) + return False + + weight_shape = weight.meta["val"].shape + + # pyre-ignore[6]: Argument type for iteration + stride_list = list(stride) + # pyre-ignore[6]: Argument type for iteration + padding_list = list(padding) + # pyre-ignore[6]: Argument type for iteration + dilation_list = list(dilation) - weight_shape = weight.to_tensor().shape # If this is a pointwise convolution, im2col will start dominating the # runtime. So we call convolution op for this case. if ( all(x == 1 for x in weight_shape[2:]) - and all(x == 1 for x in stride) - and all(x == 0 for x in padding) - and all(x == 1 for x in dilation) + and all(x == 1 for x in stride_list) + and all(x == 0 for x in padding_list) + and all(x == 1 for x in dilation_list) ): - return super().call_operator(op, args, kwargs, meta) + return False # Get the shapes - out_shape = meta["val"].shape + out_shape = node.meta["val"].shape assert None not in {weight_shape, out_shape} # Determine if the convolution is NCHW or NHWC. The NHWC, i.e., the # channel_last layout is specified by the channel_last arg of conv # op, which is either the last argument (15th) or implicitely False # if the op is quantized, or the last argument if not. - channel_last = op == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor + channel_last = node.target == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor # The weight tensor is [out_channels, in_channels, X] for NCHW layout, # and [out_channels, X, in_channels] for NHWC layout. Here, X is the # kernel_width for conv1d, and X = kernel_height * kernel_width for @@ -1150,49 +1216,51 @@ def call_operator(self, op, args, kwargs, meta): # If the convolution op was quantized, we need the input tensor's # zero_point for im2row. Otherwise in_zero_point defaults to a zero # tensor. - in_zero_point = args[7] if quantized_op else 0 + in_zero_point = node.args[7] if quantized_op else 0 # im2row expects every kernel parameter to be 2d. So we extend the # parameters for conv1d by prepending their default values. - stride = ([1] + stride) if len(stride) == 1 else stride - padding = ([0] + padding) if len(padding) == 1 else padding - dilation = ([1] + dilation) if len(dilation) == 1 else dilation + stride_2d = ([1] + stride_list) if len(stride_list) == 1 else stride_list + padding_2d = ([0] + padding_list) if len(padding_list) == 1 else padding_list + dilation_2d = ([1] + dilation_list) if len(dilation_list) == 1 else dilation_list kernel_size = ([1] + kernel_size) if len(kernel_size) == 1 else kernel_size # Assert that kernel size does not have a 0 assert 0 not in kernel_size + graph = node.graph + # Create an im2row node with the input. This will create a 2d matrix of # shape [out_height*out_weight, X*in_channels]. X is as defined in the # comment above. im2row_args = ( in_tensor, kernel_size, - dilation, - padding, - stride, + dilation_2d, + padding_2d, + stride_2d, in_zero_point, channel_last, ) - im2row = super().call_operator( - exir_ops.edge.cadence.im2row.per_tensor, - im2row_args, - kwargs, - meta, - ) + with graph.inserting_before(node): + im2row = graph.call_function( + exir_ops.edge.cadence.im2row.per_tensor, + args=im2row_args, + ) + im2row.meta = node.meta # Get the product of the >2 dims of the weight K = math.prod(weight_shape[1:]) - # Weight is always a ProxyValue, so we need a view_copy operation - linear_weight = super().call_operator( - exir_ops.edge.aten.view_copy.default, - ( - weight, - [weight_shape[0], K], - ), - kwargs, - meta, - ) + # Weight is always a Node, so we need a view_copy operation + with graph.inserting_before(node): + linear_weight = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=( + weight, + [weight_shape[0], K], + ), + ) + linear_weight.meta = node.meta # Create the linear node, which multiplies the 3d input with 2d weight # tensors with bias addition. The outermost dimension of the input is @@ -1204,13 +1272,14 @@ def call_operator(self, op, args, kwargs, meta): bias_scale, out_scale, out_zero_point, - ) = args[7:12] + ) = node.args[7:12] # If the multiplier and shift tensors are provided, use them. - if len(args) >= 14: - out_multiplier = args[12] - out_shift = args[13] + if len(node.args) >= 14: + out_multiplier = node.args[12] + out_shift = node.args[13] # If not, compute them. else: + # pyre-ignore[58]: Division operands requantize_scale = bias_scale / out_scale (out_multiplier, out_shift) = quantize_tensor_multiplier( requantize_scale @@ -1228,30 +1297,36 @@ def call_operator(self, op, args, kwargs, meta): ) else: linear_args = (im2row, linear_weight, bias) - linear_res = super().call_operator( - self.conv_op_to_linear_op[op], - linear_args, - kwargs, - meta, - ) + + with graph.inserting_before(node): + linear_res = graph.call_function( + self.conv_op_to_linear_op[cast(EdgeOpOverload, node.target)], + args=linear_args, + ) + linear_res.meta = node.meta + # The output of linear is a 3D tensor. However, the output is in NHWC # layout by default, because an input vector of size X is multiplied # with the weight matrix, i.e., column values are contiguous. If the # channel_last is False, we want to transpose this output. if not channel_last: - linear_res = super().call_operator( - exir_ops.edge.aten.transpose_copy.int, - (linear_res, 1, 2), - kwargs, - meta, - ) + with graph.inserting_before(node): + linear_res = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, + args=(linear_res, 1, 2), + ) + linear_res.meta = node.meta + # And finally, we want to view the 3D output of linear op as 4D tensor - return super().call_operator( - exir_ops.edge.aten.view_copy.default, - (linear_res, list(out_shape)), - kwargs, - meta, - ) + with graph.inserting_before(node): + out_res = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(linear_res, list(out_shape)), + ) + out_res.meta = node.meta + + node.replace_all_uses_with(out_res) + return True @register_cadence_pass(CadencePassAttribute(opt_level=2)) @@ -2152,49 +2227,49 @@ def call_operator(self, op, args, kwargs, meta): @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding(ExportPass): +class ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding( + RemoveOrReplacePassInterface +): """ Replace torch.ops.quantized_decomposed.embedding_byte.dtype with torch.ops.cadence.quantized_embedding_byte """ - def call_operator( - self, - op: torch._ops.OpOverload, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - # Check if the op is the quantized_decomposed.embedding_byte.dtype - if ( - op == exir_ops.edge.quantized_decomposed.embedding_byte.default - or op == exir_ops.edge.quantized_decomposed.embedding_byte.dtype - ): - # Replace with cadence.quantized_embedding_byte - if len(args) < 6: + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.quantized_decomposed.embedding_byte.default, + exir_ops.edge.quantized_decomposed.embedding_byte.dtype, + ] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Replace with cadence.quantized_embedding_byte + if len(node.args) < 6: + raise AssertionError( + f"Expected 6 arguments for embedding_byte, got {len(node.args)}" + ) + embedding = node.args[0] + scales = node.args[1] + weight_zero_points = node.args[2] + indices = node.args[5] + + if node.target == exir_ops.edge.quantized_decomposed.embedding_byte.dtype: + dtype = node.kwargs.get("dtype", None) + if dtype is not None and dtype != torch.float32: raise AssertionError( - f"Expected 6 arguments for embedding_byte, got {len(args)}" + f"Unsupported output dtype for embedding_byte: {dtype}" ) - embedding = args[0] - scales = args[1] - weight_zero_points = args[2] - indices = args[5] - if op == exir_ops.edge.quantized_decomposed.embedding_byte.dtype: - dtype = kwargs.get("dtype", None) - if dtype is not None and dtype != torch.float32: - raise AssertionError( - f"Unsupported output dtype for embedding_byte: {dtype}" - ) - - new_args = (embedding, scales, weight_zero_points, indices, False) - new_kwargs = {} - return super().call_operator( + + new_args = (embedding, scales, weight_zero_points, indices, False) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( exir_ops.edge.cadence.quantized_embedding_byte.default, - new_args, - new_kwargs, - meta, + args=new_args, ) - return super().call_operator(op, args, kwargs, meta) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True class CommonReplacePasses: diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 663342b4e0e..cc891da4f46 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -779,8 +779,17 @@ def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]) -> N op=exir_ops.edge.aten.constant_pad_nd.default, args=(x, padding), ) + + gm_before = copy.deepcopy(original_gm) p = ReplacePadWithCatPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x] + validate(gm_before, graph_after_passes, inputs, "ReplacePadWithCatPass") + self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 1, @@ -1013,8 +1022,17 @@ def test_replace_mm_with_addmm(self) -> None: op=exir_ops.edge.aten.mm.default, args=(x, y), ) + + gm_before = copy.deepcopy(original_gm) p = ReplaceMMWithAddMMPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x, y] + validate(gm_before, graph_after_passes, inputs, "ReplaceMMWithAddMMPass") + self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.addmm.default), @@ -1144,8 +1162,15 @@ def test_replace_conv1d_with_linear(self) -> None: args=(x, weights, bias, [1], [0], [1], 1), ) + gm_before = copy.deepcopy(original_gm) p2 = ReplaceTrivialConvWithLinear() - graph_after_passes = cast(PassResult, p2(original_gm)).graph_module + result = cast(PassResult, p2(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x, weights, bias] + validate(gm_before, graph_after_passes, inputs, "ReplaceTrivialConvWithLinear") # Assert that conv1d is trivially converted to linear self.assertEqual( @@ -1164,17 +1189,24 @@ def test_replace_conv1d_with_linear(self) -> None: @torch.no_grad() def test_replace_conv2d_with_linear(self) -> None: - x = torch.randn(1, 96, 7, 7) - weights = torch.randn(192, 96, 7, 7) - bias = torch.randn(192) + x = torch.randn(1, 6, 7, 7) + weights = torch.randn(12, 6, 7, 7) + bias = torch.randn(12) original_gm = single_op_builder( placeholders=(x, weights, bias), op=exir_ops.edge.cadence.conv2d.default, args=(x, weights, bias, [1, 1], [0, 0], [1, 1], 1), ) + gm_before = copy.deepcopy(original_gm) p2 = ReplaceTrivialConvWithLinear() - graph_after_passes = cast(PassResult, p2(original_gm)).graph_module + result = cast(PassResult, p2(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x, weights, bias] + validate(gm_before, graph_after_passes, inputs, "ReplaceTrivialConvWithLinear") # Assert that conv2d is trivially converted to linear self.assertEqual( @@ -1193,16 +1225,26 @@ def test_replace_conv2d_with_linear(self) -> None: @torch.no_grad() def test_replace_conv2d_with_im2row_and_linear(self) -> None: - x = torch.randn(1, 96, 47, 37) - weights = torch.randn(192, 96, 7, 7) - bias = torch.randn(192) + x = torch.randn(1, 2, 5, 5) + weights = torch.randn(3, 2, 4, 4) + bias = torch.randn(3) original_gm = single_op_builder( placeholders=(x, weights, bias), op=exir_ops.edge.cadence.conv2d.default, args=(x, weights, bias, [1, 1], [0, 0], [1, 1], 1), ) + + gm_before = copy.deepcopy(original_gm) p = ReplaceConvWithIm2RowAndLinear() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x, weights, bias] + validate( + gm_before, graph_after_passes, inputs, "ReplaceConvWithIm2RowAndLinear" + ) # Assert that the convolution is converted to im2row + linear self.assertEqual( @@ -1231,8 +1273,17 @@ def test_replace_select_with_view( op=exir_ops.edge.aten.select_copy.int, args=(x, dim, index), ) + + gm_before = copy.deepcopy(original_gm) p = ReplaceSelectWithViewOpPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x] + validate(gm_before, graph_after_passes, inputs, "ReplaceSelectWithViewOpPass") + # Assert that select op was replaced with view op self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 @@ -2174,8 +2225,20 @@ def test_replace_quantized_embedding( kwargs={"dtype": torch.float32} if name == "dtype" else {}, ) + gm_before = copy.deepcopy(original_gm) p = ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [embedding, scales, indices] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceTorchQuantizedEmbeddingWithCadenceQuantizedEmbedding", + ) self.assertEqual( count_node(