Skip to content

Commit

Permalink
[quant] creating quint4x2 dtype for quantized tensors
Browse files Browse the repository at this point in the history
Summary:
This is a prototype PR that introduces 4 bit qtensors. The new dtype added for this is c10::quint4x2
The underlying storage for this is still uint8_t, so we pack 2 4-bit values in a byte while quantizing it.

This change uses most of the existing scaffolding for qtensor storage. We allocate storage
based on the dtype before creating a new qtensor.

It also adds a dispatch mechanism for this dtype so we can use this to get the bitwidth, qmin and qmax info
while quantizing and packing the qtensor (when we add 2-bit qtensor)

Kernels that use this dtype should be aware of the packing format.

Test Plan:
Locally tested
```
x = torch.ones((100, 100), dtype=torch.float)
qx_8bit = torch.quantize_per_tensor(x, scale=1.0, zero_point=2, dtype=torch.quint8)
qx = torch.quantize_per_tensor(x, scale=1.0, zero_point=2, dtype=torch.quint4x2)

torch.save(x, "temp.p")
print('Size float (B):', os.path.getsize("temp.p"))
os.remove('temp.p')

torch.save(qx_8bit, "temp.p")
print('Size quantized 8bit(B):', os.path.getsize("temp.p"))
os.remove('temp.p')

torch.save(qx, "temp.p")
print('Size quantized 4bit(B):', os.path.getsize("temp.p"))
os.remove('temp.p')
```

Size float (B): 40760
Size quantized 8bit(B): 10808
Size quantized 4bit(B): 5816

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: e7098e5bd72be2b5dd5f0b4a3a5c917df49d8edf
Pull Request resolved: #44678
  • Loading branch information
supriyar committed Sep 30, 2020
1 parent 6375704 commit 6e63f37
Show file tree
Hide file tree
Showing 22 changed files with 274 additions and 39 deletions.
7 changes: 2 additions & 5 deletions aten/src/ATen/DLConvertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,10 @@ DLDataType getDLDataType(const Tensor& t) {
throw std::logic_error("BFloat16 is not supported by dlpack");
break;
case ScalarType::QInt8:
throw std::logic_error("QInt8 is not supported by dlpack");
break;
case ScalarType::QUInt8:
throw std::logic_error("QUInt8 is not supported by dlpack");
break;
case ScalarType::QInt32:
throw std::logic_error("QInt32 is not supported by dlpack");
case ScalarType::QUInt4x2:
throw std::logic_error("QUInt/QInt types are not supported by dlpack");
break;
case ScalarType::ComplexHalf:
throw std::logic_error("ComplexHalf is not supported by dlpack");
Expand Down
34 changes: 34 additions & 0 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@
return __VA_ARGS__(); \
}

#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
case enum_type: { \
using scalar_t = type; \
using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
scalar_t::underlying; \
const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \
const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
toUnderlying(enum_type); \
int bit_width = bitwidth; \
int64_t quant_min = qmin; \
int64_t quant_max = qmax; \
return __VA_ARGS__(); \
}

// This macro should be used to skip bfloat16 dispatch on non-ROCm platforms and
// should be removed once the bfloat16 bringup is complete on other platforms.
// This is supposed to be used as a wrapper around the lambda function passed to
Expand Down Expand Up @@ -346,6 +361,25 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
} \
}()

#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQInt8, at::qint8, int8_t, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQUInt8, at::quint8, uint8_t, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQInt32, at::qint32, int, CHAR_BIT * sizeof(int), INT_MIN, INT_MAX, __VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQUInt4x2, at::quint4x2, uint8_t, 4, 0, 15, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
}()

#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
Expand Down
28 changes: 22 additions & 6 deletions aten/src/ATen/native/quantized/affine_quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ DEFINE_DISPATCH(quantize_tensor_per_channel_float_qparams_stub);
DEFINE_DISPATCH(dequantize_tensor_per_tensor_affine_stub);
DEFINE_DISPATCH(dequantize_tensor_per_channel_affine_stub);
DEFINE_DISPATCH(dequantize_tensor_per_channel_float_qparams_stub);
DEFINE_DISPATCH(quantize_tensor_per_tensor_affine_sub_byte_stub);
DEFINE_DISPATCH(dequantize_tensor_per_tensor_affine_sub_byte_stub);

