Skip to content

Commit

Permalink
convert zero_point to 8bit type when exporting to onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Wu committed Aug 27, 2020
1 parent b208dbd commit 583b5ce
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torch/onnx/symbolic_opset13.py
@@ -1,13 +1,18 @@
from __future__ import absolute_import, division, print_function, unicode_literals

from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_helper import parse_args, cast_pytorch_to_onnx

@parse_args('v', 'v', 'v', 'i', 'i', 'i')
def fake_quantize_per_channel_affine(g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127):
if quant_min not in [0, -128] or quant_max not in [127, 255]:
raise RuntimeError(
"ONNX defines [0, 255] for quint8 and [-128, 127] for qint8, got [{}, {}]".format(quant_min, quant_max))

# ONNX defines zero_point to be int8 or uint8
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=cast_pytorch_to_onnx['Byte'])
else:
zero_point = g.op("Cast", zero_point, to_i=cast_pytorch_to_onnx['Char'])
return g.op(
"DequantizeLinear",
g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis),
Expand Down

0 comments on commit 583b5ce

Please sign in to comment.