Skip to content
Closed

Linear #1901

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/xnnpack/operators/op_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def define_node(
xnn_graph,
vals_to_ids,
quant_params=weight_quant_params,
fp32_static_weights=True,
)
filter_id = vals_to_ids[weight_node]

Expand All @@ -73,6 +74,7 @@ def define_node(
xnn_graph,
vals_to_ids,
quant_params=bias_quant_params,
fp32_static_weights=True,
)
bias_id = vals_to_ids[bias_node]
else:
Expand Down
16 changes: 9 additions & 7 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,23 +256,25 @@ Error defineTensor(
/*flags=*/0, // this is netiher external input or output
/*id_out=*/&id);

// this is the FP32 external value that is dynamically quantized
uint32_t fp32_id;
// this is the FP16 or FP32 external value that is being dynamically
// quantized
uint32_t float_id;
enum xnn_datatype fp_datatype = getDataType(tensor_value->datatype());
status = xnn_define_tensor_value(
/*subgraph=*/subgraph_ptr,
/*datatype=*/xnn_datatype_fp32, // always fp32
/*datatype=*/fp_datatype,
/*num_dims=*/tensor_value->num_dims(),
/*dims=*/dims_data.data(),
/*data=*/buffer_ptr,
/*external_id=*/tensor_value->external_id(),
/*flags=*/tensor_value->flags(),
/*id_out=*/&fp32_id);
executor->addDynamicQinput(fp32_id);
/*id_out=*/&float_id);
executor->addDynamicQinput(float_id);

// Define dynamic conversion from fp32 to qdint8
// Define dynamic conversion from float to qdint8
status = xnn_define_convert(
/*subgraph=*/subgraph_ptr,
/*input_id=*/fp32_id,
/*input_id=*/float_id,
/*output_id=*/id,
/*flags=*/0);
break;
Expand Down
30 changes: 25 additions & 5 deletions backends/xnnpack/test/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@


class TestLinear(unittest.TestCase):
def test_fp16_linear(self):
for use_bias in (True, False):
self._test_linear(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
uses_bias=use_bias,
dtype=torch.float16,
atol=5e-2,
)

def test_fp32_linear(self):
for use_bias in (True, False):
self._test_linear(
Expand Down Expand Up @@ -284,7 +295,14 @@ def forward(self, x):
quant=True,
)

def _test_linear(self, make_module, uses_bias, quant=False):
def _test_linear(
self,
make_module,
uses_bias,
quant=False,
dtype: torch.dtype = torch.float,
atol=1e-03,
):
aten_op, edge_op = (
(
"aten.addmm.default",
Expand All @@ -309,9 +327,10 @@ def _test_linear(self, make_module, uses_bias, quant=False):
in_size = int(in_sizes[i])
input_size = int(input_sizes[i])
output_size = int(output_sizes[i])
print(f"Testing {in_size} {input_size} {output_size}")

module = make_module(input_size, output_size).eval()
inputs = (torch.randn(in_size, input_size),)
module = make_module(input_size, output_size).eval().to(dtype)
inputs = (torch.randn(in_size, input_size).to(dtype),)

tester = Tester(module, inputs)

Expand All @@ -336,7 +355,8 @@ def _test_linear(self, make_module, uses_bias, quant=False):
tester.to_executorch()
tester.serialize()
tester.run_method()
tester.compare_outputs(qtol=quant)
tester.compare_outputs(qtol=quant, atol=atol)
print("success")

def _test_dqlinear(
self,
Expand Down Expand Up @@ -370,7 +390,7 @@ def _test_dqlinear(
tester.export()
tester.check_count({aten_op: linear_count})
tester.check(["torch.ops.quantized_decomposed"])

tester.dump_artifact()
tester.to_edge()
tester.check_count({edge_op: linear_count})

Expand Down