diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 8de0af7311d..9e95460f2f5 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -43,7 +43,6 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue -from torch._subclasses import FakeTensor from torch.fx.node import Argument # A map to represent ops that: @@ -90,11 +89,7 @@ def replace_logical_nop_where_with_where( # Get the third arg node and its input logical_not_node = node.args[0] - logical_not_input_tensor = ( - logical_not_node.args[0].to_tensor() - if isinstance(logical_not_node.args[0], ProxyValue) - else logical_not_node.args[0] - ) + logical_not_input_tensor = logical_not_node.args[0].to_tensor() # If the logical_not input is not a boolean tensor, bail. if logical_not_input_tensor.meta["spec"].dtype != torch.bool: @@ -263,7 +258,7 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) # Glean the shape of input and output tensor - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + in_tensor = args[0].to_tensor() in_shape = in_tensor.shape out_shape = meta["val"].shape # Get the select dimension @@ -295,7 +290,7 @@ def call_operator(self, op, args, kwargs, meta): # Create a zero bias tensor, and insert it as a graph buffer before the # current node - mat2_tensor = mat2.to_tensor() if isinstance(mat2, ProxyValue) else mat2 + mat2_tensor = mat2.to_tensor() bias_size = mat2_tensor.size(1) zero_bias = super().call_operator( exir_ops.edge.aten.full.default, @@ -410,7 +405,7 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) # Get the old dim and new dim order - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + in_tensor = args[0].to_tensor() old_dims = tuple(range(in_tensor.dim())) new_dims = args[1] @@ -488,11 +483,7 @@ def call_operator(self, op, args, kwargs, meta): repeats = args[1] # Glean the shapes of input tensor - in_shape = list( - in_tensor.to_tensor().shape - if isinstance(in_tensor, ProxyValue) - else in_tensor.shape - ) + in_shape = list(in_tensor.to_tensor().shape) # If the size of repeats is more than the dimensionality of the tensor, # the output of repeat will be a higher-dimensional tensor. We reshape @@ -793,15 +784,9 @@ def call_operator(self, op, args, kwargs, meta): (in_tensor, weight, bias, stride, padding, dilation, groups) = args[0:7] # Glean the shapes of input, weight, and output - in_shape = ( - in_tensor.to_tensor().shape - if isinstance(in_tensor, ProxyValue) - else in_tensor.shape - ) + in_shape = in_tensor.to_tensor().shape - weight_shape = ( - weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape - ) + weight_shape = weight.to_tensor().shape out_shape = meta["val"].shape assert None not in {in_shape, weight_shape, out_shape} @@ -823,26 +808,16 @@ def call_operator(self, op, args, kwargs, meta): # Reshape the weight to [out_channels, in_channels * X] K = math.prod(weight_shape[1:]) - # If weight is a ProxyValue, linear_weight needs to be the output of a - # graph operation (in this case a view_copy op) to be an explicit ProxyValue - # as well. If not, the view op can be done directly on the tensor. - linear_weight = ( - super().call_operator( - exir_ops.edge.aten.view_copy.default, - ( - weight, - [weight_shape[0], K], - ), - kwargs, - meta, - ) - if isinstance(weight, ProxyValue) - else weight.contiguous().view(weight_shape[0], K) + # Weight is always a ProxyValue, so we need a view_copy operation + linear_weight = super().call_operator( + exir_ops.edge.aten.view_copy.default, + ( + weight, + [weight_shape[0], K], + ), + kwargs, + meta, ) - # From the previous check, if linear_weight is a FakeTensor, it has to be - # a constant (if not, it would be a ProxyValue). Mark it as such. - if isinstance(linear_weight, FakeTensor): - linear_weight.constant = linear_weight # Reshape the input from 3d to 2d tensor in_view = super().call_operator( @@ -865,11 +840,7 @@ def call_operator(self, op, args, kwargs, meta): out_zero_point, ) = args[7:12] # If the multiplier and shift tensors are provided, use them. - if ( - len(args) >= 14 - and isinstance(args[12], ProxyValue) - and isinstance(args[13], ProxyValue) - ): + if len(args) >= 14: out_multiplier = args[12] out_shift = args[13] # If not, compute them. @@ -1073,9 +1044,7 @@ def call_operator(self, op, args, kwargs, meta): if groups != 1: return super().call_operator(op, args, kwargs, meta) - weight_shape = ( - weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape - ) + weight_shape = weight.to_tensor().shape # If this is a pointwise convolution, im2col will start dominating the # runtime. So we call convolution op for this case. if ( @@ -1114,8 +1083,6 @@ def call_operator(self, op, args, kwargs, meta): {"dtype": torch.int32}, meta, ) - if isinstance(in_tensor.to_tensor(), FakeTensor) - else get_zero_point(in_tensor.to_tensor()) ) if quantized_op else torch.tensor(0, dtype=torch.int32) @@ -1151,26 +1118,16 @@ def call_operator(self, op, args, kwargs, meta): # Get the product of the >2 dims of the weight K = math.prod(weight_shape[1:]) - # If weight is a ProxyValue, linear_weight needs to be the output of a - # graph operation (in this case a view_copy op) to be an explicit ProxyValue - # as well. If not, the view op can be done directly on the tensor. - linear_weight = ( - super().call_operator( - exir_ops.edge.aten.view_copy.default, - ( - weight, - [weight_shape[0], K], - ), - kwargs, - meta, - ) - if isinstance(weight, ProxyValue) - else weight.contiguous().view(weight_shape[0], K) + # Weight is always a ProxyValue, so we need a view_copy operation + linear_weight = super().call_operator( + exir_ops.edge.aten.view_copy.default, + ( + weight, + [weight_shape[0], K], + ), + kwargs, + meta, ) - # From the previous check, if linear_weight is a FakeTensor, it has to be - # a constant (if not, it would be a ProxyValue). Mark it as such. - if isinstance(linear_weight, FakeTensor): - linear_weight.constant = linear_weight # Create the linear node, which multiplies the 3d input with 2d weight # tensors with bias addition. The outermost dimension of the input is @@ -1184,11 +1141,7 @@ def call_operator(self, op, args, kwargs, meta): out_zero_point, ) = args[7:12] # If the multiplier and shift tensors are provided, use them. - if ( - len(args) >= 14 - and isinstance(args[12], ProxyValue) - and isinstance(args[13], ProxyValue) - ): + if len(args) >= 14: out_multiplier = args[12] out_shift = args[13] # If not, compute them. @@ -1276,9 +1229,7 @@ def call_operator(self, op, args, kwargs, meta): # Get the shapes out_shape = meta["val"].shape - weight_shape = ( - weight.to_tensor().shape if isinstance(weight, ProxyValue) else weight.shape - ) + weight_shape = weight.to_tensor().shape assert None not in {weight_shape, out_shape} # Determine if the transposed_convolution is NCHW or NHWC. The NHWC, @@ -1332,26 +1283,16 @@ def call_operator(self, op, args, kwargs, meta): # Reshape the weight to [out_channels, in_channels * X] K = math.prod(weight_shape[1:]) - # If weight is a ProxyValue, linear_weight needs to be the output of a - # graph operation (in this case a view_copy op) to be an explicit ProxyValue - # as well. If not, the view op can be done directly on the tensor. - linear_weight = ( - super().call_operator( - exir_ops.edge.aten.view_copy.default, - ( - weight, - [weight_shape[0], K], - ), - kwargs, - meta, - ) - if isinstance(weight, ProxyValue) - else weight.contiguous().view(weight_shape[0], K) + # Weight is always a ProxyValue, so we need a view_copy operation + linear_weight = super().call_operator( + exir_ops.edge.aten.view_copy.default, + ( + weight, + [weight_shape[0], K], + ), + kwargs, + meta, ) - # From the previous check, if linear_weight is a FakeTensor, it has to be - # a constant (if not, it would be a ProxyValue). Mark it as such. - if isinstance(linear_weight, FakeTensor): - linear_weight.constant = linear_weight # Create the linear node, which multiplies the 3d input with 2d weight # tensors with bias addition. The outermost dimension of the input is @@ -1422,7 +1363,7 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) # Get the input tensor and shape - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + in_tensor = args[0].to_tensor() in_shape = in_tensor.shape # Get the output tensor shape out_shape = meta["val"].shape @@ -1491,7 +1432,7 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) # Extract the input tensor - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + in_tensor = args[0].to_tensor() leading_dims = math.prod(in_tensor.shape[:-1]) # If the tensor is not a vector, do nothing. if leading_dims != 1: @@ -1557,11 +1498,7 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator( exir_ops.edge.aten.full.default, ( - ( - args[0].to_tensor().shape - if isinstance(args[0], ProxyValue) - else args[0].shape - ), + args[0].to_tensor().shape, args[1], ), {}, @@ -1602,59 +1539,57 @@ class ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass(ExportPass): replaced_scalar_args: dict[ EdgeOpOverloadPacket, tuple[EdgeOpOverload, Sequence[int]] ] = { - exir_ops.edge.cadence.quantized_add: ( + exir_ops.edge.cadence.quantized_add.default: ( exir_ops.edge.cadence.quantized_add.per_tensor, [1, 2, 4, 5], ), - exir_ops.edge.cadence.quantized_conv2d_nchw: ( + exir_ops.edge.cadence.quantized_conv2d_nchw.default: ( exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, [8, 9, 12, 13], ), - exir_ops.edge.cadence.quantized_conv2d_nhwc: ( + exir_ops.edge.cadence.quantized_conv2d_nhwc.default: ( exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, [8, 9, 12, 13], ), - exir_ops.edge.cadence.quantized_fully_connected: ( + exir_ops.edge.cadence.quantized_fully_connected.default: ( exir_ops.edge.cadence.quantized_fully_connected.per_tensor, [4, 5, 6], ), - exir_ops.edge.cadence.quantized_layer_norm: ( + exir_ops.edge.cadence.quantized_layer_norm.default: ( exir_ops.edge.cadence.quantized_layer_norm.per_tensor, [1, 2], ), - exir_ops.edge.cadence.quantized_linear: ( + exir_ops.edge.cadence.quantized_linear.default: ( exir_ops.edge.cadence.quantized_linear.per_tensor, [4, 5, 6], ), - exir_ops.edge.cadence.quantized_relu: ( + exir_ops.edge.cadence.quantized_relu.default: ( exir_ops.edge.cadence.quantized_relu.per_tensor, [1, 3, 4], ), - exir_ops.edge.cadence.im2row: ( + exir_ops.edge.cadence.im2row.default: ( exir_ops.edge.cadence.im2row.per_tensor, [5], ), - exir_ops.edge.cadence.requantize: ( + exir_ops.edge.cadence.requantize.default: ( exir_ops.edge.cadence.requantize.per_tensor, [1, 2, 3, 4], ), } def call_operator(self, op, args, kwargs, meta): - op_edge_overload_packet = get_edge_overload_packet(op) - - if op_edge_overload_packet not in self.replaced_scalar_args: + if op not in self.replaced_scalar_args: return super().call_operator(op, args, kwargs, meta) # Get all the args that need to be replaced. - new_op, args_to_be_replaced = self.replaced_scalar_args[op_edge_overload_packet] + new_op, args_to_be_replaced = self.replaced_scalar_args[op] + + if op == new_op: + return super().call_operator(op, args, kwargs, meta) updated_args = list(args) for op_arg_index in args_to_be_replaced: arg = args[op_arg_index] - if not isinstance(arg, ProxyValue): - return super().call_operator(op, args, kwargs, meta) - if not arg.is_tensor(): return super().call_operator(op, args, kwargs, meta) @@ -1696,7 +1631,7 @@ def call_operator(self, op, args, kwargs, meta): # Determine if the op is avg_pool1d or avg_pool2d avg_pool1d: bool = op == exir_ops.edge.aten.avg_pool1d.default # Get the input tensor - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + in_tensor = args[0].to_tensor() # Replace avg_pool2d with custom avg_pool2d, and if the input tensor is # quantized, pass its zero_point tensor as arg to the custom avg_pool2d. @@ -2062,7 +1997,7 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) # Get the second tensor - Y_tensor = Y_arg.to_tensor() if isinstance(Y_arg, ProxyValue) else Y_arg + Y_tensor = Y_arg.to_tensor() # Concretize the bias zero_bias = super().call_operator( exir_ops.edge.aten.full.default, @@ -2071,19 +2006,14 @@ def call_operator(self, op, args, kwargs, meta): meta, ) - # If the arg was a ProxyValue, insert a transpose node. Otherwise we - # can simply transpose the tensor inplace. - if isinstance(Y_arg, ProxyValue): - transpose_args = (Y_arg, -1, -2) - transpose_node = super().call_operator( - exir_ops.edge.aten.transpose_copy.int, - transpose_args, - {}, - meta, - ) - Y_arg_t = transpose_node - else: - Y_arg_t = Y_tensor.transpose(-1, -2) + # Y_arg is always a ProxyValue, so we insert a transpose node + transpose_args = (Y_arg, -1, -2) + Y_arg_t = super().call_operator( + exir_ops.edge.aten.transpose_copy.int, + transpose_args, + {}, + meta, + ) # Construct the new args, and return the transposed matmult op new_args = ( @@ -2178,7 +2108,7 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) # Get the input tensor - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + in_tensor = args[0].to_tensor() # Permute NCHW to NHWC for computation in_tensor_permuted = in_tensor.permute(0, 2, 3, 1) in_tensor_shape = in_tensor_permuted.shape diff --git a/backends/cadence/aot/simplify_ops.py b/backends/cadence/aot/simplify_ops.py index bf836f09044..92c14cb0f5d 100644 --- a/backends/cadence/aot/simplify_ops.py +++ b/backends/cadence/aot/simplify_ops.py @@ -19,7 +19,7 @@ from executorch.backends.cadence.aot.utils import rebind from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ExportPass, ProxyValue +from executorch.exir.pass_base import ExportPass @register_cadence_pass(CadencePassAttribute(opt_level=0)) @@ -75,7 +75,7 @@ def call_operator(self, op, args, kwargs, meta): slice_scatter = op == exir_ops.edge.aten.slice_scatter.default # Parse the arguments # Extract the tensor to be sliced, and the slicing dimension - in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + in_tensor = args[0].to_tensor() dim = args[1 + slice_scatter] if len(args) > 1 + slice_scatter else 0 # Make dim non-negative dim = dim if dim >= 0 else dim + in_tensor.dim()