namespace {

Expand Down Expand Up @@ -55,7 +57,8 @@ void checkQuantizedTensor(const std::string& fn_name, Tensor t) {
fn_name,
" expects a ",
caffe2::TypeMeta::Make<T>(),
" Tensor");
" Tensor, got ",
t.scalar_type());
}

template <typename T>
Expand Down Expand Up @@ -103,13 +106,21 @@ Tensor quantize_tensor_per_tensor_affine(
checkSameDevice(fn_name, rtensor, qtensor);
checkSameSize(fn_name, qtensor, rtensor);

AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
checkQuantizedTensor<scalar_t>(fn_name, qtensor);
checkZeroPoint<underlying_t>(fn_name, zero_point);
});

quantize_tensor_per_tensor_affine_stub(
// Temporary solution to pack the tensor if dtype is torch.quint4x2
// Can move this into the fbgemm::Quantize op.
if (qtensor.scalar_type() == at::ScalarType::QUInt4x2) {
quantize_tensor_per_tensor_affine_sub_byte_stub(
rtensor.device().type(), rtensor, qtensor, scale, zero_point);
}
else {
quantize_tensor_per_tensor_affine_stub(
rtensor.device().type(), rtensor, qtensor, scale, zero_point);
}
return qtensor;
}

Expand Down Expand Up @@ -195,13 +206,18 @@ Tensor dequantize_tensor_per_tensor_affine(
checkSameDevice(fn_name, rtensor, qtensor);
checkSameSize(fn_name, qtensor, rtensor);

AT_DISPATCH_QINT_TYPES(qtensor.scalar_type(), fn_name, [&]() {
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(qtensor.scalar_type(), fn_name, [&]() {
checkQuantizedTensor<scalar_t>(fn_name, qtensor);
checkZeroPoint<underlying_t>(fn_name, zero_point);
});

dequantize_tensor_per_tensor_affine_stub(
qtensor.device().type(), qtensor, rtensor, scale, zero_point);
if (qtensor.scalar_type() == at::ScalarType::QUInt4x2) {
dequantize_tensor_per_tensor_affine_sub_byte_stub(
qtensor.device().type(), qtensor, rtensor, scale, zero_point);
} else {
dequantize_tensor_per_tensor_affine_stub(
qtensor.device().type(), qtensor, rtensor, scale, zero_point);
}
return rtensor;
}

Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/native/quantized/affine_quantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ using dequantize_tensor_per_channel_float_qparams_fn = void (*)(
Tensor zero_points,
int64_t axis);

using quantize_tensor_per_tensor_affine_sub_byte_fn =
void (*)(Tensor rtensor, Tensor qtensor, float scale, float zero_point);

using dequantize_tensor_per_tensor_affine_sub_byte_fn =
void (*)(Tensor qtensor, Tensor rtensor, float scale, float zero_point);

DECLARE_DISPATCH(
quantize_tensor_per_tensor_affine_fn,
quantize_tensor_per_tensor_affine_stub);
Expand All @@ -97,6 +103,13 @@ DECLARE_DISPATCH(
dequantize_tensor_per_channel_float_qparams_fn,
dequantize_tensor_per_channel_float_qparams_stub);

DECLARE_DISPATCH(
quantize_tensor_per_tensor_affine_sub_byte_fn,
quantize_tensor_per_tensor_affine_sub_byte_stub);

DECLARE_DISPATCH(
dequantize_tensor_per_tensor_affine_sub_byte_fn,
dequantize_tensor_per_tensor_affine_sub_byte_stub);

// Quantize a float value into a uint value given scale and zero_point
template <typename T>
Expand Down
34 changes: 23 additions & 11 deletions aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,29 @@ namespace native {
// format of the output the same as input
Tensor int_repr_quantized_cpu(const Tensor& self) {
Tensor dst;
AT_DISPATCH_QINT_TYPES(self.scalar_type(), "int_repr", [&]() {
dst = at::empty(
self.sizes(),
self.options().dtype(UNDERLYING_TYPE),
self.suggest_memory_format());
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(dst)
.add_input(self)
.build();
cpu_kernel(iter, [](scalar_t value) -> underlying_t { return value.val_; });
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(self.scalar_type(), "int_repr", [&]() {
if (bit_width == 4) {
int64_t out_size = std::ceil(self.numel() * 0.5);
dst = at::empty(
{out_size},
self.options().dtype(UNDERLYING_TYPE),
self.suggest_memory_format());
const underlying_t* qdata = reinterpret_cast<underlying_t*>(self.data_ptr<scalar_t>());
for (int64_t i = 0; i < dst.numel(); ++i) {
dst[i] = static_cast<underlying_t>(qdata[i]);
}
} else {
dst = at::empty(
self.sizes(),
self.options().dtype(UNDERLYING_TYPE),
self.suggest_memory_format());
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.add_output(dst)
.add_input(self)
.build();
cpu_kernel(iter, [](scalar_t value) -> underlying_t { return value.val_; });
}
});
return dst;
}
Expand Down
61 changes: 61 additions & 0 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2716,6 +2716,60 @@ void dequantize_tensor_per_channel_float_qparams_cpu(
});
}

void quantize_tensor_per_tensor_affine_sub_byte_cpu(
Tensor rtensor,
Tensor qtensor,
float scale,
float zero_point) {
// TODO Use fbgemm kernel to pack values
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_sub_byte_cpu", [&]() {
check_tensor_memory_format(rtensor, qtensor);
const float* const rdata = rtensor.data_ptr<float>();
auto qdata = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
auto numel = rtensor.numel();
const auto elem_per_byte = CHAR_BIT / bit_width;
for (int i = 0; i < numel; ++i) {
float inv_scale = scale == 0 ? 1.0f : 1.0f / scale;
int qvalue = lrintf(std::nearbyint(rdata[i] * inv_scale) + zero_point);
qvalue = std::max(quant_min, std::min(qvalue, quant_max));

// We pack sub_byte values and align them to a byte.
// Eg. for 4-bits Index 0 is packed in the lower 4-bits
// and index 1 is packed in the upper 4-bits.
if (i % elem_per_byte == 0) {
qdata[i / elem_per_byte] = static_cast<underlying_t>(qvalue);
} else {
qdata[i / elem_per_byte] |= static_cast<underlying_t>(qvalue << ((i % elem_per_byte) * bit_width));
}
} // for numel
});
}

void dequantize_tensor_per_tensor_affine_sub_byte_cpu(
Tensor qtensor,
Tensor rtensor,
float scale,
float zero_point) {
// TODO Use fbgemm kernel to pack values
AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_sub_byte_cpu", [&]() {
check_tensor_memory_format(rtensor, qtensor);
auto rdata = rtensor.data_ptr<float>();
const underlying_t* qdata = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
auto numel = rtensor.numel();
const auto elem_per_byte = CHAR_BIT / bit_width;

for (int i = 0; i < numel; ++i) {
underlying_t qvalue = qdata[i / elem_per_byte];
qvalue >>= (i % elem_per_byte) * bit_width;
qvalue &= (1 << bit_width) - 1;
rdata[i] = (static_cast<float>(qvalue) - zero_point) * scale;
}
});

}

} // namespace

