Skip to content

Commit

Permalink
[OpBench] fix jit tracing with quantized op/tensor by enabling `_comp…
Browse files Browse the repository at this point in the history
…are_tensors_internal` to compare quantized tensors (#46772)

Summary:
Pull Request resolved: #46772

When running `buck run caffe2/benchmarks/operator_benchmark/pt:qactivation_test -- --use_jit`, I encountered the following error P146518683. The error was traced down to the fact that `torch.allclose` does not work with quantized tensors (the error was triggered by this particular multiplication https://fburl.com/diffusion/8vw647o6 since native mul can not work with a float scalar and a quantized tensor.)

Minimum example to reproduce:
```(Pdb) input = torch.ones(5)
(Pdb) aa = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
(Pdb) bb = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
(Pdb) torch.allclose(aa, bb)
Comparison exception: 	promoteTypes with quantized numbers is not handled yet; figure out what the correct rules should be, offending types: QUInt8 Float
```

Here the proposed fix is to compare quantized tensors strictly within `_compare_tensors_internal`.

The other two possible fixes are:
1. convert quantized tensors to float tensors first before sending them to `torch.allclose`
2. change `torch.allclose` to handle quantized tensor.

Test Plan: buck run caffe2/benchmarks/operator_benchmark/pt:qactivation_test -- --use_jit

Reviewed By: kimishpatel

Differential Revision: D24506723

fbshipit-source-id: 6426ea2a88854b4fb89abef0edd2b49921283796
  • Loading branch information
Yang Wang authored and facebook-github-bot committed Oct 28, 2020
1 parent 8e37dcb commit 810c68f
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion torch/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def is_integral(dtype: torch.dtype) -> bool:
dtypes = [x for x in get_all_dtypes() if x not in get_all_complex_dtypes()]
return dtype in dtypes and not dtype.is_floating_point

def is_quantized(dtype: torch.dtype) -> bool:
return dtype in (torch.quint8, torch.qint8, torch.qint32, torch.quint4x2)

# Helper function that maps a flattened index back into the given shape
# TODO: consider adding torch.unravel_index
def _unravel_index(flat_index, shape):
Expand Down Expand Up @@ -70,7 +73,11 @@ def _compare_tensors_internal(a: torch.Tensor, b: torch.Tensor, *, rtol, atol, e
debug_msg : Optional[str]
# Integer (including bool) comparisons are identity comparisons
# when rtol is zero and atol is less than one
if (is_integral(a.dtype) and rtol == 0 and atol < 1) or a.dtype is torch.bool:
if (
(is_integral(a.dtype) and rtol == 0 and atol < 1)
or a.dtype is torch.bool
or is_quantized(a.dtype)
):
if (a == b).all().item():
return (True, None)

Expand Down

0 comments on commit 810c68f

Please sign in to comment.