diff --git a/backends/qualcomm/_passes/__init__.py b/backends/qualcomm/_passes/__init__.py index 15fce79ea12..423ec64bbd3 100644 --- a/backends/qualcomm/_passes/__init__.py +++ b/backends/qualcomm/_passes/__init__.py @@ -13,6 +13,7 @@ from .convert_linear_to_conv2d import ConvertLinearToConv2d from .convert_square_to_pow import ConvertSquareToPow from .decompose_any import DecomposeAny +from .decompose_binary_alpha import DecomposeBinaryAlpha from .decompose_cdist import DecomposeCDist from .decompose_col_im import DecomposeColIm from .decompose_einsum import DecomposeEinsum @@ -53,6 +54,7 @@ ConvertLinearToConv2d, ConvertSquareToPow, DecomposeAny, + DecomposeBinaryAlpha, DecomposeCDist, DecomposeColIm, DecomposeEinsum, diff --git a/backends/qualcomm/_passes/canonicalize_conv.py b/backends/qualcomm/_passes/canonicalize_conv.py index 3804fb05da0..dc5c26c1a94 100644 --- a/backends/qualcomm/_passes/canonicalize_conv.py +++ b/backends/qualcomm/_passes/canonicalize_conv.py @@ -34,6 +34,7 @@ def __init__(self, edge_program: torch.export.ExportedProgram): self.transpose_conv_set = { torch.ops.aten.conv_transpose1d.default, torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d.input, } def dilate(self, tensor, dilation): diff --git a/backends/qualcomm/_passes/decompose_binary_alpha.py b/backends/qualcomm/_passes/decompose_binary_alpha.py new file mode 100644 index 00000000000..df767f10ca9 --- /dev/null +++ b/backends/qualcomm/_passes/decompose_binary_alpha.py @@ -0,0 +1,61 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.exir.pass_base import ExportPass, PassResult + +from .utils import copy_meta + +decomp_set = {torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor} + + +class DecomposeBinaryAlpha(ExportPass): + """ + QNN does not support alpha parameter for add/sub. + Decompose to mul + add / mul + sub + """ + + def __init__(self) -> None: + super().__init__() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + for node in graph.nodes: + if ( + node.target in decomp_set + and "alpha" in node.kwargs + and node.kwargs["alpha"] != 1 + ): + alpha = node.kwargs["alpha"] + # Remove alpha from immutable dict + node.kwargs = {k: v for k, v in node.kwargs.items() if k != "alpha"} + input2_node = node.args[1] + # If input2 is constant, we can just multiply the value for optimization + if isinstance(input2_node, (int, float)): + arg_list = list(node.args) + arg_list[1] = input2_node * alpha + node.args = tuple(arg_list) + continue + with graph.inserting_before(node): + mul_op = torch.ops.aten.mul.Scalar + mul_node = graph.create_node( + "call_function", + mul_op, + ( + input2_node, + alpha, + ), + ) + mul_node.meta = copy_meta(node.meta) + node.replace_input_with(input2_node, mul_node) + node.args = ( + node.args[0], + mul_node, + ) + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index ffb9f3221df..e2d9208796c 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -18,6 +18,7 @@ ConvertLinearToConv2d, ConvertSquareToPow, DecomposeAny, + DecomposeBinaryAlpha, DecomposeCDist, DecomposeColIm, DecomposeEinsum, @@ -193,6 +194,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(RecomposePixelUnshuffle(quantization_capture=True)) self.add_pass(RecomposeRmsNorm(quantization_capture=True)) self.add_pass(ReplaceArangeArgs()) + self.add_pass(DecomposeBinaryAlpha()) self.add_pass(DecomposeCDist()) self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) @@ -208,6 +210,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): def transform_for_export_pipeline( self, exported_program: ExportedProgram, convert_linear_to_conv2d: bool = False ): + self.add_pass(DecomposeBinaryAlpha()) self.add_pass(DecomposeCDist()) self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(DecomposeRoll()) diff --git a/backends/qualcomm/builders/README.md b/backends/qualcomm/builders/README.md index 6ba4eafb01f..61ae1061214 100644 --- a/backends/qualcomm/builders/README.md +++ b/backends/qualcomm/builders/README.md @@ -365,7 +365,7 @@ Please help update following table if you are contributing new operators: + 🚫 = Deprecated, supported with other QNN Ops -| Operators | HTP - 90/116 Enabled | +| Operators | HTP - 92/116 Enabled | |-----------|---------| | Argmax | ✓ | | Argmin | ✓ | @@ -375,7 +375,7 @@ Please help update following table if you are contributing new operators: | ChannelShuffle | ✗ | | Concat | ✓ | | Conv2d | ✓ | -| Conv3d | ✗ | +| Conv3d | ✓ | | Convert | ✓ | | CreateSparse | ✗ | | CumulativeSum | ✓ | @@ -481,7 +481,7 @@ Please help update following table if you are contributing new operators: | TopK | ✓ | | TransPose | ✓ | | TransPoseConv2d | ✓ | -| TransPoseConv3d | ✗ | +| TransPoseConv3d | ✓ | | Unpack | ✓ | ## Issues diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 9800fb7bdab..3fa8ae067fa 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -24,7 +24,7 @@ op_cat, op_ceil, op_clamp, - op_conv2d, + op_conv, op_copy, op_cos, op_cum_sum, @@ -129,7 +129,7 @@ op_cat, op_ceil, op_clamp, - op_conv2d, + op_conv, op_copy, op_cos, op_cum_sum, diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv.py similarity index 82% rename from backends/qualcomm/builders/op_conv2d.py rename to backends/qualcomm/builders/op_conv.py index 1cfc1e45c9b..2bc0b41524d 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv.py @@ -7,7 +7,6 @@ from typing import cast, Dict, List import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper - import numpy as np import torch from executorch.backends.qualcomm.utils.constants import QCOM_DATA @@ -16,8 +15,10 @@ from .node_visitor_manager import register_node_visitor from .qnn_constants import ( OpConv2d, + OpConv3d, OpDepthWiseConv2d, OpTransposeConv2d, + OpTransposeConv3d, QNN_OP_PACKAGE_NAME_QTI_AISW, ) from .utils import get_parameter @@ -66,7 +67,7 @@ def _add_conv_op_parameter( len(padding_shape), padding_shape, np.array( - [[padding[0], padding[0]], [padding[1], padding[1]]], + padding, dtype=np.uint32, ), True, @@ -108,8 +109,14 @@ def define_node( input_node = self.get_node(node.args[0]) input_tensor = self.get_tensor(input_node, node) assert ( - input_tensor.dim() == 4 + input_tensor.dim() != 3 ), "All Conv1D should be converted to Conv2D in CanonicalizeConv," + assert input_tensor.dim() in { + 4, + 5, + }, "Only Conv2d and Conv3d is supported in conv builder," + + is_conv2d = input_tensor.dim() == 4 input_tensor_wrapper = self.define_tensor( input_node, node, @@ -120,9 +127,15 @@ def define_node( filter_node = self.get_node(node.args[1]) filter_tensor = get_parameter(filter_node, self.edge_program) - # weight of pytorch OIHW(conv2d) | IOHW(conv_transpose2d), yet QNN is HWIO + # weight of pytorch OIHW(conv2d) / OIDHW(conv3d) or IOHW(conv_transpose2d) / IODHW(conv_transpose3d), + # yet QNN is HWIO or DHWIO is_transpose_conv = cast(bool, node.args[6]) - filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0) + if is_conv2d: + filter_axis_order = (2, 3, 0, 1) if is_transpose_conv else (2, 3, 1, 0) + else: + filter_axis_order = ( + (2, 3, 4, 0, 1) if is_transpose_conv else (2, 3, 4, 1, 0) + ) filter_tensor = filter_tensor.permute(dims=filter_axis_order).contiguous() filter_tensor_wrapper = self.define_tensor( filter_node, @@ -132,7 +145,6 @@ def define_node( nodes_to_wrappers, ) conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper] - if node.args[2] is not None: bias_node = self.get_node(node.args[2]) bias_tensor = get_parameter(bias_node, self.edge_program) @@ -159,11 +171,10 @@ def define_node( padding = cast(List[int], node.args[4]) dilation = cast(List[int], node.args[5]) output_padding = cast(List[int], node.args[7]) - groups = cast(int, node.args[8]) - # Qnn filter tensor is (H, W, Cin, Cout) - group_input_channels = filter_tensor.shape[2] - group_output_channels = int(filter_tensor.shape[3] / groups) + # Qnn filter tensor is (H, W, Cin, Cout) or (D, H, W, Cin, Cout) + group_input_channels = filter_tensor.shape[-2] + group_output_channels = int(filter_tensor.shape[-1] / groups) # 1) groups = input_channels (i.e. group_input_channels = 1) # 2) output_channels is a positive integer multiple of input channels # TODO: Currently, negative results will be zero with Depthwise conv2d when input_channel == groups == 1 @@ -175,18 +186,23 @@ def define_node( ) if len(padding) == 1: padding = padding + padding + padding = [[x, x] for x in padding] stride_shape = [len(stride)] - padding_shape = [2, 2] + padding_shape = [len(padding), len(padding[0])] dilation_shape = [len(dilation)] output_padding_shape = [len(output_padding)] - if is_depthwise_conv: + if is_transpose_conv: + assert all( + val == 1 for val in dilation + ), "CanonicalizeConv pass should perform dilate for transpose_conv." + op_class = OpTransposeConv2d if is_conv2d else OpTransposeConv3d + elif is_depthwise_conv: + assert is_conv2d, "DepthWise only supports Conv2d" op_class = OpDepthWiseConv2d - elif is_transpose_conv: - op_class = OpTransposeConv2d else: - op_class = OpConv2d + op_class = OpConv2d if is_conv2d else OpConv3d conv_op = PyQnnWrapper.PyQnnOpWrapper( node.name, diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index b0c44dcae80..79a1c93d50c 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -59,6 +59,15 @@ class OpConv2d: param_dilation: str = "dilation" +@dataclass(init=False, frozen=True) +class OpConv3d: + op_name: str = "Conv3d" + param_stride: str = "stride" + param_pad_amount: str = "pad_amount" + param_group: str = "group" + param_dilation: str = "dilation" + + @dataclass(init=False, frozen=True) class OpConvert: op_name: str = "Convert" @@ -573,6 +582,15 @@ class OpTransposeConv2d: param_output_padding: str = "output_padding" +@dataclass(init=False, frozen=True) +class OpTransposeConv3d: + op_name: str = "TransposeConv3d" + param_stride: str = "stride" + param_pad_amount: str = "pad_amount" + param_group: str = "group" + param_output_padding: str = "output_padding" + + @dataclass(init=False, frozen=True) class OpUnpack: op_name: str = "UnPack" diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 88109b51697..c3213af6338 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -1094,11 +1094,13 @@ def annotate_cdist(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator( [ + torch.ops.aten.conv1d.default, torch.ops.aten.conv2d.default, torch.ops.aten.conv2d.padding, - torch.ops.aten.conv1d.default, - torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv3d.default, torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d.input, torch.ops.aten.convolution.default, ] ) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 77ff1be4562..28ea224d747 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -66,6 +66,28 @@ def forward(self, x, y): return torch.add(x, y) +class AddAlpha(torch.nn.Module): + def __init__(self, alpha): + super().__init__() + self.alpha = alpha + + def forward(self, x, y): + return torch.add(x, y, alpha=self.alpha) + + +class AddAlphaConstant(torch.nn.Module): + def __init__(self, alpha, constant_first=False): + super().__init__() + self.alpha = alpha + self.constant_first = constant_first + + def forward(self, x): + if self.constant_first: + return torch.add(5.0, x, alpha=self.alpha) + else: + return torch.add(x, 5.0, alpha=self.alpha) + + class AddConstantFloat(torch.nn.Module): def __init__(self): super().__init__() @@ -566,6 +588,28 @@ def forward(self, x): return self.second(self.first(x)) +class Conv3dSequential(torch.nn.Module): + def __init__(self, bias=True): + super().__init__() + self.first = torch.nn.Conv3d( + in_channels=1, + out_channels=3, + kernel_size=(3, 3, 3), + padding=1, + bias=bias, + ) + self.second = torch.nn.Conv3d( + in_channels=3, + out_channels=2, + kernel_size=(3, 3, 3), + padding=1, + bias=bias, + ) + + def forward(self, x): + return self.second(self.first(x)) + + class Conv2dSingle(torch.nn.Module): def __init__( self, @@ -588,40 +632,6 @@ def forward(self, x): return self.conv(x) -class ConvTranspose1dSingle(torch.nn.Module): - def __init__(self, bias=True, dilation=1): - super().__init__() - self.conv_transpose = torch.nn.ConvTranspose1d( - in_channels=1, - out_channels=3, - kernel_size=3, - stride=2, - padding=1, - dilation=dilation, - bias=bias, - ) - - def forward(self, x): - return self.conv_transpose(x) - - -class ConvTranspose2dSingle(torch.nn.Module): - def __init__(self, bias=True, dilation=1): - super().__init__() - self.conv_transpose = torch.nn.ConvTranspose2d( - in_channels=1, - out_channels=3, - kernel_size=3, - stride=2, - padding=1, - dilation=dilation, - bias=bias, - ) - - def forward(self, x): - return self.conv_transpose(x) - - class Conv2dDownUpSample(torch.nn.Module): def __init__(self, bias=True): super().__init__() @@ -706,6 +716,57 @@ def forward(self, x): return topk_values +class ConvTranspose1dSingle(torch.nn.Module): + def __init__(self, bias=True, dilation=1): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose1d( + in_channels=1, + out_channels=3, + kernel_size=3, + stride=2, + padding=1, + dilation=dilation, + bias=bias, + ) + + def forward(self, x): + return self.conv_transpose(x) + + +class ConvTranspose2dSingle(torch.nn.Module): + def __init__(self, bias=True, dilation=1): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose2d( + in_channels=1, + out_channels=3, + kernel_size=3, + stride=2, + padding=1, + dilation=dilation, + bias=bias, + ) + + def forward(self, x): + return self.conv_transpose(x) + + +class ConvTranspose3dSingle(torch.nn.Module): + def __init__(self, bias=True, dilation=1): + super().__init__() + self.conv_transpose = torch.nn.ConvTranspose3d( + in_channels=1, + out_channels=3, + kernel_size=3, + stride=2, + padding=1, + dilation=dilation, + bias=bias, + ) + + def forward(self, x): + return self.conv_transpose(x) + + class Cos(torch.nn.Module): def __init__(self): super().__init__() @@ -1854,6 +1915,28 @@ def forward(self, x, y): return torch.sub(x, y) +class SubAlpha(torch.nn.Module): + def __init__(self, alpha): + super().__init__() + self.alpha = alpha + + def forward(self, x, y): + return torch.sub(x, y, alpha=self.alpha) + + +class SubAlphaConstant(torch.nn.Module): + def __init__(self, alpha, constant_first=False): + super().__init__() + self.alpha = alpha + self.constant_first = constant_first + + def forward(self, x): + if self.constant_first: + return torch.sub(5.0, x, alpha=self.alpha) + else: + return torch.sub(x, 5.0, alpha=self.alpha) + + class SubConstantFloat(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 5a86d5f286d..9cca7b4203d 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -282,6 +282,13 @@ def test_qnn_backend_conv2d_channel_last(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv3d_sequential(self): + modules = [Conv3dSequential(), Conv3dSequential(bias=False)] # noqa: F405 + sample_input = (torch.randn([2, 1, 10, 32, 32]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv_transpose1d(self): modules = [ ConvTranspose1dSingle(), # noqa: F405 @@ -306,6 +313,18 @@ def test_qnn_backend_conv_transpose2d(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv_transpose3d(self): + modules = [ + ConvTranspose3dSingle(), # noqa: F405 + ConvTranspose3dSingle(bias=False), # noqa: F405 + ConvTranspose3dSingle(dilation=2), # noqa: F405 + ConvTranspose3dSingle(dilation=(3, 2, 3)), # noqa: F405 + ] + sample_input = (torch.randn([1, 1, 3, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cos(self): module = Cos() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -372,6 +391,24 @@ def test_qnn_backend_element_wise_add(self): ], QCOM_SAMPLE_INPUTS: [(torch.randint(0, 10, size=(2, 3)),)], }, + { + QCOM_MODULE: [ + AddAlpha(alpha=2), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([[1.2, 1.3, 1.4]]), + torch.tensor([[0.8, 1.6, 0.2]]), + ) + ], + }, + { + QCOM_MODULE: [ + AddAlphaConstant(alpha=2, constant_first=True), # noqa: F405 + AddAlphaConstant(alpha=2, constant_first=False), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [(torch.tensor([[1.2, 1.3, 1.4]]),)], + }, ] index = 0 @@ -495,6 +532,24 @@ def test_qnn_backend_element_wise_sub(self): QCOM_MODULE: [SubConstantFloat()], # noqa: F405 QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, + { + QCOM_MODULE: [ + SubAlpha(alpha=2), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([[1.2, 1.3, 1.4]]), + torch.tensor([[0.8, 1.6, 0.2]]), + ) + ], + }, + { + QCOM_MODULE: [ + SubAlphaConstant(alpha=2, constant_first=True), # noqa: F405 + SubAlphaConstant(alpha=2, constant_first=False), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [(torch.tensor([[1.2, 1.3, 1.4]]),)], + }, ] index = 0 @@ -1789,6 +1844,14 @@ def test_qnn_backend_conv2d_channel_last(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv3d_sequential(self): + modules = [Conv3dSequential(), Conv3dSequential(bias=False)] # noqa: F405 + sample_input = (torch.randn([2, 1, 10, 32, 32]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + qdq_module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(qdq_module, sample_input) + def test_qnn_backend_conv_transpose1d(self): modules = [ ConvTranspose1dSingle(), # noqa: F405 @@ -1814,6 +1877,19 @@ def test_qnn_backend_conv_transpose2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_conv_transpose3d(self): + modules = [ + ConvTranspose3dSingle(), # noqa: F405 + ConvTranspose3dSingle(bias=False), # noqa: F405 + ConvTranspose3dSingle(dilation=2), # noqa: F405 + ConvTranspose3dSingle(dilation=(3, 2, 3)), # noqa: F405 + ] + sample_input = (torch.randn([1, 1, 3, 3, 3]),) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_cos(self): module = Cos() # noqa: F405 sample_input = (torch.randn(2, 5, 1, 3),) @@ -1863,6 +1939,24 @@ def test_qnn_backend_element_wise_add(self): QCOM_MODULE: [AddConstantFloat(), AddConstantLong()], # noqa: F405 QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, + { + QCOM_MODULE: [ + AddAlpha(alpha=2), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([[1.2, 1.3, 1.4]]), + torch.tensor([[0.8, 1.6, 0.2]]), + ) + ], + }, + { + QCOM_MODULE: [ + AddAlphaConstant(alpha=2, constant_first=True), # noqa: F405 + AddAlphaConstant(alpha=2, constant_first=False), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [(torch.tensor([[1.2, 1.3, 1.4]]),)], + }, ] index = 0 @@ -1992,6 +2086,24 @@ def test_qnn_backend_element_wise_sub(self): QCOM_MODULE: [SubConstantFloat(), SubConstantLong()], # noqa: F405 QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)], }, + { + QCOM_MODULE: [ + SubAlpha(alpha=2), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [ + ( + torch.tensor([[1.2, 1.3, 1.4]]), + torch.tensor([[0.8, 1.6, 0.2]]), + ) + ], + }, + { + QCOM_MODULE: [ + SubAlphaConstant(alpha=2, constant_first=True), # noqa: F405 + SubAlphaConstant(alpha=2, constant_first=False), # noqa: F405 + ], + QCOM_SAMPLE_INPUTS: [(torch.tensor([[1.2, 1.3, 1.4]]),)], + }, ] index = 0