Skip to content

Commit

Permalink
Update on "Fix typing errors in torch.distributed.distributed_c10d.*"
Browse files Browse the repository at this point in the history
Differential Revision: [D24952501](https://our.internmc.facebook.com/intern/diff/D24952501)

[ghstack-poisoned]
  • Loading branch information
xuzhao9 committed Nov 16, 2020
2 parents 7c3bcd7 + 2981ef2 commit 23127ef
Show file tree
Hide file tree
Showing 45 changed files with 881 additions and 588 deletions.
2 changes: 1 addition & 1 deletion .circleci/cimodel/data/pytorch_build_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
]),
]),
]),
("11.0", [
("11.1", [
("3.8", [
X(True),
("libtorch", [
Expand Down
24 changes: 12 additions & 12 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7283,37 +7283,37 @@ workflows:
build_environment: "pytorch-libtorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.2-cudnn7-py3-gcc7"
- pytorch_linux_build:
name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build
name: pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build
requires:
- "docker-pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
- "docker-pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7"
filters:
branches:
only:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
build_environment: "pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7"
- pytorch_linux_test:
name: pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_test
name: pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_test
requires:
- pytorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build
- pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build
filters:
branches:
only:
- master
- /ci-all\/.*/
- /release\/.*/
build_environment: "pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-test"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
build_environment: "pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-test"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7"
use_cuda_docker_runtime: "1"
resource_class: gpu.medium
- pytorch_linux_build:
name: pytorch_libtorch_linux_xenial_cuda11_0_cudnn8_py3_gcc7_build
name: pytorch_libtorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build
requires:
- "docker-pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
build_environment: "pytorch-libtorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7"
- "docker-pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7"
build_environment: "pytorch-libtorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-build"
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7"
- pytorch_linux_build:
name: pytorch_linux_bionic_py3_6_clang9_build
requires:
Expand Down
5 changes: 5 additions & 0 deletions .circleci/scripts/binary_checkout.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ else
export BUILDER_ROOT="$workdir/builder"
fi

# Try to extract PR number from branch if not already set
if [[ -z "${CIRCLE_PR_NUMBER:-}" ]]; then
CIRCLE_PR_NUMBER="$(echo ${CIRCLE_BRANCH} | sed -E -n 's/pull\/([0-9]*).*/\1/p')"
fi

# Clone the Pytorch branch
retry git clone https://github.com/pytorch/pytorch.git "$PYTORCH_ROOT"
pushd "$PYTORCH_ROOT"
Expand Down
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,6 @@ if(INTERN_BUILD_MOBILE AND NOT BUILD_CAFFE2_MOBILE)
endif()

# ---[ Utils
# TODO: merge the following 3 files into cmake/public/utils.cmake.
include(cmake/Utils.cmake)
include(cmake/public/utils.cmake)

# ---[ Version numbers for generated libraries
Expand Down
48 changes: 28 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ If you want to compile with CUDA support, install
- [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) 9.2 or above
- [NVIDIA cuDNN](https://developer.nvidia.com/cudnn) v7 or above
- [Compiler](https://gist.github.com/ax3l/9489132) compatible with CUDA
Note: You could refer to the [cuDNN Support Matrix](https://docs.nvidia.com/deeplearning/cudnn/pdf/cuDNN-Support-Matrix.pdf) for cuDNN versions with the various supported CUDA, CUDA driver and NVIDIA hardwares

If you want to disable CUDA support, export environment variable `USE_CUDA=0`.
Other potentially useful environment variables may be found in `setup.py`.
Expand Down Expand Up @@ -233,44 +234,51 @@ Each CUDA version only supports one particular XCode version. The following comb

On Windows

At least Visual Studio 2017 version 15.6 with the toolset 14.13 and [NVTX](https://docs.nvidia.com/gameworks/content/gameworkslibrary/nvtx/nvidia_tools_extension_library_nvtx.htm) are needed.
Build with CPU

If the version of Visual Studio 2017 is higher than 15.6, installing of "VC++ 2017 version 15.6 v14.13 toolset" is strongly recommended.
<br/> If the version of Visual Studio 2017 is lesser than 15.6, please update Visual Studio 2017 to the latest version along with installing "VC++ 2017 version 15.6 v14.13 toolset".
<br/> There is no guarantee of the correct building with VC++ 2017 toolsets, others than version 15.6 v14.13.
<br/> "VC++ 2017 version 15.6 v14.13 toolset" might be installed onto already installed Visual Studio 2017 by running its installation once again and checking the corresponding checkbox under "Individual components"/"Compilers, build tools, and runtimes".
It's fairly easy to build with CPU. Visual Studio 2019 version 16.7.6 (MSVC toolchain version 14.27) or higher is recommended.

Build with CUDA
[NVTX](https://docs.nvidia.com/gameworks/content/gameworkslibrary/nvtx/nvidia_tools_extension_library_nvtx.htm) is needed to build Pytorch with CUDA.
NVTX is a part of CUDA distributive, where it is called "Nsight Compute". To install it onto already installed CUDA run CUDA installation once again and check the corresponding checkbox.
Be sure that CUDA with Nsight Compute is installed after Visual Studio 2017.
Make sure that CUDA with Nsight Compute is installed after Visual Studio.

Currently, VS 2017, VS 2019, and Ninja are supported as the generator of CMake. If `ninja.exe` is detected in `PATH`, then Ninja will be used as the default generator, otherwise, it will use VS 2017.
<br/> If Ninja is selected as the generator, the latest MSVC which is newer than VS 2015 (14.0) will get selected as the underlying toolchain. If you use CMake <= 3.14.2 and has VS 2019 installed, then even if you specify VS 2017 as the generator, VS 2019 will get selected as the generator.
Currently, VS 2017 / 2019, and Ninja are supported as the generator of CMake. If `ninja.exe` is detected in `PATH`, then Ninja will be used as the default generator, otherwise, it will use VS 2017 / 2019.
<br/> If Ninja is selected as the generator, the latest MSVC will get selected as the underlying toolchain.

CUDA and MSVC have strong version dependencies, so even if you use VS 2017 / 2019, you will get build errors like `nvcc fatal : Host compiler targets unsupported OS`. For this kind of problem, please install the corresponding VS toolchain in the table below, and then you can either specify the toolset during activation (recommended) or set `CUDAHOSTCXX` to override the Cuda host compiler (not recommended if there are big version differences).
CUDA, MSVC, and PyTorch versions are interdependent; please install matching versions from this table:
| CUDA version | Newest supported VS version | PyTorch version |
| ------------ | ------------------------------------------------------- | --------------- |
| 9.2 | Visual Studio 2017 Update 5 (15.5) (`_MSC_VER` <= 1912) | 0.4.1 ~ 1.5.1 |
| 10.1 | Visual Studio 2019 (16.X) (`_MSC_VER` < 1930) | 1.3.0 ~ 1.7.0 |
| 10.2 | Visual Studio 2019 (16.X) (`_MSC_VER` < 1930) | 1.5.0 ~ 1.7.0 |
| 11.0 | Visual Studio 2019 (16.X) (`_MSC_VER` < 1930) | 1.7.0 |

Note: There's a [compilation issue](https://github.com/oneapi-src/oneDNN/issues/812) in serveral Visual Studio 2019 versions since 16.7.1, so please make sure your Visual Studio 2019 version is not in 16.7.1 ~ 16.7.5

Additional libraries such as
[Magma](https://developer.nvidia.com/magma), [oneDNN, a.k.a MKLDNN or DNNL](https://github.com/oneapi-src/oneDNN), and [Sccache](https://github.com/mozilla/sccache) are often needed. Please refer to the [installation-helper](https://github.com/pytorch/pytorch/tree/master/.jenkins/pytorch/win-test-helpers/installation-helpers) to install them.

You can refer to the [build_pytorch.bat](https://github.com/pytorch/pytorch/blob/master/.jenkins/pytorch/win-test-helpers/build_pytorch.bat) script for some other environment variables configurations

| CUDA version | Newest supported VS version |
| ------------ | ------------------------------------------------------- |
| 9.2 | Visual Studio 2017 Update 5 (15.5) (`_MSC_VER` <= 1912) |
| 10.0 | Visual Studio 2017 (15.X) (`_MSC_VER` < 1920) |
| 10.1 | Visual Studio 2019 (16.X) (`_MSC_VER` < 1930) |

```cmd
cmd
:: [Optional] If you want to build with VS 2019 generator, please change the value in the next line to `Visual Studio 16 2019`.
:: [Optional] If you want to build with the VS 2017 generator for old CUDA and PyTorch, please change the value in the next line to `Visual Studio 15 2017`.
:: Note: This value is useless if Ninja is detected. However, you can force that by using `set USE_NINJA=OFF`.
set CMAKE_GENERATOR=Visual Studio 15 2017
set CMAKE_GENERATOR=Visual Studio 16 2019
:: Read the content in the previous section carefully before you proceed.
:: [Optional] If you want to override the underlying toolset used by Ninja and Visual Studio with CUDA, please run the following script block.
:: "Visual Studio 2017 Developer Command Prompt" will be run automatically.
:: "Visual Studio 2019 Developer Command Prompt" will be run automatically.
:: Make sure you have CMake >= 3.12 before you do this when you use the Visual Studio generator.
set CMAKE_GENERATOR_TOOLSET_VERSION=14.11
set CMAKE_GENERATOR_TOOLSET_VERSION=14.27
set DISTUTILS_USE_SDK=1
for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -version [15^,16^) -products * -latest -property installationPath`) do call "%i\VC\Auxiliary\Build\vcvarsall.bat" x64 -vcvars_ver=%CMAKE_GENERATOR_TOOLSET_VERSION%
:: [Optional] If you want to override the Cuda host compiler
set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2017\Enterprise\VC\Tools\MSVC\14.11.25503\bin\HostX64\x64\cl.exe
:: [Optional] If you want to override the CUDA host compiler
set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.27.29110\bin\HostX64\x64\cl.exe
python setup.py install
Expand Down
25 changes: 21 additions & 4 deletions aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ATen/metal/Context.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/Parallel.h>
#include <torch/library.h>

#ifdef USE_FBGEMM
Expand Down Expand Up @@ -111,14 +112,30 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
((self.is_contiguous() && src.is_contiguous()) ||
(self.is_non_overlapping_and_dense() && self.strides() == src.strides()))) {
if (src.dtype() == at::kFloat && self.dtype() == at::kHalf) {
auto* output_ptr = reinterpret_cast<fbgemm::float16*>(
self.data_ptr<at::Half>());
fbgemm::FloatToFloat16_simd(src.data_ptr<float>(), output_ptr, self.numel());
auto* output_ptr =
reinterpret_cast<fbgemm::float16*>(self.data_ptr<at::Half>());
at::parallel_for(
0,
self.numel(),
at::internal::GRAIN_SIZE,
[&](int64_t begin, int64_t end) {
fbgemm::FloatToFloat16_simd(
src.data_ptr<float>() + begin,
output_ptr + begin,
end - begin);
});
} else {
auto in_data = reinterpret_cast<fbgemm::float16*>(
src.data_ptr<at::Half>());
auto* output_ptr = self.data_ptr<float>();
fbgemm::Float16ToFloat_simd(in_data, output_ptr, self.numel());
at::parallel_for(
0,
self.numel(),
at::internal::GRAIN_SIZE,
[&](int64_t begin, int64_t end) {
fbgemm::Float16ToFloat_simd(
in_data + begin, output_ptr + begin, end - begin);
});
}
return self;
}
Expand Down
9 changes: 6 additions & 3 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,15 +497,18 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor&

if (contraction_size * res_rows * res_cols < 400) {
if (is_bmm_out) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(batch1.scalar_type(), "bmm", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, batch1.scalar_type(), "bmm", [&] {
baddbmm_cpu_kernel<scalar_t, true>(self_or_result, batch1, batch2, beta, alpha);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(batch1.scalar_type(), "baddbmm", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, batch1.scalar_type(), "baddbmm", [&] {
baddbmm_cpu_kernel<scalar_t, false>(self_or_result, batch1, batch2, beta, alpha);
});
}
} else if (at::hasMKL() && (at::native::is_floating_point(self_or_result) ||
} else if (at::hasMKL() && ((
self_or_result.scalar_type() != kHalf &&
self_or_result.scalar_type() != kBFloat16 &&
at::native::is_floating_point(self_or_result)) ||
at::native::is_complex(self_or_result))
&& batch_items_contiguous_or_transposed(batch1)
&& batch_items_contiguous_or_transposed(batch2)
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cuda/PointwiseOpsKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace at { namespace native {

void addcmul_cuda_kernel(TensorIterator& iter, Scalar value) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "addcmul_cuda", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.dtype(), "addcmul_cuda", [&]() {
auto alpha = value.to<scalar_t>();
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
return a + alpha * b * c;
Expand All @@ -18,7 +18,7 @@ void addcmul_cuda_kernel(TensorIterator& iter, Scalar value) {
}

void addcdiv_cuda_kernel(TensorIterator& iter, Scalar value) {
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "addcdiv_cuda", [&]() {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBFloat16, iter.dtype(), "addcdiv_cuda", [&]() {
auto alpha = value.to<scalar_t>();
gpu_kernel(iter, [alpha]GPU_LAMBDA(scalar_t a, scalar_t b, scalar_t c) -> scalar_t {
return a + alpha * (b / c);
Expand Down

0 comments on commit 23127ef

Please sign in to comment.