Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[quant] Add 4-bit embedding_bag prepack/unpack support using quint4x2 #45751

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
101 changes: 71 additions & 30 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
Expand Up @@ -14,21 +14,35 @@ torch::class_<EmbeddingPackedParamsBase> register_embedding_params();
* To prepack the weights we store the scale and bias (where bias is Xmin)
* for each row along with the quantized weights.
*/
// TODO: Extend this to support 4-bits once 4-bit qtensor support is added.
c10::intrusive_ptr<EmbeddingPackedParamsBase> PackedEmbeddingBagWeight::prepack(
at::Tensor qweight) {
static constexpr int64_t version = 1;
TORCH_CHECK(
qweight.dim() == 2,
"quantized::embedding_bag_prepack weight tensor rank should be 2");
TORCH_CHECK(
qweight.scalar_type() == c10::kQUInt8,
"qembedding_bag_prepack currently only supports quint8 weights");
qweight.scalar_type() == c10::kQUInt8 ||
qweight.scalar_type() == c10::kQUInt4x2,
"qembedding_bag_prepack currently only supports quint8 and quint4x2 weights");

at::Tensor weight_contig =
qweight.contiguous(qweight.suggest_memory_format());
const uint8_t* weight_data =
reinterpret_cast<uint8_t*>(weight_contig.data_ptr<c10::quint8>());

uint8_t* weight_data;
int bit_width, scale_bias_bytes;
if (qweight.scalar_type() == c10::kQUInt8) {
weight_data =
reinterpret_cast<uint8_t*>(weight_contig.data_ptr<c10::quint8>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember data_ptr does not do type check, so maybe you can just do weight_contig.data_ptr<uint8_t*>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is applicable for qint types. It throws the error "expected scalar type Byte but found QUInt8"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry, I meant this: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/templates/TensorBody.h#L354
so:

static_cast<uint8_t*>(weight_contig.data_ptr())

bit_width = 8;
scale_bias_bytes = 8; // extra 8 bytes to store FP scale and bias per row.
} else {
weight_data =
reinterpret_cast<uint8_t*>(weight_contig.data_ptr<c10::quint4x2>());
bit_width = 4;
scale_bias_bytes =
4; // extra 4 bytes to store at::Half scale and bias per row.
}
const auto num_elem_per_byte = 8 / bit_width;

int64_t embedding_rows = qweight.size(0);
int64_t embedding_cols = qweight.size(1);
Expand All @@ -50,8 +64,9 @@ c10::intrusive_ptr<EmbeddingPackedParamsBase> PackedEmbeddingBagWeight::prepack(

std::vector<int64_t> output_shape = {
embedding_rows,
embedding_cols +
8}; // extra 8 bytes to store FP scale and zero_point per row.
static_cast<std::int64_t>(
(embedding_cols + num_elem_per_byte - 1) / num_elem_per_byte +
scale_bias_bytes)}; // extra bytes to store scale and bias per row.
size_t output_columns = output_shape[1];

// Allocate output packed weights.
Expand All @@ -61,28 +76,46 @@ c10::intrusive_ptr<EmbeddingPackedParamsBase> PackedEmbeddingBagWeight::prepack(
weight_contig.suggest_memory_format());
auto* output_data = output.data_ptr<uint8_t>();

at::parallel_for(
0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
for (int64_t row = start_idx; row < end_idx; ++row) {
const uint8_t* input_row = weight_data + row * embedding_cols;
std::uint8_t* output_row = output_data + row * output_columns;
float* output_row_scale_bias =
reinterpret_cast<float*>(output_row + embedding_cols);
output_row_scale_bias[0] = weight_scales[row];
output_row_scale_bias[1] = weight_bias[row];
for (int64_t col = 0; col < embedding_cols; ++col) {
output_row[col] = input_row[col];
if (bit_width == 8) {
at::parallel_for(
0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
for (int64_t row = start_idx; row < end_idx; ++row) {
const uint8_t* input_row = weight_data + row * embedding_cols;
std::uint8_t* output_row = output_data + row * output_columns;
float* output_row_scale_bias =
reinterpret_cast<float*>(output_row + embedding_cols);
output_row_scale_bias[0] = weight_scales[row];
output_row_scale_bias[1] = weight_bias[row];
for (int64_t col = 0; col < embedding_cols; ++col) {
output_row[col] = input_row[col];
}
}
}
});
});
} else {
// Re-calculate the number of embedding_cols, to account for values packed
// in a byte.
embedding_cols =
(embedding_cols + num_elem_per_byte - 1) / num_elem_per_byte;
at::parallel_for(
0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
for (int64_t row = start_idx; row < end_idx; ++row) {
const uint8_t* input_row = weight_data + row * embedding_cols;
std::uint8_t* output_row = output_data + row * output_columns;
at::Half* output_row_scale_bias =
reinterpret_cast<at::Half*>(output_row + embedding_cols);
output_row_scale_bias[0] = weight_scales[row];
output_row_scale_bias[1] = weight_bias[row];
for (int64_t col = 0; col < embedding_cols; ++col) {
// The weight values have already been packed, so here we just
// store it in the output tensor.
output_row[col] = input_row[col];
}
}
});
Comment on lines +95 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional: seems like this could be reused with the above section if the at::Half is templatized? Def optional though since LOC is low.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think there would be similar LOC added in templatizing it so planning to skip it for now.

}

