diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py index c3c42ed483a..23481894f0d 100644 --- a/backends/qualcomm/builders/op_index_put.py +++ b/backends/qualcomm/builders/op_index_put.py @@ -1,14 +1,19 @@ import warnings +from collections import OrderedDict from typing import Dict import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper import numpy as np import torch -from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS +from executorch.backends.qualcomm.utils.constants import ( + QCOM_DATA, + QCOM_DTYPE, + QCOM_QUANT_ATTRS, +) from executorch.exir.dialects._ops import ops as exir_ops -from .node_visitor import NodeVisitor, QNN_TENSOR_TYPE_MAP +from .node_visitor import NodeVisitor, QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP from .node_visitor_manager import register_node_visitor from .qnn_constants import ( OpConcat, @@ -26,7 +31,7 @@ class IndexPutVisitor(NodeVisitor): def __init__(self, *args) -> None: super().__init__(*args) - def define_node( + def define_node( # noqa: C901 self, node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], @@ -37,6 +42,7 @@ def define_node( if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): quant_attrs = quant_attrs.copy() input_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + input_tensor = self.get_tensor(input_node, node) input_tensor_wrapper = self.define_tensor( input_node, @@ -46,52 +52,110 @@ def define_node( nodes_to_wrappers, ) - indicies_node = node.args[1] - index_node_dim = None - index_nodes = [] - index_tensors = [] + indices_nodes = ( + node.args[1] if isinstance(node.args[1], list) else [node.args[1]] + ) target_index = [] + all_range_index = OrderedDict() + index_dtype = [ + node.meta["val"].dtype for node in indices_nodes if node is not None + ][0] + + # preprocess: + # - broadcast dimension for multiple specified index + # - broadcast specified index if dimensions are not matched + max_indices_in_specified_index = 0 + for index, idx_node in enumerate(indices_nodes): + if isinstance(idx_node, torch.fx.Node): + last_specified_index_node = index + if max_indices_in_specified_index < idx_node.meta["val"].nelement(): + max_indices_in_specified_index = idx_node.meta["val"].nelement() # If there is None in a list, it means all range at that dimension - # E.g., indicies_node: [None, None, aten__to_copy_default_1] - if isinstance(indicies_node, list): - for index, idx_node in enumerate(indicies_node): - # First, collect the indice_node and index of None to construct the shape of index node - # E.g., shape of input: [1, 1024, 12, 64] - # For "None" axis (assume indicies_node: [None, None, aten__to_copy_default_1]), - # target_index: [1, 1024, x], x is the shape of index_tensor, index_node_dim: 2 - if isinstance(idx_node, torch.fx.Node): - index_nodes.append(idx_node) - index_tensors.append(self.get_tensor(idx_node, idx_node)) - target_index.extend(index_tensors[-1].size()) - index_node_dim = index - elif idx_node is None and index_node_dim is None: - # E.g., indicies_node: [None, aten__to_copy_default_1, None] - # Don't need to consider "None" after index_node. - target_index.append(input_tensor.size(index)) - else: - warnings.warn( - f"[QNN Delegate Op Builder]: Get the index {idx_node} that is neither a node nor None", - stacklevel=1, + for index, idx_node in enumerate(indices_nodes): + # First, collect the index_node and index of None to construct the shape of index node + # E.g., shape of input: [1, 1024, 12, 64] + # For "None" axis (assume indices_node: [None, None, aten__to_copy_default_1]), + # target_index: [1, 1024, x], x is the shape of index_tensor, index_node_dim: 2 + if isinstance(idx_node, torch.fx.Node): + # e.g. for case [index_node_0, None, index_node_1], nodes will have the same number of indices + target_index.append( + self.get_tensor(idx_node, idx_node).nelement() + if last_specified_index_node == index + else 1 + ) + elif idx_node is None: + # E.g., indices_node: [None, None, aten__to_copy_default_1] + all_range_index[index] = torch.arange( + input_tensor.size(index), dtype=index_dtype + ) + target_index.append(input_tensor.size(index)) + else: + warnings.warn( + f"[QNN Delegate Op Builder]: Get the index {idx_node} that is neither a node nor None", + stacklevel=1, + ) + return + + # preprocess all range indices if any + if None in indices_nodes: + all_range_tensor = torch.cartesian_prod(*all_range_index.values()) + # repeat all_range_tensor interleavely for future concatenation + # e.g. input_node = [5, 4, 3, 2], indices = [index_0_node, None, index_2_node] + # index_0.shape == index_2.shape == 2 (will guarantee this condition) + # where user specified (3, 4) for index_0, (0, 1) for index_2 + # --- + # we should have all_range_tensor: [0, 1, 2, 3] + # repeat interleavely with 2 to match future tiled index_0_node & index_2_node + # we'll have 1(index_0 -> same as index_2)*4(index_1)*2(index_2) indices in total: + # | index_0_node | None | index_2_node | + # | 3 | 0 | 0 | + # | 4 | 0 | 1 | + # | 3 | 1 | 0 | + # | 4 | 1 | 1 | + # | 3 | 2 | 0 | + # | 4 | 2 | 1 | + # | 3 | 3 | 0 | + # | 4 | 3 | 1 | + all_range_tensor_aug = all_range_tensor.repeat_interleave( + max_indices_in_specified_index, dim=0 + ) + for index in all_range_index.keys(): + # Repeat index for "None" axis in indices_nodes + range_index_node = torch.fx.Node( + node.graph, + node.name + f"_all_range_index_{index}", + "call_function", + exir_ops.edge.aten.tensor.default, + (), # args + {}, # kwargs + ) + range_indices = ( + ( + all_range_tensor_aug[:, index] + if all_range_tensor_aug.dim() > 1 + else + # if there is only one None + all_range_tensor_aug ) - return - # Assume that there is only one node in list - assert len(index_nodes) == 1, "Not support multiple indices tensor" - indice_node = index_nodes[0] - indice_tensor = index_tensors[0] - indices_tensor_wrapper = self.define_tensor( - indice_node, - node, - indice_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, - ) + .reshape(-1, 1) + .contiguous() + ) + target_index_tensor_wrapper = self.define_tensor( + range_index_node, + node, + range_indices, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) + # store it for future concatenation + all_range_index[index] = (range_indices, target_index_tensor_wrapper) # Need to reconstruct the index tensor. # E.g., based on ScatterND Op Def in QNN Docs. # Torch: # Given that # shape of input: [1, 12, 1024, 64] - # indicies_node: [None, None, aten__to_copy_default_1] + # indices_node: [None, None, aten__to_copy_default_1] # shape of aten__to_copy_default_1: [1] # QNN: # Index tensor: @@ -104,113 +168,135 @@ def define_node( # update_indices = indices.shape[:-1] # for idx in np.ndindex(update_indices): # output[indices[idx]] = updates[idx] + specified_index = OrderedDict() + for i, indices_node in enumerate(indices_nodes): + if indices_node is None: + continue - # Append one dimension to specify x-tuple - index_shape = target_index + [1] - # Reshape the index_node for tile op - reshape_shape = [ - shape if id == index_node_dim else 1 for id, shape in enumerate(index_shape) - ] - reshape_output_tensor = indice_tensor.reshape(reshape_shape) - reshape_output_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + "_reshape", - tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=QNN_TENSOR_TYPE_MAP[reshape_output_tensor.dtype], - quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, - quant_configs={}, - dims=reshape_output_tensor.size(), - tensor=reshape_output_tensor, - is_fake_tensor=True, - nodes_to_wrappers=nodes_to_wrappers, - ) - reshape_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpReshape.op_name, - ) - reshape_op.AddInputTensors([indices_tensor_wrapper]) - reshape_op.AddOutputTensors([reshape_output_tensor_wrapper]) - op_wrapper_list.append(reshape_op) - index_put_index_input_tensor_wrapper = reshape_output_tensor_wrapper + indices_tensor = self.get_tensor(indices_node, indices_node) + indices_tensor_wrapper = self.define_tensor( + indices_node, + node, + indices_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + if indices_tensor.nelement() < max_indices_in_specified_index: + # broadcast the specified index + indices_tensor = indices_tensor.repeat(max_indices_in_specified_index) + indices_multiples = [max_indices_in_specified_index] + indices_multiples_shape = [len(indices_multiples)] + indices_tile_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + f"_indices_tile_{i}", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[indices_tensor.dtype], + quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + quant_configs={}, + dims=indices_tensor.size(), + tensor=indices_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + tile_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTile.op_name, + ) + tile_op.AddInputTensors([indices_tensor_wrapper]) + tile_op.AddOutputTensors([indices_tile_tensor_wrapper]) + tile_op.AddTensorParam( + OpTile.param_multiples, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(indices_multiples_shape), + indices_multiples_shape, + np.array(indices_multiples, dtype=np.uint32), + True, + ) + op_wrapper_list.append(tile_op) + indices_tensor_wrapper = indices_tile_tensor_wrapper - # Tile the index_node and concat the target index - if None in indicies_node: - tile_output_tensor = reshape_output_tensor.expand(index_shape) - # Tile the index_node to align with the shape of target_index - # Only need to tile the dim of None axis - # E.g., indicies_node: [None, None, aten__to_copy_default_1] - # Should tile the first two dimension. - multiples = [ - shape if id != index_node_dim else 1 - for id, shape in enumerate(index_shape) - ] - multiples_shape = [len(index_shape)] - tile_output_tensor_wrapper = self.define_custom_tensor_wrapper( - node_name=node.name + "_tile", + # Append one dimension to specify x-tuple + # Reshape the index_node for tile op + reshape_shape = list(indices_tensor.shape) + [1] + reshape_output_tensor = indices_tensor.reshape(reshape_shape) + reshape_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + f"_reshape_{i}", tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - dtype=QNN_TENSOR_TYPE_MAP[tile_output_tensor.dtype], + dtype=QNN_TENSOR_TYPE_MAP[reshape_output_tensor.dtype], quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, quant_configs={}, - dims=tile_output_tensor.size(), - tensor=tile_output_tensor, + dims=reshape_output_tensor.size(), + tensor=reshape_output_tensor, is_fake_tensor=True, nodes_to_wrappers=nodes_to_wrappers, ) - tile_op = PyQnnWrapper.PyQnnOpWrapper( + reshape_op = PyQnnWrapper.PyQnnOpWrapper( node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, - OpTile.op_name, + OpReshape.op_name, ) - tile_op.AddInputTensors([reshape_output_tensor_wrapper]) - tile_op.AddOutputTensors([tile_output_tensor_wrapper]) - tile_op.AddTensorParam( - OpTile.param_multiples, - PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, - len(multiples_shape), - multiples_shape, - np.array(multiples, dtype=np.uint32), - True, - ) - op_wrapper_list.append(tile_op) + reshape_op.AddInputTensors([indices_tensor_wrapper]) + reshape_op.AddOutputTensors([reshape_output_tensor_wrapper]) + op_wrapper_list.append(reshape_op) + index_tensor_wrapper = reshape_output_tensor_wrapper + index_tensor = reshape_output_tensor - # Repeat index for "None" axis in indicies_node - ranges = [ - torch.arange(dim, dtype=indice_tensor.dtype) - for dim in target_index[:-1] - ] - target_index_shape = target_index + [len(ranges)] - target_index_tensor = torch.cartesian_prod(*ranges) - reshape_target_index_shape = [ - shape if id != index_node_dim else 1 - for id, shape in enumerate(target_index_shape) - ] - target_index_tensor = target_index_tensor.reshape( - reshape_target_index_shape - ) - target_index_tensor = target_index_tensor.expand( - target_index_shape - ).contiguous() - target_index_node = torch.fx.Node( - node.graph, - node.name + "_target_index", - "call_function", - exir_ops.edge.aten.tensor.default, - (), # args - {}, # kwargs - ) - target_index_tensor_wrapper = self.define_tensor( - target_index_node, - node, - target_index_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) + # Tile the index_node and concat the target index + if None in indices_nodes: + tile_output_tensor = reshape_output_tensor.repeat( + all_range_tensor.size(0), 1 + ) + # Tile the index_node to align with the shape of target_index + # Only need to tile the dim of None axis + # E.g., indices_node: [None, None, aten__to_copy_default_1] + # Should tile the number of indices combination of first two dimension + # times number of indices specified by aten__to_copy_default_1 + multiples = [all_range_tensor.size(0), 1] + multiples_shape = [len(multiples)] + tile_output_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + f"_tile_{i}", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[tile_output_tensor.dtype], + quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + quant_configs={}, + dims=tile_output_tensor.size(), + tensor=tile_output_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + tile_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTile.op_name, + ) + tile_op.AddInputTensors([reshape_output_tensor_wrapper]) + tile_op.AddOutputTensors([tile_output_tensor_wrapper]) + tile_op.AddTensorParam( + OpTile.param_multiples, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(multiples_shape), + multiples_shape, + np.array(multiples, dtype=np.uint32), + True, + ) + op_wrapper_list.append(tile_op) + index_tensor_wrapper = tile_output_tensor_wrapper + index_tensor = tile_output_tensor - # Concat target_index and tile output to reconstruct index_node - # Cannot use QNN Pack (stack) since QNN Pack is not support int32 dtype - concat_output_tensor = torch.concat( - (target_index_tensor, tile_output_tensor), dim=-1 + specified_index[i] = (index_tensor, index_tensor_wrapper) + + # Concat target_index and tile output to reconstruct index_node + # Cannot use QNN Pack (stack) since QNN Pack is not support int32 dtype + index_tensors, index_tensor_wrappers = [], [] + for i, arg in enumerate(indices_nodes): + tensor, tensor_wrapper = ( + all_range_index[i] if arg is None else specified_index[i] ) + index_tensors.append(tensor) + index_tensor_wrappers.append(tensor_wrapper) + + if len(index_tensor_wrappers) > 1: + concat_output_tensor = torch.concat(index_tensors, dim=-1) concat_output_tensor_wrapper = self.define_custom_tensor_wrapper( node_name=node.name + "_concat", tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, @@ -227,9 +313,7 @@ def define_node( QNN_OP_PACKAGE_NAME_QTI_AISW, OpConcat.op_name, ) - concat_op.AddInputTensors( - [target_index_tensor_wrapper, tile_output_tensor_wrapper] - ) + concat_op.AddInputTensors(index_tensor_wrappers) concat_op.AddOutputTensors([concat_output_tensor_wrapper]) concat_op.AddScalarParam( OpConcat.param_axis, @@ -237,7 +321,6 @@ def define_node( {QCOM_DATA: np.uint32(concat_output_tensor.dim() - 1)}, ) op_wrapper_list.append(concat_op) - index_put_index_input_tensor_wrapper = concat_output_tensor_wrapper value_node = self.get_node(node.args[2]) value_tensor = self.get_tensor(value_node, node) @@ -248,6 +331,94 @@ def define_node( PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, nodes_to_wrappers, ) + # handle broadcast scenario + # e.g. input_tensor: (1, 12, 1024, 64), value_tensor: (1, 64) + # => value_reshape_tensor: (1, 1, 1, 64) + new_value_shape = ( + *([1] * (input_tensor.dim() - value_tensor.dim())), + *value_tensor.shape, + ) + # reshape the value_node for tile op + value_quant_encoding, value_quant_configs = self.get_quant_encoding_conf( + value_node, node + ) + value_dtype = ( + QNN_TENSOR_TYPE_MAP[value_tensor.dtype] + if value_quant_encoding + == PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED + else QNN_QUANT_TYPE_MAP[ + ( + torch.uint16 + if value_quant_configs[QCOM_DTYPE] == torch.int32 + else value_quant_configs[QCOM_DTYPE] + ) + ] + ) + value_reshape_tensor = value_tensor.reshape(new_value_shape) + value_reshape_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_value_reshape", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=value_dtype, + quant_encoding=value_quant_encoding, + quant_configs=value_quant_configs, + dims=value_reshape_tensor.size(), + tensor=value_reshape_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + value_reshape_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReshape.op_name, + ) + value_reshape_op.AddInputTensors([value_tensor_wrapper]) + value_reshape_op.AddOutputTensors([value_reshape_tensor_wrapper]) + op_wrapper_list.append(value_reshape_op) + + # e.g. input_tensor: (1, 12, 1024, 64), index_tensor: (None, None, 2), value_tensor: (1, 64) + # => multiples: [1, 12, 2, 1] + value_multiples = [] + for i in range(input_tensor.dim() - 1, -1, -1): + if i in specified_index: + # all user specified index node wil have the same dimension + multiplier = ( + indices_nodes[i].meta["val"].nelement() // new_value_shape[i] + if i == last_specified_index_node + else 1 + ) + else: + multiplier = input_tensor.shape[i] // new_value_shape[i] + value_multiples.insert(0, multiplier) + + value_tile_tensor = value_reshape_tensor.repeat(value_multiples) + value_multiples_shape = [len(value_multiples)] + value_tile_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_value_tile", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=value_dtype, + quant_encoding=value_quant_encoding, + quant_configs=value_quant_configs, + dims=value_tile_tensor.size(), + tensor=value_tile_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + value_tile_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpTile.op_name, + ) + value_tile_op.AddInputTensors([value_reshape_tensor_wrapper]) + value_tile_op.AddOutputTensors([value_tile_tensor_wrapper]) + value_tile_op.AddTensorParam( + OpTile.param_multiples, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(value_multiples_shape), + value_multiples_shape, + np.array(value_multiples, dtype=np.uint32), + True, + ) + op_wrapper_list.append(value_tile_op) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( @@ -263,11 +434,46 @@ def define_node( QNN_OP_PACKAGE_NAME_QTI_AISW, OpScatterNd.op_name, ) + # accumulation + if len(node.args) > 3 and node.args[3]: + index_put_op.AddScalarParam( + OpScatterNd.param_reduction, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + {QCOM_DATA: 1}, + ) + + # check final index_input tensor + index_input_tensor, index_input_tensor_wrapper = ( + (concat_output_tensor, concat_output_tensor_wrapper) + if len(index_tensor_wrappers) > 1 + else specified_index[last_specified_index_node] + ) + target_index_reshape_tensor = index_input_tensor.reshape((*target_index, -1)) + target_index_reshape_tensor_wrapper = self.define_custom_tensor_wrapper( + node_name=node.name + "_target_index_reshape", + tensor_type=PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + dtype=QNN_TENSOR_TYPE_MAP[target_index_reshape_tensor.dtype], + quant_encoding=PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED, + quant_configs={}, + dims=target_index_reshape_tensor.size(), + tensor=target_index_reshape_tensor, + is_fake_tensor=True, + nodes_to_wrappers=nodes_to_wrappers, + ) + target_index_reshape_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpReshape.op_name, + ) + target_index_reshape_op.AddInputTensors([index_input_tensor_wrapper]) + target_index_reshape_op.AddOutputTensors([target_index_reshape_tensor_wrapper]) + op_wrapper_list.append(target_index_reshape_op) + index_put_op.AddInputTensors( [ input_tensor_wrapper, - index_put_index_input_tensor_wrapper, - value_tensor_wrapper, + target_index_reshape_tensor_wrapper, + value_tile_tensor_wrapper, ] ) index_put_op.AddOutputTensors([output_tensor_wrapper]) diff --git a/backends/qualcomm/builders/op_mean_dim.py b/backends/qualcomm/builders/op_mean_dim.py index 22cb47ee288..10644e17c79 100644 --- a/backends/qualcomm/builders/op_mean_dim.py +++ b/backends/qualcomm/builders/op_mean_dim.py @@ -55,7 +55,7 @@ def define_node( mean_dims = [dim_arg] else: mean_dims = list(dim_arg) - print("mean_dims: ", mean_dims, "rank: ", rank) + mean_dims = [ mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims ] diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 3240ad7a018..5ea6caf54ad 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -1141,20 +1141,62 @@ def forward(self, input_pos, k_val): class IndexPut(torch.nn.Module): - def __init__(self, skip_mutable_buffer=False): + def __init__(self, skip_mutable_buffer=False, mode=0): super().__init__() self.skip_mutable_buffer = skip_mutable_buffer self.register_buffer( "k_cache", - torch.zeros((1, 1024, 12, 64), dtype=torch.float32), + torch.zeros((2, 1024, 12, 64), dtype=torch.float32), persistent=True, ) + self.mode = mode def forward(self, input_pos, k_val): - k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val) + match self.mode: + case 0: + k_out = torch.ops.aten.index_put_(self.k_cache, [input_pos], k_val) + case 1: + k_out = torch.ops.aten.index_put_( + self.k_cache, [None, input_pos], k_val + ) + case 2: + k_out = torch.ops.aten.index_put_( + self.k_cache, [None, None, input_pos], k_val + ) + case 3: + k_out = torch.ops.aten.index_put_( + self.k_cache, [input_pos[0], input_pos[1]], k_val + ) + case 4: + k_out = torch.ops.aten.index_put_( + self.k_cache, [None, input_pos[0], input_pos[1]], k_val + ) + case 5: + k_out = torch.ops.aten.index_put_( + self.k_cache, [input_pos[0], None, input_pos[1]], k_val + ) + return k_out + 0 +class IndexPutSuite(torch.nn.Module): + def __init__(self, accumulate=False, in_place=False): + super().__init__() + self.accumulate = accumulate + self.in_place = in_place + + def forward(self, x, indices, values): + if self.in_place: + # Clone the input to avoid modifying it in-place + result = x.clone() + # Apply index_put_ and return the modified tensor + result.index_put_(indices, values, self.accumulate) + return result + else: + # Use the non-in-place variant which returns a new tensor + return torch.index_put(x, indices, values, self.accumulate) + + class IndexSelect(torch.nn.Module): def __init__(self, dim): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 56983561e5f..2641acc5a2d 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import io +import itertools import json import subprocess import sys @@ -887,28 +888,191 @@ def test_qnn_backend_index_copy(self): ) def test_qnn_backend_index_put(self): - test_comb = [ - { - QCOM_MODULE: IndexPut(skip_mutable_buffer=False), # noqa: F405 - QCOM_SAMPLE_INPUTS: ( - torch.tensor([2], dtype=torch.int32), - torch.randn([1, 1, 12, 64]), + skip_mutable_buffer = [False, True] + total_test_combo = [] + # mode 0 + sample_inputs = [ + (torch.tensor([0], dtype=torch.int32), torch.randn([1, 1, 12, 64])), + (torch.tensor([0], dtype=torch.int32), torch.randn([1, 64])), + (torch.tensor([0, 1], dtype=torch.int32), torch.randn([2, 1, 12, 64])), + (torch.tensor([0, 1], dtype=torch.int32), torch.randn([1, 64])), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 1 + sample_inputs = [ + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 1, 12, 64])), + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 64])), + (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 2, 12, 64])), + (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 64])), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 2 + sample_inputs = [ + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 1, 1, 64])), + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 64])), + (torch.tensor([0, 1], dtype=torch.int32), torch.randn([1, 1, 2, 64])), + (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 64])), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 3 + sample_inputs = [ + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), ), - }, - { - QCOM_MODULE: IndexPut(skip_mutable_buffer=True), # noqa: F405 - QCOM_SAMPLE_INPUTS: ( - torch.tensor([2], dtype=torch.int32), - torch.randn([1, 1, 12, 64]), + torch.randn([2, 12, 64]), + ), + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), ), - }, + torch.randn([1, 64]), + ), ] - for i, test in enumerate(test_comb): + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 4 + sample_inputs = [ + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([2, 64]), + ), + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([1, 64]), + ), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 5 + sample_inputs = [ + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([64]), + ), + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([1]), + ), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + + for i, test_combo in enumerate(total_test_combo): + for j, combo in enumerate(test_combo): + with self.subTest(f"mode_{i}-{j}"): + self.lower_module_and_test_output( + IndexPut(skip_mutable_buffer=combo[0], mode=i), # noqa: F405 + combo[1], + skip_mutable_buffer=combo[0], + ) + + def test_qnn_backend_index_put_suite(self): + accumulate = [False, True] + in_place = [False, True] + sample_inputs = [ + # basic + ( + torch.rand(5, 2) * 100, + (torch.tensor([0, 2]),), + torch.tensor([10.0, 20.0]), + ), + (torch.rand(5, 2), (torch.tensor([0, 2]),), torch.tensor([10.0, 20.0])), + # shape + (torch.rand(5), (torch.tensor([0, 2]),), torch.tensor([10.0, 20.0])), + ( + torch.rand(5, 2), + (torch.tensor([0, 2]), torch.tensor([1, 1])), + torch.tensor([10.0, 20.0]), + ), + ( + torch.rand(5, 3, 2), + (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1])), + torch.tensor([10.0, 20.0]), + ), + # TODO: not supported by HTP + # ( + # torch.rand(5, 3, 2, 4), + # (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1]), torch.tensor([2, 3])), + # torch.tensor([10.0]), + # ), + # indices + (torch.rand(5, 2), (torch.tensor([2]),), torch.tensor([10.0])), + ( + torch.rand(5, 3), + (torch.tensor([0, 2, 4]),), + torch.tensor([10.0, 20.0, 30.0]), + ), + ( + torch.rand(5), + (torch.tensor([1, 1, 3, 3]),), + torch.tensor([10.0, 20.0, 30.0, 40.0]), + ), + # broadcasting + (torch.rand(5, 3), (torch.tensor([0, 2, 4]),), torch.tensor([42.0])), + ( + torch.rand(3, 4), + (torch.tensor([0, 1]), torch.tensor([1, 2])), + torch.tensor([10.0, 20.0]), + ), + (torch.rand(4, 2), (torch.tensor([0, 2]),), torch.tensor([5.0, 15.0])), + ( + torch.rand(3, 2, 2), + (torch.tensor([0, 1]),), + torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + ), + (torch.rand(4, 2), (torch.tensor([1, 1, 1]),), torch.tensor([5.0])), + # two-index + ( + torch.rand(4, 3), + (torch.tensor([0, 1, 2]), torch.tensor([1, 0, 2])), + torch.tensor([10.0, 20.0, 30.0]), + ), + ( + torch.rand(3, 3), + (torch.tensor([0, 2]), torch.tensor([1, 1])), + torch.tensor([15.0, 25.0]), + ), + ( + torch.rand(3, 2), + (torch.tensor([1, 1, 2]), torch.tensor([0, 0, 1])), + torch.tensor([5.0, 10.0, 15.0]), + ), + ( + torch.rand(3, 2), + (torch.tensor([1]), torch.tensor([0, 0, 1])), + torch.tensor([5.0, 10.0, 15.0]), + ), + ] + test_combo = list(itertools.product(accumulate, in_place, sample_inputs)) + for i, combo in enumerate(test_combo): with self.subTest(i=i): self.lower_module_and_test_output( - test[QCOM_MODULE], - test[QCOM_SAMPLE_INPUTS], - skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, + IndexPutSuite(accumulate=combo[0], in_place=combo[1]), # noqa: F405 + combo[2], ) def test_qnn_backend_index_select(self): @@ -2642,32 +2806,197 @@ def test_qnn_backend_index_copy(self): ) def test_qnn_backend_index_put(self): - test_comb = [ - { - QCOM_MODULE: IndexPut(skip_mutable_buffer=False), # noqa: F405 - QCOM_SAMPLE_INPUTS: ( - torch.tensor([2], dtype=torch.int32), - torch.randn([1, 1, 12, 64]), + skip_mutable_buffer = [False, True] + total_test_combo = [] + # mode 0 + sample_inputs = [ + (torch.tensor([0], dtype=torch.int32), torch.randn([1, 1, 12, 64])), + (torch.tensor([0], dtype=torch.int32), torch.randn([1, 64])), + (torch.tensor([0, 1], dtype=torch.int32), torch.randn([2, 1, 12, 64])), + (torch.tensor([0, 1], dtype=torch.int32), torch.randn([1, 64])), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 1 + sample_inputs = [ + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 1, 12, 64])), + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 64])), + (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 2, 12, 64])), + (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 64])), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 2 + sample_inputs = [ + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 1, 1, 64])), + (torch.tensor([2], dtype=torch.int32), torch.randn([1, 64])), + (torch.tensor([0, 1], dtype=torch.int32), torch.randn([1, 1, 2, 64])), + (torch.tensor([2, 3], dtype=torch.int32), torch.randn([1, 64])), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 3 + sample_inputs = [ + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), ), - }, - { - QCOM_MODULE: IndexPut(skip_mutable_buffer=True), # noqa: F405 - QCOM_SAMPLE_INPUTS: ( - torch.tensor([2], dtype=torch.int32), - torch.randn([1, 1, 12, 64]), + torch.randn([2, 12, 64]), + ), + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), ), - }, + torch.randn([1, 64]), + ), ] - for i, test in enumerate(test_comb): + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 4 + sample_inputs = [ + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([2, 64]), + ), + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([1, 64]), + ), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + # mode 5 + sample_inputs = [ + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([64]), + ), + ( + ( + torch.tensor([0, 1], dtype=torch.int32), + torch.tensor([2, 3], dtype=torch.int32), + ), + torch.randn([1]), + ), + ] + total_test_combo.append( + list(itertools.product(skip_mutable_buffer, sample_inputs)) + ) + + for i, test_combo in enumerate(total_test_combo): + for j, combo in enumerate(test_combo): + with self.subTest(f"mode_{i}-{j}"): + module = self.get_qdq_module( + IndexPut(skip_mutable_buffer=combo[0], mode=i), # noqa: F405 + combo[1], + ) + self.lower_module_and_test_output( + module, + combo[1], + skip_mutable_buffer=combo[0], + ) + + def test_qnn_backend_index_put_suite(self): + accumulate = [False, True] + in_place = [False, True] + sample_inputs = [ + # basic + ( + torch.rand(5, 2) * 100, + (torch.tensor([0, 2]),), + torch.tensor([10.0, 20.0]), + ), + (torch.rand(5, 2), (torch.tensor([0, 2]),), torch.tensor([10.0, 20.0])), + # shape + (torch.rand(5), (torch.tensor([0, 2]),), torch.tensor([10.0, 20.0])), + ( + torch.rand(5, 2), + (torch.tensor([0, 2]), torch.tensor([1, 1])), + torch.tensor([10.0, 20.0]), + ), + ( + torch.rand(5, 3, 2), + (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1])), + torch.tensor([10.0, 20.0]), + ), + # TODO: not supported by HTP + # ( + # torch.rand(5, 3, 2, 4), + # (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1]), torch.tensor([2, 3])), + # torch.tensor([10.0]), + # ), + # indices + (torch.rand(5, 2), (torch.tensor([2]),), torch.tensor([10.0])), + ( + torch.rand(5, 3), + (torch.tensor([0, 2, 4]),), + torch.tensor([10.0, 20.0, 30.0]), + ), + ( + torch.rand(5), + (torch.tensor([1, 1, 3, 3]),), + torch.tensor([10.0, 20.0, 30.0, 40.0]), + ), + # broadcasting + (torch.rand(5, 3), (torch.tensor([0, 2, 4]),), torch.tensor([42.0])), + ( + torch.rand(3, 4), + (torch.tensor([0, 1]), torch.tensor([1, 2])), + torch.tensor([10.0, 20.0]), + ), + (torch.rand(4, 2), (torch.tensor([0, 2]),), torch.tensor([5.0, 15.0])), + ( + torch.rand(3, 2, 2), + (torch.tensor([0, 1]),), + torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + ), + (torch.rand(4, 2), (torch.tensor([1, 1, 1]),), torch.tensor([5.0])), + # two-index + ( + torch.rand(4, 3), + (torch.tensor([0, 1, 2]), torch.tensor([1, 0, 2])), + torch.tensor([10.0, 20.0, 30.0]), + ), + ( + torch.rand(3, 3), + (torch.tensor([0, 2]), torch.tensor([1, 1])), + torch.tensor([15.0, 25.0]), + ), + ( + torch.rand(3, 2), + (torch.tensor([1, 1, 2]), torch.tensor([0, 0, 1])), + torch.tensor([5.0, 10.0, 15.0]), + ), + ( + torch.rand(3, 2), + (torch.tensor([1]), torch.tensor([0, 0, 1])), + torch.tensor([5.0, 10.0, 15.0]), + ), + ] + test_combo = list(itertools.product(accumulate, in_place, sample_inputs)) + for i, combo in enumerate(test_combo): with self.subTest(i=i): module = self.get_qdq_module( - test[QCOM_MODULE], test[QCOM_SAMPLE_INPUTS] - ) - self.lower_module_and_test_output( - module, - test[QCOM_SAMPLE_INPUTS], - skip_mutable_buffer=test[QCOM_MODULE].skip_mutable_buffer, + IndexPutSuite(accumulate=combo[0], in_place=combo[1]), # noqa: F405 + combo[2], ) + self.lower_module_and_test_output(module, combo[2]) def test_qnn_backend_index_select(self): module = IndexSelect(dim=1) # noqa: F405 diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 11b9ab88bfe..036c5060b12 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -918,24 +918,34 @@ def generate_inputs(dest_path: str, file_name: str, inputs=None): input_list_file = None input_files = [] + def prepare_input_file(tensor, fd, index, sub_index): + # transform torch.Tensor to raw file + input_file_name = f"input_{index}_{sub_index}.raw" + input_file_path = f"{dest_path}/{input_file_name}" + if not isinstance(tensor, torch.Tensor): + tensor = torch.tensor(tensor) + tensor.detach().numpy().tofile(input_file_path) + input_files.append(input_file_path) + # prepare input_list + if sub_index > 0: + fd.write(" ") + fd.write(input_file_name) + # Prepare input data if inputs is not None: input_list_file = f"{dest_path}/{file_name}" with open(input_list_file, "w") as f: for idx, data in enumerate(inputs): - for i, d in enumerate(data): - # transform torch.Tensor to raw file - file_name = f"input_{idx}_{i}.raw" - file_path = f"{dest_path}/{file_name}" - if not isinstance(d, torch.Tensor): - d = torch.tensor(d) - d.detach().numpy().tofile(file_path) - input_files.append(file_path) - - # prepare input_list - if i > 0: - f.write(" ") - f.write(file_name) + sub_index = 0 + for d in data: + if isinstance(d, (list, tuple)): + for sub_d in d: + prepare_input_file(sub_d, f, idx, sub_index) + sub_index += 1 + else: + prepare_input_file(d, f, idx, sub_index) + sub_index += 1 + f.write("\n") return input_list_file, input_files