Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexed assignment of quantized Tensors yields unexpected results #29102

Open
t-vi opened this issue Nov 3, 2019 · 11 comments
Open

Indexed assignment of quantized Tensors yields unexpected results #29102

t-vi opened this issue Nov 3, 2019 · 11 comments
Assignees

Comments

@t-vi
Copy link
Collaborator

@t-vi t-vi commented Nov 3, 2019

馃悰 Bug

Indexed assignment of quantized Tensors effectively copies over the int_repr of the source tensors. This makes no sense in terms of values when the scale / zero_point do not match.

To Reproduce

t = torch.arange(4.0)
q = torch.quantize_per_tensor(t, 0.02, 0, torch.qint8)
q2 = torch._empty_affine_quantized(q.shape, scale=0.04, zero_point=0, dtype=torch.qint8)
q3 = torch._empty_affine_quantized(q.shape, scale=0.04, zero_point=0, dtype=torch.qint8)
q2[:] = q
q3[:] = q.dequantize()
q, q2, q3

gives

(tensor([0.0000, 1.0000, 2.0000, 2.5400], size=(4,), dtype=torch.qint8,
        quantization_scheme=torch.per_tensor_affine, scale=0.02, zero_point=0),
 tensor([0.0000, 2.0000, 4.0000, 5.0800], size=(4,), dtype=torch.qint8,
        quantization_scheme=torch.per_tensor_affine, scale=0.04, zero_point=0),
 tensor([0.0000, 1.0000, 2.0000, 2.5600], size=(4,), dtype=torch.qint8,
        quantization_scheme=torch.per_tensor_affine, scale=0.04, zero_point=0))

Expected behavior

I would expect q2 and q3 to be the same and close to q.

Environment

masterish PyTorch

Additional

I would suspect that this could be fixed by changing copy_ to re-quantize to the target quantization, which I would consider reasonable semantics, but of course it is different to the current one.
If that is an acceptable change, I'd be very happy to send a PR.

cc @jerryzh168 @jianyuh @dzhulgakov @raghuramank100 @jamesr66a

@jerryzh168 jerryzh168 self-assigned this Nov 9, 2019
@jerryzh168

This comment has been minimized.

Copy link
Contributor

@jerryzh168 jerryzh168 commented Nov 9, 2019

Currently supported assignment is assigning from float to quantized tensor. and copy_ which completely overwrites the destination quantized tensor. see https://github.com/pytorch/pytorch/wiki/Introducing-Quantized-Tensor

I haven't explored the path of quantized tensor -> quantized tensor assignment. It probably just works by coincidence. If we plan to support this use case, I think the semantics mentioned above makes sense, but I'm not sure whether we should support it or not, the reason not to support this is probably we want to make this more explicit rather than hiding everything underneath. Greg do you have any suggestions @gchanan?

@t-vi

This comment has been minimized.

Copy link
Collaborator Author

@t-vi t-vi commented Nov 9, 2019

Well x.copy_(y) and x[:] = y both to work as in throw no error and but produce different results (one being incorrect).
The problem is that PyTorch does support slices and views and and the current copy_ will render the results of these bad. I know that views probably have more bugs when the quantization parameters change, but I to me, the semantics of copy_ feel quite un-PyTorchy: I would expect values to remain the same but all other properties of the destination tensor to remain the same (e.g. shape does, too, because copy broadcasts).
Similarly b.view(-1).copy_(c.view(-1)) doesn't work.

That aside, when copy_ isn't, what is the recommended way to effect requantization? dequantize and quantize again?

@jerryzh168

This comment has been minimized.

Copy link
Contributor

@jerryzh168 jerryzh168 commented Nov 11, 2019

for requantization, I think we can either dequantize then quantize or provide a dequantize aten function.

for x[:] = y, we can also disable x[:] = y(both x, y are quantized) by throwing an error, I'm not sure whether we have a use case for this.

x.copy_(y) is introduced because of

param.copy_(input_param)
, I think we can probably disable this as well, and can some other function there.

@gchanan

This comment has been minimized.

Copy link
Contributor

@gchanan gchanan commented Nov 11, 2019

I agree with @t-vi's take and brought up the issue with copy_ with @jerryzh168 separately.

@jerryzh168 mentioned that this is sensible behavior for _load_from_state_dict, which currently uses copy_:

param.copy_(input_param)
.

This feels a bit like the tail wagging the dog: we should either serialize the quantized tensor in some way so we can create the tensor with the correct parameters first, or we should call some special function that handles the deserialization correctly.

Thoughts?

@dzhulgakov

This comment has been minimized.

Copy link
Member

@dzhulgakov dzhulgakov commented Nov 12, 2019

Probably changing load_from_state_dict to call something different is better. Or we can even add special handling for quantization there (which is ugly but works).

There potentially can be a corner case when parameters are made to be a bigger tensor and then we load from a state dict. In that case I guess we can error out.

@t-vi

This comment has been minimized.

Copy link
Collaborator Author

@t-vi t-vi commented Nov 12, 2019

So what would the new copy name be? copy_strongly_? (I take it that adding an additional optional boolean flag to copy_ would not be tremendously popular.)

@gchanan

This comment has been minimized.

Copy link
Contributor

@gchanan gchanan commented Nov 13, 2019

copy_and_clobber_?

@jerryzh168

This comment has been minimized.

Copy link
Contributor

@jerryzh168 jerryzh168 commented Nov 13, 2019

@gchanan do we plan to disable copy_ for quantized tensor?
I think we can make re-quantization more explicit, e.g. provide an aten::requantize op.

@jerryzh168

This comment has been minimized.

Copy link
Contributor

@jerryzh168 jerryzh168 commented Nov 13, 2019

copy_and_clobber_?

maybe just overwrite_?

@jerryzh168

This comment has been minimized.

Copy link
Contributor

@jerryzh168 jerryzh168 commented Dec 5, 2019

Should I write a separate copy_ for quantized tensor?

@jerryzh168 jerryzh168 added the triaged label Dec 5, 2019
@t-vi

This comment has been minimized.

Copy link
Collaborator Author

@t-vi t-vi commented Dec 5, 2019

My ideal resolution would be adding overwrite_ with current copy_ semantics and a warning in the docs (unless we have some way to detect this) that it cannot be used on views. Then copy_ should work on views and loading modules use overwrite_.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
5 participants
You can鈥檛 perform that action at this time.