REGISTER_DISPATCH(dequantize_tensor_per_channel_affine_stub,
Expand Down Expand Up @@ -2773,6 +2827,13 @@ REGISTER_DISPATCH(
REGISTER_DISPATCH(quantized_normalize_stub, &quantized_normalize_kernel);
REGISTER_DISPATCH(qupsample_bilinear2d_nhwc_stub,
&qupsample_bilinear2d_nhwc_kernel);
REGISTER_DISPATCH(
quantize_tensor_per_tensor_affine_sub_byte_stub,
&quantize_tensor_per_tensor_affine_sub_byte_cpu);
REGISTER_DISPATCH(
dequantize_tensor_per_tensor_affine_sub_byte_stub,
&dequantize_tensor_per_tensor_affine_sub_byte_cpu);


} // namespace native
} // namespace at
16 changes: 15 additions & 1 deletion aten/src/ATen/quantized/Quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,18 @@ QTensorImpl* get_qtensorimpl(const Tensor& self) {
return static_cast<QTensorImpl*>(self.unsafeGetTensorImpl());
}

int64_t get_sub_byte_tensor_size(int64_t size_bytes, at::ScalarType t) {
int64_t new_size_bytes;
switch(t) {
case at::ScalarType::QUInt4x2:
new_size_bytes = std::ceil(size_bytes * 0.5);
break;
default:
new_size_bytes = size_bytes;
}
return new_size_bytes;
}

