Skip to content

Commit

Permalink
[wip] quantized tensor: add support for advanced indexing
Browse files Browse the repository at this point in the history
Summary:

Implements support for the indexing of quantized tensors with lists of
dims, such as

```
xq_slice = xq[:, [0], :, :]
```

At least a few things need to happen before this being ready for review:
1. Verify if the general idea of `TensorAdditionalOptions` is the right abstraction
2. Verify if `TensorAdditionalOptions` needs the same style as `TensorOptions` (no optionals, etc)

Test Plan:

```
python test/test_quantization.py TestQuantizedOps.test_advanced_indexing
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 10e51c59583e07ae5c42307c8879d6dd6a368422
Pull Request resolved: #49129
  • Loading branch information
vkuzo committed Dec 10, 2020
1 parent c29f516 commit dfa8b9d
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 8 deletions.
41 changes: 41 additions & 0 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,47 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
} \
}()

#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND_QINT_AND3( \
SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexFloat, c10::complex<float>, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
at::ScalarType::ComplexDouble, c10::complex<double>, __VA_ARGS__) \
AT_QINT_PRIVATE_CASE_TYPE( \
at::kQInt8, at::qint8, at::kChar, int8_t, __VA_ARGS__) \
AT_QINT_PRIVATE_CASE_TYPE( \
at::kQUInt8, at::quint8, at::kByte, uint8_t, __VA_ARGS__) \
AT_QINT_PRIVATE_CASE_TYPE( \
at::kQInt32, at::qint32, at::kInt, int, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
SCALARTYPE1, \
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
SCALARTYPE2, \
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), \
__VA_ARGS__) \
AT_PRIVATE_CASE_TYPE( \
SCALARTYPE3, \
decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE3>::t), \
__VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \
} \
}()

#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& the_index_type = TYPE; \
Expand Down
42 changes: 38 additions & 4 deletions aten/src/ATen/TensorIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,17 +475,38 @@ void TensorIteratorBase::allocate_or_resize_outputs() {
}
}
auto tensor_shape = invert_perm(shape_);

// For quantized output tensors, get the quantization params from the
// input.
c10::optional<TensorQuantizationOptions> quant_options = c10::nullopt;
const auto& options = op.options();
if (options.dtype() == at::kQInt8 || options.dtype() == at::kQUInt8 ||
options.dtype() == at::kQInt32) {
// get the first input and copy its quantization parameters
const auto& first_input = operands_[num_outputs_];
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(first_input.tensor.is_quantized());
TORCH_INTERNAL_ASSERT(
first_input.tensor.qscheme() == c10::kPerTensorAffine ||
first_input.tensor.qscheme() == c10::kPerTensorSymmetric,
"Advanced indexing of per-channel quantized tensors is not supported yet");
quant_options = TensorQuantizationOptions(
first_input.tensor.q_scale(), first_input.tensor.q_zero_point());
}

if (inverted) {
// can just return contiguous output
// it is faster because it avoids allocating 0 size tensor and
// resizing and restriding it
set_output(i, tensor_shape, {}, op.options(), names_);
set_output(
i, tensor_shape, {}, op.options(), names_, quant_options);
} else {
auto tensor_stride = invert_perm(op.stride_bytes);
for (int dim = 0; dim < ndim(); dim++) {
tensor_stride[dim] /= element_size;
}
set_output(i, tensor_shape, tensor_stride, op.options(), names_);
set_output(
i, tensor_shape, tensor_stride, op.options(), names_,
quant_options);
}
op.current_dtype = op.target_dtype;
} else if (op.tensor.defined() && !names_.empty()) {
Expand Down Expand Up @@ -1281,12 +1302,25 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) {
view_offsets_ = DimVector(ndim_offsets, 0);
}

void TensorIterator::set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) {
void TensorIterator::set_output(
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
TensorOptions options, DimnameList names,
c10::optional<TensorQuantizationOptions> quant_options) {
auto& op = operands_[output_idx];
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
if (!op.tensor.defined()) {
if (strides.empty()) {
op.tensor = at::empty(sizes, options);
if (options.dtype() == at::kQInt8 || options.dtype() == at::kQUInt8 ||
options.dtype() == at::kQInt32) {
// quantized path
TORCH_INTERNAL_ASSERT(quant_options.has_value());
op.tensor = at::_empty_affine_quantized(
sizes, options, (*quant_options).q_scale,
(*quant_options).q_zero_point);
} else {
// non-quantized path
op.tensor = at::empty(sizes, options);
}
} else {
op.tensor = at::empty_strided(sizes, strides, options);
}
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/TensorIterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,10 @@ struct CAFFE2_API TensorIterator final : public TensorIteratorBase {
static TensorIterator reduce_op(Tensor& out, const Tensor& a);
static TensorIterator reduce_op(Tensor& out1, Tensor& out2, const Tensor& a);

void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) override;
void set_output(
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options,
DimnameList names,
c10::optional<TensorQuantizationOptions> quant_options = c10::nullopt) override;
};

class CAFFE2_API TensorIteratorConfig final {
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/TensorMeta.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ class Tensor;
namespace impl {

struct MetaBase {
virtual void set_output(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) = 0;
virtual void set_output(
int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options,
DimnameList names,
c10::optional<TensorQuantizationOptions> quant_options = c10::nullopt) = 0;
void set_output(IntArrayRef sizes, TensorOptions options) {
set_output(0, sizes, {}, options, {});
}
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/cpu/IndexKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
}

void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND_QINT_AND3(
ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16,
iter.dtype(), "index_cpu", [&] {
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)dst = *(scalar_t*)(src + offset);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2073,7 +2073,7 @@
- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
variants: function, method
dispatch:
CPU, CUDA: index
CPU, CUDA, QuantizedCPU: index
# NB: This function is special-cased in tools/autograd/gen_variable_type.py
# NB: The following functions are declared in aten/src/ATen/templates/TensorBody.h and defined in aten/src/ATen/TensorIndexing.cpp:
# - Tensor Tensor::index(ArrayRef<TensorIndex> indices)
Expand Down
20 changes: 20 additions & 0 deletions c10/core/TensorOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -691,9 +691,29 @@ inline DeviceType computeDeviceType(DispatchKey tid) {
return DeviceType::Vulkan;
} else if (tid == DispatchKey::Metal) {
return DeviceType::Metal;
} else if (tid == DispatchKey::QuantizedCPU) {
return DeviceType::CPU;
} else {
AT_ASSERTM(false, "Unknown DispatchKey: ", tid);
}
}

// This is a holder struct for additional options which are only needed for
// quantized Tensors, and so are undesirable to put into TensorOptions.
struct C10_API TensorQuantizationOptions {
// quantization scale for per-Tensor quantization schemes
double q_scale;
// quantization zero_point for per-Tensor quantization schemes
int64_t q_zero_point;
// TODO(future PR): per-channel quantization parameters. Note: the per-Tensor
// quantization parameters would need to become optional at that time.

TensorQuantizationOptions(
double q_scale, int64_t q_zero_point
) : q_scale(q_scale), q_zero_point(q_zero_point) {}
};

static_assert( sizeof(TensorQuantizationOptions) <= sizeof(int64_t) * 2,
"TensorQuantizationOptions must fit in 128-bits" );

} // namespace c10
28 changes: 28 additions & 0 deletions test/quantization/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2274,6 +2274,34 @@ def test_empty_batch(self):
result = torch.ops.quantized.linear_dynamic(X, w_packed)
self.assertEqual(result.shape, (0, 2))

def test_advanced_indexing(self):
"""
Verifies that the x[:, [0], :, :] syntax works for quantized tensors.
"""
x_q = torch.quantize_per_tensor(
torch.randn(1, 4, 4, 4), 0.1, 0, torch.qint8)
# reference
x_fp32 = x_q.dequantize()

# single dim, single index
x_q_s1 = x_q[:, [0], :, :]
x_fp32_s1 = x_fp32[:, [0], :, :]
self.assertTrue(torch.allclose(x_fp32_s1, x_q_s1.dequantize()))

# multiple dim, single index
x_q_s2 = x_q[:, [0], [2], :]
x_fp32_s2 = x_fp32[:, [0], [2], :]
self.assertTrue(torch.allclose(x_fp32_s2, x_q_s2.dequantize()))

# single dim, multiple indices
x_q_s3 = x_q[:, [2, 0, 1], :, :]
x_fp32_s3 = x_fp32[:, [2, 0, 1], :, :]
self.assertTrue(torch.allclose(x_fp32_s3, x_q_s3.dequantize()))

# multiple dim, multiple indices
x_q_s4 = x_q[:, [2, 0, 1], :, [1]]
x_fp32_s4 = x_fp32[:, [2, 0, 1], :, [1]]
self.assertTrue(torch.allclose(x_fp32_s4, x_q_s4.dequantize()))


class TestDynamicQuantizedLinear(TestCase):
Expand Down

0 comments on commit dfa8b9d

Please sign in to comment.