Skip to content

Commit

Permalink
Add quantized::cat support to glow (#5513)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5513

Add `quantized::cat` support to glow, utilizing dequant->concat->quant strategy.

Reviewed By: jackm321

Differential Revision: D27583217

fbshipit-source-id: ed8e7cfd489f724c82563b28e2c7e93fd937d86b
  • Loading branch information
spaugh authored and facebook-github-bot committed Apr 6, 2021
1 parent bb6dc0d commit 353e97c
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 5 deletions.
60 changes: 55 additions & 5 deletions torch_glow/src/PyTorchModelLoader.cpp
Expand Up @@ -1273,6 +1273,7 @@ PyTorchModelLoader::buildSymbolsMapping() {
{{"quantized::add"}, &PyTorchModelLoader::loadQuantizedAdd},
{{"quantized::add_relu"}, &PyTorchModelLoader::loadQuantizedAddRelu},
{{"quantized::mul"}, &PyTorchModelLoader::loadQuantizedMul},
{{"quantized::cat"}, &PyTorchModelLoader::loadQuantizedCat},
{{"quantized::mul_scalar"}, &PyTorchModelLoader::loadQuantizedMul},
{{"glow::fused_linear"}, &PyTorchModelLoader::loadGlowFusedLinear},
{{"glow::unpacked_quantized_conv2d"},
Expand Down Expand Up @@ -2267,6 +2268,54 @@ Error PyTorchModelLoader::loadQuantizedMul(const torch::jit::Node *ptNode) {
RETURN_ERR(addValueMapping(outputs[0], output, dtype));
}
}
Error PyTorchModelLoader::loadQuantizedCat(const torch::jit::Node *ptNode) {
// Current strategy for quantized::cat is dequantize->cat->requantize.
// TODO: Remove the quantization step to potentially improve performance.

RETURN_IF_ERR(
checkInputAndOutputSizes(ptNode->inputs(), -2, ptNode->outputs(), 1));

std::vector<glow::NodeValue> *inputTensors;
float quantizationScale;
int32_t quantizationOffset;
int64_t concatDimension;

ASSIGN_VALUE_OR_RETURN_ERR(
inputTensors,
iValToNodeValueList(getGlowIValueForValue(ptNode->input(0))));
ASSIGN_VALUE_OR_RETURN_ERR(
quantizationScale, iValToDouble(getGlowIValueForValue(ptNode->input(2))));
ASSIGN_VALUE_OR_RETURN_ERR(
quantizationOffset, iValToInt(getGlowIValueForValue(ptNode->input(3))));
ASSIGN_VALUE_OR_RETURN_ERR(
concatDimension, iValToInt(getGlowIValueForValue(ptNode->input(1))));

std::vector<glow::NodeValue> dequantizedInputs;
for (glow::NodeValue input : *inputTensors) {
// Legacy behavior suggests supporting concat of empty tensors, but only
// for this specific shape. See the legacy_cat_wrap_dim function in
// caffe2/aten/src/ATen/WrapDimUtils.h for more info.
if (input.dims() == llvm::ArrayRef<dim_t>({0})) {
continue;
}
dequantizedInputs.emplace_back(F_.createDequantize(
"quantized_cat_dequantize", input, ElemKind::FloatTy));
}

glow::NodeValue concatResult =
F_.createConcat("quantized_cat_nested_cat", dequantizedInputs,
concatDimension)
->getResult();

auto *outputTy = F_.getParent()->uniqueType(
ElemKind::Int8QTy, concatResult.dims(), quantizationScale,
quantizationOffset - UINT8_TO_INT8_SHIFT);

auto quantizedResult =
F_.createQuantize("quantized_cat_requantize", concatResult, outputTy);

RETURN_ERR(addValueMapping(ptNode->output(0), quantizedResult));
}

Error PyTorchModelLoader::loadLinear(const torch::jit::Node *ptNode) {
auto inputs = ptNode->inputs();
Expand Down Expand Up @@ -2732,8 +2781,8 @@ Expected<NodeValue> PyTorchModelLoader::loadArithmeticNode(

// For aten::div, it will promote the output to default scalar type if both
// inputs are of integer type. However, Glow requires inputs and output have
// the same type. In order to achieve same behavior as Pytorch div, we convert
// the inputs to default scalar type if they are both integer.
// the same type. In order to achieve same behavior as Pytorch div, we
// convert the inputs to default scalar type if they are both integer.
if (convertToDefaultType) {
if (isNonQuantizedIntElemKind(rhsInput.getElementType()) &&
isNonQuantizedIntElemKind(lhsInput.getElementType())) {
Expand Down Expand Up @@ -2998,8 +3047,8 @@ Error PyTorchModelLoader::loadSum(const torch::jit::Node *ptNode) {
if (!keepDim) {
return addValueMapping(outputs[0], batchedReduceAddNode);
} else {
// If keepDim is true we need to insert the removed dimension(s) manually by
// reshaping
// If keepDim is true we need to insert the removed dimension(s) manually
// by reshaping
std::vector<dim_t> shape =
batchedReduceAddNode->getResult().getType()->dims();
std::sort(glowAxes.begin(), glowAxes.end());
Expand Down Expand Up @@ -3797,7 +3846,8 @@ Error PyTorchModelLoader::loadReshape(const torch::jit::Node *ptNode) {
}
}

// If there was a negative index, replace it with the remaining dims in input.
// If there was a negative index, replace it with the remaining dims in
// input.
if (negOneIndex >= 0) {
shape[negOneIndex] = inputTotalDims / shapeTotalDims;
}
Expand Down
2 changes: 2 additions & 0 deletions torch_glow/src/PyTorchModelLoader.h
Expand Up @@ -664,6 +664,8 @@ class PyTorchModelLoader {
/// \return error on failure.
Error loadQuantizedMul(const torch::jit::Node *ptNode);

Error loadQuantizedCat(const torch::jit::Node *ptNode);

/// Load a glow::unpacked_quantized_conv node.
// \return error on failure.
Error loadQuantizedConvUnpacked(const torch::jit::Node *ptNode);
Expand Down
117 changes: 117 additions & 0 deletions torch_glow/tests/nodes/quantized_cat_test.py
@@ -0,0 +1,117 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import torch
from tests import utils


class SimpleQuantizedCatModel(torch.nn.Module):
def __init__(self, dimension, scale, zero_point):
super(SimpleQuantizedCatModel, self).__init__()
self.dimension = dimension
self.scale = scale
self.zero_point = zero_point

def forward(self, a, b):
return torch.nn.quantized.DeQuantize()(
torch.ops.quantized.cat(
(a, b),
dim=self.dimension,
scale=self.scale,
zero_point=self.zero_point,
)
)


class TestQuantizedCat(utils.TorchGlowTestCase):
@utils.deterministic_expand(
[
lambda: (
"zero_offset",
SimpleQuantizedCatModel(
0,
0.05,
0,
),
(
torch.nn.quantized.Quantize(
scale=0.3,
zero_point=0,
dtype=torch.quint8,
)(torch.randn([1, 2, 3, 4], dtype=torch.float32)),
torch.nn.quantized.Quantize(
scale=0.3,
zero_point=0,
dtype=torch.quint8,
)(torch.randn([5, 2, 3, 4], dtype=torch.float32)),
),
),
lambda: (
"basic",
SimpleQuantizedCatModel(
1,
0.05,
0,
),
(
torch.nn.quantized.Quantize(
scale=0.3,
zero_point=0.3,
dtype=torch.quint8,
)(torch.randn([8, 8, 8, 8], dtype=torch.float32)),
torch.nn.quantized.Quantize(
scale=0.3,
zero_point=0.3,
dtype=torch.quint8,
)(torch.randn([8, 8, 8, 8], dtype=torch.float32)),
),
),
lambda: (
"with_empty_tensor",
SimpleQuantizedCatModel(
0,
0.05,
0,
),
(
torch.nn.quantized.Quantize(
scale=0.2,
zero_point=0.1,
dtype=torch.quint8,
)(torch.empty(0, dtype=torch.float32)),
torch.nn.quantized.Quantize(
scale=0.2,
zero_point=0.1,
dtype=torch.quint8,
)(torch.randn([8, 8], dtype=torch.float32)),
),
),
lambda: (
"with_differing_quantizations",
SimpleQuantizedCatModel(
2,
0.05,
0,
),
(
torch.nn.quantized.Quantize(
scale=0.6,
zero_point=0.2,
dtype=torch.quint8,
)(torch.randn([7, 7, 7], dtype=torch.float32)),
torch.nn.quantized.Quantize(
scale=0.2,
zero_point=0.1,
dtype=torch.quint8,
)(torch.randn([7, 7, 7], dtype=torch.float32)),
),
),
]
)
def test_quantized_cat(self, _, module, tensors, fusion_blocklist=None):
utils.compare_tracing_methods(
module,
*tensors,
fusible_ops={"quantized::cat"},
fusion_blocklist=None,
skip_to_glow=False,
)

0 comments on commit 353e97c

Please sign in to comment.