auto packed_ptr = c10::make_intrusive<PackedEmbeddingBagWeight>(
output,
weight_scales,
weight_zero_points,
8 /* bit rate */,
qtype,
version);
output, weight_scales, weight_zero_points, bit_width, qtype, version);

return packed_ptr;
}
Expand Down Expand Up @@ -290,13 +323,21 @@ class QEmbeddingPackWeights final {
};

TORCH_LIBRARY_IMPL(quantized, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"), TORCH_FN(qembeddingbag_byte_prepack));
m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"), TORCH_FN(qembeddingbag_4bit_prepack));
m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"), TORCH_FN(qembeddingbag_2bit_prepack));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"),
TORCH_FN(qembeddingbag_byte_prepack));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"),
TORCH_FN(qembeddingbag_4bit_prepack));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"),
TORCH_FN(qembeddingbag_2bit_prepack));
}

TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_prepack"), TORCH_FN(QEmbeddingPackWeights::run));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_prepack"),
TORCH_FN(QEmbeddingPackWeights::run));
}

} // namespace
Expand Down
88 changes: 65 additions & 23 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp
Expand Up @@ -9,38 +9,71 @@ torch::class_<EmbeddingPackedParamsBase> register_embedding_params();
at::Tensor PackedEmbeddingBagWeight::unpack() {
auto packed_weight = packed_w;
at::Tensor weight_origin;
if (bit_rate_ == 8) {

if (bit_rate_ == 8 || bit_rate_ == 4) {
const auto input_rows = packed_weight.size(0);
const auto input_columns = packed_weight.size(1);

// The last 2 values are used to store the FP32 scale and zero_point values
// per row.
int output_columns = input_columns - 2 * sizeof(float);
int scale_bias_bytes;
const auto num_elem_per_byte = 8 / bit_rate_;
if (bit_rate_ == 8) {
// The last 2 values are used to store the FP32 scale and zero_point
// values per row.
scale_bias_bytes = 8;
} else {
scale_bias_bytes = 4;
}

const auto* input = packed_weight.data_ptr<uint8_t>();
std::vector<int64_t> output_shape = {input_rows, output_columns};
// Calculate the output shape, accounting for the last n bytes to be used
// for scale/bias rest of the entries are packed depending on the bit_width.
std::vector<int64_t> output_shape = {
input_rows,
static_cast<std::int64_t>(input_columns - scale_bias_bytes) *
num_elem_per_byte};

auto scales = at::from_blob(
w_scale.data(), w_scale.size(), device(c10::kCPU).dtype(c10::kFloat));
auto zero_points = at::from_blob(
w_zp.data(), w_zp.size(), device(c10::kCPU).dtype(c10::kFloat));

weight_origin = at::_empty_per_channel_affine_quantized(
output_shape,
scales.toType(c10::kFloat),
zero_points.toType(c10::kFloat),
0, // The output channel axis is 0
device(c10::kCPU).dtype(c10::kQUInt8));

uint8_t* output_data =
reinterpret_cast<uint8_t*>(weight_origin.data_ptr<c10::quint8>());

auto output_columns = output_shape[1];
uint8_t* output_data;

// Allocate output weight tensor based on the bit_width
if (bit_rate_ == 8) {
weight_origin = at::_empty_per_channel_affine_quantized(
output_shape,
scales.toType(c10::kFloat),
zero_points.toType(c10::kFloat),
0, // The output channel axis is 0
device(c10::kCPU).dtype(c10::kQUInt8));
output_data =
reinterpret_cast<uint8_t*>(weight_origin.data_ptr<c10::quint8>());
} else {
// We create empty qtensor with the full output shape, and dtype set to
// quint4x2 This will internally allocate appropriate storage bytes to
// account for the packed nature of this dtype.
weight_origin = at::_empty_per_channel_affine_quantized(
output_shape,
scales.toType(c10::kFloat),
zero_points.toType(c10::kFloat),
0, // The output channel axis is 0
device(c10::kCPU).dtype(c10::kQUInt4x2));
output_data =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

reinterpret_cast<uint8_t*>(weight_origin.data_ptr<c10::quint4x2>());
}

// Copy over the data from the packed weight to the output.
// For sub-byte tensors this will copy the packed bytes over since the
// sub_byte qtensors are expected to store data in packed format.
at::parallel_for(0, input_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
for (int64_t row = start_idx; row < end_idx; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
uint8_t* output_row = output_data + row * output_columns;
uint8_t* output_row =
output_data + row * output_columns / num_elem_per_byte;

for (std::size_t col = 0; col < output_columns; ++col) {
for (std::size_t col = 0; col < output_columns / num_elem_per_byte;
++col) {
output_row[col] = input_row[col];
} // output_columns
}
Expand All @@ -49,7 +82,8 @@ at::Tensor PackedEmbeddingBagWeight::unpack() {
return weight_origin;
}
TORCH_INTERNAL_ASSERT(
"Currently only supporting 8-bit quantization of embedding bag.");
false,
"We currently only support 8-bit and 4-bit quantization of embedding_bag.");
return weight_origin;
}

Expand Down Expand Up @@ -171,15 +205,23 @@ class QEmbeddingUnpackWeights final {
};

TORCH_LIBRARY_IMPL(quantized, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_unpack"), qembeddingbag_byte_unpack);
m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_unpack"), qembeddingbag_4bit_unpack);
m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_unpack"), qembeddingbag_2bit_unpack);
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_unpack"),
qembeddingbag_byte_unpack);
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_unpack"),
qembeddingbag_4bit_unpack);
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_unpack"),
qembeddingbag_2bit_unpack);
}

