Skip to content

Commit

Permalink
Update on "[wip] quantization: store input_qrange_le_128 flag on quan…
Browse files Browse the repository at this point in the history
…tized conv"


Summary:

This is a start of fixing the problems surfaced in #46749.
This particular PR only fixes a small part of this:
1. if a conv module is unsafe to run in fbgemm, we now persist this
information with a `input_qrange_le_128` boolean flag stored on `ConvPackedParams{n}d` set to False.
2. if we are in an fbgemm kernel and we detect that the current conv
packed params are tagged as unsafe, we throw an error.

For now, this PR is a WIP to get some early feedback if this is the
right direction, since iteration cost on this is high. In particular,
missing things here are:
* testing serialization of saving v3 and loading it back
* getting all the conv callsites (currently just module + conv2d is handled)

Note: there were some potential improvements discussed on dynamically
dispatching to qnnpack if it is available and the flag is set.  This PR
does not attempt to solve this issue - it can be solved by future PRs.

Test Plan:

```
# test that the error gets thrown when we are trying to run an operation which could
# saturate, and does not get thrown otherwise
python test/test_quantization.py TestQuantizedOps.test_conv_reduce_range

# test that loading older versions of conv packed params works as expected
# TODO(before land): extend these tests with the v3 files
python test/test_quantization.py TestSerialization
```

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D29175285](https://our.internmc.facebook.com/intern/diff/D29175285)

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jun 30, 2021
2 parents 6f6edf1 + 3650685 commit d267788
Show file tree
Hide file tree
Showing 11 changed files with 374 additions and 87 deletions.
40 changes: 38 additions & 2 deletions .circleci/scripts/binary_windows_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,42 @@ else
fi

if [[ "${DESIRED_CUDA}" == "cu111" || "${DESIRED_CUDA}" == "cu113" ]]; then
export BUILD_SPLIT_CUDA="ON"
export BUILD_SPLIT_CUDA="ON"

echo "Free Space for CUDA DEBUG BUILD"
if [[ "$CIRCLECI" == 'true' ]]; then
if [[ -d "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Commnuity" ]]; then
rm -rf "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Commnuity"
fi

if [[ -d "C:\\Program Files (x86)\\Microsoft Visual Studio 14.0" ]]; then
rm -rf "C:\\Program Files (x86)\\Microsoft Visual Studio 14.0"
fi

if [[ -d "C:\\Program Files (x86)\\Microsoft.NET" ]]; then
rm -rf "C:\\Program Files (x86)\\Microsoft.NET"
fi

if [[ -d "C:\\Program Files\\dotnet" ]]; then
rm -rf "C:\\Program Files\\dotnet"
fi

if [[ -d "C:\\Program Files (x86)\\dotnet" ]]; then
rm -rf "C:\\Program Files (x86)\\dotnet"
fi

if [[ -d "C:\\Program Files (x86)\\Microsoft SQL Server" ]]; then
rm -rf "C:\\Program Files (x86)\\Microsoft SQL Server"
fi

if [[ -d "C:\\Program Files (x86)\\Xamarin" ]]; then
rm -rf "C:\\Program Files (x86)\\Xamarin"
fi

if [[ -d "C:\\Program Files (x86)\\Google" ]]; then
rm -rf "C:\\Program Files (x86)\\Google"
fi
fi
fi

set +x
Expand All @@ -32,7 +67,8 @@ if [[ "$CIRCLECI" == 'true' && -d "C:\\ProgramData\\Microsoft\\VisualStudio\\Pac
fi

if [[ "$CIRCLECI" == 'true' && -d "C:\\Microsoft" ]]; then
rm -rf "C:\\Microsoft\\Android*"
# don't use quota here
rm -rf /c/Microsoft/AndroidNDK*
fi

echo "Free space on filesystem before build:"
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ namespace {

#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)

static inline void cvtbf16_fp32(const __m128i& a, __m256& o) {
o = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a), 16));
}

static inline void cvtbf16_fp32(const __m256i& a, __m256& o1, __m256& o2) {
__m128i lo = _mm256_extractf128_si256(a, 0);
__m128i hi = _mm256_extractf128_si256(a, 1);
o1 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(lo), 16));
o2 = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(hi), 16));
cvtbf16_fp32(lo, o1);
cvtbf16_fp32(hi, o2);
}
static inline __m256i cvtfp32_bf16(const __m256& a, const __m256& b) {
__m256i lo = _mm256_castps_si256(a);
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/cuda/CUDAApplyUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,9 @@ __host__ __device__ __forceinline__ T ATenCeilDiv(T a, T b) {
}

template <int step = 1>
inline bool getApplyGrid(uint64_t totalElements, dim3& grid, int64_t curDevice) {
inline bool getApplyGrid(uint64_t totalElements, dim3& grid, int64_t curDevice, int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
if (curDevice == -1) return false;
uint64_t numel_per_thread = static_cast<uint64_t>(AT_APPLY_THREADS_PER_BLOCK) * static_cast<uint64_t>(step);
uint64_t numel_per_thread = static_cast<uint64_t>(max_threads_per_block) * static_cast<uint64_t>(step);
uint64_t numBlocks = ATenCeilDiv(totalElements, numel_per_thread);
uint64_t maxGridX = at::cuda::getDeviceProperties(curDevice)->maxGridSize[0];
if (numBlocks > maxGridX)
Expand All @@ -406,8 +406,8 @@ constexpr int getApplyBlockSize() {
return AT_APPLY_THREADS_PER_BLOCK;
}

inline dim3 getApplyBlock() {
return dim3(AT_APPLY_THREADS_PER_BLOCK);
inline dim3 getApplyBlock(int max_threads_per_block=AT_APPLY_THREADS_PER_BLOCK) {
return dim3(max_threads_per_block);
}

template <typename scalar1, typename scalar2, int step, typename Op,
Expand All @@ -434,12 +434,12 @@ inline bool CUDA_tensor_apply2(at::Tensor a,
// Empty tensor; do nothing
return true;
}
const dim3 block = getApplyBlock();
const dim3 block = getApplyBlock(max_threads_per_block);

dim3 grid;
int64_t curDevice = current_device();
if (curDevice == -1) return false;
if (!getApplyGrid<step>(totalElements, grid, curDevice)) {
if (!getApplyGrid<step>(totalElements, grid, curDevice, max_threads_per_block)) {
return false;
}

Expand Down
91 changes: 71 additions & 20 deletions aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,19 @@ std::tuple<Tensor,Tensor> batch_norm_cpu_update_stats_template(
}

// non-contiguous path
auto channel_stride = input.strides()[1];
auto in_data = input.data_ptr<scalar_t>();
auto reduce_iter = TensorIteratorConfig()
.add_input(input)
.resize_outputs(false)
.declare_static_shape(input.sizes(), /*squash_dims=*/1)
.build();

parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
TensorIterator iter(reduce_iter);
for (int64_t f = b_begin; f < b_end; ++f) {
Tensor in = input.select(1, f);

// compute variance per input
auto iter = TensorIteratorConfig()
.add_input(in)
.build();
iter.unsafe_replace_operand(0, in_data + channel_stride * f);
accscalar_t var_sum = 0;
auto mean = static_cast<accscalar_t>(save_mean_a[f]);
cpu_serial_kernel(iter, [&](const scalar_t i) -> void {
Expand Down Expand Up @@ -279,11 +284,47 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
auto sum = at::sum(grad_out_, /*dims=*/reduce_dims);
auto sum_a = sum.accessor<scalar_t, 1>();

auto reduce_iter = TensorIteratorConfig()
.add_input(input)
.add_input(grad_out_)
.resize_outputs(false)
.declare_static_shape(input.sizes(), /*squash_dims=*/1)
.build();

TensorIterator unary_iter;
TensorIterator binary_iter;
if (grad_input_mask[0]) {
unary_iter.build(
TensorIteratorConfig()
.add_output(grad_input)
.add_input(train ? input : grad_out_)
.resize_outputs(false)
.declare_static_shape(input.sizes(), /*squash_dims=*/1));

if (train) {
binary_iter.build(
TensorIteratorConfig()
.add_output(grad_input)
.add_input(grad_input)
.add_input(grad_out_)
.resize_outputs(false)
.declare_static_shape(input.sizes(), /*squash_dims=*/1));
}
}

auto in_channel_stride = input.strides()[1];
auto in_data = input.data_ptr<scalar_t>();
auto grad_in_channel_stride = grad_input_mask[0] ? grad_input.strides()[1] : 0;
auto grad_in_data = grad_input_mask[0] ? grad_input.data_ptr<scalar_t>() : nullptr;
auto grad_out_channel_stride = grad_out_.strides()[1];
auto grad_out_data = grad_out_.data_ptr<scalar_t>();

parallel_for(0, n_input, 1, [&](int64_t b_begin, int64_t b_end) {
for (int64_t f = b_begin; f < b_end; ++f) {
Tensor in = input.select(1, f);
Tensor grad_out = grad_out_.select(1, f);
TensorIterator reduce_iter_local(reduce_iter);
TensorIterator unary_iter_local(unary_iter);
TensorIterator binary_iter_local(binary_iter);

for (int64_t f = b_begin; f < b_end; ++f) {
scalar_t w = weight.defined() ? weight_a[f] : 1;

scalar_t mean, invstd;
Expand All @@ -297,16 +338,16 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(

// dot product of the Q(X) and gradOuput
accscalar_t dotp = 0;
auto iter = TensorIteratorConfig()
.add_input(in)
.add_input(grad_out)
.build();
cpu_serial_kernel(iter, [&](const scalar_t i, const scalar_t go) -> void {
reduce_iter_local.unsafe_replace_operand(
0, in_data + f * in_channel_stride);
reduce_iter_local.unsafe_replace_operand(
1, grad_out_data + f * grad_out_channel_stride);

cpu_serial_kernel(reduce_iter_local, [&](const scalar_t i, const scalar_t go) -> void {
dotp += (i - mean) * go;
});

if (grad_input_mask[0]) {
Tensor grad_in = grad_input.select(1, f);
if (train) {
// when in training mode
// Q(X) = X - E[x] ; i.e. input centered to zero mean
Expand All @@ -316,16 +357,23 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
// projection of gradOutput on to output scaled by std
scalar_t k = (scalar_t) dotp * invstd * invstd / n;
{
auto iter = TensorIterator::unary_op(grad_in, in);
cpu_serial_kernel(iter, [&](const scalar_t i) -> scalar_t {
unary_iter_local.unsafe_replace_operand(
0, grad_in_data + f * grad_in_channel_stride);
unary_iter_local.unsafe_replace_operand(
1, in_data + f * in_channel_stride);
cpu_serial_kernel(unary_iter_local, [&](const scalar_t i) -> scalar_t {
return (i - mean) * k;
});
}

scalar_t grad_mean = sum_a[f] / n;
{
auto iter = TensorIterator::borrowing_binary_op(grad_in, grad_in, grad_out);
cpu_serial_kernel(iter, [&](scalar_t gi, scalar_t go) -> scalar_t {
auto gI_data = grad_in_data + f * grad_in_channel_stride;
binary_iter_local.unsafe_replace_operand(0, gI_data);
binary_iter_local.unsafe_replace_operand(1, gI_data);
binary_iter_local.unsafe_replace_operand(
2, grad_out_data + f * grad_out_channel_stride);
cpu_serial_kernel(binary_iter_local, [&](scalar_t gi, scalar_t go) -> scalar_t {
return (go - grad_mean - gi) * invstd * w;
});
}
Expand All @@ -335,8 +383,11 @@ std::tuple<Tensor, Tensor, Tensor> batch_norm_backward_cpu_template(
// Y = Q(X) / running_std ; i.e. BN output before weight and bias
// dL/dX = w / running_std
{
auto iter = TensorIterator::unary_op(grad_in, grad_out);
cpu_serial_kernel(iter, [&](const scalar_t i) -> scalar_t {
unary_iter_local.unsafe_replace_operand(
0, grad_in_data + f * grad_in_channel_stride);
unary_iter_local.unsafe_replace_operand(
1, grad_out_data + f * grad_out_channel_stride);
cpu_serial_kernel(unary_iter_local, [&](const scalar_t i) -> scalar_t {
return i * invstd * w;
});
}
Expand Down

0 comments on commit d267788

Please sign in to comment.