Skip to content

Commit

Permalink
Update on "[quant] Add quantized::leaky_relu that takes scale/zero_po…
Browse files Browse the repository at this point in the history
…int as input"

Summary:
#45593

Previously quantized leaky_relu does not require observation and just inherits
the quantization parameters from input, but that does not work very well in qat
This PR added a quantized::leaky_relu that has observation for output and it will
become the default leaky_relu that our quantization tools produce (eager/graph mode)

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D24067681](https://our.internmc.facebook.com/intern/diff/D24067681)

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Oct 6, 2020
2 parents c37f812 + 162717e commit 762545a
Show file tree
Hide file tree
Showing 146 changed files with 3,706 additions and 1,078 deletions.
6 changes: 3 additions & 3 deletions .circleci/docker/common/install_cache.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fi
chmod a+x /opt/cache/bin/sccache

function write_sccache_stub() {
printf "#!/bin/sh\nexec sccache $(which $1) \$*" > "/opt/cache/bin/$1"
printf "#!/bin/sh\nexec sccache $(which $1) \"\$@\"" > "/opt/cache/bin/$1"
chmod a+x "/opt/cache/bin/$1"
}

Expand Down Expand Up @@ -57,8 +57,8 @@ if [ -n "$ROCM_VERSION" ]; then
TOPDIR=$(dirname $OLDCOMP)
WRAPPED="$TOPDIR/original/$COMPNAME"
mv "$OLDCOMP" "$WRAPPED"
printf "#!/bin/sh\nexec sccache $WRAPPED \$*" > "$OLDCOMP"
chmod a+x "$1"
printf "#!/bin/sh\nexec sccache $WRAPPED \"\$@\"" > "$OLDCOMP"
chmod a+x "$OLDCOMP"
}

if [[ -e "/opt/rocm/hcc/bin/hcc" ]]; then
Expand Down
2 changes: 1 addition & 1 deletion .circleci/scripts/binary_ios_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ source ~/anaconda/bin/activate

# Install dependencies
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing requests --yes
conda install -c conda-forge valgrind
conda install -c conda-forge valgrind --yes
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}

# sync submodules
Expand Down
4 changes: 2 additions & 2 deletions .circleci/scripts/binary_populate_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ PIP_UPLOAD_FOLDER='nightly/'
# We put this here so that OVERRIDE_PACKAGE_VERSION below can read from it
export DATE="$(date -u +%Y%m%d)"
#TODO: We should be pulling semver version from the base version.txt
BASE_BUILD_VERSION="1.7.0.dev$DATE"
BASE_BUILD_VERSION="1.8.0.dev$DATE"
# Change BASE_BUILD_VERSION to git tag when on a git tag
# Use 'git -C' to make doubly sure we're in the correct directory for checking
# the git tag
Expand Down Expand Up @@ -130,7 +130,7 @@ if [[ "${BUILD_FOR_SYSTEM:-}" == "windows" ]]; then
fi
export DATE="$DATE"
export NIGHTLIES_DATE_PREAMBLE=1.7.0.dev
export NIGHTLIES_DATE_PREAMBLE=1.8.0.dev
export PYTORCH_BUILD_VERSION="$PYTORCH_BUILD_VERSION"
export PYTORCH_BUILD_NUMBER="$PYTORCH_BUILD_NUMBER"
export OVERRIDE_PACKAGE_VERSION="$PYTORCH_BUILD_VERSION"
Expand Down
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ torch_cuda_half_options = [
"-DCUDA_HAS_FP16=1",
"-D__CUDA_NO_HALF_OPERATORS__",
"-D__CUDA_NO_HALF_CONVERSIONS__",
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-D__CUDA_NO_HALF2_OPERATORS__",
]

Expand Down
6 changes: 3 additions & 3 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
# Distributed package
# This list is mostly if you'd like to be tagged as reviewer, feel free to add
# or remove yourself from it.
/torch/lib/c10d/ @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma
/torch/csrc/distributed/ @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma
/torch/distributed/ @apaszke @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma
/torch/lib/c10d/ @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088
/torch/csrc/distributed/ @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088
/torch/distributed/ @apaszke @pietern @mrshenli @zhaojuanmao @pritamdamania87 @rohan-varma @mingzhe09088

