Skip to content

Commit

Permalink
add cuda fq bench on "fake_quant: add a more memory efficient backward"
Browse files Browse the repository at this point in the history
Summary:

tl;dr; add an alternative implementation of `fake_quantize` which saves
a mask of whether the input was clamped during the forward pass and uses it to calculate the backward.  The math:

```
# before - forward (pseudocode)
def fq_forward(x, scale, zp, qmin, qmax):
    q_val = clamp(nearby_int(x / scale) + zp, qmin, qmax)
    fq_val = (q_val - zp) * scale
    return fq_val

# before - backward (pseudocode)
def fq_backward(dy, x, scale, zp, qmin, qmax):
    q_val_unclamped = nearby_int(x / scale) + zp
    mask = qmin <= q_val_unclamped and q_val_unclamped <= qmax
    return dy * mask

# after - forward (pseudocode)
def fq_forward(x, scale, zp, qmin, qmax):
    q_val_unclamped = nearby_int(x / scale) + zp
    mask = qmin <= q_val_unclamped and q_val_unclamped <= qmax
    q_val = clamp(q_val_unclamped, qmin, qmax)
    fq_val = (q_val - zp) * scale
    return fq_val, mask

# after - backward (pseudocode)
def fq_backward(dy, mask):
    return dy * mask
```

This way the backward function no longer needs the input Tensor, and it can be
gc'ed earlier by autograd.  Instead of passing `x: FloatTensor`, we pass a `mask: BoolTensor`
with the same number of elements.  `BoolTensor` uses 1 byte per element, 
so we expect an upper bound of a 75% memory overhead reduction.  We observe a 73% memory 
overhead reduction on torchvision's MobileNetV2 in real world tests.  Packing the bools
into a custom storage format to take 1 bit per element is an optimization left for the future.

Performance impact of this seems negligible, I observed a 1% to 5% regression on MobileNetV2 but it's unclear if it's real.

Adding this as a new function (as opposed to replacing the old implementation) for easy testing, but
might be worth deleting the old fake_quant backward in a future PR.  We can adjust the signature
of this function to take `model.training` as an additional parameter, and skip the mask computation for eval.

Test Plan:

QAT on MobileNetV2 on FB infra, with `opt` build flags, batch_size = 32.  Results for fbgemm settings, qnnpack results are similar.
```
# qat_fp32: model with fake_quants turned off (baseline)
# qat_1: step 2 of qat, with observers disabled and fake_quants enabled (all of the overhead is the fake_quants)

# before: fbgemm - qat_fp32 -> qat_1
max memory usage (mib): 3299 -> 4170 (overhead: 26.4%)
latency (ms):  147 -> 181

# after: fbgemm - qat_fp32 -> qat_1
max memory usage (mib): 3302 -> 3528 (overhead: 7.1%)
latency (ms):  147 -> 183
```

Note: similar metrics are observed in an OSS / torchvision / MobileNetV2 setup, with this command:
```
python references/classification/train_quantization.py
  --print-freq 1
  --data-path /data/local/packages/ai-group.imagenet-256-smallest-side/prod/
  --output-dir ~/nfs/pytorch_vision_tests/
  --backend qnnpack
  --epochs 5
```

All CI tests here: #50849

