Skip to content

Commit

Permalink
update delegate for derived bias qspec
Browse files Browse the repository at this point in the history
Summary: In the below we added quantizer annotation to the bias of GEMM operations. We update some of our delegate lowering logic to allow for this change

Differential Revision: D56959071
  • Loading branch information
mcr229 authored and facebook-github-bot committed May 4, 2024
1 parent d34f9f2 commit 0c9ec73
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 38 deletions.
2 changes: 1 addition & 1 deletion backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def define_nodes_tensor_inputs_outputs(
if input_type_map.node_bias is not None:
bias_node = get_input_node(node, input_type_map.node_bias)
bias_quant_params = QuantParams.from_bias(
bias_node, weight_quant_params, input_quant_params
bias_node, self._exported_program
)
self.define_tensor(
bias_node,
Expand Down
4 changes: 1 addition & 3 deletions backends/xnnpack/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,7 @@ def define_node(
if node.args[2] is not None:
# If there is a bias
bias_node = get_input_node(node, 2)
bias_quant_params = QuantParams.from_bias(
bias_node, weight_quant_params, input_quant_params
)
bias_quant_params = QuantParams.from_bias(bias_node, self._exported_program)
self.define_tensor(
get_input_node(node, 2),
xnn_graph,
Expand Down
4 changes: 1 addition & 3 deletions backends/xnnpack/operators/op_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ def define_node(
# bias
if len(node.args) > 2:
bias_node = get_input_node(node, 2)
bias_quant_params = QuantParams.from_bias(
bias_node, weight_quant_params, input_quant_params
)
bias_quant_params = QuantParams.from_bias(bias_node, self._exported_program)
self.define_tensor(
get_input_node(node, 2),
xnn_graph,
Expand Down
45 changes: 14 additions & 31 deletions backends/xnnpack/operators/quant_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,41 +306,24 @@ def from_outputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:

@classmethod
def from_bias(
cls,
bias: torch.fx.Node,
weight_quantizer: Optional[QuantParams],
input_quantizer: Optional[QuantParams],
cls, tensor_node: torch.fx.Node, ep: ExportedProgram
) -> Optional[QuantParams]:
if weight_quantizer is None or input_quantizer is None:
check_or_raise(
weight_quantizer is None and input_quantizer is None,
"Weight and Input should both be quantized",
)
return None
# source node for quant params
dq = tensor_node

if input_quantizer.is_dynamic:
# No need to quantize bias for dyanamic quantization
if not is_dequant(dq):
return None

check_or_raise(
not input_quantizer.per_channel,
"Input can not be quantized per channel",
)
src = dq

# is input of dq is q?
dq_input = dq.all_input_nodes[0]
if is_quant(dq_input):
src = dq_input

# Only per_tensor quantization is supported for input here
check_or_raise(
isinstance(input_quantizer.scale, float),
f"q_input scale should be float, but got {input_quantizer.scale}",
)
return cls(
per_channel=weight_quantizer.per_channel,
q_input=bias,
scale=weight_quantizer.scale * cast(float, input_quantizer.scale),
zp=weight_quantizer.zp * 0,
axis=0, # not using weight_quantizer.axis because bias is always of shape [out_channels] i.e. 1D
dtype=torch.int32,
qmin=-(2**31),
qmax=(2**31) - 1,
is_output=False,
is_input=False,
src.all_input_nodes[0].op in ["get_attr", "placeholder"],
f"Expected input to quant -> dequant chain from bias to be static tensor, but instead got: {src.all_input_nodes[0]}",
)

return cls.from_q_dq_node(src, ep)

0 comments on commit 0c9ec73

Please sign in to comment.