From 6898e070c2739531ab5950345c85ea107a6fecce Mon Sep 17 00:00:00 2001 From: zrphercule Date: Fri, 6 Dec 2019 15:16:47 -0800 Subject: [PATCH] Add quantized packing weights support --- .circleci/build.sh | 2 +- torch_glow/src/GlowIValue.cpp | 22 + torch_glow/src/GlowIValue.h | 14 + torch_glow/src/PyTorchCommon.cpp | 60 ++- torch_glow/src/PyTorchCommon.h | 10 + torch_glow/src/PyTorchModelLoader.cpp | 388 ++++++++++++++++-- torch_glow/src/PyTorchModelLoader.h | 20 +- .../tests/nodes/quantized_conv2d_relu_test.py | 44 ++ .../tests/nodes/quantized_conv2d_test.py | 32 +- .../tests/nodes/quantized_linear_test.py | 25 ++ 10 files changed, 581 insertions(+), 36 deletions(-) create mode 100644 torch_glow/tests/nodes/quantized_conv2d_relu_test.py diff --git a/.circleci/build.sh b/.circleci/build.sh index 91cc19f102..244731f942 100755 --- a/.circleci/build.sh +++ b/.circleci/build.sh @@ -135,7 +135,7 @@ elif [[ "$CIRCLE_JOB" == "PYTORCH" ]]; then git clone https://github.com/pytorch/pytorch.git --recursive --depth 1 cd pytorch pip install -r requirements.txt - BUILD_BINARY=OFF BUILD_TEST=0 BUILD_CAFFE2_OPS=0 python setup.py install + BUILD_BINARY=OFF BUILD_TEST=0 BUILD_CAFFE2_OPS=0 USE_FBGEMM=ON python setup.py install cd ${GLOW_DIR} cd build elif [[ "$CIRCLE_JOB" == "OPENCL" ]]; then diff --git a/torch_glow/src/GlowIValue.cpp b/torch_glow/src/GlowIValue.cpp index 7a7193a930..f743c721a2 100644 --- a/torch_glow/src/GlowIValue.cpp +++ b/torch_glow/src/GlowIValue.cpp @@ -46,6 +46,8 @@ const char *GlowIValue::tagToStr(GlowIValue::Tag tag) { return "BoolList"; case GlowIValue::Tag::Tuple: return "Tuple"; + case GlowIValue::Tag::PTTensor: + return "PyTorch Tensor"; } LOG(DFATAL) << "Cannot reach here."; } @@ -67,6 +69,9 @@ void GlowIValue::reset() { case Tag::Tuple: delete payload_.asTuple; break; + case Tag::PTTensor: + delete payload_.asPTTensor; + break; case Tag::None: case Tag::Double: case Tag::Int: @@ -104,6 +109,7 @@ bool GlowIValue::isIntList() const { return Tag::IntList == tag_; } bool GlowIValue::isDoubleList() const { return Tag::DoubleList == tag_; } bool GlowIValue::isBoolList() const { return Tag::BoolList == tag_; } bool GlowIValue::isTuple() const { return Tag::Tuple == tag_; } +bool GlowIValue::isPTTensor() const { return Tag::PTTensor == tag_; } #define ExpectTag(EXPECTED_TAG) \ RETURN_ERR_IF_NOT(tag_ == (EXPECTED_TAG), \ @@ -175,6 +181,16 @@ Expected *> GlowIValue::toTuple() const { return payload_.asTuple; } +Expected GlowIValue::toPTTensor() { + ExpectTag(Tag::PTTensor); + return payload_.asPTTensor; +} + +Expected GlowIValue::toPTTensor() const { + ExpectTag(Tag::PTTensor); + return payload_.asPTTensor; +} + #undef ExpectTag void GlowIValue::fromNone() { @@ -234,6 +250,12 @@ void GlowIValue::fromTuple(std::vector glowIValList) { std::swap(glowIValList, *payload_.asTuple); } +void GlowIValue::fromPTTensor(at::Tensor tensor) { + reset(); + tag_ = Tag::PTTensor; + payload_.asPTTensor = new at::Tensor(tensor); +} + Error GlowIValue::fromIValue(const at::IValue &ival) { reset(); if (ival.isNone()) { diff --git a/torch_glow/src/GlowIValue.h b/torch_glow/src/GlowIValue.h index 5686e78c88..3572943527 100644 --- a/torch_glow/src/GlowIValue.h +++ b/torch_glow/src/GlowIValue.h @@ -41,6 +41,7 @@ class GlowIValue { DoubleList, BoolList, Tuple, + PTTensor, }; private: @@ -56,6 +57,7 @@ class GlowIValue { std::vector *asDoubleList; std::vector *asBoolList; std::vector *asTuple; + at::Tensor *asPTTensor; }; Tag tag_ = Tag::None; @@ -96,6 +98,7 @@ class GlowIValue { bool isDoubleList() const; bool isBoolList() const; bool isTuple() const; + bool isPTTensor() const; /// \returns Payload a glow Tensor or error if the tag is not Tensor. Expected toTensor(); @@ -138,12 +141,23 @@ class GlowIValue { /// \returns Payload a vector of GlowIValues or error if the tag is not Tuple. Expected *> toTuple() const; + /// \returns Payload a PyTorch Tensor* or error if the tag is not a PyTorch + /// Tensor. + Expected toPTTensor(); + + /// \returns Payload a const Pytorch Tensor* or error if the tag is not + /// Tensor. + Expected toPTTensor() const; + /// Set the tag to None. void fromNone(); /// Set the tag to Tensor. void fromTensor(Tensor tensor); + /// Set the tag to PyTorch Tensor. + void fromPTTensor(at::Tensor tensor); + /// Set the tag to Double. void fromDouble(double d); diff --git a/torch_glow/src/PyTorchCommon.cpp b/torch_glow/src/PyTorchCommon.cpp index 2ec1383b71..7036dd7967 100644 --- a/torch_glow/src/PyTorchCommon.cpp +++ b/torch_glow/src/PyTorchCommon.cpp @@ -132,6 +132,14 @@ glow::ElemKind scalarTypeToElemKind(c10::ScalarType ty) { return ElemKind::Int64ITy; } else if (ty == at::kBool) { return ElemKind::BoolTy; + } else if (ty == at::kByte) { + // We should have an 8-byte non-quantized integer type eventually + // Currently usage of Bool is fine + return ElemKind::BoolTy; + } else if (ty == at::kQInt8) { + return ElemKind::Int8QTy; + } else if (ty == at::kQUInt8) { + return ElemKind::UInt8QTy; } else { LOG(DFATAL) << "ScalarType " << static_cast(ty) << " not supported yet."; @@ -177,6 +185,20 @@ glow::Type ptTypeToGlowType(const c10::TensorType &ptType) { return glow::Type(scalarTypeToElemKind(scalarType), dims); } +glow::Type ptTypeToGlowType(const c10::TensorType &ptType, float scale, + int32_t zero_point) { + DCHECK(ptType.scalarType().has_value()) + << "TensorType has no associated scalar type."; + const auto concreteSizes = ptType.sizes().concrete_sizes().value(); + std::vector dims; + for (const auto &size : concreteSizes) { + dims.push_back(static_cast(size)); + } + + auto scalarType = ptType.scalarType().value(); + return glow::Type(scalarTypeToElemKind(scalarType), dims, scale, zero_point); +} + at::Tensor glowTypeToEmptyPTTensor(const glow::Type &glowType) { std::vector sizes; for (const auto dim : glowType.dims()) { @@ -188,7 +210,41 @@ at::Tensor glowTypeToEmptyPTTensor(const glow::Type &glowType) { } glow::Tensor ptTensorToGlowTensor(const at::Tensor &ptTensor) { - auto glowType = ptTypeToGlowType(*c10::TensorType::create(ptTensor)); - return glow::Tensor(ptTensor.data_ptr(), &glowType); + if (ptTensor.is_quantized()) { + float scale = 1.0; + int32_t offset = 0; + if (ptTensor.qscheme() == at::kPerChannelAffine) { + // If it is channel wise quantized, which means + // this tensor is the weight of quantized linear or conv + // Then we dont deal with the qparams here, + // and only set up soome dummy scale & offset by using the first + // elements's scale & offset. + scale = ptTensor.q_per_channel_scales()[0].item(); + offset = ptTensor.q_per_channel_zero_points()[0].item(); + } else if (ptTensor.qscheme() == at::kPerTensorAffine) { + scale = static_cast(ptTensor.q_scale()); + offset = static_cast(ptTensor.q_zero_point()); + } else { + LOG(DFATAL) + << "PyTorch tensor with unsupported quantization scheme detected."; + } + auto glowType = + ptTypeToGlowType(*c10::TensorType::create(ptTensor), scale, offset); + return glow::Tensor(ptTensor.data_ptr(), &glowType); + } else { + auto glowType = ptTypeToGlowType(*c10::TensorType::create(ptTensor)); + return glow::Tensor(ptTensor.data_ptr(), &glowType); + } +} + +at::Tensor glowTensorToPTTensor(const glow::Tensor &glowTensor, + const at::ScalarType &torch_type) { + std::vector sizes; + for (const auto dim : glowTensor.dims()) { + sizes.push_back(dim); + } + return at::from_blob(glowTensor.getUnsafePtr(), sizes, + at::device(at::kCPU).dtype(torch_type)); } + } // namespace glow diff --git a/torch_glow/src/PyTorchCommon.h b/torch_glow/src/PyTorchCommon.h index 0fbad80e51..04768dddbe 100644 --- a/torch_glow/src/PyTorchCommon.h +++ b/torch_glow/src/PyTorchCommon.h @@ -26,6 +26,11 @@ namespace glow { +/// For Glow: -128 <= orig_fp32/scale_1 + offset_1 <= 127 +/// For PyTorch: 0 <= orig_fp32/scale_2 + offset_2 <= 255 +/// Therefore, we can make scale_1 == scale_2, and offset_1 = offset2 - 128 +const int32_t OFFSETSHIFT = 128; + extern bool GlowCompilePyTorchModule; /// Various settings to be used by code that loads PyTorch models. There should /// only be one of these and it should be obtained by calling @@ -88,6 +93,11 @@ glow::Tensor ptTensorToGlowTensor(const at::Tensor &ptTensor); /// matching type. at::Tensor glowTypeToEmptyPTTensor(const glow::Type &glowType); +/// Given a Glow Tensor \p glowTensor, \returns a PyTorch Tensor with the same +/// type, shape and content. +at::Tensor glowTensorToPTTensor(const glow::Tensor &glowTensor, + const at::ScalarType &torch_type); + } // namespace glow #endif // GLOW_TORCH_GLOW_SRC_COMMON_H diff --git a/torch_glow/src/PyTorchModelLoader.cpp b/torch_glow/src/PyTorchModelLoader.cpp index 83c41a2aca..289e211684 100644 --- a/torch_glow/src/PyTorchModelLoader.cpp +++ b/torch_glow/src/PyTorchModelLoader.cpp @@ -22,6 +22,8 @@ #include "glow/Support/Support.h" #include +#include +#include #include namespace glow { @@ -32,11 +34,6 @@ namespace { /// read from quantized pytorch model, we need to subtract 128(i.e. INT8_MIN) to /// make the activations becomes int8_t. -/// For Glow: -128 <= orig_fp32/scale_1 + offset_1 <= 127 -/// For PyTorch: 0 <= orig_fp32/scale_2 + offset_2 <= 255 -/// Therefore, we can make scale_1 == scale_2, and offset_1 = offset2 - 128 -const int32_t OFFSETSHIFT = 128; - /// Downcast a double to a float. Expected to32Bit(double val) { RETURN_ERR_IF_NOT(val <= std::numeric_limits::max() || @@ -219,6 +216,16 @@ iValToIntList(Expected expectedIVal) { } } +/// Unwrap a Expected \p expectedIVal and call toPTTensor, +/// propogate any Errors. +Expected iValToPTTensor(Expected expectedIVal) { + if (expectedIVal) { + return (*expectedIVal)->toPTTensor(); + } else { + return expectedIVal.takeError(); + } +} + /// Given Node inputs and outputs, check the expected sizes. Negative size /// indicates that the size should be equal to or greater than that size (for /// example -2 means at least 2). @@ -528,6 +535,20 @@ struct QuantizedUnpackedConv2dInputs { }; }; +/// Indexes of quantized::conv2d and quantized::conv2d_relu inputs. +struct QuantizedConv2dInputs { + enum { + input = 0, // NCHW + packed_weights = 1, + stride = 2, + padding = 3, + dilation = 4, + group = 5, + scale = 6, + zero_point = 7, + }; +}; + /// Indexes of quantized::add_relu inputs. struct QuantizedAddReluInputs { enum { @@ -549,7 +570,7 @@ struct QuantizedAddInputs { }; /// Indexes of glow::unpacked_quantized_linear inputs. -struct QuantizedLinearInputs { +struct QuantizedUnpackedLinearInputs { enum { input = 0, weight = 1, @@ -559,6 +580,16 @@ struct QuantizedLinearInputs { }; }; +/// Indexes of quantized::linear inputs. +struct QuantizedLinearInputs { + enum { + input = 0, + packed_weights = 1, + scale = 2, + zero_point = 3, + }; +}; + /// Indexes of aten::quantize_per_tensor inputs. struct QuantizeInputs { enum { @@ -761,13 +792,32 @@ PyTorchModelLoader::getSymbolsMapping() { QuantizedUnpackedConv2dInputs::scale, QuantizedUnpackedConv2dInputs::zero_point}}, {{"glow::unpacked_quantized_linear"}, + &PyTorchModelLoader::loadQuantizedLinearUnpacked, + { + QuantizedUnpackedLinearInputs::weight, + QuantizedUnpackedLinearInputs::bias, + QuantizedUnpackedLinearInputs::scale, + QuantizedUnpackedLinearInputs::zero_point, + }}, + {{"quantized::linear"}, &PyTorchModelLoader::loadQuantizedLinear, { - QuantizedLinearInputs::weight, - QuantizedLinearInputs::bias, + QuantizedLinearInputs::packed_weights, QuantizedLinearInputs::scale, QuantizedLinearInputs::zero_point, }}, + {{"quantized::conv2d"}, + &PyTorchModelLoader::loadQuantizedConv, + {QuantizedConv2dInputs::packed_weights, QuantizedConv2dInputs::stride, + QuantizedConv2dInputs::padding, QuantizedConv2dInputs::dilation, + QuantizedConv2dInputs::group, QuantizedConv2dInputs::scale, + QuantizedConv2dInputs::zero_point}}, + {{"quantized::conv2d_relu"}, + &PyTorchModelLoader::loadQuantizedConvRelu, + {QuantizedConv2dInputs::packed_weights, QuantizedConv2dInputs::stride, + QuantizedConv2dInputs::padding, QuantizedConv2dInputs::dilation, + QuantizedConv2dInputs::group, QuantizedConv2dInputs::scale, + QuantizedConv2dInputs::zero_point}}, {{"aten::quantize_per_tensor"}, &PyTorchModelLoader::loadQuantize, {QuantizeInputs::scale, QuantizeInputs::zero_point, @@ -1121,6 +1171,179 @@ glow::NodeValue PyTorchModelLoader::rescaleIntToUint(glow::NodeValue input) { } } +Expected +PyTorchModelLoader::loadQuantizedConvImpl(const torch::jit::Node *ptNode, + const bool isRelu) { + auto inputs = ptNode->inputs(); + auto outputs = ptNode->outputs(); + const glow::TransposeNode *output; + + RETURN_IF_ERR(checkInputAndOutputSizes(inputs, 8, outputs, 1)); + + // input + glow::NodeValue input; + ASSIGN_VALUE_OR_RETURN_ERR( + input, getGlowNodeValueForValue(inputs[QuantizedConv2dInputs::input])); + input = rescaleUIntToInt(input); + + input = F_.createTranspose("qconv_input_transposed", input, NCHW2NHWC); + glow::ShapeNHWC inputShape(input.dims()); + + // groups + glow::unsigned_t groups; + ASSIGN_VALUE_OR_RETURN_ERR( + groups, + static_cast_expected(iValToInt( + getGlowIValueForValue(inputs[QuantizedConv2dInputs::group])))); + + // weight and bias + at::Tensor *ptTensor; + ASSIGN_VALUE_OR_RETURN_ERR( + ptTensor, iValToPTTensor(getGlowIValueForValue( + inputs[QuantizedConv2dInputs::packed_weights]))); + + auto op = + c10::Dispatcher::singleton().findSchema({"quantized::conv2d_unpack", ""}); + CHECK(op.has_value()); + auto unpackedParams = callOp(*op, *ptTensor); + const at::Tensor ptWeightTensor = unpackedParams[0].toTensor(); + const c10::optional ptBiasTensor = + unpackedParams[1].toOptional(); + + // if groups == 1 it is regular conv + bool isGroupwiseQuantized = ptWeightTensor.is_quantized() && + ptWeightTensor.qscheme() == at::kPerChannelAffine; + isGroupwiseQuantized &= (groups > 1); + + // unpacked weights + auto weightTensor = ptTensorToGlowTensor(ptWeightTensor); + glow::Tensor weightTensorTransposed; + weightTensor.transpose(&weightTensorTransposed, NCHW2NHWC); + glow::Constant *weightConstant = F_.getParent()->createConstant( + "quantized_conv2d_weights", std::move(weightTensorTransposed)); + auto weight = weightConstant->getOutput(); + weight = rescaleUIntToInt(weight); + + // unpacked bias + glow::Tensor biasTensor; + glow::NodeValue bias; + glow::ShapeNHWC weightShape(weight.dims()); + if (ptBiasTensor.has_value()) { + biasTensor = ptTensorToGlowTensor(ptBiasTensor.value()); + } else { + biasTensor = glow::Tensor(glow::ElemKind::FloatTy, {weightShape.n}); + biasTensor.zero(); + } + glow::Constant *biasConstant = F_.getParent()->createConstant( + "quantized_conv2d_bias", std::move(biasTensor)); + biasConstant->ensureIsOwned(); + // bias is not used for groupwised quantization. + // Instead we use biasConstant + bias = biasConstant->getOutput(); + auto biasType = F_.getParent()->uniqueType( + glow::ElemKind::Int32QTy, bias.dims(), + input.getType()->getScale() * weight.getType()->getScale(), 0); + bias = F_.createQuantize("quantize_bias", bias, biasType); + + // strides + std::vector strides; + ASSIGN_VALUE_OR_RETURN_ERR( + strides, + castVector(expandIntIValIfNeeded( + getGlowIValueForValue(inputs[QuantizedConv2dInputs::stride]), 2))); + + // pad + glow::unsigned_t pad; + ASSIGN_VALUE_OR_RETURN_ERR( + pad, static_cast_expected(contractIntIValIfNeeded( + getGlowIValueForValue(inputs[QuantizedConv2dInputs::padding])))); + std::vector pads = {pad, pad, pad, pad}; + + // dilation + glow::unsigned_t dilation; + ASSIGN_VALUE_OR_RETURN_ERR( + dilation, + static_cast_expected(contractIntIValIfNeeded( + getGlowIValueForValue(inputs[QuantizedConv2dInputs::dilation])))); + + // quantized params + float outScale; + ASSIGN_VALUE_OR_RETURN_ERR(outScale, + iValToDouble(getGlowIValueForValue( + inputs[QuantizedConv2dInputs::scale]))); + + int32_t outOffset; + ASSIGN_VALUE_OR_RETURN_ERR(outOffset, + iValToInt(getGlowIValueForValue( + inputs[QuantizedConv2dInputs::zero_point]))); + + // calc output type + std::vector kernels = { + static_cast(weightShape.h), + static_cast(weightShape.w)}; + auto outSz = glow::calculateConvPoolOutputDims( + inputShape.h, inputShape.w, kernels, strides, pads, dilation); + std::array outDims = { + {input.dims()[0], outSz.first, outSz.second, weightShape.n}}; + glow::TypeRef outTy = F_.getParent()->uniqueType( + glow::ElemKind::Int8QTy, outDims, outScale, outOffset); + + glow::NodeValue output_not_transposed; + if (isGroupwiseQuantized) { + RETURN_ERR_IF_NOT(dilation <= 1, + "Dilation not supported for group quantized convolution"); + + // extract qparams from ptWeightTensor. + // Notice since the memory of qparams may not be continous + // we CANNOT use the data ptr of this chunk of memory and + // convert them into glow tensor directly by using PtTensorToGlowTensor. + // Instead, we extract them one after one. + std::vector scalesVector; + std::vector offsetsVector; + std::vector dims; + const int n = ptWeightTensor.q_per_channel_scales().size(0); + dims.push_back(n); + for (int i = 0; i < n; i++) { + float scale = + ptWeightTensor.q_per_channel_scales().to(at::kFloat)[i].item(); + int32_t offset = ptWeightTensor.q_per_channel_zero_points() + .to(at::kInt)[i] + .item(); + scalesVector.push_back(scale); + offsetsVector.push_back(offset); + } + + // construct qparam constants + auto scaleType = glow::Type(ElemKind::FloatTy, dims); + auto offsetType = glow::Type(ElemKind::Int32ITy, dims); + auto wScalesTensor = glow::Tensor(scalesVector.data(), &scaleType); + auto wOffsetsTensor = glow::Tensor(offsetsVector.data(), &offsetType); + + auto wScales = F_.getParent()->createConstant( + "channel_wised_scales_of_qconv", std::move(wScalesTensor)); + wScales->ensureIsOwned(); + auto wOffsets = F_.getParent()->createConstant( + "channel_wised_offsets_of_qconv", std::move(wOffsetsTensor)); + wOffsets->ensureIsOwned(); + + auto qconv = F_.createChannelwiseQuantizedConv( + "qconv_channel_wised", input, weightConstant, biasConstant, wScales, + wOffsets, outTy, kernels, strides, pads, groups); + output_not_transposed = qconv->getResult(); + } else { + auto qconv = F_.createConv("qconv", input, weight, bias, outTy, kernels, + strides, pads, groups, dilation); + glow::NodeValue output_not_transposed = qconv->getResult(); + } + if (isRelu) { + glow::ReluNode *qrelu = F_.createRELU("qconv_relu", output_not_transposed); + output_not_transposed = qrelu->getResult(); + } + output = F_.createTranspose("channel_wised_qconv_relu_output_transposed", + output_not_transposed, NHWC2NCHW); + return Expected(output->getResult()); +} + template NodeValue PyTorchModelLoader::loadNodeValueOrCreateBroadcastedConstant( const torch::jit::Value *value, llvm::StringRef name, const Type &ty, @@ -1212,20 +1435,63 @@ Error PyTorchModelLoader::loadQuantizedAddRelu(const torch::jit::Node *ptNode) { Error PyTorchModelLoader::loadQuantizedLinear(const torch::jit::Node *ptNode) { auto inputs = ptNode->inputs(); auto outputs = ptNode->outputs(); - RETURN_IF_ERR(checkInputAndOutputSizes(inputs, 5, outputs, 1)); + RETURN_IF_ERR(checkInputAndOutputSizes(inputs, 4, outputs, 1)); glow::NodeValue input; ASSIGN_VALUE_OR_RETURN_ERR( input, getGlowNodeValueForValue(inputs[QuantizedLinearInputs::input])); input = rescaleUIntToInt(input); - glow::NodeValue weight; - ASSIGN_VALUE_OR_RETURN_ERR( - weight, getGlowNodeValueForValue(inputs[QuantizedLinearInputs::weight])); + at::Tensor *ptTensor; + ASSIGN_VALUE_OR_RETURN_ERR( + ptTensor, iValToPTTensor(getGlowIValueForValue( + inputs[QuantizedLinearInputs::packed_weights]))); + + auto op = + c10::Dispatcher::singleton().findSchema({"quantized::linear_unpack", ""}); + CHECK(op.has_value()); + auto unpackedParams = callOp(*op, *ptTensor); + const at::Tensor ptWeightTensor = unpackedParams[0].toTensor(); + const c10::optional ptBiasTensor = + unpackedParams[1].toOptional(); + + // unpacked weights + auto weightTensor = ptTensorToGlowTensor(ptWeightTensor); + glow::Constant *weightConstant = F_.getParent()->createConstant( + "quantized_linear_weights", std::move(weightTensor)); + weightConstant->ensureIsOwned(); + auto weight = weightConstant->getOutput(); weight = rescaleUIntToInt(weight); - RETURN_ERR_IF_NOT(weight.dims().size() == 2, "Expected 2d Linear weights"); + // unpacked bias + glow::Tensor biasTensor; + if (ptBiasTensor.has_value()) { + biasTensor = ptTensorToGlowTensor(ptBiasTensor.value()); + } else { + biasTensor = glow::Tensor(glow::ElemKind::FloatTy, {weight.dims()[1]}); + biasTensor.zero(); + } + + // Choose bias quantization params and quantize it. + glow::Constant *biasConstant = F_.getParent()->createConstant( + "quantized_linear_bias", std::move(biasTensor)); + biasConstant->ensureIsOwned(); + RETURN_ERR_IF_NOT(biasConstant, "quantized::linear bias must be constant"); + const auto biasHandle = biasConstant->getPayload().getHandle(); + const auto biasMinMaxIdx = biasHandle.minMaxArg(); + + const auto biasQParams = chooseQuantizationParams( + biasHandle.raw(biasMinMaxIdx.first), biasHandle.raw(biasMinMaxIdx.second), + glow::quantization::Schema::Asymmetric, glow::ElemKind::Int32QTy); + + auto bias = biasConstant->getOutput(); + + auto biasType = + F_.getParent()->uniqueType(glow::ElemKind::Int32QTy, bias.dims(), + biasQParams.scale, biasQParams.offset); + bias = F_.createQuantize("quantize_bias", bias, biasType); + RETURN_ERR_IF_NOT(weight.dims().size() == 2, "Expected 2d Linear weights"); weight = F_.createTranspose("weight_transpose", weight, {1, 0}); float outScale; @@ -1242,14 +1508,56 @@ Error PyTorchModelLoader::loadQuantizedLinear(const torch::jit::Node *ptNode) { {input.dims()[0], weight.dims()[1]}, outScale, outZeroPoint - OFFSETSHIFT); + // TODO Rowwise quantization is not enabled + // We should enable it later. + auto fc = F_.createFullyConnected("quantized_fc", input, weight, bias, outTy); + return addValueMapping(outputs[0], rescaleIntToUint(fc->getResult())); +} + +Error PyTorchModelLoader::loadQuantizedLinearUnpacked( + const torch::jit::Node *ptNode) { + auto inputs = ptNode->inputs(); + auto outputs = ptNode->outputs(); + RETURN_IF_ERR(checkInputAndOutputSizes(inputs, 5, outputs, 1)); + + glow::NodeValue input; + ASSIGN_VALUE_OR_RETURN_ERR( + input, + getGlowNodeValueForValue(inputs[QuantizedUnpackedLinearInputs::input])); + input = rescaleUIntToInt(input); + + glow::NodeValue weight; + ASSIGN_VALUE_OR_RETURN_ERR( + weight, + getGlowNodeValueForValue(inputs[QuantizedUnpackedLinearInputs::weight])); + weight = rescaleUIntToInt(weight); + + RETURN_ERR_IF_NOT(weight.dims().size() == 2, "Expected 2d Linear weights"); + + weight = F_.createTranspose("weight_transpose", weight, {1, 0}); + + float outScale; + ASSIGN_VALUE_OR_RETURN_ERR( + outScale, to32Bit(iValToDouble(getGlowIValueForValue( + inputs[QuantizedUnpackedLinearInputs::scale])))); + + int64_t outZeroPoint; + ASSIGN_VALUE_OR_RETURN_ERR( + outZeroPoint, iValToInt(getGlowIValueForValue( + inputs[QuantizedUnpackedLinearInputs::zero_point]))); + + auto outTy = F_.getParent()->uniqueType(ElemKind::Int8QTy, + {input.dims()[0], weight.dims()[1]}, + outScale, outZeroPoint - OFFSETSHIFT); + // Get bias or create a zero bias if no bias is found. glow::NodeValue bias = loadNodeValueOrCreateBroadcastedConstant( - inputs[QuantizedLinearInputs::bias], "quantized_linear_bias", + inputs[QuantizedUnpackedLinearInputs::bias], "quantized_linear_bias", glow::Type(ElemKind::FloatTy, {weight.dims()[1]}), 0.0); // Choose bias quantization params and quantize it. glow::Constant *biasConstant = llvm::dyn_cast(bias.getNode()); - RETURN_ERR_IF_NOT(biasConstant, "quantized::linear bias must be constant"); + const auto biasHandle = biasConstant->getPayload().getHandle(); const auto biasMinMaxIdx = biasHandle.minMaxArg(); @@ -1474,7 +1782,6 @@ Error PyTorchModelLoader::loadListConstruct(const torch::jit::Node *ptNode) { auto outputs = ptNode->outputs(); // Requires -1 because this requires at least one input. RETURN_IF_ERR(checkInputAndOutputSizes(inputs, -1, outputs, 1)); - // Get the Tag of the first input to use for the whole list. GlowIValue *firstInputIVal; ASSIGN_VALUE_OR_RETURN_ERR(firstInputIVal, getGlowIValueForValue(inputs[0])); @@ -1958,7 +2265,6 @@ Error PyTorchModelLoader::loadQuantize(const torch::jit::Node *ptNode) { } else { return MAKE_ERR("Quantize only supports QUInt8 and QInt8"); } - glow::QuantizeNode *qn = F_.createQuantize("quantize", input, outTy); return addValueMapping(outputs[0], qn->getResult()); @@ -1976,6 +2282,21 @@ Error PyTorchModelLoader::loadDequantize(const torch::jit::Node *ptNode) { return addValueMapping(outputs[0], dn->getResult()); } +Error PyTorchModelLoader::loadQuantizedConvRelu( + const torch::jit::Node *ptNode) { + auto outputs = ptNode->outputs(); + glow::NodeValue output; + ASSIGN_VALUE_OR_RETURN_ERR(output, loadQuantizedConvImpl(ptNode, true)); + return addValueMapping(outputs[0], rescaleIntToUint(output)); +} + +Error PyTorchModelLoader::loadQuantizedConv(const torch::jit::Node *ptNode) { + auto outputs = ptNode->outputs(); + glow::NodeValue output; + ASSIGN_VALUE_OR_RETURN_ERR(output, loadQuantizedConvImpl(ptNode, true)); + return addValueMapping(outputs[0], rescaleIntToUint(output)); +} + Error PyTorchModelLoader::loadQuantizedConvUnpacked( const torch::jit::Node *ptNode) { auto inputs = ptNode->inputs(); @@ -2214,7 +2535,7 @@ Error PyTorchModelLoader::loadAdaptiveAvgPool2d( size_t inputH = input.dims()[1]; size_t inputW = input.dims()[2]; - + input = rescaleUIntToInt(input); input = F_.createTranspose("adaptive_avg_pool2d_input_transposed", input, NCHW2NHWC); @@ -2238,7 +2559,7 @@ Error PyTorchModelLoader::loadAdaptiveAvgPool2d( F_.createAdaptiveAvgPool("adaptive_avg_pool2d", input, outTy); output = F_.createTranspose("adaptive_avg_pool2d_output_transposed", output, NHWC2NCHW); - return addValueMapping(outputs[0], output); + return addValueMapping(outputs[0], rescaleIntToUint(output)); } Error PyTorchModelLoader::loadT(const torch::jit::Node *ptNode) { @@ -2705,6 +3026,7 @@ Error PyTorchModelLoader::loadFlatten(const torch::jit::Node *ptNode) { glow::NodeValue in; ASSIGN_VALUE_OR_RETURN_ERR( in, getGlowNodeValueForValue(inputs[FlattenInputs::input])); + in = rescaleUIntToInt(in); int64_t startDim; ASSIGN_VALUE_OR_RETURN_ERR(startDim, iValToInt(getGlowIValueForValue( @@ -2717,7 +3039,7 @@ Error PyTorchModelLoader::loadFlatten(const torch::jit::Node *ptNode) { auto xDim = glow::flattenCdr(in.dims(), startDim); auto *glowNode = F_.createReshape("flatten", in, {xDim.first, xDim.second}); - return addValueMapping(outputs[0], glowNode); + return addValueMapping(outputs[0], rescaleIntToUint(glowNode->getResult())); } Error PyTorchModelLoader::loadTopK(const torch::jit::Node *ptNode) { @@ -3006,16 +3328,24 @@ Error PyTorchModelLoader::loadAttributes( std::make_pair(&ival.toObjectRef(), newNameHierarchy); continue; } else if (ival.isTensor()) { - const auto &ptTensor = ival.toTensor(); - auto glowTensor = ptTensorToGlowTensor(ptTensor); - - glow::Constant *glowConstant = F_.getParent()->createConstant( - newNameHierarchy, std::move(glowTensor)); + const auto ptTensor = ival.toTensor(); + // PyTorch Tensor extracted type is kByte + // indicate it is the address of stored weights of quantized + // linear or conv. + if (ptTensor.scalar_type() == at::kByte) { + GlowIValue glowIVal; + glowIVal.fromPTTensor(ptTensor); + RETURN_IF_ERR(addValueMapping(outputValue, std::move(glowIVal))); + } else { + auto glowTensor = ptTensorToGlowTensor(ptTensor); + glow::Constant *glowConstant = F_.getParent()->createConstant( + newNameHierarchy, std::move(glowTensor)); - if (copyTensorMemory_) { - glowConstant->ensureIsOwned(); + if (copyTensorMemory_) { + glowConstant->ensureIsOwned(); + } + RETURN_IF_ERR(addValueMapping(outputValue, glowConstant->getOutput())); } - RETURN_IF_ERR(addValueMapping(outputValue, glowConstant->getOutput())); } else { GlowIValue glowIVal; RETURN_IF_ERR(glowIVal.fromIValue(ival)); @@ -3064,7 +3394,6 @@ PyTorchModelLoader::PyTorchModelLoader( for (size_t i = 0; i < graphInputValues.size(); ++i) { const torch::jit::Value *inputValue = graphInputValues[i]; glow::Placeholder *ph; - if (!inputMeta.empty()) { if (inputValue->type()->kind() == c10::TypeKind::TensorType) { glow::Type t(scalarTypeToElemKind(inputMeta[i].type), @@ -3102,7 +3431,6 @@ PyTorchModelLoader::PyTorchModelLoader( } RETURN_IF_ERR(loadAttributes(graph, inputs)); - RETURN_IF_ERR(loadNodes(graph)); // Create Glow Placeholders for outputs. diff --git a/torch_glow/src/PyTorchModelLoader.h b/torch_glow/src/PyTorchModelLoader.h index f526394f42..d8992dbb46 100644 --- a/torch_glow/src/PyTorchModelLoader.h +++ b/torch_glow/src/PyTorchModelLoader.h @@ -279,6 +279,12 @@ class PyTorchModelLoader { /// Rescale a int8 NodeValue \p input to the equivalent uint8 NodeValue. glow::NodeValue rescaleIntToUint(glow::NodeValue input); + /// Load a quantized conv node from ptNode to qconv. + /// a wrapper function of loadQuantizedConv and loadQuantizedConvRelu. + /// Returns error on failure. + Expected loadQuantizedConvImpl(const torch::jit::Node *ptNode, + const bool isRelu); + /// For each Placeholder input to \p ptNode, if this input has been marked /// as being an input that should be frozen in MappingOfMemberFunctions, /// create a glow Constant for that Placeholder with the iValue from the stack @@ -399,12 +405,24 @@ class PyTorchModelLoader { /// \return error on failure. Error loadQuantizedAddRelu(const torch::jit::Node *ptNode); - /// Load a PyTorch glow::unpacked_quantized_conv node. + /// Load a glow::unpacked_quantized_conv node. // \return error on failure. Error loadQuantizedConvUnpacked(const torch::jit::Node *ptNode); + /// Load a PyTorch quantized::conv2d node. + // \return error on failure. + Error loadQuantizedConv(const torch::jit::Node *ptNode); + + /// Load a PyTorch quantized::conv2d_relu node. + // \return error on failure. + Error loadQuantizedConvRelu(const torch::jit::Node *ptNode); + /// Load a glow::unpacked_quantized_linear node. /// \return error on failure. + Error loadQuantizedLinearUnpacked(const torch::jit::Node *ptNode); + + /// Load a PyTorch quantized::linear node. + /// \return error on failure. Error loadQuantizedLinear(const torch::jit::Node *ptNode); /// Load a PyTorch quantize_per_tensor node. diff --git a/torch_glow/tests/nodes/quantized_conv2d_relu_test.py b/torch_glow/tests/nodes/quantized_conv2d_relu_test.py new file mode 100644 index 0000000000..e8b6234b8b --- /dev/null +++ b/torch_glow/tests/nodes/quantized_conv2d_relu_test.py @@ -0,0 +1,44 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import torch + +from tests.utils import jitVsGlow +import pytest + +from collections import OrderedDict + + +def test_quantized_conv2d_relu_packed_groupwise(): + """Basic test of PyTorch quantized::conv2d_relu Node with packed weights on Glow.""" + + x = torch.tensor(range(5), dtype=torch.float) * 1.5 + x = torch.cat((x, x, x, x, x)) + x = torch.cat((x, x, x)) + x = torch.reshape(x, [1, 3, 5, 5]) + q = torch.nn.quantized.Quantize(0.2, 2, torch.quint8) + conv = torch.nn.Conv2d(3, 3, [2, 2], groups=3) + relu = torch.nn.ReLU() + dq = torch.nn.quantized.DeQuantize() + + # Due to the off-by-one error, we cannot let the weights, bias & input + # to be totally random. + conv.weight.data.fill_(1.5) + conv.bias.data.fill_(2.5) + + model = torch.nn.Sequential(OrderedDict([ + ('quantize', q), + ('conv1', conv), + ('relu1', relu), + ('deuantize', dq)])) + model.eval() + model.qconfig = torch.quantization.get_default_qconfig('fbgemm') + + # Fuse conv and relu to conv_relu + model = torch.quantization.fuse_modules(model, [['conv1', 'relu1']]) + + torch.quantization.prepare(model, inplace=True) + torch.quantization.convert(model, inplace=True) + + jitVsGlow(model, x, expected_fused_ops={"aten::quantize_per_tensor", + "quantized::conv2d_relu", + "aten::dequantize"}) diff --git a/torch_glow/tests/nodes/quantized_conv2d_test.py b/torch_glow/tests/nodes/quantized_conv2d_test.py index f780595559..53a9b36eaa 100644 --- a/torch_glow/tests/nodes/quantized_conv2d_test.py +++ b/torch_glow/tests/nodes/quantized_conv2d_test.py @@ -6,8 +6,8 @@ import pytest -def test_quantized_conv2d(): - """Basic test of the PyTorch quantized onv2d Node on Glow.""" +def test_quantized_conv2d_unpacked(): + """Basic test of the PyTorch quantize::conv2d Node with unpacked weights on Glow.""" def test_f(a, w, b): qu = torch.nn.quantized.Quantize(1/16, 0, torch.quint8) @@ -38,6 +38,34 @@ def test_f(a, w, b): "aten::dequantize"}) +def test_quantized_conv2d_packed_groupwise(): + """Basic test of PyTorch quantize::conv2d Node with packed weights on Glow.""" + + x = torch.tensor(range(5), dtype=torch.float) + x = torch.cat((x, x, x, x, x)) + x = torch.cat((x, x, x)) + x = torch.reshape(x, [1, 3, 5, 5]) + q = torch.nn.quantized.Quantize(0.1, 2, torch.quint8) + conv = torch.nn.Conv2d(3, 3, [2, 2], groups=3) + dq = torch.nn.quantized.DeQuantize() + + # Due to the off-by-one error, we cannot let the weights, bias & input + # to be totally random. + conv.weight.data.fill_(2.0) + conv.bias.data.fill_(1.0) + + model = torch.nn.Sequential(q, conv, dq) + model.eval() + model.qconfig = torch.quantization.get_default_qconfig('fbgemm') + + torch.quantization.prepare(model, inplace=True) + torch.quantization.convert(model, inplace=True) + + jitVsGlow(model, x, expected_fused_ops={"aten::quantize_per_tensor", + "quantized::conv2d", + "aten::dequantize"}) + + @pytest.mark.skip(reason="accuracy between glow & pytorch") def test_quantized_conv2d_nonfunctional(): """Basic test of the PyTorch quantized conv2d Node with external quantized diff --git a/torch_glow/tests/nodes/quantized_linear_test.py b/torch_glow/tests/nodes/quantized_linear_test.py index 2f7a693b81..d71bb33dcd 100644 --- a/torch_glow/tests/nodes/quantized_linear_test.py +++ b/torch_glow/tests/nodes/quantized_linear_test.py @@ -5,6 +5,31 @@ from tests.utils import jitVsGlow +def test_quantized_linear_packed(): + """Basic test of the PyTorch quantized::linear Node on Glow.""" + + q = torch.nn.quantized.Quantize(scale=1/25, zero_point=17, + dtype=torch.quint8) + dq = torch.nn.quantized.DeQuantize() + linear = torch.nn.Linear(5, 5) + + linear.weight.data.fill_(1.2) + linear.bias.data.fill_(3.0) + + model = torch.nn.Sequential(q, linear, dq) + model.qconfig = torch.quantization.get_default_qconfig('fbgemm') + torch.quantization.prepare(model, inplace=True) + torch.quantization.convert(model, inplace=True) + + x = torch.tensor(range(5), dtype=torch.float) + x = torch.cat((x, x, x, x, x)) + x = torch.reshape(x, [5, 5]) + + jitVsGlow(model, x, expected_fused_ops={"aten::quantize_per_tensor", + "quantized::linear", + "aten::dequantize"}) + + def test_quantized_linear_random_input(): """Basic test of the PyTorch quantized::linear Node on Glow."""