From 4abc055ad1344058aa0a1e2347c0bd5a11658dea Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 8 Feb 2024 17:48:07 -0800 Subject: [PATCH] Linear Summary: Add fp16 Linear Differential Revision: D53333693 --- backends/xnnpack/operators/op_linear.py | 2 ++ backends/xnnpack/runtime/XNNCompiler.cpp | 16 +++++++------ backends/xnnpack/test/ops/linear.py | 30 ++++++++++++++++++++---- 3 files changed, 36 insertions(+), 12 deletions(-) diff --git a/backends/xnnpack/operators/op_linear.py b/backends/xnnpack/operators/op_linear.py index 3517403b008..7fb0de8228d 100644 --- a/backends/xnnpack/operators/op_linear.py +++ b/backends/xnnpack/operators/op_linear.py @@ -59,6 +59,7 @@ def define_node( xnn_graph, vals_to_ids, quant_params=weight_quant_params, + fp32_static_weights=True, ) filter_id = vals_to_ids[weight_node] @@ -73,6 +74,7 @@ def define_node( xnn_graph, vals_to_ids, quant_params=bias_quant_params, + fp32_static_weights=True, ) bias_id = vals_to_ids[bias_node] else: diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index bc94e5df152..80d4248f3ec 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -256,23 +256,25 @@ Error defineTensor( /*flags=*/0, // this is netiher external input or output /*id_out=*/&id); - // this is the FP32 external value that is dynamically quantized - uint32_t fp32_id; + // this is the FP16 or FP32 external value that is being dynamically + // quantized + uint32_t float_id; + enum xnn_datatype fp_datatype = getDataType(tensor_value->datatype()); status = xnn_define_tensor_value( /*subgraph=*/subgraph_ptr, - /*datatype=*/xnn_datatype_fp32, // always fp32 + /*datatype=*/fp_datatype, /*num_dims=*/tensor_value->num_dims(), /*dims=*/dims_data.data(), /*data=*/buffer_ptr, /*external_id=*/tensor_value->external_id(), /*flags=*/tensor_value->flags(), - /*id_out=*/&fp32_id); - executor->addDynamicQinput(fp32_id); + /*id_out=*/&float_id); + executor->addDynamicQinput(float_id); - // Define dynamic conversion from fp32 to qdint8 + // Define dynamic conversion from float to qdint8 status = xnn_define_convert( /*subgraph=*/subgraph_ptr, - /*input_id=*/fp32_id, + /*input_id=*/float_id, /*output_id=*/id, /*flags=*/0); break; diff --git a/backends/xnnpack/test/ops/linear.py b/backends/xnnpack/test/ops/linear.py index b1474d505d2..f7c7a840dc0 100644 --- a/backends/xnnpack/test/ops/linear.py +++ b/backends/xnnpack/test/ops/linear.py @@ -23,6 +23,17 @@ class TestLinear(unittest.TestCase): + def test_fp16_linear(self): + for use_bias in (True, False): + self._test_linear( + lambda in_size, out_size: torch.nn.Linear( + in_size, out_size, bias=use_bias # noqa + ), + uses_bias=use_bias, + dtype=torch.float16, + atol=5e-2, + ) + def test_fp32_linear(self): for use_bias in (True, False): self._test_linear( @@ -284,7 +295,14 @@ def forward(self, x): quant=True, ) - def _test_linear(self, make_module, uses_bias, quant=False): + def _test_linear( + self, + make_module, + uses_bias, + quant=False, + dtype: torch.dtype = torch.float, + atol=1e-03, + ): aten_op, edge_op = ( ( "aten.addmm.default", @@ -309,9 +327,10 @@ def _test_linear(self, make_module, uses_bias, quant=False): in_size = int(in_sizes[i]) input_size = int(input_sizes[i]) output_size = int(output_sizes[i]) + print(f"Testing {in_size} {input_size} {output_size}") - module = make_module(input_size, output_size).eval() - inputs = (torch.randn(in_size, input_size),) + module = make_module(input_size, output_size).eval().to(dtype) + inputs = (torch.randn(in_size, input_size).to(dtype),) tester = Tester(module, inputs) @@ -336,7 +355,8 @@ def _test_linear(self, make_module, uses_bias, quant=False): tester.to_executorch() tester.serialize() tester.run_method() - tester.compare_outputs(qtol=quant) + tester.compare_outputs(qtol=quant, atol=atol) + print("success") def _test_dqlinear( self, @@ -370,7 +390,7 @@ def _test_dqlinear( tester.export() tester.check_count({aten_op: linear_count}) tester.check(["torch.ops.quantized_decomposed"]) - + tester.dump_artifact() tester.to_edge() tester.check_count({edge_op: linear_count})