# Distributed tests
# This list is mostly if you'd like to be tagged as reviewer, feel free to add
Expand Down
10 changes: 5 additions & 5 deletions android/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ repositories {
}
dependencies {
implementation 'org.pytorch:pytorch_android:1.5.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.5.0'
implementation 'org.pytorch:pytorch_android:1.6.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.6.0'
}
```

Expand All @@ -34,12 +34,12 @@ repositories {
dependencies {
...
implementation 'org.pytorch:pytorch_android:1.7.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_torchvision:1.7.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
implementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT'
...
}
```
The current nightly(snapshots) version is the value of `VERSION_NAME` in `gradle.properties` in current folder, at this moment it is `1.7.0-SNAPSHOT`.
The current nightly(snapshots) version is the value of `VERSION_NAME` in `gradle.properties` in current folder, at this moment it is `1.8.0-SNAPSHOT`.

## Building PyTorch Android from Source

Expand Down
2 changes: 1 addition & 1 deletion android/gradle.properties
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
ABI_FILTERS=armeabi-v7a,arm64-v8a,x86,x86_64

VERSION_NAME=1.7.0-SNAPSHOT
VERSION_NAME=1.8.0-SNAPSHOT
GROUP=org.pytorch
MAVEN_GROUP=org.pytorch
POM_URL=https://github.com/pytorch/pytorch/tree/master/android
Expand Down
6 changes: 3 additions & 3 deletions android/test_app/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ android {

tasks.all { task ->
// Disable externalNativeBuild for all but nativeBuild variant
if (task.name.startsWith('externalNativeBuild')
if (task.name.startsWith('externalNativeBuild')
&& !task.name.contains('NativeBuild')) {
task.enabled = false
}
Expand All @@ -149,8 +149,8 @@ dependencies {
//nativeBuildImplementation(name: 'pytorch_android_torchvision-release', ext: 'aar')
//extractForNativeBuild(name: 'pytorch_android-release', ext: 'aar')

nightlyImplementation 'org.pytorch:pytorch_android:1.7.0-SNAPSHOT'
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:1.7.0-SNAPSHOT'
nightlyImplementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'
nightlyImplementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT'

aarImplementation(name:'pytorch_android', ext:'aar')
aarImplementation(name:'pytorch_android_torchvision', ext:'aar')
Expand Down
7 changes: 2 additions & 5 deletions aten/src/ATen/DLConvertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,10 @@ DLDataType getDLDataType(const Tensor& t) {
throw std::logic_error("BFloat16 is not supported by dlpack");
break;
case ScalarType::QInt8:
throw std::logic_error("QInt8 is not supported by dlpack");
break;
case ScalarType::QUInt8:
throw std::logic_error("QUInt8 is not supported by dlpack");
break;
case ScalarType::QInt32:
throw std::logic_error("QInt32 is not supported by dlpack");
case ScalarType::QUInt4x2:
throw std::logic_error("QUInt/QInt types are not supported by dlpack");
break;
case ScalarType::ComplexHalf:
throw std::logic_error("ComplexHalf is not supported by dlpack");
Expand Down
34 changes: 34 additions & 0 deletions aten/src/ATen/Dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@
return __VA_ARGS__(); \
}

#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
enum_type, type, underlying_type, bitwidth, qmin, qmax, ...) \
case enum_type: { \
using scalar_t = type; \
using underlying_t C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
scalar_t::underlying; \
const auto& SCALAR_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = enum_type; \
const auto& UNDERLYING_TYPE C10_UNUSED_DISPATCH_CUDA_WORKAROUND = \
toUnderlying(enum_type); \
int bit_width = bitwidth; \
int64_t quant_min = qmin; \
int64_t quant_max = qmax; \
return __VA_ARGS__(); \
}

// This macro should be used to skip bfloat16 dispatch on non-ROCm platforms and
// should be removed once the bfloat16 bringup is complete on other platforms.
// This is supposed to be used as a wrapper around the lambda function passed to
Expand Down Expand Up @@ -346,6 +361,25 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
} \
}()

#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQInt8, at::qint8, int8_t, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQUInt8, at::quint8, uint8_t, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQInt32, at::qint32, int, CHAR_BIT * sizeof(int), INT_MIN, INT_MAX, __VA_ARGS__) \
AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
at::kQUInt4x2, at::quint4x2, uint8_t, 4, 0, 15, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
}()

#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
Expand Down

0 comments on commit 762545a

Please sign in to comment.