Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revisit how __truediv__ works for QuantTensor #740

Closed
nickfraser opened this issue Oct 27, 2023 · 3 comments
Closed

Revisit how __truediv__ works for QuantTensor #740

nickfraser opened this issue Oct 27, 2023 · 3 comments

Comments

@nickfraser
Copy link
Collaborator

nickfraser commented Oct 27, 2023

__truediv__ in QuantTensor currently does the inverse operation of multiplication, W.R.T. how the output bitwidths, scales are calculated based on the inputs.

For d = a / b, we currently do the following:

d_bitwidth = a_bitwidth - b_bitwidth
d_scale = a_scale / b_scale
d = a / b

This makes sense, when the numerator is the result of the previous multiplication, but perhaps doesn't generalise to a standalone division operation. This also diverges a little bit from what "traditional"* fixed-point arithmetic does. My suggestion is to generalise traditional fixed point arithmetic rules for floating point scales, and to be able to represent the extremes of the input and output range.

A convenient way to do this, would be do decompose division (a / b) into a * (1 / b), and have a simple rule for calculating to output bitwidth of 1 / b, then apply the regular rules to multiplication that we normally do. My suggestion is then as follows:

c = 1 / b
c_bitwidth = b_bitwidth
c_scale = 1 / b_threshold # where b_threshold = b_scale * 2^b_bitwidth (for unsigned) or b_scale * 2^(b_bitwidth-1) (signed)
c_scale = 1 / (b_scale * 2^(b_bitwidth - int(signed))

Adding the multiply, leads to:

d = a / b = a * c = a * (1 / b)
d_bitwidth = a_bitwidth + c_bitwidth = a_bitwidth + b_bitwidth
d_scale = a_scale * c_scale = a_scale / (b_scale * 2^(b_bitwidth - int(signed))

The resulting value would then be:

d = a * QuantInt(1 / b, c_bitwidth, c_scale) # Requantization required to ensure the output QuantTensor contains only valid values, which can be reconverted to integers

I believe this would match traditional fixed point arithmetic rules in the power-of-two case and do something not completely unreasonable in the floating point scaling case. However, the floating point scaling case should be studied further to understand the ramifications of making such a choice. I expect that the decomposition of division into inversion and multiplication is not unreasonable.

*traditional = power-of-two scaling

@nickfraser
Copy link
Collaborator Author

@preusser, I would appreciate feedback on my proposal here.

@nickfraser nickfraser changed the title Revisit how __truediv__ works for TensorQuant Revisit how __truediv__ works for QuantTensor Oct 27, 2023
@preusser
Copy link

preusser commented Nov 1, 2023

d = a / b = a * c = a * (1 / b)
d_bitwidth = a_bitwidth + c_bitwidth = a_bitwidth + b_bitwidth
d_scale = a_scale * c_scale = a_scale / (b_scale * 2^(b_bitwidth - int(signed))

In most of the cases, division is a lossy operation. So, I assume the goal would be to preserve the accuracy that can reasonably be expected in the operands?
These proposed bitwidths and scales are reasonable choices. I'd argue that they are even when representing values by a (fixed) floating-point scale (_scale) and a variable integer value (_val):

d = a / b
  = a_scale/b_scale * a_val/b_val
  = (a_scale/b_scale / 2^b_ws) * (a_val * 2^b_ws / b_val)    // b_ws = b_bitwidth - int(signed)
  = d_scale                    * d_val

The scaling performed for the value computation ensures that even the biggest b can fit into the smallest a preserving whatever accuracy a brought into the operation.
However, I would definitely go for this direct implementation rather than taking a detour through 1/b, which would introduce an avoidable accuracy bottleneck.

@nickfraser
Copy link
Collaborator Author

Thanks @preusser, I will adopt your suggestions into the proposal.

nickfraser added a commit to nickfraser/brevitas that referenced this issue Nov 23, 2023
nickfraser added a commit to nickfraser/brevitas that referenced this issue Dec 20, 2023
Giuseppe5 pushed a commit that referenced this issue Dec 21, 2023
…d fixed point rules" (#769)

* [quant_tensor] Updated `__truediv__` behaviour based on #740

* [quant_tensor] Updated div behaviour to throw RuntimeError when non-zero zero-point operands are used

* Fix: changed other.tensor -> other.value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants