Skip to content

Commit

Permalink
Update prototype_source/pt2e_quant_ptq_static.rst
Browse files Browse the repository at this point in the history
  • Loading branch information
Svetlana Karslioglu authored Jul 28, 2023
1 parent a41e859 commit 9e3850d
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions prototype_source/pt2e_quant_ptq_static.rst
Original file line number Diff line number Diff line change
Expand Up @@ -349,15 +349,16 @@ the statistics of the Tensors and we can later use this information to calculate
the model produced here also had some improvement upon the previous `representations <https://github.com/pytorch/rfcs/blob/master/RFC-0019-Extending-PyTorch-Quantization-to-Custom-Backends.md>`_ in the FX graph mode quantizaiton, previously all quantized operators are represented as ``dequantize -> fp32_op -> qauntize``, in the new flow, we choose to represent some of the operators with integer computation so that it's closer to the computation happens in hardwares.
For example, here is how we plan to represent a quantized linear operator:

def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_int32, bias_scale, bias_zero_point, output_scale, output_zero_point):
x_int16 = x_int8.to(torch.int16)
weight_int16 = weight_int8.to(torch.int16)
acc_int32 = torch.ops.out_dtype(torch.mm, torch.int32, (x_int16 - x_zero_point), (weight_int16 - weight_zero_point))
acc_rescaled_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, torch.int32, acc_int32, x_scale * weight_scale / output_scale)
bias_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, bias_int32 - bias_zero_point, bias_scale / output_scale))
out_int8 = torch.ops.aten.clamp(acc_rescaled_int32 + bias_int32 + output_zero_point, qmin, qmax).to(torch.int8)
return out_int8
```
.. code-block:: python
def quantized_linear(x_int8, x_scale, x_zero_point, weight_int8, weight_scale, weight_zero_point, bias_int32, bias_scale, bias_zero_point, output_scale, output_zero_point):
x_int16 = x_int8.to(torch.int16)
weight_int16 = weight_int8.to(torch.int16)
acc_int32 = torch.ops.out_dtype(torch.mm, torch.int32, (x_int16 - x_zero_point), (weight_int16 - weight_zero_point))
acc_rescaled_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, torch.int32, acc_int32, x_scale * weight_scale / output_scale)
bias_int32 = torch.ops.out_dtype(torch.ops.aten.mul.Scalar, bias_int32 - bias_zero_point, bias_scale / output_scale))
out_int8 = torch.ops.aten.clamp(acc_rescaled_int32 + bias_int32 + output_zero_point, qmin, qmax).to(torch.int8)
return out_int8
For more details, please see: `Quantized Model Representation <https://docs.google.com/document/d/17h-OEtD4o_hoVuPqUFsdm5uo7psiNMY8ThN03F9ZZwg/edit>`_ (TODO: make this a public API doc/issue).

Expand Down

0 comments on commit 9e3850d

Please sign in to comment.