PyTorch microbenchmarks (CUDA performance about the same: https://gist.github.com/vkuzo/11a7bed73fe60e340862d37e7975e9cd)

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D25918519](https://our.internmc.facebook.com/intern/diff/D25918519)

[ghstack-poisoned]
  • Loading branch information
vkuzo committed Jan 26, 2021
2 parents dab2d44 + f7b339d commit 671ae05
Show file tree
Hide file tree
Showing 340 changed files with 8,057 additions and 3,247 deletions.
2 changes: 2 additions & 0 deletions .circleci/docker/common/install_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
conda_install magma-cuda110 -c pytorch
elif [[ "$CUDA_VERSION" == 11.1* ]]; then
conda_install magma-cuda111 -c pytorch
elif [[ "$CUDA_VERSION" == 11.2* ]]; then
conda_install magma-cuda112 -c pytorch
fi

# TODO: This isn't working atm
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ docs/cpp/source/html/
docs/cpp/source/latex/
docs/source/generated/
log
test-reports/
test/.coverage
test/.hypothesis/
test/cpp/api/mnist
Expand All @@ -50,7 +51,6 @@ dropout_model.pt
test/generated_type_hints_smoketest.py
test/htmlcov
test/cpp_extensions/install/
test/test-reports/
third_party/build/
tools/shared/_utils_internal.py
tools/fast_nvcc/wrap_nvcc.sh
Expand Down
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ genrule(
"aten/src/ATen/RegisterMeta.cpp",
"aten/src/ATen/RegisterDefaultBackend.cpp",
"aten/src/ATen/RegisterSchema.cpp",
"aten/src/ATen/CPUFunctions.h",
"aten/src/ATen/CUDAFunctions.h",
"aten/src/ATen/Functions.h",
"aten/src/ATen/Functions.cpp",
"aten/src/ATen/NativeFunctions.h",
Expand Down
23 changes: 3 additions & 20 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ cmake_dependent_option(
USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
"USE_CUDNN" OFF)
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
option(USE_KINETO "Use Kineto profiling library" OFF)
option(USE_KINETO "Use Kineto profiling library" ON)
option(USE_CUPTI_SO "Use CUPTI as a shared library" OFF)
option(USE_FAKELOWP "Use FakeLowp operators" OFF)
option(USE_FFMPEG "Use ffmpeg" OFF)
option(USE_GFLAGS "Use GFLAGS" OFF)
Expand Down Expand Up @@ -248,6 +249,7 @@ cmake_dependent_option(
option(USE_TBB "Use TBB" OFF)
option(ONNX_ML "Enable traditional ONNX ML API." ON)
option(HAVE_SOVERSION "Whether to add SOVERSION to the shared objects" OFF)
option(USE_DEPLOY "Enable torch::deploy embedded python interpreter" OFF)

# Since TensorPipe does not support Windows, set it to OFF when WIN32 detected
# On Windows platform, if user does not install libuv in build conda env and
Expand Down Expand Up @@ -545,31 +547,12 @@ if(USE_FBGEMM AND ((CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND CMAKE_SIZEOF_VO
set(USE_FBGEMM OFF)
endif()

if(USE_KINETO AND INTERN_BUILD_MOBILE)
message(STATUS "Not using libkineto in a mobile build.")
set(USE_KINETO OFF)
endif()

if(USE_KINETO AND (NOT USE_CUDA))
message(STATUS "Not using libkineto in a non-CUDA build.")
set(USE_KINETO OFF)
endif()

if(USE_KINETO AND MSVC)
message(STATUS "Not using libkineto in a Windows build.")
set(USE_KINETO OFF)
endif()

include(cmake/Dependencies.cmake)

if(USE_FBGEMM)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_FBGEMM")
endif()

if(USE_KINETO)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_KINETO")
endif()

if(USE_QNNPACK)
string(APPEND CMAKE_CXX_FLAGS " -DUSE_QNNPACK")
endif()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

#include "jni.h"

#define clamp0255(x) x > 255 ? 255 : x < 0 ? 0 : x

namespace pytorch_vision_jni {

static void imageYUV420CenterCropToFloatBuffer(
Expand Down Expand Up @@ -65,7 +63,7 @@ static void imageYUV420CenterCropToFloatBuffer(
const uint8_t* vData = (uint8_t*)jniEnv->GetDirectBufferAddress(vBuffer);

float scale = cropWidthAfterRtn / tensorWidth;
int uvRowStride = uRowStride >> 1;
int uvRowStride = uRowStride;
int cropXMult = 1;
int cropYMult = 1;
int cropXAdd = offsetX;
Expand All @@ -91,7 +89,7 @@ static void imageYUV420CenterCropToFloatBuffer(
float normStdBm255 = 255 * normStdRGB[2];

int xBeforeRtn, yBeforeRtn;
int yIdx, uvIdx, ui, vi, a0, ri, gi, bi;
int yi, yIdx, uvIdx, ui, vi, a0, ri, gi, bi;
int channelSize = tensorWidth * tensorHeight;
int wr = outOffset;
int wg = wr + channelSize;
Expand All @@ -101,16 +99,23 @@ static void imageYUV420CenterCropToFloatBuffer(
xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
uvIdx = (yBeforeRtn >> 1) * uvRowStride + xBeforeRtn * uvPixelStride;
uvIdx = (yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
ui = uData[uvIdx];
vi = vData[uvIdx];
a0 = 1192 * (yData[yIdx] - 16);
ri = (a0 + 1634 * (vi - 128)) >> 10;
gi = (a0 - 832 * (vi - 128) - 400 * (ui - 128)) >> 10;
bi = (a0 + 2066 * (ui - 128)) >> 10;
outData[wr++] = (clamp0255(ri) - normMeanRm255) / normStdRm255;
outData[wg++] = (clamp0255(gi) - normMeanGm255) / normStdGm255;
outData[wb++] = (clamp0255(bi) - normMeanBm255) / normStdBm255;
yi = yData[yIdx];
yi = (yi - 16) < 0 ? 0 : (yi - 16);
ui -= 128;
vi -= 128;
a0 = 1192 * yi;
ri = (a0 + 1634 * vi) >> 10;
gi = (a0 - 833 * vi - 400 * ui) >> 10;
bi = (a0 + 2066 * ui) >> 10;
ri = ri > 255 ? 255 : ri < 0 ? 0 : ri;
gi = gi > 255 ? 255 : gi < 0 ? 0 : gi;
bi = bi > 255 ? 255 : bi < 0 ? 0 : bi;
outData[wr++] = (ri - normMeanRm255) / normStdRm255;
outData[wg++] = (gi - normMeanGm255) / normStdGm255;
outData[wb++] = (bi - normMeanBm255) / normStdBm255;
}
}
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ if(USE_CUDA AND NOT USE_ROCM)
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublas_static.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcufft_static_nocallback.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcusolver_static.a
${CUDA_TOOLKIT_ROOT_DIR}/lib64/liblapack_static.a # needed for libcusolver_static
)
else()
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/CPUGeneratorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/Utils.h>
#include <ATen/core/MT19937RNGEngine.h>
#include <c10/util/C++17.h>
#include <c10/util/MathConstants.h>
#include <algorithm>

namespace at {
Expand Down Expand Up @@ -153,7 +154,7 @@ void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
// intermediate values.
if (legacy_pod->normal_is_valid) {
auto r = legacy_pod->normal_rho;
auto theta = 2.0 * M_PI * legacy_pod->normal_x;
auto theta = 2.0 * c10::pi<double> * legacy_pod->normal_x;
// we return the sin version of the normal sample when in caching mode
double_normal_sample = c10::optional<double>(r * ::sin(theta));
}
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/CUDAGeneratorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ struct PhiloxCudaState {
bool captured_ = false;
};

struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl {
struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
// Constructors
CUDAGeneratorImpl(DeviceIndex device_index = -1);
~CUDAGeneratorImpl() = default;
Expand Down Expand Up @@ -155,10 +155,10 @@ struct TORCH_CUDA_API CUDAGeneratorImpl : public c10::GeneratorImpl {
namespace cuda {
namespace detail {

TORCH_CUDA_API const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1);
TORCH_CUDA_API Generator createCUDAGenerator(DeviceIndex device_index = -1);
TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator(
DeviceIndex device_index = -1);
TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1);

} // namespace detail
} // namespace cuda
} // namespace at

22 changes: 11 additions & 11 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,25 @@ void Context::setDeterministicCuDNN(bool b) {
deterministic_cudnn = b;
}

bool Context::deterministic() const {
return _deterministic;
bool Context::deterministicAlgorithms() const {
return _deterministic_algorithms;
}

void Context::setDeterministic(bool b) {
void Context::setDeterministicAlgorithms(bool b) {
if (b) {
TORCH_WARN_ONCE("torch.set_deterministic is in beta, and its design and "
TORCH_WARN_ONCE("torch.use_deterministic_algorithms is in beta, and its design and"
" functionality may change in the future.");
}

_deterministic = b;
_deterministic_algorithms = b;
}

void Context::alertNotDeterministic(c10::string_view const& caller) {
if (globalContext().deterministic()) {
if (globalContext().deterministicAlgorithms()) {
TORCH_CHECK(false,
caller, " does not have a deterministic implementation, but you set "
"'torch.set_deterministic(True)'. You can turn off determinism just "
"for this operation if that's acceptable for your application. You "
"'torch.use_deterministic_algorithms(True)'. You can turn off determinism ",
"just for this operation if that's acceptable for your application. You "
"can also file an issue at https://github.com/pytorch/pytorch/issues "
"to help us prioritize adding deterministic support for this operation.");
}
Expand Down Expand Up @@ -111,9 +111,9 @@ bool Context::checkCuBLASConfigDeterministic() {

void Context::alertCuBLASConfigNotDeterministic() {
static bool cublas_config_deterministic = checkCuBLASConfigDeterministic();
TORCH_CHECK(!deterministic() || cublas_config_deterministic,
"Deterministic behavior was enabled with either `torch.set_deterministic(True)` or ",
"`at::Context::setDeterministic(true)`, but this operation is not deterministic because ",
TORCH_CHECK(!deterministicAlgorithms() || cublas_config_deterministic,
"Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or ",
"`at::Context::setDeterministicAlgorithms(true)`, but this operation is not deterministic because ",
"it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this ",
"case, you must set an environment variable before running your PyTorch application: ",
cublas_config_var_name, "=", cublas_deterministic_configs[0], " or ",
Expand Down
34 changes: 18 additions & 16 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,27 +120,27 @@ class TORCH_API Context {
//
// * Include this comment: "See Note [Enabling Deterministic Operations]"
//
// * Check the value of `at::globalContext().deterministic()` to toggle between
// nondeterministic and deterministic implementations.
// * Check the value of `at::globalContext().deterministicAlgorithms()` to toggle
// between nondeterministic and deterministic implementations.
//
// * Have an entry in the list of PyTorch operations that toggle between nondeterministic
// and deterministic implementations, in the docstring of `set_deterministic()`
// and deterministic implementations, in the docstring of `use_deterministic_algorithms()`
// in torch/__init__.py
//
// `example_func()` below shows an example of toggling between nondeterministic and
// deterministic implementations:
//
// void example_func() {
// // See Note [Enabling Deterministic Operations]
// if (at::globalContext().deterministic()) {
// if (at::globalContext().deterministicAlgorithms()) {
// example_func_deterministic();
// } else {
// example_func_nondeterministic();
// }
// }

bool deterministic() const;
void setDeterministic(bool);
bool deterministicAlgorithms() const;
void setDeterministicAlgorithms(bool);

// Note [Writing Nondeterministic Operations]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -151,16 +151,18 @@ class TORCH_API Context {
//
// * Include a comment explaining why the operation is nondeterministic.
//
// * Throw an error when `Context::deterministic()` is true. Most of the time, this
// should be accomplished by calling `at::globalContext().alertNotDeterminstic()`.
// However, if the nondeterministic behavior is caused by the CuBLAS workspace
// * Throw an error when `Context::deterministicAlgorithms()` is true. Most
// of the time, this should be accomplished by calling
// `at::globalContext().alertNotDeterminstic()`. However, if the
// nondeterministic behavior is caused by the CuBLAS workspace
// configuration in CUDA >= 10.2,
// `at::globalContext().alertCuBLASConfigNotDeterministic()` should
// be called instead (in this case, a comment explaining why the operation is
// nondeterministic is not necessary). See below for details on these methods.
// `at::globalContext().alertCuBLASConfigNotDeterministic()` should be
// called instead (in this case, a comment explaining why the operation is
// nondeterministic is not necessary). See below for details on these
// methods.
//
// * Have an entry in the list of nondeterministic PyTorch operations in the
// docstring of `set_deterministic()` in torch/__init__.py
// docstring of `use_deterministic_algorithms()` in torch/__init__.py
//
// `example_func()` below shows an example of the comments and error-throwing code
// for a nondeterministic operation:
Expand All @@ -172,10 +174,10 @@ class TORCH_API Context {
// ...
// }

// Throws an error if `Context::deterministic()` is true
// Throws an error if `Context::deterministicAlgorithms()` is true
void alertNotDeterministic(c10::string_view const& caller);

// Throws an error if `Context::deterministic()` is true, CUDA >= 10.2, and
// Throws an error if `Context::deterministicAlgorithms()` is true, CUDA >= 10.2, and
// CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or ":4096:8". For more details:
// https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
void alertCuBLASConfigNotDeterministic();
Expand Down Expand Up @@ -210,7 +212,7 @@ class TORCH_API Context {
std::once_flag thh_init;
bool enabled_cudnn = true;
bool deterministic_cudnn = false;
bool _deterministic = false;
bool _deterministic_algorithms = false;
bool benchmark_cudnn = false;
bool allow_tf32_cudnn = true;
bool allow_tf32_cublas = true;
Expand Down

0 comments on commit 671ae05

Please sign in to comment.