TORCH_LIBRARY_IMPL(quantized, CatchAll, m) {
// Unpack the packed embedding_bag weights using TorchBind custom class.
// TODO extend to support 4-bit qtensor.
m.impl(TORCH_SELECTIVE_NAME("quantized::embedding_bag_unpack"), TORCH_FN(QEmbeddingUnpackWeights::run));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_unpack"),
TORCH_FN(QEmbeddingUnpackWeights::run));
}

} // namespace
Expand Down
9 changes: 5 additions & 4 deletions test/quantization/test_quantized_op.py
Expand Up @@ -2771,23 +2771,24 @@ class TestQuantizedEmbeddingOps(TestCase):
def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, optimized_qparams):
weights = torch.from_numpy((np.random.random_sample((
num_embeddings, embedding_dim)) + 1).astype(np.float32))

qtype = torch.quint8
if bit_rate == 8:
w_packed = pack_fn(weights)
else:
w_packed = pack_fn(weights, optimized_qparams=optimized_qparams)
w_unpacked = unpack_fn(w_packed)

if bit_rate == 8:
if bit_rate == 8 or bit_rate == 4:
# Check numerics of prepack function that accepts qtensor as input.
# We use min-max observer to mimic the quantization performed in the original function.
obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
obs(weights)
# Get the scale and zero point for the weight tensor
qparams = obs.calculate_qparams()

if bit_rate == 4:
qtype = torch.quint4x2
# Quantize the weights to 8bits
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=torch.quint8)
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qtype)
real_packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight)
self.assertEqual(isinstance(real_packed_weight, torch._C.ScriptObject), True)
unpacked_weight = torch.ops.quantized.embedding_bag_unpack(real_packed_weight)
Expand Down