inline Tensor new_qtensor(
IntArrayRef sizes,
const TensorOptions& options,
Expand All @@ -99,7 +111,9 @@ inline Tensor new_qtensor(
TORCH_CHECK(
isQIntType(typeMetaToScalarType(dtype)),
"ScalarType is not supported in new_qtensor.");
int64_t size_bytes = nelements * dtype.itemsize();
auto scalar_type = typeMetaToScalarType(dtype);
int64_t size_bytes = get_sub_byte_tensor_size(nelements * dtype.itemsize(), scalar_type);

auto storage = c10::make_intrusive<StorageImpl>(
StorageImpl::use_byte_size_t(),
size_bytes,
Expand Down
1 change: 1 addition & 0 deletions aten/src/TH/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ install(FILES
THGenerateComplexTypes.h
THGenerateIntTypes.h
THGenerateQUInt8Type.h
THGenerateQUInt4x2Type.h
THGenerateQInt8Type.h
THGenerateQInt32Type.h
THGenerateQTypes.h
Expand Down
1 change: 1 addition & 0 deletions aten/src/TH/THGenerateQTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <TH/THGenerateQUInt8Type.h>
#include <TH/THGenerateQInt8Type.h>
#include <TH/THGenerateQInt32Type.h>
#include <TH/THGenerateQUInt4x2Type.h>

#ifdef THQLocalGenerateManyTypes
#undef THQLocalGenerateManyTypes
Expand Down
24 changes: 24 additions & 0 deletions aten/src/TH/THGenerateQUInt4x2Type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef TH_GENERIC_FILE
#error "You must define TH_GENERIC_FILE before including THGenerateQUInt4x2Type.h"
#endif

#define quantized_t c10::quint4x2
#define scalar_t uint8_t
#define Real QUInt4x2
#define RealUnderlying Byte
#define THQUANTIZED
#define THQUINT8
#define TH_REAL_IS_BYTE
#line 1 TH_GENERIC_FILE
#include TH_GENERIC_FILE
#undef scalar_t
#undef quantized_t
#undef Real
#undef RealUnderlying
#undef TH_REAL_IS_BYTE
#undef THQUINT8
#undef THQUANTIZED

#ifndef THGenerateManyTypes
#undef TH_GENERIC_FILE
#endif
1 change: 1 addition & 0 deletions aten/src/TH/generic/THStorage.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#define THQUInt8Storage THStorage
#define THQInt8Storage THStorage
#define THQInt32Storage THStorage
#define THQUInt4x2Storage THStorage
#define THComplexFloatStorage THStorage
#define THComplexDoubleStorage THStorage

Expand Down
10 changes: 7 additions & 3 deletions c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ namespace c10 {
_(c10::qint8, QInt8) /* 12 */ \
_(c10::quint8, QUInt8) /* 13 */ \
_(c10::qint32, QInt32) /* 14 */ \
_(at::BFloat16, BFloat16) /* 15 */
_(at::BFloat16, BFloat16) /* 15 */ \
_(c10::quint4x2, QUInt4x2) /* 16 */


// If you want to support ComplexHalf for real, add ComplexHalf
Expand Down Expand Up @@ -154,7 +155,8 @@ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(SPECIALIZE_CppTypeToScalarType)
#define AT_FORALL_QINT_TYPES(_) \
_(c10::qint8, QInt8) \
_(c10::quint8, QUInt8) \
_(c10::qint32, QInt32)
_(c10::qint32, QInt32) \
_(c10::quint4x2, QUInt4x2)

#define AT_FORALL_COMPLEX_TYPES(_) \
_(c10::complex<float>, ComplexFloat) \
Expand Down Expand Up @@ -279,7 +281,7 @@ static inline bool isComplexType(ScalarType t) {

static inline bool isQIntType(ScalarType t) {
// Don't forget to extend this when adding new QInt types
return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32;
return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32 || t == ScalarType::QUInt4x2;
}

static inline ScalarType toQIntType(ScalarType t) {
Expand All @@ -303,6 +305,8 @@ static inline ScalarType toUnderlying(ScalarType t) {
return ScalarType::Char;
case ScalarType::QInt32:
return ScalarType::Int;
case ScalarType::QUInt4x2:
return ScalarType::Byte;
default:
return t;
}
Expand Down

0 comments on commit 6e63f37

Please sign in to comment.