Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant] Make choose_qparams_optimized return Tensors to preserve dtype #45530

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Expand Up @@ -4437,7 +4437,7 @@
use_c10_dispatcher: full
variants: function

- func: choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (float, float)
- func: choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering the reason for having numel as an argument? Isn't it the same as input tensor size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily, it could be per-channel in which case it will be numel per channel

use_c10_dispatcher: full
variants: function

Expand Down
17 changes: 10 additions & 7 deletions aten/src/ATen/native/quantized/QTensor.cpp
Expand Up @@ -245,15 +245,14 @@ float calculate_quant_loss(
float scale = data_range == 0
? 1.0
: static_cast<float>(static_cast<at::Half>(data_range / qmax));
float inverse_scale = 1.0f / scale;
float inverse_scale = scale == 0 ? 1.0f : 1.0f / scale;

float norm = 0.0f;
constexpr int VLEN = 8;
int i = 0;

// TODO add FBGEMM kernel
// #ifdef USE_FBGEMM
// #endif
// TODO add FBGEMM kernel
// #ifdef USE_FBGEMM
// #endif

// remainder loop
for (; i < numel; i++) {
Expand All @@ -271,7 +270,7 @@ float calculate_quant_loss(
and tries to minimize the quant error by doing `torch.norm(x-fake_quant(x,s,z))`
Returns the optimized xmax and xmin value of the tensor.
*/
std::tuple<double, double> choose_qparams_optimized(
std::tuple<Tensor, Tensor> choose_qparams_optimized(
const at::Tensor& input_tensor,
int64_t numel,
const int64_t n_bins,
Expand Down Expand Up @@ -318,7 +317,11 @@ std::tuple<double, double> choose_qparams_optimized(
}
}

return std::make_tuple((float) xmax, (float) xmin);
at::Tensor xmax_tensor = at::empty({1});
at::Tensor xmin_tensor = at::empty({1});
xmax_tensor[0] = xmax;
xmin_tensor[0] = xmin;
return std::make_tuple(xmax_tensor, xmin_tensor);
}
} // namespace native
} // namespace at
16 changes: 13 additions & 3 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
Expand Up @@ -196,8 +196,14 @@ Tensor _qembeddingbag_nbit_prepack_helper(

float Xmin, Xmax;
if (optimized_qparams) {
std::tie(Xmax, Xmin) = at::choose_qparams_optimized(
at::Tensor xmax_tensor, xmin_tensor;
std::tie(xmax_tensor, xmin_tensor) = at::choose_qparams_optimized(
weight_contig[row], embedding_cols, 200, 0.16, bit_width);
TORCH_CHECK(
xmax_tensor.numel() == 1 && xmin_tensor.numel() == 1,
"Expected choose_qparams_optimized to return min/max tensors of size 1");
Xmax = xmax_tensor.item<float>();
Xmin = xmin_tensor.item<float>();
} else {
Xmin = *std::min_element(input_row, input_row + embedding_cols);
Xmax = *std::max_element(input_row, input_row + embedding_cols);
Expand Down Expand Up @@ -254,7 +260,9 @@ Tensor _qembeddingbag_nbit_prepack_helper(
// To later de-quantize values, the scale (range / 15) and zero_point
// are stored alongside the data. More precisely, each row first has quantized
// values, and then 2-byte fp16 scale and 2-byte zero_offset.
Tensor qembeddingbag_4bit_prepack(const Tensor& weight, bool optimized_qparams) {
Tensor qembeddingbag_4bit_prepack(
const Tensor& weight,
bool optimized_qparams) {
return _qembeddingbag_nbit_prepack_helper(
weight, 4 /*bit_width*/, optimized_qparams);
}
Expand All @@ -267,7 +275,9 @@ Tensor qembeddingbag_4bit_prepack(const Tensor& weight, bool optimized_qparams)
// are stored alongside the data. More precisely, each row first has quantized
// values, and then 2-byte fp16 scale and 2-byte zero_offset.
// TODO() - Add 2Bit Embedding Lookup operator.
Tensor qembeddingbag_2bit_prepack(const Tensor& weight, bool optimized_qparams) {
Tensor qembeddingbag_2bit_prepack(
const Tensor& weight,
bool optimized_qparams) {
return _qembeddingbag_nbit_prepack_helper(
weight, 2 /*bit_width*/, optimized_qparams);
}
Expand Down
Expand Up @@ -108,6 +108,7 @@
("aten::_foreach_sub_", datetime.date(2020, 10, 1)),
("aten::_foreach_div", datetime.date(2020, 10, 1)),
("aten::_foreach_sub", datetime.date(2020, 10, 1)),
("aten::choose_qparams_optimized", datetime.date(2020, 10, 5)),
]


Expand Down
77 changes: 77 additions & 0 deletions test/quantization/test_quantized_tensor.py
Expand Up @@ -67,6 +67,75 @@ def _calculate_dynamic_qparams(X, dtype, reduce_range=False):
def get_supported_device_types():
return ['cpu', 'cuda'] if torch.cuda.is_available() and not TEST_WITH_ROCM else ['cpu']

# Note we explicitly cast variables to np.float32 in a couple of places to avoid
# the default casting in Python often resuling in double precision and to make
# sure we're doing the same numerics as C++ code.
def param_search_greedy(x, bit_rate, n_bins=200, ratio=0.16):
xmin, xmax = np.min(x), np.max(x)
stepsize = (xmax - xmin) / np.float32(n_bins)
min_bins = np.float32(n_bins) * (np.float32(1) - np.float32(ratio))
xq, loss = _compress_uniform_simplified(x, bit_rate, xmin, xmax)

solutions = [] # [(left, right, loss)] # local optima solution

cur_min, cur_max, cur_loss = xmin, xmax, loss
thr = min_bins * stepsize
while cur_min + thr < cur_max:
# move left
xq, loss1 = _compress_uniform_simplified(
x, bit_rate, cur_min + stepsize, cur_max
)
# move right
xq, loss2 = _compress_uniform_simplified(
x, bit_rate, cur_min, cur_max - stepsize
)

if cur_loss < loss1 and cur_loss < loss2:
# found a local optima
solutions.append((cur_min, cur_max, cur_loss))
if loss1 < loss2:
cur_min, cur_max, cur_loss = cur_min + stepsize, cur_max, loss1
else:
cur_min, cur_max, cur_loss = cur_min, cur_max - stepsize, loss2
if len(solutions):
best = solutions[0]
for solution in solutions:
if solution[-1] < best[-1]:
best = solution
return best[1], best[0] # xmax, xmin
return xmax, xmin


def _compress_uniform_simplified(X, bit_rate, xmin, xmax, fp16_scale_bias=True):
# affine transform to put Xq in [0,2**bit_rate - 1]
# Xq = (2 ** bit_rate - 1) * (Xq - xmin) / data_range
if fp16_scale_bias:
xmin = xmin.astype(np.float16).astype(np.float32)
data_range = xmax - xmin
scale = np.where(
data_range == 0, np.float32(1), data_range / np.float32(2 ** bit_rate - 1)
)
if fp16_scale_bias:
scale = scale.astype(np.float16).astype(np.float32)
inverse_scale = np.float32(1) / scale
Xq = np.clip(np.round((X - xmin) * inverse_scale), 0, np.float32(2 ** bit_rate - 1))
Xq = Xq * scale + xmin

# Manually compute loss instead of using np.linalg.norm to use the same
# accumulation order used by C++ code
vlen = 8
loss_v = np.zeros(vlen).astype(np.float32)
for i in range(len(Xq) // vlen * vlen):
loss_v[i % vlen] += (X[i] - Xq[i]) * (X[i] - Xq[i])
loss = np.float32(0)
for i in range(vlen):
loss += loss_v[i]
for i in range(len(Xq) // vlen * vlen, len(Xq)):
loss += (X[i] - Xq[i]) * (X[i] - Xq[i])
loss = np.sqrt(loss)

return Xq, loss

class TestQuantizedTensor(TestCase):
def test_qtensor(self):
num_elements = 10
Expand Down Expand Up @@ -745,3 +814,11 @@ def test_fp16_saturate_op(self):
ref[0] = torch.ones(5) * -65504
y = torch._saturate_weight_to_fp16(x)
self.assertEqual(y, ref)

def test_choose_qparams_optimized(self):
for bit_width in [4, 2]:
x = torch.randn(64, dtype=torch.float)
y = torch.choose_qparams_optimized(x, numel=64, n_bins=200, ratio=0.16, bit_width=bit_width)
ref = param_search_greedy(x.numpy(), bit_rate=bit_width)
self.assertEqual(y[0].numpy(), ref[0])
self.assertEqual(y[1].numpy(), ref[1])
1 change: 0 additions & 1 deletion tools/autograd/gen_python_functions.py
Expand Up @@ -405,7 +405,6 @@ def get_cpp_formal(arg, ensure_temp_safe=True):
'std::tuple<Tensor,Tensor,Tensor,Tensor,int64_t>',
'std::tuple<Tensor,Tensor,double,Tensor,int64_t>',
'std::tuple<double,int64_t>',
'std::tuple<double,double>',
'std::vector<Tensor>',
'Scalar', 'bool', 'int64_t', 'void*', 'void',
'QScheme', 'double',
Expand Down