diff --git a/.gitmodules b/.gitmodules index cd2e81e07a55..4239b5cda2b9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,7 @@ [submodule "third_party/pybind11"] ignore = dirty path = third_party/pybind11 - url = https://github.com/seemethere/pybind11.git + url = https://github.com/pybind/pybind11.git [submodule "third_party/cub"] ignore = dirty path = third_party/cub diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 6fedef185b21..2151b26634ac 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -325,6 +325,7 @@ if(USE_CUDA AND NOT USE_ROCM) ${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcublas_static.a ${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcufft_static_nocallback.a ${CUDA_TOOLKIT_ROOT_DIR}/lib64/libcusolver_static.a + ${CUDA_TOOLKIT_ROOT_DIR}/lib64/liblapack_static.a # needed for libcusolver_static ) else() list(APPEND ATen_CUDA_DEPENDENCY_LIBS diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.cpp b/aten/src/ATen/LegacyTHFunctionsCPU.cpp index f9249364f87a..c2e5e2e1da1e 100644 --- a/aten/src/ATen/LegacyTHFunctionsCPU.cpp +++ b/aten/src/ATen/LegacyTHFunctionsCPU.cpp @@ -776,53 +776,6 @@ std::tuple _th_geqrf(const Tensor & self) { } return std::tuple(res1, res2); } -Tensor & _th_orgqr_out(Tensor & result, const Tensor & self, const Tensor & input2) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type); - auto input2_ = checked_dense_tensor_unwrap(input2, "input2", 2, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type); - THDoubleTensor_orgqr(result_, self_, input2_); - break; - } - case ScalarType::Float: { - auto result_ = checked_dense_tensor_unwrap(result, "result", 0, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type); - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type); - auto input2_ = checked_dense_tensor_unwrap(input2, "input2", 2, "_th_orgqr_out", false, DeviceType::CPU, dispatch_scalar_type); - THFloatTensor_orgqr(result_, self_, input2_); - break; - } - default: - AT_ERROR("_th_orgqr_out not supported on CPUType for ", dispatch_scalar_type); - } - return result; -} -Tensor _th_orgqr(const Tensor & self, const Tensor & input2) { - // DeviceGuard omitted - auto dispatch_scalar_type = infer_scalar_type(self); - auto result_ = c10::make_intrusive(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CPU, scalarTypeToTypeMeta(dispatch_scalar_type)).release(); - auto result = Tensor(c10::intrusive_ptr::reclaim(result_)); - switch (dispatch_scalar_type) { - case ScalarType::Double: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_orgqr", false, DeviceType::CPU, dispatch_scalar_type); - auto input2_ = checked_dense_tensor_unwrap(input2, "input2", 2, "_th_orgqr", false, DeviceType::CPU, dispatch_scalar_type); - THDoubleTensor_orgqr(result_, self_, input2_); - break; - } - case ScalarType::Float: { - auto self_ = checked_dense_tensor_unwrap(self, "self", 1, "_th_orgqr", false, DeviceType::CPU, dispatch_scalar_type); - auto input2_ = checked_dense_tensor_unwrap(input2, "input2", 2, "_th_orgqr", false, DeviceType::CPU, dispatch_scalar_type); - THFloatTensor_orgqr(result_, self_, input2_); - break; - } - default: - AT_ERROR("_th_orgqr not supported on CPUType for ", dispatch_scalar_type); - } - return result; -} Tensor & _th_ormqr_out(Tensor & result, const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose) { // DeviceGuard omitted auto dispatch_scalar_type = infer_scalar_type(self); diff --git a/aten/src/ATen/LegacyTHFunctionsCPU.h b/aten/src/ATen/LegacyTHFunctionsCPU.h index 9a2ec45efefa..7cd8106202c2 100644 --- a/aten/src/ATen/LegacyTHFunctionsCPU.h +++ b/aten/src/ATen/LegacyTHFunctionsCPU.h @@ -42,8 +42,6 @@ Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper); Tensor _th_potri(const Tensor & self, bool upper); std::tuple _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self); std::tuple _th_geqrf(const Tensor & self); -Tensor & _th_orgqr_out(Tensor & result, const Tensor & self, const Tensor & input2); -Tensor _th_orgqr(const Tensor & self, const Tensor & input2); Tensor & _th_ormqr_out(Tensor & result, const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose); Tensor _th_ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose); diff --git a/aten/src/ATen/core/Vitals.cpp b/aten/src/ATen/core/Vitals.cpp new file mode 100644 index 000000000000..da1a208ab501 --- /dev/null +++ b/aten/src/ATen/core/Vitals.cpp @@ -0,0 +1,28 @@ +#include +#include + +TorchVitalAttr& TorchVital::create(const std::string& attr) { + if (!torchVitalEnabled()) { + static TorchVitalAttr disabled; + return disabled; + } + auto iter = attrs.find(attr); + if (iter == attrs.end()) { + auto r = attrs.emplace(std::make_pair(attr, TorchVitalAttr())); + return r.first->second; + } + return iter->second; +} + +bool torchVitalEnabled() { + // If this is a performance hit, make `enabled` variable static + // and return `const bool&` instead + bool enabled = []() { + auto e = getenv("TORCH_VITAL"); + if (e != nullptr) { + return strlen(e) > 0; + } + return false; + }(); + return enabled; +} diff --git a/aten/src/ATen/core/Vitals.h b/aten/src/ATen/core/Vitals.h new file mode 100644 index 000000000000..19d3436ebfef --- /dev/null +++ b/aten/src/ATen/core/Vitals.h @@ -0,0 +1,44 @@ +#pragma once +#include +#include +#include +#include + +bool torchVitalEnabled(); + +struct TorchVitalAttr { + // always initialized to empty + std::string value = ""; + template + TorchVitalAttr& operator<<(const T& t) { + if (torchVitalEnabled()) { + std::stringstream ss; + ss << t; + value += ss.str(); + } + return *this; + } +}; + +struct TorchVital { + std::string name; + std::unordered_map attrs; + + explicit TorchVital(std::string n) : name(std::move(n)) {} + TorchVital() = delete; + + TorchVitalAttr& create(const std::string& attr); + + ~TorchVital() { + for (const auto& m : attrs) { + std::cout << "[TORCH_VITAL] " << name << "." << m.first << "\t\t " + << m.second.value << "\n"; + } + } +}; + +#define TORCH_VITAL_DECLARE(name) extern TorchVital TorchVital_##name; + +#define TORCH_VITAL_DEFINE(name) TorchVital TorchVital_##name(#name); + +#define TORCH_VITAL(name, attr) TorchVital_##name.create(#attr) diff --git a/aten/src/ATen/cpu/vec256/vec256_base.h b/aten/src/ATen/cpu/vec256/vec256_base.h index 477e366ea18b..f309e9d7d0e0 100644 --- a/aten/src/ATen/cpu/vec256/vec256_base.h +++ b/aten/src/ATen/cpu/vec256/vec256_base.h @@ -21,8 +21,6 @@ #include #include -#include -#include #include #include #include diff --git a/aten/src/ATen/cpu/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec256/vec256_qint.h index 6e73dbae3b0d..3ba7f4f51101 100644 --- a/aten/src/ATen/cpu/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec256/vec256_qint.h @@ -5,7 +5,7 @@ #include #include -#include +#include #include #include #include diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 0521adf669c5..2ba9894e0c52 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -327,7 +327,6 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); - TORCH_CHECK(prop->major >= 8, "BFloat16 bgemm in CUDA requires Ampere or later GPU"); TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, @@ -343,7 +342,7 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { (int) num_batches, rocblas_datatype_f32_r, rocblas_gemm_algo_standard, 0, 0, NULL, NULL)); #else - TORCH_CHECK(false, "BFloat16 bgemm in CUDA requires Ampere or later GPU"); + TORCH_CHECK(false, "CUDA BFloat16 bgemm requires CUDA 11 or later"); #endif // defined(CUDA_VERSION) && CUDA_VERSION >= 11000 } #endif // __HIP_PLATFORM_HCC__ @@ -550,37 +549,26 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { float fbeta = beta; _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); GEMM_CHECK_ARGVALUES(at::BFloat16); - cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); - if (prop->major >= 8) { - // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH - // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); - TORCH_CUDABLAS_CHECK(cublasGemmEx( - handle, - opa, - opb, - m, - n, - k, - &falpha, - a, - CUDA_R_16BF, - lda, - b, - CUDA_R_16BF, - ldb, - &fbeta, - c, - CUDA_R_16BF, - ldc, - CUDA_R_32F, - CUBLAS_GEMM_DFALT_TENSOR_OP)); - // On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH - // manually to be able to use tensor cores for FP16. On CUDA 11, this is no longer required. - TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); - } else { - TORCH_CHECK(false, "BFloat16 gemm in CUDA requires Ampere or later GPU"); - } + TORCH_CUDABLAS_CHECK(cublasGemmEx( + handle, + opa, + opb, + m, + n, + k, + &falpha, + a, + CUDA_R_16BF, + lda, + b, + CUDA_R_16BF, + ldb, + &fbeta, + c, + CUDA_R_16BF, + ldc, + CUDA_R_32F, + CUBLAS_GEMM_DFALT_TENSOR_OP)); } #endif diff --git a/aten/src/ATen/cuda/CUDASolver.cpp b/aten/src/ATen/cuda/CUDASolver.cpp index bcd630a06b9e..acf9a3f0443b 100644 --- a/aten/src/ATen/cuda/CUDASolver.cpp +++ b/aten/src/ATen/cuda/CUDASolver.cpp @@ -145,6 +145,196 @@ void getrs>( info)); } + +template<> +void gesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, float* A, int lda, float* S, float* U, + int ldu, float *V, int ldv, int *info, gesvdjInfo_t params +) { + int lwork; + TORCH_CUSOLVER_CHECK(cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, params)); + + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto dataPtr = allocator.allocate(sizeof(float)*lwork); + + TORCH_CUSOLVER_CHECK(cusolverDnSgesvdj( + handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, + static_cast(dataPtr.get()), + lwork, info, params)); +} + +template<> +void gesvdj( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, double* A, int lda, double* S, double* U, + int ldu, double *V, int ldv, int *info, gesvdjInfo_t params +) { + int lwork; + TORCH_CUSOLVER_CHECK(cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, &lwork, params)); + + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto dataPtr = allocator.allocate(sizeof(double)*lwork); + + TORCH_CUSOLVER_CHECK(cusolverDnDgesvdj( + handle, jobz, econ, m, n, A, lda, S, U, ldu, V, ldv, + static_cast(dataPtr.get()), + lwork, info, params)); +} + +template<> +void gesvdj>( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, c10::complex* A, int lda, float* S, c10::complex* U, + int ldu, c10::complex *V, int ldv, int *info, gesvdjInfo_t params +) { + int lwork; + TORCH_CUSOLVER_CHECK(cusolverDnCgesvdj_bufferSize( + handle, jobz, econ, m, n, + reinterpret_cast(A), + lda, S, + reinterpret_cast(U), + ldu, + reinterpret_cast(V), + ldv, &lwork, params)); + + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto dataPtr = allocator.allocate(sizeof(cuComplex)*lwork); + + TORCH_CUSOLVER_CHECK(cusolverDnCgesvdj( + handle, jobz, econ, m, n, + reinterpret_cast(A), + lda, S, + reinterpret_cast(U), + ldu, + reinterpret_cast(V), + ldv, + static_cast(dataPtr.get()), + lwork, info, params)); +} + +template<> +void gesvdj>( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, c10::complex* A, int lda, double* S, c10::complex* U, + int ldu, c10::complex *V, int ldv, int *info, gesvdjInfo_t params +) { + int lwork; + TORCH_CUSOLVER_CHECK(cusolverDnZgesvdj_bufferSize( + handle, jobz, econ, m, n, + reinterpret_cast(A), + lda, S, + reinterpret_cast(U), + ldu, + reinterpret_cast(V), + ldv, &lwork, params)); + + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto dataPtr = allocator.allocate(sizeof(cuDoubleComplex)*lwork); + + TORCH_CUSOLVER_CHECK(cusolverDnZgesvdj( + handle, jobz, econ, m, n, + reinterpret_cast(A), + lda, S, + reinterpret_cast(U), + ldu, + reinterpret_cast(V), + ldv, + static_cast(dataPtr.get()), + lwork, info, params)); +} + + +template<> +void gesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, float* A, int lda, float* S, float* U, + int ldu, float *V, int ldv, int *info, gesvdjInfo_t params, int batchSize +) { + int lwork; + TORCH_CUSOLVER_CHECK(cusolverDnSgesvdjBatched_bufferSize(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, &lwork, params, batchSize)); + + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto dataPtr = allocator.allocate(sizeof(float)*lwork); + + TORCH_CUSOLVER_CHECK(cusolverDnSgesvdjBatched( + handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, + static_cast(dataPtr.get()), + lwork, info, params, batchSize)); +} + +template<> +void gesvdjBatched( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, double* A, int lda, double* S, double* U, + int ldu, double *V, int ldv, int *info, gesvdjInfo_t params, int batchSize +) { + int lwork; + TORCH_CUSOLVER_CHECK(cusolverDnDgesvdjBatched_bufferSize(handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, &lwork, params, batchSize)); + + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto dataPtr = allocator.allocate(sizeof(double)*lwork); + + TORCH_CUSOLVER_CHECK(cusolverDnDgesvdjBatched( + handle, jobz, m, n, A, lda, S, U, ldu, V, ldv, + static_cast(dataPtr.get()), + lwork, info, params, batchSize)); +} + +template<> +void gesvdjBatched>( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, c10::complex* A, int lda, float* S, c10::complex* U, + int ldu, c10::complex *V, int ldv, int *info, gesvdjInfo_t params, int batchSize +) { + int lwork; + TORCH_CUSOLVER_CHECK(cusolverDnCgesvdjBatched_bufferSize( + handle, jobz, m, n, + reinterpret_cast(A), + lda, S, + reinterpret_cast(U), + ldu, + reinterpret_cast(V), + ldv, &lwork, params, batchSize)); + + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto dataPtr = allocator.allocate(sizeof(cuComplex)*lwork); + + TORCH_CUSOLVER_CHECK(cusolverDnCgesvdjBatched( + handle, jobz, m, n, + reinterpret_cast(A), + lda, S, + reinterpret_cast(U), + ldu, + reinterpret_cast(V), + ldv, + static_cast(dataPtr.get()), + lwork, info, params, batchSize)); +} + +template<> +void gesvdjBatched>( + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, c10::complex* A, int lda, double* S, c10::complex* U, + int ldu, c10::complex *V, int ldv, int *info, gesvdjInfo_t params, int batchSize +) { + int lwork; + TORCH_CUSOLVER_CHECK(cusolverDnZgesvdjBatched_bufferSize( + handle, jobz, m, n, + reinterpret_cast(A), + lda, S, + reinterpret_cast(U), + ldu, + reinterpret_cast(V), + ldv, &lwork, params, batchSize)); + + auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); + auto dataPtr = allocator.allocate(sizeof(cuDoubleComplex)*lwork); + + TORCH_CUSOLVER_CHECK(cusolverDnZgesvdjBatched( + handle, jobz, m, n, + reinterpret_cast(A), + lda, S, + reinterpret_cast(U), + ldu, + reinterpret_cast(V), + ldv, + static_cast(dataPtr.get()), + lwork, info, params, batchSize)); +} + } // namespace solver } // namespace cuda } // namespace at diff --git a/aten/src/ATen/cuda/CUDASolver.h b/aten/src/ATen/cuda/CUDASolver.h index 327c7b824c5e..7225fe7bb579 100644 --- a/aten/src/ATen/cuda/CUDASolver.h +++ b/aten/src/ATen/cuda/CUDASolver.h @@ -42,6 +42,41 @@ template<> void getrs>(CUDASOLVER_GETRS_ARGTYPES(c10::complex)); +#define CUDASOLVER_GESVDJ_ARGTYPES(Dtype, Vtype) \ + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int econ, int m, int n, Dtype* A, int lda, Vtype* S, Dtype* U, \ + int ldu, Dtype *V, int ldv, int *info, gesvdjInfo_t params + +template +void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(Dtype, Vtype)) { + TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdj: not implemented for ", typeid(Dtype).name()); +} +template<> +void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(float, float)); +template<> +void gesvdj(CUDASOLVER_GESVDJ_ARGTYPES(double, double)); +template<> +void gesvdj>(CUDASOLVER_GESVDJ_ARGTYPES(c10::complex, float)); +template<> +void gesvdj>(CUDASOLVER_GESVDJ_ARGTYPES(c10::complex, double)); + + +#define CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(Dtype, Vtype) \ + cusolverDnHandle_t handle, cusolverEigMode_t jobz, int m, int n, Dtype* A, int lda, Vtype* S, Dtype* U, \ + int ldu, Dtype *V, int ldv, int *info, gesvdjInfo_t params, int batchSize + +template +void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(Dtype, Vtype)) { + TORCH_INTERNAL_ASSERT(false, "at::cuda::solver::gesvdj: not implemented for ", typeid(Dtype).name()); +} +template<> +void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(float, float)); +template<> +void gesvdjBatched(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(double, double)); +template<> +void gesvdjBatched>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(c10::complex, float)); +template<> +void gesvdjBatched>(CUDASOLVER_GESVDJ_BATCHED_ARGTYPES(c10::complex, double)); + } // namespace solver } // namespace cuda } // namespace at diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index d2f1358a345d..f2980ee95798 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -82,6 +82,22 @@ extern "C" void ssyevd_(char *jobz, char *uplo, int *n, float *a, int *lda, floa // geev extern "C" void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info); extern "C" void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info); +extern "C" void cgeev_(char *jobvl, char *jobvr, int *n, + std::complex *a, int *lda, + std::complex *w, + std::complex *vl, int *ldvl, + std::complex *vr, int *ldvr, + std::complex *work, int *lwork, + float *rwork, + int *info); +extern "C" void zgeev_(char *jobvl, char *jobvr, int *n, + std::complex *a, int *lda, + std::complex *w, + std::complex *vl, int *ldvl, + std::complex *vr, int *ldvr, + std::complex *work, int *lwork, + double *rwork, + int *info); // gesdd extern "C" void zgesdd_(char *jobz, int *m, int *n, std::complex *a, int *lda, @@ -127,9 +143,6 @@ void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, sc template void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); -template -void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); - template void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int *info); @@ -310,14 +323,44 @@ template<> void lapackSyevd(char jobz, char uplo, int n, float *a, int ld ssyevd_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); } -template<> void lapackEig(char jobvl, char jobvr, int n, double *a, int lda, double *wr, double *wi, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, int *info) { +template<> void lapackEig(char jobvl, char jobvr, int n, double *a, int lda, double *w, double* vl, int ldvl, double *vr, int ldvr, double *work, int lwork, double *rwork, int *info) { + // lapack [sd]geev wants to separate output arrays: wr and wi for the real + // and imaginary parts + double *wr = w; + double *wi = w + n; + (void)rwork; // unused dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); } -template<> void lapackEig(char jobvl, char jobvr, int n, float *a, int lda, float *wr, float *wi, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, int *info) { +template<> void lapackEig(char jobvl, char jobvr, int n, float *a, int lda, float *w, float* vl, int ldvl, float *vr, int ldvr, float *work, int lwork, float *rwork, int *info) { + // lapack [sd]geev wants to separate output arrays: wr and wi for the real + // and imaginary parts + float *wr = w; + float *wi = w + n; + (void)rwork; // unused sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info); } +template<> void lapackEig, double>(char jobvl, char jobvr, int n, c10::complex *a, int lda, c10::complex *w, c10::complex *vl, int ldvl, c10::complex *vr, int ldvr, c10::complex *work, int lwork, double *rwork, int *info) { + zgeev_(&jobvl, &jobvr, &n, + reinterpret_cast*>(a), &lda, + reinterpret_cast*>(w), + reinterpret_cast*>(vl), &ldvl, + reinterpret_cast*>(vr), &ldvr, + reinterpret_cast*>(work), &lwork, + rwork, info); +} + +template<> void lapackEig, float>(char jobvl, char jobvr, int n, c10::complex *a, int lda, c10::complex *w, c10::complex *vl, int ldvl, c10::complex *vr, int ldvr, c10::complex *work, int lwork, float *rwork, int *info) { + cgeev_(&jobvl, &jobvr, &n, + reinterpret_cast*>(a), &lda, + reinterpret_cast*>(w), + reinterpret_cast*>(vl), &ldvl, + reinterpret_cast*>(vr), &ldvr, + reinterpret_cast*>(work), &lwork, + rwork, info); +} + template<> void lapackSvd, double>(char jobz, int m, int n, c10::complex *a, int lda, double *s, c10::complex *u, int ldu, c10::complex *vt, int ldvt, c10::complex *work, int lwork, double *rwork, int *iwork, int *info) { zgesdd_(&jobz, &m, &n, reinterpret_cast*>(a), &lda, s, reinterpret_cast*>(u), &ldu, @@ -982,44 +1025,6 @@ static void apply_geqrf(Tensor& self, Tensor& tau, int64_t m, int64_t n, #endif } -template -static void apply_orgqr(Tensor& self, const Tensor& tau, int64_t m, int64_t n_columns, - int64_t k, std::vector& infos) { -#ifndef USE_LAPACK - AT_ERROR("qr: LAPACK library not found in compilation"); -#else - using value_t = typename c10::scalar_value_type::type; - auto self_data = self.data_ptr(); - auto tau_data = tau.data_ptr(); - auto self_matrix_stride = matrixStride(self); - auto tau_stride = tau.size(-1); - auto batch_size = batchCount(self); - - int info; - // Run once, first to get the optimum work size. - // Since we deal with batches of matrices with the same dimensions, doing this outside - // the loop saves (batch_size - 1) workspace queries which would provide the same result - // and (batch_size - 1) calls to allocate and deallocate workspace using at::empty() - int lwork = -1; - scalar_t wkopt; - lapackOrgqr(m, n_columns, k, self_data, m, tau_data, &wkopt, lwork, &info); - lwork = static_cast(real_impl(wkopt)); - Tensor work = at::empty({lwork}, self.options()); - - for (int64_t i = 0; i < batch_size; i++) { - scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; - scalar_t* tau_working_ptr = &tau_data[i * tau_stride]; - - // now compute the actual Q - lapackOrgqr(m, n_columns, k, self_working_ptr, m, tau_working_ptr, work.data_ptr(), lwork, &info); - infos[i] = info; - if (info != 0) { - return; - } - } -#endif -} - std::tuple _linalg_qr_helper_cpu(const Tensor& self, std::string mode) { bool compute_q, reduced; std::tie(compute_q, reduced) = _parse_qr_mode(mode); @@ -1074,13 +1079,14 @@ std::tuple _linalg_qr_helper_cpu(const Tensor& self, std::string } // Next perform ORGQR for Q using the results (both raw R and TAU) from GEQRF + auto infos_orgqr = at::empty({std::max(1, batchCount(self))}, self.options().dtype(kInt)); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "qr_cpu", [&]{ - apply_orgqr(q_working_copy, tau_working_copy, m, n_columns_q, std::min(m, n), infos); + apply_orgqr(q_working_copy, tau_working_copy, infos_orgqr, n_columns_q); }); if (self.dim() > 2) { - batchCheckErrors(infos, "qr_cpu"); + batchCheckErrors(infos_orgqr, "qr_cpu"); } else { - singleCheckErrors(infos[0], "qr_cpu"); + singleCheckErrors(infos_orgqr.item().toInt(), "qr_cpu"); } return std::make_tuple(q_working_copy.narrow(-1, 0, n_columns_q), R); } @@ -1113,6 +1119,114 @@ std::tuple qr_out(Tensor& Q, Tensor& R, const Tensor& self, boo return at::linalg_qr_out(Q, R, self, mode); } +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ orgqr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +DEFINE_DISPATCH(orgqr_stub); + +/* + The orgqr function allows reconstruction of an orthogonal (or unitary) matrix Q, + from a sequence of elementary reflectors, such as is produced by the geqrf function. + + Args: + * `input` - Tensor with the directions of the elementary reflectors below the diagonal. + * `tau` - Tensor containing the magnitudes of the elementary reflectors. + * `result` - result Tensor, which will contain the orthogonal (or unitary) matrix Q. + * `infos` - Tensor to store LAPACK/MAGMA error codes + + For further details, please see the LAPACK/MAGMA documentation. +*/ +Tensor& orgqr_out_info(const Tensor& input, const Tensor& tau, Tensor& result, Tensor& infos) { + TORCH_INTERNAL_ASSERT(input.dim() >= 2); + TORCH_INTERNAL_ASSERT(input.size(-2) >= input.size(-1)); + TORCH_INTERNAL_ASSERT(input.size(-1) >= tau.size(-1)); + + TORCH_INTERNAL_ASSERT(input.scalar_type() == tau.scalar_type()); + TORCH_INTERNAL_ASSERT(input.device() == tau.device()); + + TORCH_INTERNAL_ASSERT(result.scalar_type() == input.scalar_type()); + TORCH_INTERNAL_ASSERT(result.device() == input.device()); + + TORCH_INTERNAL_ASSERT(infos.scalar_type() == at::kInt); + TORCH_INTERNAL_ASSERT(infos.device() == at::kCPU); + TORCH_INTERNAL_ASSERT(infos.numel() == std::max(1, batchCount(input))); + + // if result has no elements we can modify it + if (result.numel() == 0) { + at::native::resize_as_(result, input.transpose(-2, -1), MemoryFormat::Contiguous); + result.transpose_(-2, -1); + } + + // result tensor must be in batched column major order (Fortran contiguous) + TORCH_INTERNAL_ASSERT(result.transpose(-2, -1).is_contiguous()); + TORCH_INTERNAL_ASSERT(result.sizes().equals(input.sizes())); + + // orgqr_stub (apply_orgqr) performs calculations in-place and result must be a copy of input + result.copy_(input); + + // infos must be contiguous + TORCH_INTERNAL_ASSERT(infos.is_contiguous()); + infos.fill_(0); + + auto n = input.size(-1); + result = orgqr_stub(result.device().type(), result, tau, infos, n); + return result; +} + +Tensor& orgqr_out(const Tensor& input, const Tensor& tau, Tensor& result) { + TORCH_CHECK(input.dim() >= 2, "orgqr: input must have at least 2 dimensions."); + TORCH_CHECK(input.size(-2) >= input.size(-1), "orgqr: input.shape[-2] must be greater than or equal to input.shape[-1]"); + TORCH_CHECK(input.size(-1) >= tau.size(-1), "orgqr: input.shape[-1] must be greater than or equal to tau.shape[-1]"); + + TORCH_CHECK(tau.scalar_type() == input.scalar_type(), + "orgqr: tau dtype ", tau.scalar_type(), " does not match input dtype ", input.scalar_type()); + TORCH_CHECK(input.device() == tau.device(), + "orgqr: Expected input and tau to be on the same device, but found input on ", + input.device(), " and tau on ", tau.device(), " instead."); + + TORCH_CHECK(result.scalar_type() == input.scalar_type(), + "orgqr: result dtype ", result.scalar_type(), " does not match the expected dtype ", input.scalar_type()); + TORCH_CHECK(result.device() == input.device(), + "orgqr: Expected result and input to be on the same device, but found result on ", + result.device(), " and input on ", input.device(), " instead."); + + // TODO: uncomment the following when passing incorrectly sized 'result' is not allowed + // if (result.numel() != 0) { + // // Resize messes up the strides, so let's not use at::native::resize_output + // TORCH_CHECK(result.sizes().equals(input.sizes()), + // "result shape ", result.sizes(), " does not match input shape ", input.sizes()); + // } + + // Single matrix MAGMA routine requires 'infos' to reside in CPU memory, + // therefore we create 'infos' only on CPU for now. + // This should be changed if cuSOLVER would be used + auto infos = at::empty({std::max(1, batchCount(input))}, input.options().dtype(kInt).device(kCPU)); + + // if result is not empty and not in batched column major format we have to allocate a temporary tensor + if (result.numel() != 0 && !result.transpose(-2, -1).is_contiguous()) { + Tensor result_tmp = at::empty({0}, input.options()); + result_tmp = orgqr_out_info(input, tau, result_tmp, infos); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + } else { + // use result's storage directly + result = orgqr_out_info(input, tau, result, infos); + } + + // Now check LAPACK/MAGMA error codes + if (result.dim() > 2) { + batchCheckErrors(infos, "orgqr"); + } else { + singleCheckErrors(infos.item().toInt(), "orgqr"); + } + return result; +} + +Tensor orgqr(const Tensor& input, const Tensor& tau) { + Tensor result = at::empty({0}, input.options()); + result = at::orgqr_outf(input, tau, result); + return result; +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // This function computes eigenvalues 'w' and eigenvectors 'v' of the input that is stored initially in 'v' @@ -1373,7 +1487,11 @@ std::tuple eig_out(Tensor& e, Tensor& v, const Tensor& self, b TORCH_CHECK(v.dtype() == self.dtype(), "Expected 'v' to have dtype ", self.dtype(), " but got ", v.dtype()); int64_t n = self.size(-1); - at::native::resize_output(e, {n, 2}); + if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + at::native::resize_output(e, {n}); + } else { + at::native::resize_output(e, {n, 2}); + } if (eigenvectors) { at::native::resize_output(v, self.sizes()); } @@ -1498,6 +1616,8 @@ std::tuple _svd_helper_cpu(const Tensor& self, bool some VT_working_copy.zero_(); } // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + // Note that the 'apply_svd' routine returns VT = V^T (for real inputs) or VT = V^H (for complex inputs), not V. + VT_working_copy = VT_working_copy.conj(); VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } @@ -1528,8 +1648,8 @@ std::tuple svd_out(Tensor& U, Tensor& S, Tensor& V, 1. the 2nd parameter is bool some=True, which if effectively the opposite of full_matrices=True - 2. svd returns V, while linalg.svd returns VT. To accommodate the - difference, we transpose() V upon return + 2. svd returns V, while linalg.svd returns VT = V^T (for real inputs) or VT = V^H (for complex inputs). + To accommodate the difference, we transpose() and conj() V upon return */ std::tuple linalg_svd(const Tensor& self, bool full_matrices, bool compute_uv) { @@ -1540,7 +1660,7 @@ std::tuple linalg_svd(const Tensor& self, bool full_matr Tensor U, S, V; std::tie(U, S, V) = at::_svd_helper(self, some, compute_uv); if (compute_uv) { - Tensor VT = V.transpose(-2, -1); + Tensor VT = V.conj().transpose(-2, -1); return std::make_tuple(U, S, VT); } else { Tensor empty_U = at::empty({0}, self.options()); diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h index 95fc2c6097ce..38a9f92b158e 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.h +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -2,6 +2,8 @@ #include #include +#include +#include #include // for USE_LAPACK @@ -12,8 +14,11 @@ namespace at { namespace native { // Define per-batch functions to be used in the implementation of batched // linear algebra operations +template +void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *w, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, value_t *rwork, int *info); + template -void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info); +void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); #endif @@ -21,4 +26,72 @@ using eig_fn = std::tuple (*)(const Tensor&, bool&); DECLARE_DISPATCH(eig_fn, eig_stub); +/* + The orgqr function allows reconstruction of an orthogonal (or unitary) matrix Q, + from a sequence of elementary reflectors, such as produced by the geqrf function. + + Args: + * `self` - Tensor with the directions of the elementary reflectors below the diagonal, + it will be overwritten with the result + * `tau` - Tensor containing the magnitudes of the elementary reflectors + * `infos` - Tensor to store LAPACK's error codes + * `n_columns` - The number of columns of Q to be computed + + For further details, please see the LAPACK documentation for ORGQR and UNGQR. +*/ +template +inline void apply_orgqr(Tensor& self, const Tensor& tau, Tensor& infos, int64_t n_columns) { +#ifndef USE_LAPACK + TORCH_CHECK(false, "Calling torch.orgqr on a CPU tensor requires compiling ", + "PyTorch with LAPACK. Please use PyTorch built with LAPACK support."); +#else + // Some LAPACK implementations might not work well with empty matrices: + // workspace query might return lwork as 0, which is not allowed (requirement is lwork >= 1) + // We don't need to do any calculations in this case, so let's return early + if (self.numel() == 0) { + infos.fill_(0); + return; + } + + using value_t = typename c10::scalar_value_type::type; + auto self_data = self.data_ptr(); + auto tau_data = tau.data_ptr(); + auto infos_data = infos.data_ptr(); + auto self_matrix_stride = matrixStride(self); + auto tau_stride = tau.size(-1); + auto batch_size = batchCount(self); + auto m = self.size(-2); + auto k = tau.size(-1); + auto lda = std::max(1, m); + + // LAPACK's requirement + TORCH_INTERNAL_ASSERT(m >= n_columns); + TORCH_INTERNAL_ASSERT(n_columns >= k); + + // Run once, first to get the optimum work size. + // Since we deal with batches of matrices with the same dimensions, doing this outside + // the loop saves (batch_size - 1) workspace queries which would provide the same result + // and (batch_size - 1) calls to allocate and deallocate workspace using at::empty() + int lwork = -1; + scalar_t wkopt; + lapackOrgqr(m, n_columns, k, self_data, lda, tau_data, &wkopt, lwork, &infos_data[0]); + lwork = static_cast(real_impl(wkopt)); + Tensor work = at::empty({lwork}, self.options()); + + for (int64_t i = 0; i < batch_size; i++) { + scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; + scalar_t* tau_working_ptr = &tau_data[i * tau_stride]; + int* info_working_ptr = &infos_data[i]; + // now compute the actual Q + lapackOrgqr(m, n_columns, k, self_working_ptr, lda, tau_working_ptr, work.data_ptr(), lwork, info_working_ptr); + if (*info_working_ptr != 0) { + return; + } + } +#endif +} + +using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/, Tensor& /*infos*/, int64_t /*n_columns*/); +DECLARE_DISPATCH(orgqr_fn, orgqr_stub); + }} // namespace at::native diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index d251245c60c5..f5afa96eb6bc 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include // for USE_LAPACK @@ -15,29 +16,38 @@ void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vec TORCH_CHECK(false, "Calling torch.eig on a CPU tensor requires compiling ", "PyTorch with LAPACK. Please use PyTorch built with LAPACK support."); #else + using value_t = typename c10::scalar_value_type::type; + char jobvr = eigenvectors ? 'V' : 'N'; int64_t n = self.size(-1); auto self_data = self.data_ptr(); auto vals_data = vals_.data_ptr(); scalar_t* wr = vals_data; - scalar_t* wi = vals_data + n; scalar_t* vecs_data = eigenvectors ? vecs_.data_ptr() : nullptr; int ldvr = eigenvectors ? n : 1; + Tensor rwork; + value_t* rwork_data = nullptr; + if (self.is_complex()) { + ScalarType real_dtype = toValueType(typeMetaToScalarType(self.dtype())); + rwork = at::empty({n*2}, self.options().dtype(real_dtype)); + rwork_data = rwork.data_ptr(); + } + if (n > 0) { // call lapackEig once to get the optimal size for work data scalar_t wkopt; int info; - lapackEig('N', jobvr, n, self_data, n, wr, wi, - nullptr, 1, vecs_data, ldvr, &wkopt, -1, &info); - int lwork = static_cast(wkopt); + lapackEig('N', jobvr, n, self_data, n, wr, + nullptr, 1, vecs_data, ldvr, &wkopt, -1, rwork_data, &info); + int lwork = static_cast(real_impl(wkopt)); // call again to do the actual work Tensor work = at::empty({lwork}, self.dtype()); - lapackEig('N', jobvr, n, self_data, n, wr, wi, - nullptr, 1, vecs_data, ldvr, work.data_ptr(), lwork, &info); + lapackEig('N', jobvr, n, self_data, n, wr, + nullptr, 1, vecs_data, ldvr, work.data_ptr(), lwork, rwork_data, &info); *info_ptr = info; } #endif @@ -55,23 +65,45 @@ std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvector self_.copy_(self); auto options = self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); - Tensor vals_ = at::empty_strided({n, 2}, {1, n}, options); + + // the API is slightly different for the complex vs real case: if the input + // is complex, eigenvals will be a vector of complex. If the input is real, + // eigenvals will be a (n, 2) matrix containing the real and imaginary parts + // in each column + Tensor vals_; + if (self.is_complex()) { + vals_ = at::empty({n}, options); + } else { + vals_ = at::empty_strided({n, 2}, {1, n}, options); + } Tensor vecs_ = eigenvectors ? at::empty_strided({n, n}, {1, n}, options) : Tensor(); int64_t info; - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "eig_cpu", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cpu", [&]{ apply_eig(self_, eigenvectors, vals_, vecs_, &info); }); singleCheckErrors(info, "eig_cpu"); return std::tuple(vals_, vecs_); } +// This is a type dispatching helper function for 'apply_orgqr' +Tensor& orgqr_kernel_impl(Tensor& result, const Tensor& tau, Tensor& infos, int64_t n_columns) { + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(result.scalar_type(), "orgqr_cpu", [&]{ + apply_orgqr(result, tau, infos, n_columns); + }); + return result; +} + } // anonymous namespace REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl); REGISTER_AVX_DISPATCH(eig_stub, &eig_kernel_impl); REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl); +REGISTER_ARCH_DISPATCH(orgqr_stub, DEFAULT, &orgqr_kernel_impl); +REGISTER_AVX_DISPATCH(orgqr_stub, &orgqr_kernel_impl); +REGISTER_AVX2_DISPATCH(orgqr_stub, &orgqr_kernel_impl); + }} // namespace at::native diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 880941059951..674c91597ae9 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -147,16 +147,14 @@ Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) { // If not Hermitian use singular value decomposition, else use eigenvalue decomposition if (!hermitian) { - // until https://github.com/pytorch/pytorch/issues/45821 is resolved - // svd() returns conjugated V for complex-valued input - Tensor U, S, V_conj; + Tensor U, S, V; // TODO: replace input.svd with linalg_svd - std::tie(U, S, V_conj) = input.svd(); + // using linalg_svd breaks pytorch/xla, see https://github.com/pytorch/xla/issues/2755 + std::tie(U, S, V) = input.svd(); Tensor max_val = at::narrow(S, /*dim=*/-1, /*start=*/0, /*length=*/1); // singular values are sorted in descending order Tensor S_pseudoinv = at::where(S > (rcond.unsqueeze(-1) * max_val), S.reciprocal(), at::zeros({}, S.options())).to(input.dtype()); - // computes V @ diag(S_pseudoinv) @ U.T.conj() - // TODO: replace V_conj.conj() -> V once https://github.com/pytorch/pytorch/issues/45821 is resolved - return at::matmul(V_conj.conj() * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1)); + // computes V @ diag(S_pseudoinv) @ U.conj().T + return at::matmul(V * S_pseudoinv.unsqueeze(-2), U.conj().transpose(-2, -1)); } else { Tensor S, U; std::tie(S, U) = at::linalg_eigh(input); @@ -2223,67 +2221,34 @@ Tensor chain_matmul(TensorList matrices) { /* Calculates the Kronecker product between two Tensors. */ -Tensor kron(const Tensor& self, const Tensor& other) { - /* - We can obtain the kron result using tensordot or einsum. The implementation below uses tensordot. - In einsum notation suppose we have `self` with dim 4 and `other` with dim 2 - the result of below tensordot is in einsum 0123, 45 -> 012345. - To obtain the correct kron we need to permute and reshape the array. - The permutation rule is the following: going from right to left - take axes in turn to form the permutation - with our example the correct permutation is 012435 and - the kron shape is (shape_self[0], shape_self[1], shape_self[3]*shape_other[0], - shape_self[4]*shape_other[1]) - */ - std::vector self_sizes = self.sizes().vec(); - std::vector other_sizes = other.sizes().vec(); - int64_t self_ndim = self.dim(); - int64_t other_ndim = other.dim(); - int64_t min_ndim = std::min(self_ndim, other_ndim); - int64_t ndim_diff = std::abs(self_ndim - other_ndim); - - std::vector a_axes(self_ndim); - std::vector b_axes(other_ndim); - std::iota(a_axes.begin(), a_axes.end(), 0); - std::iota(b_axes.begin(), b_axes.end(), 0 + self_ndim); - - bool is_a_larger = self_ndim >= other_ndim; - std::vector kron_permutation(self_ndim + other_ndim); - for (int64_t i = 0; i < ndim_diff; i++) { - kron_permutation[i] = is_a_larger ? a_axes[i] : b_axes[i]; - } - for (int64_t i = 0, j = 0; i < min_ndim; i++, j += 2) { - kron_permutation[self_ndim + other_ndim - 1 - j] = b_axes[other_ndim - 1 - i]; - kron_permutation[self_ndim + other_ndim - 1 - j - 1] = a_axes[self_ndim - 1 - i]; - } - - std::vector result_shape(std::max(self_ndim, other_ndim)); - for (int64_t i = 0; i < ndim_diff; i++) { - result_shape[i] = is_a_larger ? self_sizes[i] : other_sizes[i]; +Tensor& kron_out(Tensor& result, const Tensor& self, const Tensor& other) { + auto maxdim = std::max(self.dim(), other.dim()); + auto pad_self = maxdim - self.dim(); + auto pad_other = maxdim - other.dim(); + c10::SmallVector a_reshape(2 * maxdim); + c10::SmallVector b_reshape(2 * maxdim); + c10::SmallVector result_reshape(maxdim); + for (int i = 0; i < maxdim; i++) { + a_reshape[2 * i] = i >= pad_self ? self.sizes()[i - pad_self] : 1; + a_reshape[2 * i + 1] = 1; + b_reshape[2 * i] = 1; + b_reshape[2 * i + 1] = i >= pad_other ? other.sizes()[i - pad_other] : 1; + result_reshape[i] = a_reshape[2 * i] * b_reshape[2 * i + 1]; } - for (int64_t i = 0; i < min_ndim; i++) { - result_shape[ndim_diff + i] = is_a_larger - ? self_sizes[ndim_diff + i] * other_sizes[i] - : other_sizes[ndim_diff + i] * self_sizes[i]; + auto self_view = at::_unsafe_view(self, a_reshape); + auto other_view = at::_unsafe_view(other, b_reshape); + if (!result.defined()) { + result = at::_unsafe_view(at::mul(self_view, other_view), result_reshape); + } else { + at::mul_out(result, self_view, other_view); + result.resize_(result_reshape); } - - Tensor result = at::tensordot(self, other, {}, {}); - // Step 2: now permute result - result = result.permute(kron_permutation); - // Step 3: reshape - result = result.reshape(result_shape); - return result; } -Tensor& kron_out(Tensor& result, const Tensor& self, const Tensor& other) { - TORCH_CHECK(result.scalar_type() == self.scalar_type(), - "result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type()); - - Tensor result_tmp = at::kron(self, other); - at::native::resize_output(result, result_tmp.sizes()); - result.copy_(result_tmp); - return result; +Tensor kron(const Tensor& self, const Tensor& other) { + at::Tensor result; + return at::kron_out(result, self, other); } } // namespace native diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 4322c4c79222..b114c468a05d 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -239,7 +239,14 @@ static inline std::tuple, } // Function to generate empty tensors of required size, strides and dtype for the SVD operation -static inline std::tuple _create_U_S_VT(const Tensor& input, bool some, bool compute_uv) { +static inline std::tuple _create_U_S_VT(const Tensor& input, bool some, bool compute_uv, + const bool svd_use_cusolver=false) { + + // U, S, VT are initialized as empty tensors. + // For CPU LAPACK and GPU MAGMA backend, the tensors are initialized on CPU. + // For GPU cuSOLVER backend, the tensors are initialized on GPU. + const auto usvt_device = svd_use_cusolver ? at::kCUDA : at::kCPU; + auto sizes = input.sizes().vec(); int64_t m = input.size(-2), n = input.size(-1); @@ -251,47 +258,21 @@ static inline std::tuple _create_U_S_VT(const Tensor& in strides[input.dim() - 1] = m; strides[input.dim() - 2] = 1; - Tensor U_empty; - if (!input.is_cuda()) { - U_empty = at::empty_strided(sizes, strides, input.options()); - } else { - // NB: U_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd - // (which is the driver routine for the divide and conquer SVD operation) - // takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that - // moves the inputs between devices internally. - U_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU)); - } + Tensor U_empty = at::empty_strided(sizes, strides, input.options().device(usvt_device)); + U_empty.zero_(); // VT should be a column-major or a batch of column-major matrices sizes[input.dim() - 2] = n; sizes[input.dim() - 1] = n; - strides = at::detail::defaultStrides(sizes); - strides[input.dim() - 1] = n; - strides[input.dim() - 2] = 1; - Tensor VT_empty; - if (!input.is_cuda()) { - VT_empty = at::empty_strided(sizes, strides, input.options()); - } else { - // NB: VT_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd - // (which is the driver routine for the divide and conquer SVD operation) - // takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that - // moves the inputs between devices internally. - VT_empty = at::empty_strided(sizes, strides, input.options().device(at::kCPU)); - } + // VT should be a column-major or a batch of column-major matrices + Tensor VT_empty = at::zeros(sizes, input.options().device(usvt_device)); + VT_empty.transpose_(-2, -1); sizes.pop_back(); sizes[input.dim() - 2] = std::min(m, n); - Tensor S_empty; ScalarType dtype = toValueType(typeMetaToScalarType(input.dtype())); - if (!input.is_cuda()) { - S_empty = at::empty(sizes, input.options().dtype(dtype)); - } else { - // NB: S_empty is an empty tensor created on the CPU intentionally, because magma_(d/s)gesdd - // (which is the driver routine for the divide and conquer SVD operation) - // takes in arrays on the CPU as input. This routine is a hybrid CPU-GPU routine that - // moves the inputs between devices internally. - S_empty = at::empty(sizes, input.options().dtype(dtype).device(at::kCPU)); - } + Tensor S_empty = at::empty(sizes, input.options().dtype(dtype).device(usvt_device)); + return std::tuple(U_empty, S_empty, VT_empty); } diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index f85dc2320e7b..e3b4644e79e0 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -6,6 +6,8 @@ #include #include #include +#include +#include #include #include diff --git a/aten/src/ATen/native/README.md b/aten/src/ATen/native/README.md index 4752ded27ee8..6acd1be47c6f 100644 --- a/aten/src/ATen/native/README.md +++ b/aten/src/ATen/native/README.md @@ -268,8 +268,8 @@ dispatch: ``` This specifies the actual name of the function you want to dispatch to, so you -can dispatch to different functions depending on whether or not you have CPU or -CUDA tensors. Technically, it is also possible to write `dispatch: func_name` +can dispatch to different functions depending on which backend the passed tensors +belong to. Technically, it is also possible to write `dispatch: func_name` to unconditionally dispatch to a native function whose name is different than the name in the public ATen API, but this is generally frowned upon (just name them the same thing!) @@ -277,17 +277,55 @@ them the same thing!) If two backends have the same dispatch function, you can write `CPU, CUDA: func` to reuse the same function name in both cases. -Available backend options can be found at -https://github.com/pytorch/pytorch/blob/master/tools/codegen/gen.py#L970. -In addition to the backends above, we also support the keywords: +Available backend options can be found by searching `dispatch_keys` in +[codegen](https://github.com/pytorch/pytorch/blob/master/tools/codegen/gen.py). +Among the supported backends, there're a few alias keys that maps to a set of backends: - `DefaultBackend`: an alias that maps to all backends. Functions registered to `DefaultBackend` should work for any backend inference. - `Math`: an alias that maps to all backend and autograd backend keys. Functions registered to `Math` key should be plain mathematical composition of other `at::` functions and support training and inference for any backend. -If you add `dispatch` section to any API that didn't have it before, you **have to** move -the old implementation to `Math` field so that it's still available for other backends to use. +`DefaultBackend` and `Math` keys act as defaults: for example, you can specify a custom +kernel for a particular backend using a backend-specific dispatch key, and use +`DefaultBackend` or `Math` to specify a generic kernel for the others. + +Note that like those registered to `Math`, kernels registered to `DefaultBackend` are +very often implemented as mathematical expressions built up from calls to other `at::` +functions. This is because in both cases, the kernel needs to delegate backend-specific +computation to the functions it calls. The difference between `DefaultBackend` and `Math` +is that a `Math` kernel also implicitly defines a derivative formula: to do this, it must +call only functions that themselves support autograd. + +For example, suppose `my_op` can be implemented in the following way: +``` +at::Tensor my_op(const Tensor& self, const Tensor& other) { + return self + 2 * other; +} +``` + +If we already know inference kernels and derivative formulas for operators `+` and `*` in our system, +you can just register `my_op` to `Math` and both inference & autograd will just work. +Although it seems we only write down the inference formula here, PyTorch autograd system would correctly +set up the backward for `my_op` using the chain formula and derivatives of `+` & `*` operators. +In other words `d_out/d_self = 1; d_out/d_other = 2` can be derived automatically from +the `my_op` inference kernel. Of course if we don't have derivative formula defined for either `+` or `*`, +backward of `my_op` can no longer be derived automatically. + +Whether to use `Math` or `DefaultBackend` for your kernel can be decided by the following steps: +1. If you can, always start with a `Math` kernel that's composable from existing operators. +2. If you don't want to use the derived gradient formula from `Math` kernel for autograd, either to + get better performance or better numerical stability, you should put the kernel in `DefaultBackend` + so that it's only used in inference. + Later for autograd, depending on whether your autograd kernel works for all backends or not, + you can put them in alias `Autograd` or specific keys like `AutogradCPU`. +3. If you prefer to write backend-specific kernels, use reserved dispatch keys for your backend instead, + e.g. `CPU/AutogradCPU`. + +**Important**: because a `Math` kernel is implicitly registered for ops with no `dispatch:` section, +when you add a backend-specific kernel (and hence a `dispatch:` section) to one of these, you **must** also +add a `Math:` entry that names the old kernel implementation (it's named after the op, with _ +added if applicable), so that it's still available for other backends to use. If you implemented a native function in C++ and want to find out which dispatch keyword should be used in native_functions.yaml, please [follow steps in dispatch keywords](#choosing-the-right-dispatch-keyword) @@ -474,8 +512,17 @@ Here're steps to follow to decide the right dispatch keyword: Note: current plan on record for ops using this boilerplate is to replace `at::` with `at::native` in the implementations and add dispatch section with device keywords instead. +3. Validate the computed dispatch table matches what you want. You can use `PythonDispatcher` provided in +[torch/_python_dispatcher.py](https://github.com/pytorch/pytorch/blob/master/torch/_python_dispacher.py). +It shows for a certain operator, what the computed dispatch table looks like after your registrations. + + ``` + dispatcher = PythonDispatcher() + dispatcher.register(["CPU", "XLA", "AutogradCPU", "Math"]) + print(dispatcher.dispatchTable()) # Tells you exactly which kernel is used for certain backend. + ``` -3. TODO: AutogradCPUOrCUDA +4. TODO: AutogradCPUOrCUDA Note that in native_functions.yaml you can mix using backend keywords and alias keywords above for one op: - direct registration to backend always has higher precendence than alias @@ -483,6 +530,8 @@ Note that in native_functions.yaml you can mix using backend keywords and alias e.g. adding both `Math` and `DefaultBackend` kernels for one op will completely ignore `Math` kernel for both inference and training. Thus this will trigger an error when native_functions.yaml is parsed. + + ### Will this function be exposed to python? What are the namespaces? We don't generate python bindings for all functions. There're certain patterns in function diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index fd27b3e7efe5..2ea68995207c 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -502,9 +502,12 @@ static Tensor& prod_out_impl(Tensor& result, const Tensor& self, IntArrayRef dim // see https://github.com/pytorch/pytorch/pull/47305, Tensor trace_cpu(const Tensor& self) { Tensor result; + // Returns the ScalarType of the self tensor if the tensor is non integral type + // In the case, self is an integer type tensor, at::kLong is return since promote_integers + // is set to true ScalarType dtype = get_dtype(result, self, c10::nullopt, true); result = at::empty({}, self.options().dtype(dtype)); - AT_DISPATCH_ALL_TYPES(self.scalar_type(), "trace", [&] { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX(self.scalar_type(), "trace", [&] { using accscalar_t = at::acc_type; accscalar_t sum = 0; const auto* t_data = self.data_ptr(); @@ -521,12 +524,11 @@ Tensor trace_cpu(const Tensor& self) { sum += t_data[i * (t_stride_0 + t_stride_1)]; } - // all integer types get promoted to kLong - if (result.scalar_type() == at::kLong) { - *result.data_ptr() = sum; - } else { - *result.data_ptr() = sum; - } + c10::guts::if_constexpr::value>( + // all integer types get promoted to kLong + [&] (auto _) { *result.data_ptr() = _(sum); }, // then-case, invalid for non-integral types + [&] (auto _) { *result.data_ptr() = _(sum); } // else-case, invalid for integral types + ); }); return result; @@ -843,7 +845,7 @@ Tensor any(const Tensor& self) { "any only supports CPU AND CUDA device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided || self.layout() == Layout::Sparse, "any only supports strided AND sparse layout, got: ", self.layout()); - + // Refer [all, any : uint8 compatibility] Tensor result; ScalarType out_dtype; diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index f2c3f6309a2e..4f3cd6c9055f 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1794,7 +1794,7 @@ Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) { shape.push_back(self.size(i)); } - return self.reshape(shape); + return native::reshape(self, shape); } Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim, Dimname out_dim) { diff --git a/aten/src/ATen/native/UpSampleNearest2d.cpp b/aten/src/ATen/native/UpSampleNearest2d.cpp index c7e3d1d4860a..8cfed01e606f 100644 --- a/aten/src/ATen/native/UpSampleNearest2d.cpp +++ b/aten/src/ATen/native/UpSampleNearest2d.cpp @@ -3,51 +3,9 @@ #include namespace at { -namespace native { -namespace { - -static void upsample_nearest2d_out_cpu_template( - Tensor& output, - const Tensor& input, - IntArrayRef output_size, - c10::optional scales_h, - c10::optional scales_w) { - TORCH_CHECK( - output_size.size() == 2, - "It is expected output_size equals to 2, but got size ", - output_size.size()); - - int64_t output_height = output_size[0]; - int64_t output_width = output_size[1]; - - int64_t nbatch = input.size(0); - int64_t channels = input.size(1); - int64_t input_height = input.size(2); - int64_t input_width = input.size(3); +namespace meta { - upsample_2d_shape_check( - input, - Tensor(), - nbatch, - channels, - input_height, - input_width, - output_height, - output_width); - - output.resize_({nbatch, channels, output_height, output_width}, input.suggest_memory_format()); - - AT_ASSERT(input_width > 0 && output_width > 0); - upsample_nearest2d_kernel(kCPU, output, input, scales_h, scales_w); -} - -static void upsample_nearest2d_backward_out_cpu_template( - Tensor& grad_input, - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scales_h, - c10::optional scales_w) { +static std::array upsample_nearest2d_common_check(IntArrayRef input_size, IntArrayRef output_size) { TORCH_CHECK( output_size.size() == 2, "It is expected output_size equals to 2, but got size ", @@ -66,83 +24,100 @@ static void upsample_nearest2d_backward_out_cpu_template( int64_t input_height = input_size[2]; int64_t input_width = input_size[3]; - upsample_2d_shape_check( - Tensor(), - grad_output, - nbatch, - channels, + TORCH_CHECK( + input_height > 0 && input_width > 0 && output_height > 0 && + output_width > 0, + "Input and output sizes should be greater than 0," + " but got input (H: ", input_height, + ", W: ", input_width, + ") output (H: ", output_height, - output_width); - - grad_input.resize_({nbatch, channels, input_height, input_width}); - grad_input.zero_(); + ", W: ", + output_width, + ")"); - upsample_nearest2d_backward_kernel(kCPU, grad_input, grad_output, scales_h, scales_w); + return {nbatch, channels, output_height, output_width}; } -} // namespace -Tensor& upsample_nearest2d_out_cpu( - Tensor& output, - const Tensor& input, - IntArrayRef output_size, - c10::optional scales_h, - c10::optional scales_w) { - upsample_nearest2d_out_cpu_template(output, input, output_size, scales_h, scales_w); - return output; +TORCH_META_FUNC(upsample_nearest2d) ( + const Tensor& input, IntArrayRef output_size, c10::optional scales_h, c10::optional scales_w +) { + auto full_output_size = upsample_nearest2d_common_check(input.sizes(), output_size); + + // Allow for empty batch size but not other dimensions + TORCH_CHECK( + input.numel() != 0 || prod_intlist(input.sizes().begin() + 1, input.sizes().end()), + "Non-empty 4D data tensor expected but got a tensor with sizes ", + input.sizes()); + + set_output(full_output_size, input.options()); } -Tensor upsample_nearest2d_cpu( - const Tensor& input, - IntArrayRef output_size, - c10::optional scales_h, - c10::optional scales_w) { - auto output = at::empty({0}, input.options()); - upsample_nearest2d_out_cpu_template(output, input, output_size, scales_h, scales_w); - return output; +TORCH_META_FUNC(upsample_nearest2d_backward) ( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + c10::optional scales_h, + c10::optional scales_w +) { + auto full_output_size = upsample_nearest2d_common_check(input_size, output_size); + + TORCH_CHECK( + grad_output.dim() == 4, + "Expected grad_output to be a tensor of dimension 4 but got: dimension ", grad_output.dim()); + + for (int i = 0; i < 4; ++i) { + TORCH_CHECK( + grad_output.size(i) == full_output_size[i], + "Expected grad_output to have the same shape as output;", + " output.size(", i, ") = ", full_output_size[i], + " but got grad_output.size(", i, ") = ", grad_output.size(i)); + } + + set_output(input_size, grad_output.options()); } -Tensor& upsample_nearest2d_backward_out_cpu( - Tensor& grad_input, - const Tensor& grad_output, +} // namespace meta + +namespace native { + +TORCH_IMPL_FUNC(upsample_nearest2d_out_cpu) ( + const Tensor& input, IntArrayRef output_size, - IntArrayRef input_size, c10::optional scales_h, - c10::optional scales_w) { - upsample_nearest2d_backward_out_cpu_template( - grad_input, grad_output, output_size, input_size, scales_h, scales_w); - return grad_input; + c10::optional scales_w, + Tensor& output +) { + upsample_nearest2d_kernel(kCPU, output, input, scales_h, scales_w); } -Tensor upsample_nearest2d_backward_cpu( +TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cpu) ( const Tensor& grad_output, IntArrayRef output_size, IntArrayRef input_size, c10::optional scales_h, - c10::optional scales_w) { - auto grad_input = at::zeros(input_size, grad_output.options()); - upsample_nearest2d_backward_out_cpu_template( - grad_input, grad_output, output_size, input_size, scales_h, scales_w); - return grad_input; + c10::optional scales_w, + Tensor& grad_input) { + grad_input.zero_(); + upsample_nearest2d_backward_kernel(kCPU, grad_input, grad_output, scales_h, scales_w); } using at::native::upsample::compute_output_size; using at::native::upsample::get_scale_value; -Tensor upsample_nearest2d_cpu( +Tensor upsample_nearest2d( const Tensor& input, c10::optional output_size, c10::optional> scale_factors) { - auto output = at::empty({0}, input.options()); auto osize = compute_output_size(input.sizes(), output_size, scale_factors); auto scale_h = get_scale_value(scale_factors, 0); auto scale_w = get_scale_value(scale_factors, 1); - upsample_nearest2d_out_cpu_template(output, input, osize, scale_h, scale_w); - return output; + return at::upsample_nearest2d(input, osize, scale_h, scale_w); } -Tensor upsample_nearest2d_backward_cpu( +Tensor upsample_nearest2d_backward( const Tensor& grad_output, c10::optional output_size, IntArrayRef input_size, @@ -150,10 +125,7 @@ Tensor upsample_nearest2d_backward_cpu( auto osize = compute_output_size(input_size, output_size, scale_factors); auto scale_h = get_scale_value(scale_factors, 0); auto scale_w = get_scale_value(scale_factors, 1); - auto grad_input = at::zeros(input_size, grad_output.options()); - upsample_nearest2d_backward_out_cpu_template( - grad_input, grad_output, osize, input_size, scale_h, scale_w); - return grad_input; + return at::upsample_nearest2d_backward(grad_output, osize, input_size, scale_h, scale_w); } DEFINE_DISPATCH(upsample_nearest2d_kernel); diff --git a/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp b/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp index ba7f1af7eabb..8e23504cd4fd 100644 --- a/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp +++ b/aten/src/ATen/native/cpu/RangeFactoriesKernel.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index 8153b75aae8c..cf928b0d10ba 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -133,11 +133,13 @@ void magmaSymeig( value_t* w, scalar_t* wA, magma_int_t ldwa, scalar_t* work, magma_int_t lwork, value_t* rwork, magma_int_t lrwork, magma_int_t* iwork, magma_int_t liwork, magma_int_t* info); -template +template void magmaEig( magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, scalar_t *A, magma_int_t lda, - scalar_t *wr, scalar_t *wi, scalar_t *VL, magma_int_t ldvl, - scalar_t *VR, magma_int_t ldvr, scalar_t *work, magma_int_t lwork, magma_int_t *info); + scalar_t *w, scalar_t *VL, magma_int_t ldvl, + scalar_t *VR, magma_int_t ldvr, scalar_t *work, magma_int_t lwork, + value_t *rwork, + magma_int_t *info); template void magmaSvd( @@ -1055,27 +1057,87 @@ void magmaSymeig, float>( ldwa, reinterpret_cast(work), lwork, rwork, lrwork, iwork, liwork, info); AT_CUDA_CHECK(cudaGetLastError()); } - + template<> void magmaEig( - magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, double *A, magma_int_t lda, - double *wr, double *wi, double *VL, magma_int_t ldvl, - double *VR, magma_int_t ldvr, double *work, magma_int_t lwork, magma_int_t *info) { + magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, + double *A, magma_int_t lda, + double *w, + double *VL, magma_int_t ldvl, + double *VR, magma_int_t ldvr, + double *work, magma_int_t lwork, + double *rwork, + magma_int_t *info) { MagmaStreamSyncGuard guard; + // magma [sd]geev wants to separate output arrays: wr and wi for the real + // and imaginary parts + double *wr = w; + double *wi = w + n; + (void)rwork; // unused magma_dgeev(jobvl, jobvr, n, A, lda, wr, wi, VL, ldvl, VR, ldvr, work, lwork, info); AT_CUDA_CHECK(cudaGetLastError()); } template<> void magmaEig( - magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, float *A, magma_int_t lda, - float *wr, float *wi, float *VL, magma_int_t ldvl, - float *VR, magma_int_t ldvr, float *work, magma_int_t lwork, magma_int_t *info) { + magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, + float *A, magma_int_t lda, + float *w, + float *VL, magma_int_t ldvl, + float *VR, magma_int_t ldvr, + float *work, magma_int_t lwork, + float *rwork, + magma_int_t *info) { MagmaStreamSyncGuard guard; + float *wr = w; + float *wi = w + n; + (void)rwork; // unused magma_sgeev(jobvl, jobvr, n, A, lda, wr, wi, VL, ldvl, VR, ldvr, work, lwork, info); AT_CUDA_CHECK(cudaGetLastError()); } +template<> +void magmaEig, double>( + magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, + c10::complex *A, magma_int_t lda, + c10::complex *w, + c10::complex *VL, magma_int_t ldvl, + c10::complex *VR, magma_int_t ldvr, + c10::complex *work, magma_int_t lwork, + double *rwork, + magma_int_t *info) { + MagmaStreamSyncGuard guard; + magma_zgeev(jobvl, jobvr, n, + reinterpret_cast(A), lda, + reinterpret_cast(w), + reinterpret_cast(VL), ldvl, + reinterpret_cast(VR), ldvr, + reinterpret_cast(work), lwork, + rwork, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + +template<> +void magmaEig, float>( + magma_vec_t jobvl, magma_vec_t jobvr, magma_int_t n, + c10::complex *A, magma_int_t lda, + c10::complex *w, + c10::complex *VL, magma_int_t ldvl, + c10::complex *VR, magma_int_t ldvr, + c10::complex *work, magma_int_t lwork, + float *rwork, + magma_int_t *info) { + MagmaStreamSyncGuard guard; + magma_cgeev(jobvl, jobvr, n, + reinterpret_cast(A), lda, + reinterpret_cast(w), + reinterpret_cast(VL), ldvl, + reinterpret_cast(VR), ldvr, + reinterpret_cast(work), lwork, + rwork, info); + AT_CUDA_CHECK(cudaGetLastError()); +} + template<> void magmaSvd( magma_vec_t jobz, magma_int_t m, magma_int_t n, double* A, @@ -1625,26 +1687,34 @@ AT_ERROR("cholesky: MAGMA library not found in " Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) { std::vector infos(batchCount(self), 0); - Tensor self_working_copy; - if (upper) { - self_working_copy = cloneBatchedColumnMajor(self.transpose(-1, -2)); + + Tensor result; + if (self.dim() > 2) { + // MAGMA's batched cholesky operator has an off-by-one error causing IMA + // (see https://github.com/pytorch/pytorch/issues/42666). This code is based + // on the #cloneBatchedColumnMajor function however it pads the input with + // one extra element utilizing the fact that the resize_as_ method preserves + // the storage even if it's larger than the new sizes. This way if MAGMA + // reads off bounds it will still be valid user memory. + const Tensor input = upper ? self : self.transpose(-1, -2); + result = at::empty(input.numel() + 1, input.options()); + result.resize_as_(input).copy_(input).transpose_(-1, -2); } else { - self_working_copy = cloneBatchedColumnMajor(self); + result = cloneBatchedColumnMajor(upper ? self.transpose(-1, -2) : self); } - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "cholesky_cuda", [&]{ - apply_cholesky(self_working_copy, false, infos); - }); + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( + self.scalar_type(), "cholesky_cuda", [&] { + apply_cholesky(result, false, infos); + }); + if (self.dim() > 2) { batchCheckErrors(infos, "cholesky_cuda"); } else { singleCheckErrors(infos[0], "cholesky_cuda"); } - if (upper) { - return self_working_copy.transpose(-1, -2); - } else { - return self_working_copy; - } + + return upper ? result.transpose_(-1, -2) : result; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -2054,13 +2124,13 @@ TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorc "Either transfer the tensor to the CPU before calling torch.eig or recompile with MAGMA."); #else TORCH_INTERNAL_ASSERT(self.device() == at::kCPU, "Internal error: apply_eig needs a CPU tensor"); + using value_t = typename c10::scalar_value_type::type; magma_vec_t jobvr = eigenvectors ? MagmaVec : MagmaNoVec; magma_int_t n = magma_int_cast(self.size(-1), "n"); auto self_data = self.data_ptr(); auto out_eigvals_data = out_eigvals.data_ptr(); scalar_t *wr = out_eigvals_data; - scalar_t *wi = out_eigvals_data+n; scalar_t *vr_data = NULL; magma_int_t ldvr = 1; @@ -2070,17 +2140,22 @@ TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorc ldvr = n; } + value_t *rwork_data = nullptr; + if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + ALLOCATE_ARRAY(rwork_data, value_t, n*2); + } + if (n > 0) { // call magmaEig once to get the optimal size of work_data scalar_t wkopt; magma_int_t info; - magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, wi, NULL, 1, vr_data, ldvr, &wkopt, -1, &info); - magma_int_t lwork = (magma_int_t) wkopt; + magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, &wkopt, -1, rwork_data, &info); + magma_int_t lwork = static_cast(real_impl(wkopt)); // call it a 2nd time to to the actual work scalar_t *work_data = nullptr; ALLOCATE_ARRAY(work_data, scalar_t, lwork); - magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, wi, NULL, 1, vr_data, ldvr, work_data, lwork, &info); + magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, work_data, lwork, rwork_data, &info); *info_ptr = info; } #endif @@ -2103,13 +2178,18 @@ std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvector // tensors holding the results. We use empty_strided to make them column-ordered auto options = self.options().device(at::kCPU).memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); - auto out_eigvals = at::empty_strided({n, 2}, {1, n}, options); + Tensor out_eigvals; + if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { + out_eigvals = at::empty({n}, options); + } else { + out_eigvals = at::empty_strided({n, 2}, {1, n}, options); + } auto out_eigvecs = eigenvectors ? at::empty_strided({n, n}, {1, n}, options) : Tensor(); int64_t info; - AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "eig_cuda", [&]{ + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cuda", [&]{ apply_eig(self_working_copy, eigenvectors, out_eigvals, out_eigvecs, &info); }); singleCheckErrors(info, "eig_cuda"); @@ -2200,7 +2280,7 @@ AT_ERROR("svd: MAGMA library not found in " #endif } -std::tuple _svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) { +std::tuple _svd_helper_cuda_legacy(const Tensor& self, bool some, bool compute_uv) { std::vector infos(batchCount(self), 0); int64_t m = self.size(-2), n = self.size(-1); int64_t k = std::min(m, n); @@ -2252,10 +2332,20 @@ std::tuple _svd_helper_cuda(const Tensor& self, bool som VT_working_copy = same_stride_to(VT_working_copy, self.options()).zero_(); } // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + // Note that the 'apply_svd' routine returns VT = V^T (for real inputs) or VT = V^H (for complex inputs), not V. + VT_working_copy = VT_working_copy.conj(); VT_working_copy.transpose_(-2, -1); return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); } +std::tuple _svd_helper_cuda(const Tensor& self, bool some, bool compute_uv) { +#ifdef USE_CUSOLVER + return _svd_helper_cuda_lib(self, some, compute_uv); +#else + return _svd_helper_cuda_legacy(self, some, compute_uv); +#endif +} + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu index 534f257d55bb..3431091661dd 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu @@ -50,18 +50,10 @@ static void apply_batched_inverse_lib(Tensor& self, Tensor& self_inv, Tensor& in auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); - if (use_loop_launch(batch_size, n)) { - auto main_stream = at::cuda::getCurrentCUDAStream(); - - at::cuda::CUDAEvent main_event; - main_event.record(main_stream); - + // Heuristic: For small batch size or large matrix size, we use for-loop to iterate over the batches instead of + // calling the batched cublas routine. + if (batch_size <= 8 || /* batch_size > 8 && */ n >= 512) { for (int64_t i = 0; i < batch_size; i++) { - auto stream = at::cuda::getStreamFromPool(); - at::cuda::CUDAStreamGuard guard(stream); - - main_event.block(stream); - auto dataPtr = allocator.allocate(sizeof(int) * lda); int* pivot = reinterpret_cast(dataPtr.get()); @@ -70,21 +62,17 @@ static void apply_batched_inverse_lib(Tensor& self, Tensor& self_inv, Tensor& in _apply_single_inverse_helper( &self_data[i * self_mat_stride], &self_inv_data[i * self_inv_mat_stride], pivot, infos_getrf_working_ptr, infos_getrs_working_ptr, n, lda); - - at::cuda::CUDAEvent finished; - finished.record(stream); - finished.block(main_stream); } } else { // cublas batched kernels require input be "device array of device pointers" Tensor self_array = at::arange( - reinterpret_cast(self_data), - reinterpret_cast(&self_data[(batch_size-1) * self_mat_stride]) + 1, - static_cast(self_mat_stride * sizeof(scalar_t)), self.options().dtype(at::kLong)); + reinterpret_cast(self_data), + reinterpret_cast(&self_data[(batch_size-1) * self_mat_stride]) + 1, + static_cast(self_mat_stride * sizeof(scalar_t)), self.options().dtype(at::kLong)); Tensor self_inv_array = at::arange( - reinterpret_cast(self_inv_data), - reinterpret_cast(&self_inv_data[(batch_size-1) * self_inv_mat_stride]) + 1, - static_cast(self_inv_mat_stride * sizeof(scalar_t)), self.options().dtype(at::kLong)); + reinterpret_cast(self_inv_data), + reinterpret_cast(&self_inv_data[(batch_size-1) * self_inv_mat_stride]) + 1, + static_cast(self_inv_mat_stride * sizeof(scalar_t)), self.options().dtype(at::kLong)); auto dataPtr = allocator.allocate(sizeof(int)*batch_size*lda); int* ipiv_array = reinterpret_cast(dataPtr.get()); @@ -134,6 +122,7 @@ Tensor& _linalg_inv_out_helper_cuda_lib(Tensor& result, Tensor& infos_getrf, Ten return result; } +// entrance of calculations of `inverse` using cusolver getrf + getrs, cublas getrfBatched + getriBatched Tensor _inverse_helper_cuda_lib(const Tensor& self) { Tensor self_working_copy = cloneBatchedColumnMajor(self); Tensor self_inv_working_copy = column_major_identity_matrix_like(self_working_copy); @@ -161,6 +150,151 @@ Tensor _inverse_helper_cuda_lib(const Tensor& self) { return self_inv_working_copy; } +// call cusolver gesvdj function to calculate svd +template +inline static void _apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor& S, Tensor& VT, Tensor& infos, bool compute_uv, bool some) { + using value_t = typename c10::scalar_value_type::type; + auto self_data = self.data_ptr(); + auto U_data = U.data_ptr(); + auto S_data = S.data_ptr(); + auto VT_data = VT.data_ptr(); + auto self_stride = matrixStride(self); + auto U_stride = matrixStride(U); + auto S_stride = S.size(-1); + auto VT_stride = matrixStride(VT); + + int batchsize = cuda_int_cast(batchCount(self), "batch size"); + int m = cuda_int_cast(self.size(-2), "m"); + int n = cuda_int_cast(self.size(-1), "n"); + + for(int i = 0; i < batchsize; i++){ + // gesvdj_params controls the numerical accuracy of cusolver gesvdj iterations on GPU + gesvdjInfo_t gesvdj_params; + TORCH_CUSOLVER_CHECK(cusolverDnCreateGesvdjInfo(&gesvdj_params)); + // TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetTolerance(gesvdj_params, 1.0e-7)); + // TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetMaxSweeps(gesvdj_params, 15)); + + auto handle = at::cuda::getCurrentCUDASolverDnHandle(); + auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; + at::cuda::solver::gesvdj( + handle, jobz, /*econ=*/ some ? 1 : 0, m, n, + self_data + i * self_stride, + m, + S_data + i * S_stride, + U_data + i * U_stride, + m, + VT_data + i * VT_stride, + n, + infos.data_ptr() + i, + gesvdj_params + ); + + TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params)); + } +} + +// wrapper around _apply_svd_lib_gesvdj that handles dtype dispatch, +// creates a working copy of the input, and creates V^H from the V returned by gesvdj +inline static void apply_svd_lib_gesvdj(const Tensor& self, Tensor& U, Tensor& S, Tensor& VT, Tensor& infos, bool compute_uv, bool some) { + const int64_t m = self.size(-2); + const int64_t n = self.size(-1); + Tensor self_working_copy = cloneBatchedColumnMajor(self); + VT = VT.transpose(-2, -1); // gesvdj returns V instead of V^H + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda_gesvdj", [&] { + _apply_svd_lib_gesvdj(self_working_copy, U, S, VT, infos, compute_uv, some); + }); +} + +// call cusolver gesvdj batched function to calculate svd +template +inline static void _apply_svd_lib_gesvdjBatched(const Tensor& self, Tensor& U, Tensor& S, Tensor& VT, Tensor& infos, bool compute_uv) { + using value_t = typename c10::scalar_value_type::type; + auto self_data = self.data_ptr(); + auto U_data = U.data_ptr(); + auto S_data = S.data_ptr(); + auto VT_data = VT.data_ptr(); + auto self_stride = matrixStride(self); + auto U_stride = matrixStride(U); + auto S_stride = S.size(-1); + auto VT_stride = matrixStride(VT); + + int batchsize = cuda_int_cast(batchCount(self), "batch size"); + int m = cuda_int_cast(self.size(-2), "m"); + int n = cuda_int_cast(self.size(-1), "n"); + + TORCH_INTERNAL_ASSERT(m <= 32 && n <= 32, "gesvdjBatched requires both matrix dimensions not greater than 32, but got " + "m = ", m, " n = ", n); + + // gesvdj_params controls the numerical accuracy of cusolver gesvdj iterations on GPU + gesvdjInfo_t gesvdj_params; + TORCH_CUSOLVER_CHECK(cusolverDnCreateGesvdjInfo(&gesvdj_params)); + // TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetTolerance(gesvdj_params, 1.0e-7)); + // TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetMaxSweeps(gesvdj_params, 15)); + TORCH_CUSOLVER_CHECK(cusolverDnXgesvdjSetSortEig(gesvdj_params, 1)); + + auto handle = at::cuda::getCurrentCUDASolverDnHandle(); + auto jobz = compute_uv ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR; + at::cuda::solver::gesvdjBatched( + handle, jobz, m, n, self_data, m, S_data, U_data, m, VT_data, n, + infos.data_ptr(), gesvdj_params, batchsize + ); + + TORCH_CUSOLVER_CHECK(cusolverDnDestroyGesvdjInfo(gesvdj_params)); +} + +// wrapper around _apply_svd_lib_gesvdjBatched that handles dtype dispatch, +// creates a working copy of the input, and creates V^H from the V returned by gesvdj +inline static void apply_svd_lib_gesvdjBatched(const Tensor& self, Tensor& U, Tensor& S, Tensor& VT, Tensor& infos, bool compute_uv) { + const int64_t m = self.size(-2); + const int64_t n = self.size(-1); + Tensor self_working_copy = cloneBatchedColumnMajor(self); + VT = VT.transpose(-2, -1); // gesvdj returns V instead of V^H + + AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "svd_cuda_gesvdjBatched", [&] { + _apply_svd_lib_gesvdjBatched(self_working_copy, U, S, VT, infos, compute_uv); + }); +} + +// entrance of calculations of `svd` using cusolver gesvdj and gesvdjBatched +std::tuple _svd_helper_cuda_lib(const Tensor& self, bool some, bool compute_uv) { + const int64_t batch_size = batchCount(self); + at::Tensor infos = at::zeros({batch_size}, self.options().dtype(at::kInt)); + const int64_t m = self.size(-2); + const int64_t n = self.size(-1); + const int64_t k = std::min(m, n); + + Tensor U_working_copy, S_working_copy, VT_working_copy; + std::tie(U_working_copy, S_working_copy, VT_working_copy) = \ + _create_U_S_VT(self, some, compute_uv, /* svd_use_cusolver = */ true); + // U, S, V working copies are already column majored now + + if (self.numel() > 0) { + // heuristic for using `gesvdjBatched` over `gesvdj` + if (m <= 32 && n <= 32 && batch_size > 1 && (!some || m == n)) { + apply_svd_lib_gesvdjBatched(self, U_working_copy, S_working_copy, VT_working_copy, infos, compute_uv); + } else { + apply_svd_lib_gesvdj(self, U_working_copy, S_working_copy, VT_working_copy, infos, compute_uv, some); + } + + // A device-host sync will be performed. + batchCheckErrors(infos, "svd_cuda"); + + if (compute_uv) { + if (some) { + VT_working_copy = VT_working_copy.narrow(-2, 0, k); + } + } else { + VT_working_copy.zero_(); + U_working_copy.zero_(); + } + } + + // so far we have computed VT, but torch.svd returns V instead. Adjust accordingly. + VT_working_copy.transpose_(-2, -1); + return std::make_tuple(U_working_copy, S_working_copy, VT_working_copy); +} + }} // namespace at::native #endif // USE_CUSOLVER diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h index 2be18137a64f..65fc9af1f654 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebraLib.h @@ -17,9 +17,13 @@ namespace at { namespace native { +// entrance of calculations of `inverse` using cusolver getrf + getrs, cublas getrfBatched + getriBatched Tensor _inverse_helper_cuda_lib(const Tensor& self); Tensor& _linalg_inv_out_helper_cuda_lib(Tensor& result, Tensor& infos_getrf, Tensor& infos_getrs); +// entrance of calculations of `svd` using cusolver gesvdj and gesvdjBatched +std::tuple _svd_helper_cuda_lib(const Tensor& self, bool some, bool compute_uv); + }} // namespace at::native #endif // USE_CUSOLVER diff --git a/aten/src/ATen/native/cuda/MiscUtils.h b/aten/src/ATen/native/cuda/MiscUtils.h index 8f78e8d78003..8baa0703d5eb 100644 --- a/aten/src/ATen/native/cuda/MiscUtils.h +++ b/aten/src/ATen/native/cuda/MiscUtils.h @@ -110,15 +110,5 @@ static inline Storage pin_memory(int64_t size) { /*resizable=*/false); } -// heuristic: -// cublas_x_batched doesn't work very well for small batchsize -// cublas_x_batched is intended to be used for matrices of small sizes where the launch overhead is a significant factor. -// with use_loop_launch = True, we will loop through all batches, and launch single matrix cusolver/cublas kernels -// (This heuristic was originally tested in getrf + getrs(getri), which may not work well on other kernels. ) -inline static bool use_loop_launch(int batch_size, int matrix_size) { - return (batch_size <= 8) || \ - (/* batch_size > 8 && */ matrix_size >= 512); -} - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu index 0ac02e292b28..1b7db0272a24 100644 --- a/aten/src/ATen/native/cuda/UpSampleNearest2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleNearest2d.cu @@ -292,44 +292,24 @@ static void upsample_nearest2d_backward_out_cuda_template( } // namespace -Tensor& upsample_nearest2d_out_cuda( - Tensor& output, +TORCH_IMPL_FUNC(upsample_nearest2d_out_cuda) ( const Tensor& input, IntArrayRef output_size, c10::optional scales_h, - c10::optional scales_w) { + c10::optional scales_w, + Tensor& output) { upsample_nearest2d_out_cuda_template(output, input, output_size, scales_h, scales_w); - return output; } -Tensor upsample_nearest2d_cuda(const Tensor& input, IntArrayRef output_size, c10::optional scales_h, c10::optional scales_w) { - Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - upsample_nearest2d_out_cuda_template(output, input, output_size, scales_h, scales_w); - return output; -} - -Tensor& upsample_nearest2d_backward_out_cuda( - Tensor& grad_input, +TORCH_IMPL_FUNC(upsample_nearest2d_backward_out_cuda) ( const Tensor& grad_output, IntArrayRef output_size, IntArrayRef input_size, c10::optional scales_h, - c10::optional scales_w) { + c10::optional scales_w, + Tensor& grad_input) { upsample_nearest2d_backward_out_cuda_template( grad_input, grad_output, output_size, input_size, scales_h, scales_w); - return grad_input; -} - -Tensor upsample_nearest2d_backward_cuda( - const Tensor& grad_output, - IntArrayRef output_size, - IntArrayRef input_size, - c10::optional scales_h, - c10::optional scales_w) { - Tensor grad_input = at::empty_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT); - upsample_nearest2d_backward_out_cuda_template( - grad_input, grad_output, output_size, input_size, scales_h, scales_w); - return grad_input; } using at::native::upsample::compute_output_size; diff --git a/aten/src/ATen/native/cudnn/ConvShared.cpp b/aten/src/ATen/native/cudnn/ConvShared.cpp index e360008e2707..5e3d5e05f3e9 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.cpp +++ b/aten/src/ATen/native/cudnn/ConvShared.cpp @@ -54,7 +54,7 @@ namespace at { namespace native { // --------------------------------------------------------------------- // -// ConvolutionParams and ConvolutionArgs +// ConvolutionParams // // --------------------------------------------------------------------- @@ -86,10 +86,11 @@ void setConvolutionParams( memset(params, 0, sizeof(ConvolutionParams)); params->dataType = dataType; // ASSERT(weight.dim() == input.dim()) - for (int i = 0; i != input.dim(); ++i) { - params->input_size[i] = (int) input.size(i); - params->input_stride[i] = (int) input.stride(i); - params->weight_size[i] = (int) weight.size(i); + params->input_dim = input.dim(); + params->memory_format = input.suggest_memory_format(); + for (int i = 0; i != params->input_dim; ++i) { + params->input_size[i] = (int) input.sizes()[i]; + params->weight_size[i] = (int) weight.sizes()[i]; } // ASSERT(padding.size() == stride.size()) // ASSERT(padding.size() == dilation.size()) @@ -105,21 +106,21 @@ void setConvolutionParams( params->allow_tf32 = allow_tf32; } -std::string repro_from_args(const ConvolutionArgs& args) { +std::string repro_from_args(const ConvolutionParams& params) { auto pybool = [](bool b) -> const char* { return b ? "True" : "False"; }; std::string partial_dtype; - switch (args.params.dataType) { + switch (params.dataType) { case CUDNN_DATA_FLOAT: partial_dtype = "float"; break; case CUDNN_DATA_DOUBLE: partial_dtype = "double"; break; case CUDNN_DATA_HALF: partial_dtype = "half"; break; default: partial_dtype = "unsupported"; } const std::string full_dtype = "torch." + partial_dtype; - const int out_channels = args.weight.sizes()[0]; - const int in_channels = args.weight.sizes()[1] * args.params.groups; - const size_t dim = args.input.sizes().size(); + const int out_channels = params.weight_size[0]; + const int in_channels = params.weight_size[1] * params.groups; + const size_t dim = params.input_dim; const std::string channels_last_xd = dim == 4 ? "channels_last" : "channels_last_3d"; - const std::string to_channels_last = args.input.suggest_memory_format() == at::MemoryFormat::ChannelsLast \ + const std::string to_channels_last = params.memory_format == at::MemoryFormat::ChannelsLast \ ? ".to(memory_format=torch." + channels_last_xd + ")" : ""; std::ostringstream ss; @@ -128,36 +129,22 @@ std::string repro_from_args(const ConvolutionArgs& args) { ss << "import torch\n"; ss << "torch.backends.cuda.matmul.allow_tf32 = " << pybool(at::globalContext().allowTF32CuBLAS()) << "\n"; ss << "torch.backends.cudnn.benchmark = " << pybool(at::globalContext().benchmarkCuDNN()) << "\n"; - ss << "torch.backends.cudnn.deterministic = " << pybool(args.params.deterministic) << "\n"; - ss << "torch.backends.cudnn.allow_tf32 = " << pybool(args.params.allow_tf32) << "\n"; - ss << "data = torch.randn(" << args.input.sizes() << ", dtype=" << full_dtype << ", "; + ss << "torch.backends.cudnn.deterministic = " << pybool(params.deterministic) << "\n"; + ss << "torch.backends.cudnn.allow_tf32 = " << pybool(params.allow_tf32) << "\n"; + ss << "data = torch.randn(" << ArrayRef(params.input_size, dim) << ", dtype=" << full_dtype << ", "; ss << "device='cuda', requires_grad=True)" << to_channels_last << "\n"; ss << "net = torch.nn.Conv" << dim-2 << "d(" << in_channels << ", " << out_channels << ", "; - ss << "kernel_size=" << args.weight.sizes().slice(2) << ", "; - ss << "padding=" << ArrayRef(args.params.padding, dim-2) << ", "; - ss << "stride=" << ArrayRef(args.params.stride, dim-2) << ", "; - ss << "dilation=" << ArrayRef(args.params.dilation, dim-2) << ", "; - ss << "groups=" << args.params.groups << ")\n"; + ss << "kernel_size=" << ArrayRef(¶ms.weight_size[2], dim - 2) << ", "; + ss << "padding=" << ArrayRef(params.padding, dim-2) << ", "; + ss << "stride=" << ArrayRef(params.stride, dim-2) << ", "; + ss << "dilation=" << ArrayRef(params.dilation, dim-2) << ", "; + ss << "groups=" << params.groups << ")\n"; ss << "net = net.cuda()." << partial_dtype << "()" << to_channels_last << "\n"; ss << "out = net(data)\n"; ss << "out.backward(torch.randn_like(out))\n"; ss << "torch.cuda.synchronize()\n\n"; - - return ss.str(); -} - -std::ostream& operator<<(std::ostream & out, const ConvolutionArgs& args) { - out << repro_from_args(args) // already has a trailing newline - << args.params // already has a trailing newline - << "input: " << args.idesc // already has a trailing newline - << "output: " << args.odesc // already has a trailing newline - << "weight: " << args.wdesc // already has a trailing newline - << "Pointer addresses: " << "\n" - << " input: " << args.input.data_ptr() << "\n" - << " output: " << args.output.data_ptr() << "\n" - << " weight: " << args.weight.data_ptr() << "\n"; - return out; + return ss.str(); } // --------------------------------------------------------------------- @@ -238,13 +225,16 @@ Tensor cudnn_convolution_forward( checkAllSameType(c, {input, weight}); checkAllSameGPU(c, {input, weight}); - auto layout = cudnn_conv_use_channels_last(*input, *weight) ? + auto memory_format = cudnn_conv_use_channels_last(*input, *weight) ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; - auto output_t = at::empty( + auto output_t = at::native::empty_cuda( conv_output_size(input->sizes(), weight->sizes(), padding, stride, dilation), - input->options(), - layout); + /*dtype=*/input->scalar_type(), + /*layout=*/c10::nullopt, + /*device=*/kCUDA, + /*pin_memory=*/c10::nullopt, + /*memory_format=*/memory_format); if (output_t.numel() == 0) { return output_t; @@ -255,11 +245,11 @@ Tensor cudnn_convolution_forward( convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups); // See #4500 - Tensor weight_contig = weight->contiguous(layout); + Tensor weight_contig = weight->contiguous(memory_format); // Make sure that NC11 strides follow formula - weight_contig.resize_(weight_contig.sizes(), layout); - Tensor input_contig = input->contiguous(layout); - input_contig.resize_(input_contig.sizes(), layout); + weight_contig.resize_(weight_contig.sizes(), memory_format); + Tensor input_contig = input->contiguous(memory_format); + input_contig.resize_(input_contig.sizes(), memory_format); raw_cudnn_convolution_forward_out( *output, input_contig, weight_contig, @@ -340,21 +330,27 @@ Tensor cudnn_convolution_backward_input( checkAllSameType(c, {grad_output, weight}); checkAllSameGPU(c, {grad_output, weight}); - auto layout = cudnn_conv_use_channels_last(*grad_output, *weight) ? + auto memory_format = cudnn_conv_use_channels_last(*grad_output, *weight) ? at::MemoryFormat::ChannelsLast : at::MemoryFormat::Contiguous; - auto grad_input_t = at::empty(input_size, grad_output->options(), layout); + auto grad_input_t = at::native::empty_cuda( + input_size, + /*dtype=*/grad_output->scalar_type(), + /*layout=*/c10::nullopt, + /*device=*/kCUDA, + /*pin_memory=*/c10::nullopt, + /*memory_format=*/memory_format); // Avoid "grad_input" when this is being used as transposed convolution TensorArg grad_input{ grad_input_t, "result", 0 }; convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups); // See #4500 - Tensor weight_contig = weight->contiguous(layout); + Tensor weight_contig = weight->contiguous(memory_format); // Make sure that NC11 strides follow formula - weight_contig.resize_(weight_contig.sizes(), layout); + weight_contig.resize_(weight_contig.sizes(), memory_format); - Tensor grad_output_contig = grad_output->contiguous(layout); - grad_output_contig.resize_(grad_output_contig.sizes(), layout); + Tensor grad_output_contig = grad_output->contiguous(memory_format); + grad_output_contig.resize_(grad_output_contig.sizes(), memory_format); raw_cudnn_convolution_backward_input_out( *grad_input, grad_output_contig, weight_contig, diff --git a/aten/src/ATen/native/cudnn/ConvShared.h b/aten/src/ATen/native/cudnn/ConvShared.h index e30b5c7be581..b74a0477b1f5 100644 --- a/aten/src/ATen/native/cudnn/ConvShared.h +++ b/aten/src/ATen/native/cudnn/ConvShared.h @@ -19,7 +19,8 @@ struct ConvolutionParams { cudnnDataType_t dataType; int input_size[2 + max_dim]; - int input_stride[2 + max_dim]; + uint8_t input_dim; + at::MemoryFormat memory_format; int weight_size[2 + max_dim]; int padding[max_dim]; int stride[max_dim]; @@ -31,20 +32,6 @@ struct ConvolutionParams // forward and backward, so you can reuse the benchmark entry, }; -// Convenience struct for passing around descriptors and data -// pointers -struct ConvolutionArgs { - cudnnHandle_t handle; - ConvolutionParams params; - TensorDescriptor idesc, odesc; - FilterDescriptor wdesc; - const Tensor& input, output, weight; - ConvolutionDescriptor cdesc; - - ConvolutionArgs(const Tensor& input, const Tensor& output, const Tensor& weight) : input(input), output(output), weight(weight) { - } -}; - std::ostream& operator<<(std::ostream & out, const ConvolutionParams& params); // NB: This can't be a constructor, because then ConvolutionParams @@ -58,9 +45,8 @@ void setConvolutionParams( IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool deterministic, bool allow_tf32); -std::string repro_from_args(const ConvolutionArgs& args); +std::string repro_from_args(const ConvolutionParams& args); -std::ostream& operator<<(std::ostream & out, const ConvolutionArgs& args); // --------------------------------------------------------------------- // diff --git a/aten/src/ATen/native/cudnn/Conv_v7.cpp b/aten/src/ATen/native/cudnn/Conv_v7.cpp index 5e1f124f1185..117f86504cca 100644 --- a/aten/src/ATen/native/cudnn/Conv_v7.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v7.cpp @@ -61,8 +61,33 @@ constexpr size_t operator "" _TiB(unsigned long long n) { namespace at { namespace native { -// TODO: Go through all the checking code again and make sure -// we haven't missed anything. +// Convenience struct for passing around descriptors and data +// pointers +struct ConvolutionArgs { + cudnnHandle_t handle; + ConvolutionParams params; + TensorDescriptor idesc, odesc; + FilterDescriptor wdesc; + const Tensor& input, output, weight; + ConvolutionDescriptor cdesc; + + ConvolutionArgs(const Tensor& input, const Tensor& output, const Tensor& weight) : input(input), output(output), weight(weight) { + } +}; + +std::ostream& operator<<(std::ostream & out, const ConvolutionArgs& args) { + out << repro_from_args(args.params) // already has a trailing newline + << args.params // already has a trailing newline + << "input: " << args.idesc // already has a trailing newline + << "output: " << args.odesc // already has a trailing newline + << "weight: " << args.wdesc // already has a trailing newline + << "Pointer addresses: " << "\n" + << " input: " << args.input.data_ptr() << "\n" + << " output: " << args.output.data_ptr() << "\n" + << " weight: " << args.weight.data_ptr() << "\n"; + + return out; +} // --------------------------------------------------------------------- // diff --git a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm index d5ce5fe2ec68..d18951ccc303 100644 --- a/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm +++ b/aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm @@ -420,7 +420,7 @@ bool test_upsampling_nearest2d_vec() { __block std::vector size{1, 48, 24, 24}; return TEST(size, __PRETTY_FUNCTION__, ^bool { auto X1 = torch::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)); - auto Y1 = torch::native::upsample_nearest2d_cpu( + auto Y1 = at::native::upsample_nearest2d( X1, c10::optional({}), c10::optional>({2, 2})); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a2b615c1d94e..906dc08eddd5 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -332,7 +332,6 @@ structured_delegate: add.out variants: function, method dispatch: - CPU, CUDA: add SparseCPU, SparseCUDA: add_sparse MkldnnCPU: mkldnn_add @@ -340,7 +339,6 @@ variants: method structured_delegate: add.out dispatch: - CPU, CUDA: add_ SparseCPU, SparseCUDA: add_sparse_ MkldnnCPU: mkldnn_add_ @@ -5927,14 +5925,13 @@ CUDA: legacy::cuda::_th_geqrf - func: orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!) - use_c10_dispatcher: hacky_wrapper_for_legacy_signatures dispatch: - CPU: legacy::cpu::_th_orgqr_out + CPU: orgqr_out - func: orgqr(Tensor self, Tensor input2) -> Tensor variants: method, function dispatch: - CPU: legacy::cpu::_th_orgqr + CPU: orgqr - func: ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!) use_c10_dispatcher: hacky_wrapper_for_legacy_signatures @@ -8161,15 +8158,12 @@ - func: upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> Tensor python_module: nn dispatch: - CPU: upsample_nearest2d_cpu - CUDA: upsample_nearest2d_cuda - QuantizedCPU: upsample_nearest2d_quantized_cpu + DefaultBackend: upsample_nearest2d - func: upsample_nearest2d_backward.vec(Tensor grad_output, int[]? output_size, int[] input_size, float[]? scale_factors) -> Tensor python_module: nn dispatch: - CPU: upsample_nearest2d_backward_cpu - CUDA: upsample_nearest2d_backward_cuda + DefaultBackend: upsample_nearest2d_backward - func: upsample_nearest3d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> Tensor python_module: nn @@ -8313,31 +8307,28 @@ structured_delegate: upsample_nearest1d_backward.grad_input - func: upsample_nearest2d.out(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) - use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn + structured: True dispatch: CPU: upsample_nearest2d_out_cpu CUDA: upsample_nearest2d_out_cuda - func: upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn + structured_delegate: upsample_nearest2d.out dispatch: - CPU: upsample_nearest2d_cpu - CUDA: upsample_nearest2d_cuda QuantizedCPU: upsample_nearest2d_quantized_cpu - func: upsample_nearest2d_backward.grad_input(Tensor grad_output, int[2] output_size, int[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) - use_c10_dispatcher: hacky_wrapper_for_legacy_signatures python_module: nn + structured: True dispatch: CPU: upsample_nearest2d_backward_out_cpu CUDA: upsample_nearest2d_backward_out_cuda - func: upsample_nearest2d_backward(Tensor grad_output, int[2] output_size, int[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor python_module: nn - dispatch: - CPU: upsample_nearest2d_backward_cpu - CUDA: upsample_nearest2d_backward_cuda + structured_delegate: upsample_nearest2d_backward.grad_input - func: upsample_nearest3d.out(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) use_c10_dispatcher: hacky_wrapper_for_legacy_signatures diff --git a/aten/src/ATen/native/quantized/affine_quantizer.cpp b/aten/src/ATen/native/quantized/affine_quantizer.cpp index ecbe1de4bbfa..a9cd3aab993a 100644 --- a/aten/src/ATen/native/quantized/affine_quantizer.cpp +++ b/aten/src/ATen/native/quantized/affine_quantizer.cpp @@ -290,219 +290,5 @@ Tensor dequantize_tensor_per_channel_float_qparams( return rtensor; } -#ifdef USE_FBGEMM -// Note: quantize_val is only explicitly used in test outside of this file -template -T quantize_val(double scale, int64_t zero_point, float value) { - // Internally, fbgemm::Quantize uses std::nearbyint. - // std::nearbyint results in nearest integer value according to the current - // rounding mode and the default rounding mode is rounds to even in half-way - // cases in most popular processor architectures like x86 and ARM. This is - // typically faster than an alternatives like std::round that rounds half-way - // cases away from zero, and can be consistent with SIMD implementations for - // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with - // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode. - int32_t qvalue; - qvalue = fbgemm::Quantize( - value, - static_cast(zero_point), - static_cast(scale), - /*result_precision=*/CHAR_BIT * sizeof(typename T::underlying)); - return static_cast(qvalue); -} - -template -void quantize_vec( - double scale, - int64_t zero_point, - const float* src, - T* dst, - size_t count) { - fbgemm::Quantize( - src, - (typename T::underlying*)dst, - count, - fbgemm::TensorQuantizationParams{ - (float)scale, (int32_t)zero_point, precision}); -} - -template -inline float dequantize_val(double scale, int64_t zero_point, T value) { - fbgemm::TensorQuantizationParams qparams; - qparams.scale = static_cast(scale); - qparams.zero_point = static_cast(zero_point); - return fbgemm::Dequantize(value.val_, qparams); -} -#else // USE_FBGEMM - -#if defined(__ANDROID__) && !defined(__NDK_MAJOR__) -template -inline float Round(const float x) { - return ::nearbyintf(x); -} -inline double Round(const double x) { - return ::nearbyint(x); -} -#else -template -inline T Round(const T x) { - return std::nearbyint(x); -} -#endif - -template -T quantize_val(double scale, int64_t zero_point, float value) { - // std::nearbyint results in nearest integer value according to the current - // rounding mode and the default rounding mode is rounds to even in half-way - // cases in most popular processor architectures like x86 and ARM. This is - // typically faster than an alternatives like std::round that rounds half-way - // cases away from zero, and can be consistent with SIMD implementations for - // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with - // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode. - int64_t qvalue; - constexpr int64_t qmin = std::numeric_limits::min(); - constexpr int64_t qmax = std::numeric_limits::max(); - float inv_scale = 1.0f / static_cast(scale); - qvalue = static_cast(zero_point + Round(value * inv_scale)); - qvalue = std::max(qvalue, qmin); - qvalue = std::min(qvalue, qmax); - return static_cast(qvalue); -} - -uint8_t quantize_val_arm( - const float scale, - const int32_t zero_point, - const float value) { - const int32_t qmin = std::numeric_limits::min(); - const int32_t qmax = std::numeric_limits::max(); - float inv_scale = 1.0f / scale; - auto r = zero_point + static_cast(Round(value * inv_scale)); - r = std::max(r, qmin); - r = std::min(r, qmax); - return static_cast(r); -} - -template -void quantize_vec( - double scale, - int64_t zero_point, - const float* src, - T* dst, - size_t count) { - checkZeroPoint("quantize_vec", zero_point); - for (int64_t i = 0; i < count; ++i) { - dst[i] = quantize_val(scale, zero_point, src[i]); - } -} - -template -TORCH_API float dequantize_val(double scale, int64_t zero_point, T value) { - // We need to convert the qint8 value to float to ensure the subtraction - // subexpression returns a float - return (static_cast(value.val_) - zero_point) * scale; -} -#endif // USE_FBGEMM - -/* -* Quantize value based on the following equation -* Xq = Round(Xf * inv_scale + zero_point) -* where zero_point is in float. -* -* Note: For the case of embedding quantization we will set zero_point -* to (-Xmin/scale), where Xmin is the min value in input tensor row. -*/ -int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax) { - int qvalue; - - float inv_scale = scale == 0 ? 1.0f : 1.0f / scale; - qvalue = lrintf(value * inv_scale + zero_point); - qvalue = std::max(qmin, std::min(qvalue, qmax)); - return qvalue; -} - -template -DST_T requantize_val( - double src_scale, - int64_t src_zero_point, - double dst_scale, - int64_t dst_zero_point, - SRC_T src) { - const auto dq = dequantize_val(src_scale, src_zero_point, src); - return quantize_val(dst_scale, dst_zero_point, dq); -} - -template -DST_T requantize_from_int(double multiplier, int64_t zero_point, int64_t src) { - int64_t quantize_down = - zero_point + lrintf(src * static_cast(multiplier)); - int32_t min = std::numeric_limits::min(); - int32_t max = std::numeric_limits::max(); - return static_cast( - std::min(std::max(quantize_down, min), max)); -} - -template TORCH_API qint8 -quantize_val(double scale, int64_t zero_point, float value); -template TORCH_API quint8 -quantize_val(double scale, int64_t zero_point, float value); -template TORCH_API qint32 -quantize_val(double scale, int64_t zero_point, float value); -template TORCH_API void quantize_vec( - double scale, - int64_t zero_point, - const float* src, - c10::qint8* dst, - size_t count); -template TORCH_API void quantize_vec( - double scale, - int64_t zero_point, - const float* src, - c10::quint8* dst, - size_t count); -template TORCH_API void quantize_vec( - double scale, - int64_t zero_point, - const float* src, - c10::qint32* dst, - size_t count); - -template TORCH_API float dequantize_val( - double scale, - int64_t zero_point, - qint8 value); -template TORCH_API float dequantize_val( - double scale, - int64_t zero_point, - quint8 value); -template TORCH_API float dequantize_val( - double scale, - int64_t zero_point, - qint32 value); - -template TORCH_API qint8 -requantize_val(double, int64_t, double, int64_t, qint8); -template TORCH_API quint8 -requantize_val(double, int64_t, double, int64_t, qint8); -template TORCH_API qint32 -requantize_val(double, int64_t, double, int64_t, qint8); -template TORCH_API qint8 -requantize_val(double, int64_t, double, int64_t, quint8); -template TORCH_API quint8 -requantize_val(double, int64_t, double, int64_t, quint8); -template TORCH_API qint32 -requantize_val(double, int64_t, double, int64_t, quint8); -template TORCH_API qint8 -requantize_val(double, int64_t, double, int64_t, qint32); -template TORCH_API quint8 -requantize_val(double, int64_t, double, int64_t, qint32); -template TORCH_API qint32 -requantize_val(double, int64_t, double, int64_t, qint32); - -template TORCH_API qint8 requantize_from_int(double, int64_t, int64_t); -template TORCH_API quint8 -requantize_from_int(double, int64_t, int64_t); -template TORCH_API qint32 -requantize_from_int(double, int64_t, int64_t); - } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/affine_quantizer.h b/aten/src/ATen/native/quantized/affine_quantizer.h index d583106116e6..4e3c58f6ff81 100644 --- a/aten/src/ATen/native/quantized/affine_quantizer.h +++ b/aten/src/ATen/native/quantized/affine_quantizer.h @@ -2,6 +2,7 @@ #include #include +#include namespace at { namespace native { @@ -111,22 +112,6 @@ DECLARE_DISPATCH( dequantize_tensor_per_tensor_affine_sub_byte_fn, dequantize_tensor_per_tensor_affine_sub_byte_stub); -// Quantize a float value into a uint value given scale and zero_point -template -TORCH_API T quantize_val(double scale, int64_t zero_point, float value); -// TODO combine this with quantize_val once the numerics for ARM are aligned -// with it -uint8_t quantize_val_arm( - const float scale, - const int32_t zero_point, - const float value); -template -void quantize_vec( - double scale, - int64_t zero_point, - const float* src, - T* dst, - size_t count = 8); template TORCH_API Tensor quantize_tensor( Tensor rtensor, @@ -134,31 +119,11 @@ TORCH_API Tensor quantize_tensor( double scale, int64_t zero_point); template -TORCH_API float dequantize_val(double scale, int64_t zero_point, T value); -template -TORCH_API float dequantize_vec( - double scale, - int64_t zero_point, - const T* src, - float* dst, - size_t count = 8); -template TORCH_API Tensor dequantize_tensor( Tensor qtensor, Tensor rtensor, double scale, int64_t zero_point); -template -TORCH_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src); - -// Given a multiplier and a zero_point, requantize int32_t computed values back -// to quantized values. See comment above -// make_per_tensor_affine_quantizer function for the usage of int64_t -template -TORCH_API DST_T -requantize_from_int(double multiplier, int64_t zero_point, int64_t src); - -int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax); } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/affine_quantizer_base.cpp b/aten/src/ATen/native/quantized/affine_quantizer_base.cpp new file mode 100644 index 000000000000..4178ca77104d --- /dev/null +++ b/aten/src/ATen/native/quantized/affine_quantizer_base.cpp @@ -0,0 +1,250 @@ +#include +#include +#include + +#ifdef USE_FBGEMM +#include +#endif +#ifdef __ARM_NEON__ +#include +#endif + +namespace at { +namespace native { + +namespace { + +template +void checkZeroPoint(const std::string& fn_name, int64_t zero_point) { + TORCH_CHECK( + zero_point <= std::numeric_limits::max(), + fn_name, + " zero_point ", + zero_point, + " is out of range."); + TORCH_CHECK( + zero_point >= std::numeric_limits::min(), + fn_name, + " zero_point ", + zero_point, + " is out of range."); +} + +} // anonymous namespace + +#ifdef USE_FBGEMM +// Note: quantize_val is only explicitly used in test outside of this file +template +T quantize_val(double scale, int64_t zero_point, float value) { + // Internally, fbgemm::Quantize uses std::nearbyint. + // std::nearbyint results in nearest integer value according to the current + // rounding mode and the default rounding mode is rounds to even in half-way + // cases in most popular processor architectures like x86 and ARM. This is + // typically faster than an alternatives like std::round that rounds half-way + // cases away from zero, and can be consistent with SIMD implementations for + // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with + // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode. + int32_t qvalue; + qvalue = fbgemm::Quantize( + value, + static_cast(zero_point), + static_cast(scale), + /*result_precision=*/CHAR_BIT * sizeof(typename T::underlying)); + return static_cast(qvalue); +} + +template +void quantize_vec( + double scale, + int64_t zero_point, + const float* src, + T* dst, + size_t count) { + fbgemm::Quantize( + src, + (typename T::underlying*)dst, + count, + fbgemm::TensorQuantizationParams{ + (float)scale, (int32_t)zero_point, precision}); +} + +template +inline float dequantize_val(double scale, int64_t zero_point, T value) { + fbgemm::TensorQuantizationParams qparams; + qparams.scale = static_cast(scale); + qparams.zero_point = static_cast(zero_point); + return fbgemm::Dequantize(value.val_, qparams); +} +#else // USE_FBGEMM + +#if defined(__ANDROID__) && !defined(__NDK_MAJOR__) +template +inline float Round(const float x) { + return ::nearbyintf(x); +} +inline double Round(const double x) { + return ::nearbyint(x); +} +#else +template +inline T Round(const T x) { + return std::nearbyint(x); +} +#endif + +template +T quantize_val(double scale, int64_t zero_point, float value) { + // std::nearbyint results in nearest integer value according to the current + // rounding mode and the default rounding mode is rounds to even in half-way + // cases in most popular processor architectures like x86 and ARM. This is + // typically faster than an alternatives like std::round that rounds half-way + // cases away from zero, and can be consistent with SIMD implementations for + // example in x86 using _mm512_cvtps_epi32 or mm512_round_ps with + // _MM_FROUND_CUR_DIRECTION option that also follow the current rounding mode. + int64_t qvalue; + constexpr int64_t qmin = std::numeric_limits::min(); + constexpr int64_t qmax = std::numeric_limits::max(); + float inv_scale = 1.0f / static_cast(scale); + qvalue = static_cast(zero_point + Round(value * inv_scale)); + qvalue = std::max(qvalue, qmin); + qvalue = std::min(qvalue, qmax); + return static_cast(qvalue); +} + +uint8_t quantize_val_arm( + const float scale, + const int32_t zero_point, + const float value) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + float inv_scale = 1.0f / scale; + auto r = zero_point + static_cast(Round(value * inv_scale)); + r = std::max(r, qmin); + r = std::min(r, qmax); + return static_cast(r); +} + +template +void quantize_vec( + double scale, + int64_t zero_point, + const float* src, + T* dst, + size_t count) { + checkZeroPoint("quantize_vec", zero_point); + for (int64_t i = 0; i < count; ++i) { + dst[i] = quantize_val(scale, zero_point, src[i]); + } +} + +template +TORCH_API float dequantize_val(double scale, int64_t zero_point, T value) { + // We need to convert the qint8 value to float to ensure the subtraction + // subexpression returns a float + return (static_cast(value.val_) - zero_point) * scale; +} +#endif // USE_FBGEMM + +/* +* Quantize value based on the following equation +* Xq = Round(Xf * inv_scale + zero_point) +* where zero_point is in float. +* +* Note: For the case of embedding quantization we will set zero_point +* to (-Xmin/scale), where Xmin is the min value in input tensor row. +*/ +int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax) { + int qvalue; + + float inv_scale = scale == 0 ? 1.0f : 1.0f / scale; + qvalue = lrintf(value * inv_scale + zero_point); + qvalue = std::max(qmin, std::min(qvalue, qmax)); + return qvalue; +} + +template +DST_T requantize_val( + double src_scale, + int64_t src_zero_point, + double dst_scale, + int64_t dst_zero_point, + SRC_T src) { + const auto dq = dequantize_val(src_scale, src_zero_point, src); + return quantize_val(dst_scale, dst_zero_point, dq); +} + +template +DST_T requantize_from_int(double multiplier, int64_t zero_point, int64_t src) { + int64_t quantize_down = + zero_point + lrintf(src * static_cast(multiplier)); + int32_t min = std::numeric_limits::min(); + int32_t max = std::numeric_limits::max(); + return static_cast( + std::min(std::max(quantize_down, min), max)); +} + +template TORCH_API qint8 +quantize_val(double scale, int64_t zero_point, float value); +template TORCH_API quint8 +quantize_val(double scale, int64_t zero_point, float value); +template TORCH_API qint32 +quantize_val(double scale, int64_t zero_point, float value); +template TORCH_API void quantize_vec( + double scale, + int64_t zero_point, + const float* src, + c10::qint8* dst, + size_t count); +template TORCH_API void quantize_vec( + double scale, + int64_t zero_point, + const float* src, + c10::quint8* dst, + size_t count); +template TORCH_API void quantize_vec( + double scale, + int64_t zero_point, + const float* src, + c10::qint32* dst, + size_t count); + +template TORCH_API float dequantize_val( + double scale, + int64_t zero_point, + qint8 value); +template TORCH_API float dequantize_val( + double scale, + int64_t zero_point, + quint8 value); +template TORCH_API float dequantize_val( + double scale, + int64_t zero_point, + qint32 value); + +template TORCH_API qint8 +requantize_val(double, int64_t, double, int64_t, qint8); +template TORCH_API quint8 +requantize_val(double, int64_t, double, int64_t, qint8); +template TORCH_API qint32 +requantize_val(double, int64_t, double, int64_t, qint8); +template TORCH_API qint8 +requantize_val(double, int64_t, double, int64_t, quint8); +template TORCH_API quint8 +requantize_val(double, int64_t, double, int64_t, quint8); +template TORCH_API qint32 +requantize_val(double, int64_t, double, int64_t, quint8); +template TORCH_API qint8 +requantize_val(double, int64_t, double, int64_t, qint32); +template TORCH_API quint8 +requantize_val(double, int64_t, double, int64_t, qint32); +template TORCH_API qint32 +requantize_val(double, int64_t, double, int64_t, qint32); + +template TORCH_API qint8 requantize_from_int(double, int64_t, int64_t); +template TORCH_API quint8 +requantize_from_int(double, int64_t, int64_t); +template TORCH_API qint32 +requantize_from_int(double, int64_t, int64_t); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/quantized/affine_quantizer_base.h b/aten/src/ATen/native/quantized/affine_quantizer_base.h new file mode 100644 index 000000000000..9e6a9ff58d24 --- /dev/null +++ b/aten/src/ATen/native/quantized/affine_quantizer_base.h @@ -0,0 +1,46 @@ +#pragma once +#include +#include + +namespace at { +namespace native { + +// Quantize a float value into a uint value given scale and zero_point +template +TORCH_API T quantize_val(double scale, int64_t zero_point, float value); +// TODO combine this with quantize_val once the numerics for ARM are aligned +// with it +uint8_t quantize_val_arm( + const float scale, + const int32_t zero_point, + const float value); +template +void quantize_vec( + double scale, + int64_t zero_point, + const float* src, + T* dst, + size_t count = 8); +template +TORCH_API float dequantize_val(double scale, int64_t zero_point, T value); +template +TORCH_API float dequantize_vec( + double scale, + int64_t zero_point, + const T* src, + float* dst, + size_t count = 8); +template +TORCH_API DST_T requantize_val(double, int64_t, double, int64_t, SRC_T src); + +// Given a multiplier and a zero_point, requantize int32_t computed values back +// to quantized values. See comment above +// make_per_tensor_affine_quantizer function for the usage of int64_t +template +TORCH_API DST_T +requantize_from_int(double multiplier, int64_t zero_point, int64_t src); + +int quantize_val_float_qparams(float scale, float zero_point, float value, int qmin, int qmax); + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp b/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp index 29e7a9b259bb..0f1cda18fe69 100644 --- a/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp +++ b/aten/src/ATen/native/quantized/cpu/int_repr_quant.cpp @@ -1,5 +1,7 @@ +#include #include #include +#include namespace at { namespace native { diff --git a/aten/src/ATen/native/quantized/cpu/make_per_tensor_quantized_tensor.cpp b/aten/src/ATen/native/quantized/cpu/make_per_tensor_quantized_tensor.cpp index a0bd512ae455..a321de08b994 100644 --- a/aten/src/ATen/native/quantized/cpu/make_per_tensor_quantized_tensor.cpp +++ b/aten/src/ATen/native/quantized/cpu/make_per_tensor_quantized_tensor.cpp @@ -1,5 +1,7 @@ #include #include +#include +#include namespace at { namespace native { diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-run.cc b/aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-run.cc index 34629f29168a..5109ed1fec55 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-run.cc +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/src/conv-run.cc @@ -328,11 +328,17 @@ enum pytorch_qnnp_status qnnpackConv( // We need to check if the corresponding values on this // invocation is same as cached values. // If so we can skip setup step. - if (convolution->input != input || + + bool recalculate_indirection_buffer{true}; +#ifndef __APPLE__ + recalculate_indirection_buffer = + (convolution->input != input || convolution->batch_size != batch_size || convolution->input_height != input_height || convolution->input_width != input_width || - convolution->input_pixel_stride != input_pixel_stride) { + convolution->input_pixel_stride != input_pixel_stride); +#endif + if (recalculate_indirection_buffer) { pytorch_qnnp_status status = pytorch_qnnp_setup_convolution2d_nhwc_q8( convolution, batch_size, diff --git a/aten/src/ATen/test/vitals.cpp b/aten/src/ATen/test/vitals.cpp new file mode 100644 index 000000000000..d39e64ebf23b --- /dev/null +++ b/aten/src/ATen/test/vitals.cpp @@ -0,0 +1,73 @@ +#include + +#include +#include +#include + +TEST(Vitals, Basic) { + std::stringstream buffer; + + std::streambuf* sbuf = std::cout.rdbuf(); + std::cout.rdbuf(buffer.rdbuf()); + { + setenv("TORCH_VITAL", "1", 1); + TORCH_VITAL_DEFINE(Testing); + TORCH_VITAL(Testing, Attribute0) << 1; + TORCH_VITAL(Testing, Attribute1) << "1"; + TORCH_VITAL(Testing, Attribute2) << 1.0f; + TORCH_VITAL(Testing, Attribute3) << 1.0; + auto t = at::ones({1, 1}); + TORCH_VITAL(Testing, Attribute4) << t; + } + std::cout.rdbuf(sbuf); + + auto s = buffer.str(); + ASSERT_TRUE(s.find("Testing.Attribute0\t\t 1") != std::string::npos); + ASSERT_TRUE(s.find("Testing.Attribute1\t\t 1") != std::string::npos); + ASSERT_TRUE(s.find("Testing.Attribute2\t\t 1") != std::string::npos); + ASSERT_TRUE(s.find("Testing.Attribute3\t\t 1") != std::string::npos); + ASSERT_TRUE(s.find("Testing.Attribute4\t\t 1") != std::string::npos); +} + +TEST(Vitals, MultiString) { + std::stringstream buffer; + + std::streambuf* sbuf = std::cout.rdbuf(); + std::cout.rdbuf(buffer.rdbuf()); + { + setenv("TORCH_VITAL", "1", 1); + TORCH_VITAL_DEFINE(Testing); + TORCH_VITAL(Testing, Attribute0) << 1 << " of " << 2; + TORCH_VITAL(Testing, Attribute1) << 1; + TORCH_VITAL(Testing, Attribute1) << " of "; + TORCH_VITAL(Testing, Attribute1) << 2; + } + std::cout.rdbuf(sbuf); + + auto s = buffer.str(); + ASSERT_TRUE(s.find("Testing.Attribute0\t\t 1 of 2") != std::string::npos); + ASSERT_TRUE(s.find("Testing.Attribute1\t\t 1 of 2") != std::string::npos); +} + +TEST(Vitals, OnAndOff) { + for (auto i = 0; i < 2; ++i) { + std::stringstream buffer; + + std::streambuf* sbuf = std::cout.rdbuf(); + std::cout.rdbuf(buffer.rdbuf()); + { + setenv("TORCH_VITAL", i ? "1" : "", 1); + TORCH_VITAL_DEFINE(Testing); + TORCH_VITAL(Testing, Attribute0) << 1; + } + std::cout.rdbuf(sbuf); + + auto s = buffer.str(); + auto f = s.find("Testing.Attribute0\t\t 1"); + if (i) { + ASSERT_TRUE(f != std::string::npos); + } else { + ASSERT_TRUE(f == std::string::npos); + } + } +} diff --git a/aten/src/TH/generic/THLapack.cpp b/aten/src/TH/generic/THLapack.cpp index 5b4ef15e7c2c..48ae4a8e83aa 100644 --- a/aten/src/TH/generic/THLapack.cpp +++ b/aten/src/TH/generic/THLapack.cpp @@ -57,20 +57,6 @@ void THLapack_(geqrf)(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_ #endif } -/* Build Q from output of geqrf */ -void THLapack_(orgqr)(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info) -{ -#ifdef USE_LAPACK -#if defined(TH_REAL_IS_DOUBLE) - dorgqr_(&m, &n, &k, a, &lda, tau, work, &lwork, info); -#else - sorgqr_(&m, &n, &k, a, &lda, tau, work, &lwork, info); -#endif -#else - THError("orgqr: Lapack library not found in compile time\n"); -#endif -} - /* Multiply Q with a matrix using the output of geqrf */ void THLapack_(ormqr)(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info) { diff --git a/aten/src/TH/generic/THLapack.h b/aten/src/TH/generic/THLapack.h index 121eee871c67..fc5b4e68bd6c 100644 --- a/aten/src/TH/generic/THLapack.h +++ b/aten/src/TH/generic/THLapack.h @@ -11,8 +11,6 @@ TH_API void THLapack_(potri)(char uplo, int n, scalar_t *a, int lda, int *info); /* QR decomposition */ TH_API void THLapack_(geqrf)(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); -/* Build Q from output of geqrf */ -TH_API void THLapack_(orgqr)(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); /* Multiply Q with a matrix from output of geqrf */ TH_API void THLapack_(ormqr)(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info); diff --git a/aten/src/TH/generic/THTensorLapack.cpp b/aten/src/TH/generic/THTensorLapack.cpp index e6c200169191..5a8e05e5c231 100644 --- a/aten/src/TH/generic/THTensorLapack.cpp +++ b/aten/src/TH/generic/THTensorLapack.cpp @@ -308,64 +308,6 @@ void THTensor_(geqrf)(THTensor *ra_, THTensor *rtau_, THTensor *a) c10::raw::intrusive_ptr::decref(work); } -/* - The orgqr function allows reconstruction of a matrix Q with orthogonal - columns, from a sequence of elementary reflectors, such as is produced by the - geqrf function. - - Args: - * `ra_` - result Tensor, which will contain the matrix Q. - * `a` - input Tensor, which should be a matrix with the directions of the - elementary reflectors below the diagonal. If NULL, `ra_` is used as - input. - * `tau` - input Tensor, containing the magnitudes of the elementary - reflectors. - - For further details, please see the LAPACK documentation. - -*/ -void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau) -{ - if (a == NULL) a = ra_; - THArgCheck(THTensor_nDimension(a) == 2, 1, "'input' should be 2 dimensional"); - THArgCheck(!a->is_empty(), 2, "'input' should not be empty"); - THArgCheck(!tau->is_empty(), 3, "'tau' should not be empty"); - - THTensor *ra__ = NULL; - ra__ = THTensor_(cloneColumnMajor)(ra_, a); - - int m = THTensor_(size)(ra__, 0); - int n = THTensor_(size)(ra__, 1); - int k = THTensor_sizeLegacyNoScalars(tau, 0); - - THArgCheck(m >= n, 1, "input.size(0) must be greater than or equal to input.size(1)"); - THArgCheck(n >= k, 1, "input.size(1) must be greater than or equal to input2.size(0)"); - - int lda = m; - - /* Dry-run to query the suggested size of the workspace. */ - int info = 0; - scalar_t wkopt = 0; - THLapack_(orgqr)(m, n, k, ra__->data(), lda, - tau->data(), - &wkopt, -1, &info); - - /* Allocate the workspace and call LAPACK to do the real work. */ - int lwork = (int)wkopt; - THTensor *work = THTensor_(newWithSize1d)(lwork); - THLapack_(orgqr)(m, n, k, ra__->data(), lda, - tau->data(), - work->data(), lwork, &info); - - THLapackCheckWithCleanup(" Lapack Error %s : unknown Lapack error. info = %i", - THCleanup( - c10::raw::intrusive_ptr::decref(ra__); - c10::raw::intrusive_ptr::decref(work);), - "orgqr", info,""); - THTensor_(freeCopyTo)(ra__, ra_); - c10::raw::intrusive_ptr::decref(work); -} - /* The ormqr function multiplies Q with another matrix from a sequence of elementary reflectors, such as is produced by the geqrf function. diff --git a/aten/src/TH/generic/THTensorLapack.h b/aten/src/TH/generic/THTensorLapack.h index c19df681cd6f..f08348308416 100644 --- a/aten/src/TH/generic/THTensorLapack.h +++ b/aten/src/TH/generic/THTensorLapack.h @@ -5,7 +5,6 @@ TH_API void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b_, THTensor *a_); TH_API void THTensor_(potri)(THTensor *ra_, THTensor *a, bool upper); TH_API void THTensor_(geqrf)(THTensor *ra_, THTensor *rtau_, THTensor *a); -TH_API void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau); TH_API void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c, bool left, bool transpose); #endif diff --git a/benchmarks/cpp/tensorexpr/bench_approx.cpp b/benchmarks/cpp/tensorexpr/bench_approx.cpp index 220ea71497ff..5a8a81f5c768 100644 --- a/benchmarks/cpp/tensorexpr/bench_approx.cpp +++ b/benchmarks/cpp/tensorexpr/bench_approx.cpp @@ -4,10 +4,23 @@ #include #include #include +#include "caffe2/operators/tanh_op.h" +#include "caffe2/operators/logit_op.h" +using namespace torch::jit; using namespace torch::jit::tensorexpr; -static void log_sleef(benchmark::State& state) { +void optimizePointwise(tensorexpr::LoopNest* ln, tensorexpr::Tensor* target) { + std::vector loops = ln->getLoopStmtsFor(target); + For *outer, *inner, *tail; + ln->splitWithTail(loops[0], 16 * 8, &outer, &inner, &tail); + ln->vectorize(inner); + ln->splitWithTail(outer, 8, &outer, &inner, &tail); + Stmt* unrolled; + LoopNest::unroll(inner, &unrolled); +} + +static void log_nnc_sleef(benchmark::State& state) { KernelScope ks; auto N = VarHandle("N", kInt); Placeholder A("A", kFloat, {N}); @@ -17,7 +30,7 @@ static void log_sleef(benchmark::State& state) { }); LoopNest ln({B}); ln.prepareForCodegen(); - ln.vectorizeInnerLoops(); + optimizePointwise(&ln, B); Stmt* s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; @@ -29,7 +42,7 @@ static void log_sleef(benchmark::State& state) { at::Tensor B_t = torch::randn({state.range(0)}); auto B_ref = at::log(A_t); cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); - assert(at::allclose(B_t, B_ref)); + TORCH_CHECK(at::allclose(B_t, B_ref)); for (auto _ : state) { cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); } @@ -37,7 +50,7 @@ static void log_sleef(benchmark::State& state) { uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); } -static void log_fast(benchmark::State& state) { +static void log_nnc_fast(benchmark::State& state) { KernelScope ks; auto N = VarHandle("N", kInt); Placeholder A("A", kFloat, {N}); @@ -46,8 +59,8 @@ static void log_fast(benchmark::State& state) { return fast_log(A.load(i)); }); LoopNest ln({B}); + optimizePointwise(&ln, B); ln.prepareForCodegen(); - ln.vectorizeInnerLoops(); Stmt* s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; @@ -59,7 +72,7 @@ static void log_fast(benchmark::State& state) { at::Tensor B_t = torch::randn({state.range(0)}); auto B_ref = at::log(A_t); cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); - assert(at::allclose(B_t, B_ref)); + TORCH_CHECK(at::allclose(B_t, B_ref)); for (auto _ : state) { cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); } @@ -77,18 +90,24 @@ static void log_aten(benchmark::State& state) { uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); } -static void logit_fast(benchmark::State& state) { +static void logit_nnc_sleef(benchmark::State& state) { KernelScope ks; auto N = VarHandle("N", kInt); Placeholder A("A", kFloat, {N}); - torch::jit::tensorexpr::Tensor* B = - Compute("B", {N}, [&](const VarHandle& i) { - auto A_elem = A.load(i); - return fast_log(A_elem / (FloatImm::make(1.0f) - A_elem)); - }); + auto clamp = 1e-6f; + tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) { + auto A_elem = [&]() { + auto elem = A.load(i); + auto min = FloatImm::make(clamp); + auto max = FloatImm::make(1.0f - clamp); + elem = CompareSelect::make(elem, min, min, elem, kLT); + return CompareSelect::make(elem, max, max, elem, kGT); + }(); + return log(A_elem / (FloatImm::make(1.0f) - A_elem)); + }); LoopNest ln({B}); ln.prepareForCodegen(); - ln.vectorizeInnerLoops(); + optimizePointwise(&ln, B); Stmt* s = ln.root_stmt(); s = torch::jit::tensorexpr::IRSimplifier::simplify(s); std::vector args; @@ -98,9 +117,46 @@ static void logit_fast(benchmark::State& state) { LLVMCodeGen cg(s, args); at::Tensor A_t = torch::abs(torch::randn({state.range(0)})); at::Tensor B_t = torch::randn({state.range(0)}); - auto B_ref = at::logit(A_t); + auto B_ref = at::logit(A_t, clamp); cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); - assert(at::allclose(B_t, B_ref)); + TORCH_CHECK(at::allclose(at::nan_to_num(B_t), at::nan_to_num(B_ref))); + for (auto _ : state) { + cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); + } + state.counters["logit/s"] = benchmark::Counter( + uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); +} + +static void logit_nnc_fast(benchmark::State& state) { + KernelScope ks; + auto N = VarHandle("N", kInt); + Placeholder A("A", kFloat, {N}); + auto clamp = 1e-6f; + tensorexpr::Tensor* B = Compute("B", {N}, [&](const VarHandle& i) { + auto A_elem = [&]() { + auto elem = A.load(i); + auto min = FloatImm::make(clamp); + auto max = FloatImm::make(1.0f - clamp); + elem = CompareSelect::make(elem, min, min, elem, kLT); + return CompareSelect::make(elem, max, max, elem, kGT); + }(); + return fast_log(A_elem / (FloatImm::make(1.0f) - A_elem)); + }); + LoopNest ln({B}); + ln.prepareForCodegen(); + optimizePointwise(&ln, B); + Stmt* s = ln.root_stmt(); + s = torch::jit::tensorexpr::IRSimplifier::simplify(s); + std::vector args; + args.emplace_back(B); + args.emplace_back(A); + args.emplace_back(N); + LLVMCodeGen cg(s, args); + at::Tensor A_t = torch::abs(torch::randn({state.range(0)})); + at::Tensor B_t = torch::randn({state.range(0)}); + auto B_ref = at::logit(A_t, clamp); + cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); + TORCH_CHECK(at::allclose(at::nan_to_num(B_t), at::nan_to_num(B_ref))); for (auto _ : state) { cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); } @@ -111,19 +167,111 @@ static void logit_fast(benchmark::State& state) { static void logit_aten(benchmark::State& state) { at::Tensor A_t = torch::abs(torch::randn({state.range(0)})); at::Tensor B_t = torch::randn({state.range(0)}); + auto clamp = 1e-6f; + for (auto _ : state) { + at::native::logit_out(B_t, A_t, clamp); + } + state.counters["logit/s"] = benchmark::Counter( + uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); +} + +template +void logit_caffe2_impl(int size, const T* X, T* Y, float eps_ = 1e-6f) { + using namespace caffe2; + ConstEigenVectorMap X_vec(X, size); + EigenVectorMap Y_vec(Y, size); + Y_vec = X_vec.array().min(static_cast(1.0f - eps_)); + Y_vec = Y_vec.array().max(eps_); + Y_vec = (Y_vec.array() / (T(1) - Y_vec.array())).log(); +} + +static void logit_caffe2(benchmark::State& state) { + at::Tensor A_t = torch::abs(torch::randn({state.range(0)})); + at::Tensor B_t = torch::randn({state.range(0)}); + at::Tensor B_ref = torch::randn({state.range(0)}); + auto N = state.range(0); + auto X = A_t.data_ptr(); + auto Y = B_t.data_ptr(); + auto clamp = 1e-6f; + at::native::logit_out(B_ref, A_t, clamp); + logit_caffe2_impl(N, X, Y, clamp); + TORCH_CHECK(at::allclose(at::nan_to_num(B_t), at::nan_to_num(B_ref))); + for (auto _ : state) { - at::native::logit_out(B_t, A_t); + logit_caffe2_impl(N, X, Y, clamp); } + state.counters["logit/s"] = benchmark::Counter( uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); } -BENCHMARK(log_sleef) +static void tanh_nnc_fast(benchmark::State& state) { + KernelScope ks; + auto N = VarHandle("N", kInt); + Placeholder A("A", kFloat, {N}); + torch::jit::tensorexpr::Tensor* B = + Compute("B", {N}, [&](const VarHandle& i) { + return fast_tanh(A.load(i)); + }); + LoopNest ln({B}); + optimizePointwise(&ln, B); + ln.prepareForCodegen(); + Stmt* s = ln.root_stmt(); + s = torch::jit::tensorexpr::IRSimplifier::simplify(s); + std::vector args; + args.emplace_back(B); + args.emplace_back(A); + args.emplace_back(N); + LLVMCodeGen cg(s, args); + at::Tensor A_t = torch::abs(torch::randn({state.range(0)})); + at::Tensor B_t = torch::randn({state.range(0)}); + auto B_ref = at::tanh(A_t); + cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); + TORCH_CHECK(at::allclose(B_t, B_ref, 1e-3f, 1e-6f)); + for (auto _ : state) { + cg.call({B_t.data_ptr(), A_t.data_ptr(), state.range(0)}); + } + state.counters["tanh/s"] = benchmark::Counter( + uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); +} + +static void tanh_aten(benchmark::State& state) { + at::Tensor A_t = torch::abs(torch::randn({state.range(0)})); + at::Tensor B_t = torch::randn({state.range(0)}); + for (auto _ : state) { + at::native::tanh_out(B_t, A_t); + } + state.counters["tanh/s"] = benchmark::Counter( + uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); +} + +static void tanh_caffe2(benchmark::State& state) { + at::Tensor A_t = torch::abs(torch::randn({state.range(0)})); + at::Tensor B_t = torch::randn({state.range(0)}); + at::Tensor B_ref = torch::randn({state.range(0)}); + + auto N = state.range(0); + auto X = A_t.data_ptr(); + auto Y = B_t.data_ptr(); + caffe2::CPUContext c; + auto tanh = caffe2::TanhFunctor(); + at::native::tanh_out(B_ref, A_t); + tanh(N, X, Y, &c); + TORCH_CHECK(at::native::allclose(B_t, B_ref, 1e-3f, 1e-6f)); + + for (auto _ : state) { + tanh(N, X, Y, &c); + } + state.counters["tanh/s"] = benchmark::Counter( + uint64_t(state.range(0) * state.iterations()), benchmark::Counter::kIsRate); +} + +BENCHMARK(log_nnc_sleef) ->Args({2<<5}) ->Args({2<<8}) ->Args({2<<12}) ->Args({2<<14}); -BENCHMARK(log_fast) +BENCHMARK(log_nnc_fast) ->Args({2<<5}) ->Args({2<<8}) ->Args({2<<12}) @@ -133,7 +281,12 @@ BENCHMARK(log_aten) ->Args({2<<8}) ->Args({2<<12}) ->Args({2<<14}); -BENCHMARK(logit_fast) +BENCHMARK(logit_nnc_sleef) + ->Args({2<<5}) + ->Args({2<<8}) + ->Args({2<<12}) + ->Args({2<<14}); +BENCHMARK(logit_nnc_fast) ->Args({2<<5}) ->Args({2<<8}) ->Args({2<<12}) @@ -143,3 +296,23 @@ BENCHMARK(logit_aten) ->Args({2<<8}) ->Args({2<<12}) ->Args({2<<14}); +BENCHMARK(logit_caffe2) + ->Args({2<<5}) + ->Args({2<<8}) + ->Args({2<<12}) + ->Args({2<<14}); +BENCHMARK(tanh_nnc_fast) + ->Args({2<<5}) + ->Args({2<<8}) + ->Args({2<<12}) + ->Args({2<<14}); +BENCHMARK(tanh_aten) + ->Args({2<<5}) + ->Args({2<<8}) + ->Args({2<<12}) + ->Args({2<<14}); +BENCHMARK(tanh_caffe2) + ->Args({2<<5}) + ->Args({2<<8}) + ->Args({2<<12}) + ->Args({2<<14}); diff --git a/benchmarks/cpp/tensorexpr/bench_ops.py b/benchmarks/cpp/tensorexpr/bench_ops.py new file mode 100644 index 000000000000..7ae8d8ee2c5f --- /dev/null +++ b/benchmarks/cpp/tensorexpr/bench_ops.py @@ -0,0 +1,67 @@ +import timeit +import torch + +torch._C._jit_override_can_fuse_on_cpu(True) +torch._C._debug_set_fusion_group_inlining(False) +torch.set_num_threads(1) + + +def hardswish(x): + return x * torch.clamp(x + 3.0, 0.0, 6.0) / 6.0 + + +unary_ops = [ + hardswish, + torch._C._nn.hardswish, + torch.sigmoid, + torch.reciprocal, + torch.neg, + torch.relu, + torch.isnan, + torch.log, + torch.log10, + torch.log1p, + torch.log2, + torch.exp, + torch.expm1, + torch.erf, + torch.erfc, + torch.cos, + torch.sin, + torch.tan, + torch.acos, + torch.asin, + torch.cosh, + torch.sinh, + torch.atan, + torch.tanh, + torch.sqrt, + torch.rsqrt, + torch.abs, + torch.ceil, + torch.floor, + torch.round, + torch.trunc, + torch.lgamma, +] + +print("{:20s} {:>10s} {:>10s} {:>10s}".format("op", "eager", "nnc", "speedup")) + +for op in unary_ops: + x = torch.rand((1024, 1024)) + traced = torch.jit.trace(lambda x: op(x), (x)) + + # Warmup. + warmup_iters = 8 + for _ in range(warmup_iters): + op(x) + traced(x) + + # Validate result. + torch.testing.assert_allclose(op(x), traced(x)) + + # Benchmark. + bench_iters = 100 + teager = timeit.timeit(stmt="op(x)", globals=globals(), number=bench_iters) + tjit = timeit.timeit(stmt="traced(x)", globals=globals(), number=bench_iters) + print(f"{op.__name__:20s} {teager:10.3f} {tjit:10.3f} {teager/tjit:10.2f}") diff --git a/c10/util/Logging.h b/c10/util/Logging.h index 6fa7e93f26d8..0435123ea8bd 100644 --- a/c10/util/Logging.h +++ b/c10/util/Logging.h @@ -300,6 +300,27 @@ BINARY_COMP_HELPER(LessEquals, <=) C10_API void SetAPIUsageLogger(std::function logger); C10_API void LogAPIUsage(const std::string& context); +// PyTorch ddp usage logging capabilities +// DDPLoggingData holds data that can be logged in applications +// for analysis and debugging. Data structure is defined in +// c10 directory so that it can be easily imported by both c10 +// and torch files. +// TODO -- right now starting with logging a small set of straightforward +// fields, will add more fields as follow ups such as performance stats, +// internal states and env variables and etc. +struct DDPLoggingData { + // Data that can be got during DistributedDataParallel construction time + int world_size; + int rank; + std::string module_name; + std::vector device_ids; + int output_device; + bool broadcast_buffers; + int bucket_cap_mb; + bool find_unused_parameters; + bool gradient_as_bucket_view; +}; + namespace detail { // Return value is needed to do the static variable initialization trick C10_API bool LogAPIUsageFakeReturn(const std::string& context); diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index fde123d87442..62f9e8be3e4c 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1319,8 +1319,18 @@ if(BUILD_TEST) list(GET CPU_CAPABILITY_NAMES ${i} CPU_CAPABILITY) list(GET CPU_CAPABILITY_FLAGS ${i} FLAGS) separate_arguments(FLAGS UNIX_COMMAND "${FLAGS}") - add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}") - target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library gtest_main) + # Build vec256 with minimal dependencies on all platforms but Windows + if(NOT MSVC) + add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}" ../aten/src/ATen/native/quantized/affine_quantizer_base.cpp) + # TODO: Get rid of c10 dependency (which is only needed for the implementation of AT_ERROR) + target_link_libraries(${test_name}_${CPU_CAPABILITY} c10 sleef gtest_main) + if(USE_FBGEMM) + target_link_libraries(${test_name}_${CPU_CAPABILITY} fbgemm) + endif() + else() + add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}") + target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library gtest_main) + endif() target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $) target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $) target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE ${ATen_CPU_INCLUDE}) diff --git a/caffe2/contrib/fakelowp/int8_dequantize_op_nnpi.h b/caffe2/contrib/fakelowp/int8_dequantize_op_nnpi.h index d8aa77d99978..6769a978c391 100644 --- a/caffe2/contrib/fakelowp/int8_dequantize_op_nnpi.h +++ b/caffe2/contrib/fakelowp/int8_dequantize_op_nnpi.h @@ -25,7 +25,6 @@ void Int8DequantizeNNPI( for (auto i = 0; i < N; ++i) { out[i] = (float)(static_cast(in[i]) - X_offset) / X_scale_fp32; } - fbgemm::RoundToFloat16(out, out, N, FLAGS_caffe2_fbgemm_fake_fp16_clamp); } // namespace } // namespace @@ -46,10 +45,9 @@ class Int8DequantizeNNPIOp final : public Operator { X.t.numel(), X_scale, X_offset); - // UsingOneOverScale_); + // UsingOneOverScale_); return true; } - }; } // namespace int8 diff --git a/caffe2/python/operator_test/locally_connected_op_test.py b/caffe2/python/operator_test/locally_connected_op_test.py index 6eb3181ea9ad..79a10663ca71 100644 --- a/caffe2/python/operator_test/locally_connected_op_test.py +++ b/caffe2/python/operator_test/locally_connected_op_test.py @@ -6,11 +6,12 @@ from hypothesis import given, settings, assume import hypothesis.strategies as st -from caffe2.python import core, utils +from caffe2.python import core, utils, workspace import caffe2.python.hypothesis_test_util as hu import caffe2.python.serialized_test.serialized_test_util as serial + class TestLocallyConnectedOp(serial.SerializedTestCase): @given(N=st.integers(1, 3), C=st.integers(1, 3), @@ -104,6 +105,9 @@ def lc_2d_nhwc(X, W, b=None): **hu.gcs) @settings(deadline=1000) def test_lc_1d(self, N, C, size, M, kernel, op_name, use_bias, gc, dc): + if workspace.has_hip_support: + # Skip as test flaky on ROCM with deadline set to 1000 + return if size < kernel: kernel = size diff --git a/docs/source/_static/img/pipeline_parallelism/no_pipe.png b/docs/source/_static/img/pipeline_parallelism/no_pipe.png new file mode 100644 index 000000000000..4b2b79514879 Binary files /dev/null and b/docs/source/_static/img/pipeline_parallelism/no_pipe.png differ diff --git a/docs/source/_static/img/pipeline_parallelism/pipe.png b/docs/source/_static/img/pipeline_parallelism/pipe.png new file mode 100644 index 000000000000..084b4552f110 Binary files /dev/null and b/docs/source/_static/img/pipeline_parallelism/pipe.png differ diff --git a/docs/source/distributed.rst b/docs/source/distributed.rst index b35a34fc0265..9ed635457859 100644 --- a/docs/source/distributed.rst +++ b/docs/source/distributed.rst @@ -417,6 +417,25 @@ Collective functions :class:`~torch.distributed.ReduceOp` is recommended to use instead. +Autograd-enabled communication primitives +----------------------------------------- + +If you want to use collective communication functions supporting autograd +you can find an implementation of those in the `torch.distributed.nn.*` module. + +Functions here are synchronous and will be inserted in the autograd graph, so +you need to ensure that all the processes that participated in the collective operation +will do the backward pass for the backward communication to effectively happen and +don't cause a deadlock. + +Please notice that currently the only backend where all the functions are guaranteed to work is ``gloo``. +.. autofunction:: torch.distributed.nn.broadcast +.. autofunction:: torch.distributed.nn.gather +.. autofunction:: torch.distributed.nn.scatter +.. autofunction:: torch.distributed.nn.reduce +.. autofunction:: torch.distributed.nn.all_gather +.. autofunction:: torch.distributed.nn.all_to_all +.. autofunction:: torch.distributed.nn.all_reduce Multi-GPU collective functions ------------------------------ diff --git a/docs/source/index.rst b/docs/source/index.rst index a334bffab01e..a105b6df84d2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -71,6 +71,7 @@ Features described in this documentation are classified by release status: onnx optim complex_numbers + pipeline quantization rpc torch.random diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index 50100720b33b..ef5986f2cf57 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -486,7 +486,8 @@ of a function from ``torch.Tensor`` subclasses, they must use Subclassing ``torch.Tensor`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -As of version 1.7.0, methods and functions applied on ``torch.Tensor`` subclasses +As of version 1.7.0, methods on ``torch.Tensor`` and functions in public +``torch.*`` namespaces applied on ``torch.Tensor`` subclasses will return subclass instances instead of ``torch.Tensor`` instances:: >>> class SubTensor(torch.Tensor): diff --git a/docs/source/pipeline.rst b/docs/source/pipeline.rst new file mode 100644 index 000000000000..6f52e6229141 --- /dev/null +++ b/docs/source/pipeline.rst @@ -0,0 +1,71 @@ +.. _pipeline-parallelism: + +Pipeline Parallelism +==================== + +Pipeline parallelism was original introduced in the +`Gpipe `__ paper and is an efficient +technique to train large models on multiple GPUs. + +.. warning :: + Pipeline Parallelism is experimental and subject to change. + +Model Parallelism using multiple GPUs +------------------------------------- + +Typically for large models which don't fit on a single GPU, model parallelism +is employed where certain parts of the model are placed on different GPUs. +Although, if this is done naively for sequential models, the training process +suffers from GPU under utilization since only one GPU is active at one time as +shown in the figure below: + +.. figure:: _static/img/pipeline_parallelism/no_pipe.png + + The figure represents a model with 4 layers placed on 4 different GPUs + (vertical axis). The horizontal axis represents training this model through + time demonstrating that only 1 GPU is utilized at a time + (`image source `__). + +Pipelined Execution +------------------- + +To alleviate this problem, pipeline parallelism splits the input minibatch into +multiple microbatches and pipelines the execution of these microbatches across +multiple GPUs. This is outlined in the figure below: + +.. figure:: _static/img/pipeline_parallelism/pipe.png + + The figure represents a model with 4 layers placed on 4 different GPUs + (vertical axis). The horizontal axis represents training this model through + time demonstrating that the GPUs are utilized much more efficiently. + However, there still exists a bubble (as demonstrated in the figure) where + certain GPUs are not utilized. + (`image source `__). + +Pipe APIs in PyTorch +-------------------- +.. autoclass:: torch.distributed.pipeline.sync.Pipe + :members: forward + +Skip connections +^^^^^^^^^^^^^^^^ + +Certain models like ResNeXt are not completely sequential and have skip +connections between layers. Naively implementing as part of pipeling +parallelism would imply that we need to copy outputs for certain layers through +multiple GPUs till we eventually reach the GPU where the layer for the skip +connection resides. To avoid this copy overhead, we provide APIs below to stash +and pop Tensors in different layers of the model. + +.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.skippable +.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.stash +.. autoclass:: torch.distributed.pipeline.sync.skip.skippable.pop +.. autofunction:: torch.distributed.pipeline.sync.skip.skippable.verify_skippables + +Acknowledgements +---------------- + +The implementation for pipeline parallelism is based on `fairscale's pipe implementation `__ and +`torchgpipe `__. We would like to +thank both teams for their contributions and guidance towards bringing pipeline +parallelism into PyTorch. diff --git a/mypy-strict.ini b/mypy-strict.ini index 00545679e8f1..717802cf59df 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -1,7 +1,10 @@ -# This is a MyPy config file but unlike mypy.ini, it enforces very strict typing -# rules. The intention is for this config file to be used to ENFORCE that -# people are using mypy on codegen files. -# +# This is the PyTorch mypy-strict.ini file (note: don't change this line! - +# test_run_mypy in test/test_type_hints.py uses this string) + +# Unlike mypy.ini, it enforces very strict typing rules. The intention is for +# this config file to be used to ENFORCE that people are using mypy on codegen +# files. + # For now, only code_template.py and benchmark utils Timer are covered this way [mypy] @@ -33,6 +36,7 @@ strict_equality = True files = tools/codegen/gen.py, tools/autograd/*.py, tools/pyi/*.py, + torch/testing/_internal/mypy_wrapper.py, torch/utils/benchmark/utils/common.py, torch/utils/benchmark/utils/timer.py, torch/utils/benchmark/utils/valgrind_wrapper/*.py, @@ -50,6 +54,10 @@ follow_imports = skip [mypy-torch.*] follow_imports = skip -# Missing stub. +# Missing stubs. + [mypy-numpy] ignore_missing_imports = True + +[mypy-mypy.*] +ignore_missing_imports = True diff --git a/mypy.ini b/mypy.ini index 578a2ddd1094..1d9526ffdd40 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,6 @@ -# This is the PyTorch MyPy config file (note: don't change this line! - +# This is the PyTorch mypy.ini file (note: don't change this line! - # test_run_mypy in test/test_type_hints.py uses this string) + [mypy] cache_dir = .mypy_cache/normal warn_unused_configs = True @@ -263,3 +264,6 @@ ignore_missing_imports = True [mypy-librosa.*] ignore_missing_imports = True + +[mypy-mypy.*] +ignore_missing_imports = True diff --git a/test/cpp/api/tensor_flatten.cpp b/test/cpp/api/tensor_flatten.cpp new file mode 100644 index 000000000000..36888e8b0356 --- /dev/null +++ b/test/cpp/api/tensor_flatten.cpp @@ -0,0 +1,39 @@ +#include +#include + +#include +#include +#include + +using namespace torch::test; + +TEST(UnflattenDenseTensorTest, TestEmptyTensor) { + auto emptyTensor1 = at::tensor(std::vector()); + auto emptyTensor2 = at::tensor(std::vector()); + auto tensor1 = at::tensor({1, 2, 3}); + auto tensor2 = at::tensor({4, 5}); + auto tensorList = std::vector({tensor1, emptyTensor1, emptyTensor2, tensor2}); + auto flatTensor = at::tensor({1, 2, 3, 4, 5}); + auto unflatten_results = torch::utils::unflatten_dense_tensors(flatTensor, tensorList); + ASSERT_EQ(unflatten_results.size(), 4); + ASSERT_EQ(unflatten_results.at(0).numel(), 3); + ASSERT_EQ(unflatten_results.at(1).numel(), 0); + ASSERT_EQ(unflatten_results.at(2).numel(), 0); + ASSERT_EQ(unflatten_results.at(3).numel(), 2); + + // empty tensor address is 0 as memory is not allocated yet + ASSERT_EQ(unflatten_results.at(1).data_ptr(), nullptr); + ASSERT_EQ(unflatten_results.at(2).data_ptr(), nullptr); + // without fix in unflatten_dense_tensors() for empty tensors, + // unflattend empty tensor unflatten_results.at(1) will share the same storage + // as other non-empty tenosr like unflatten_results.at(3). + // after fix, the empty tensor and non-empty tensor do not share the same + // storage. + ASSERT_NE(unflatten_results.at(1).data_ptr(), unflatten_results.at(3).data_ptr()); + unflatten_results.at(1).resize_(1); + unflatten_results.at(2).resize_(1); + // after resizing the two empty tensors, the resized tensors do not share + // the same storage. without fix in unflatten_dense_tensors() for empty tensors, + // the resized tensors will share the same storage. + ASSERT_NE(unflatten_results.at(1).data_ptr(), unflatten_results.at(2).data_ptr()); +} diff --git a/test/custom_backend/test_custom_backend.py b/test/custom_backend/test_custom_backend.py index 9b850508fab9..2456080891a8 100644 --- a/test/custom_backend/test_custom_backend.py +++ b/test/custom_backend/test_custom_backend.py @@ -1,12 +1,12 @@ import os import tempfile import torch -import unittest from backend import Model, to_custom_backend, get_custom_backend_library_path +from torch.testing._internal.common_utils import TestCase, run_tests -class TestCustomBackend(unittest.TestCase): +class TestCustomBackend(TestCase): def setUp(self): # Load the library containing the custom backend. self.library_path = get_custom_backend_library_path() @@ -51,4 +51,4 @@ def test_save_load(self): if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/custom_operator/test_custom_classes.py b/test/custom_operator/test_custom_classes.py index 0a3a874558ce..c67fa22ec76c 100644 --- a/test/custom_operator/test_custom_classes.py +++ b/test/custom_operator/test_custom_classes.py @@ -5,6 +5,9 @@ import glob import os +from torch.testing._internal.common_utils import TestCase, run_tests + + def get_custom_class_library_path(): library_filename = glob.glob("build/*custom_class*") assert (len(library_filename) == 1) @@ -18,7 +21,7 @@ def test_equality(f, cmp_key): obj2 = jit.script(f)() return (cmp_key(obj1), cmp_key(obj2)) -class TestCustomOperators(unittest.TestCase): +class TestCustomOperators(TestCase): def setUp(self): ops.load_library(get_custom_class_library_path()) @@ -77,4 +80,4 @@ def f(): if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/custom_operator/test_custom_ops.py b/test/custom_operator/test_custom_ops.py index 27c779fdac1b..3937abde9147 100644 --- a/test/custom_operator/test_custom_ops.py +++ b/test/custom_operator/test_custom_ops.py @@ -1,14 +1,14 @@ import os.path import tempfile -import unittest import torch from torch import ops from model import Model, get_custom_op_library_path +from torch.testing._internal.common_utils import TestCase, run_tests -class TestCustomOperators(unittest.TestCase): +class TestCustomOperators(TestCase): def setUp(self): self.library_path = get_custom_op_library_path() ops.load_library(self.library_path) @@ -90,4 +90,4 @@ def test_saving_and_loading_script_module_with_custom_op(self): if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py index 5492d6a9c3b2..1ee009ead45a 100755 --- a/test/distributed/test_c10d.py +++ b/test/distributed/test_c10d.py @@ -24,6 +24,7 @@ import torch.testing._internal.common_utils as common from torch import nn from torch._six import string_classes + from torch.nn.parallel import DistributedDataParallel from torch.testing._internal.common_distributed import ( MultiProcessTestCase, @@ -49,6 +50,7 @@ slowTest, ) + # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests @@ -4618,7 +4620,7 @@ def test_gloo_barrier_device_ids(self): with self.assertRaisesRegex(RuntimeError, "device_ids not supported"): c10d.barrier(device_ids=[self.rank]) -if __name__ == "__main__": +if __name__ == '__main__': assert ( not torch.cuda._initialized ), "test_distributed must not have initialized CUDA context on main process" diff --git a/test/distributed/test_c10d_spawn.py b/test/distributed/test_c10d_spawn.py index 0bba18bf3ab9..9c49726bf514 100644 --- a/test/distributed/test_c10d_spawn.py +++ b/test/distributed/test_c10d_spawn.py @@ -11,12 +11,21 @@ from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU from torch.testing._internal.common_distributed import requires_gloo, \ - create_device + create_device, MultiProcessTestCase, skip_if_not_multigpu from torch.testing._internal.common_utils import TestCase, load_tests, \ run_tests, skipIfRocm from torch.testing._internal.common_utils import NO_MULTIPROCESSING_SPAWN, TEST_WITH_TSAN +# Torch distributed.nn is not available in windows +# check #42095, it errors on import. +_torch_dist_nn_available = True +try: + import torch.distributed.nn +except ImportError: + _torch_dist_nn_available = False + + # load_tests from common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests @@ -317,5 +326,186 @@ def forward(self, x, y): self._test_base(net, inp, check_allclose=False) +class TestDistributedNNFunctions(MultiProcessTestCase): + def setUp(self): + if not _torch_dist_nn_available: + raise unittest.SkipTest("torch.distributed.nn is not available") + super(TestDistributedNNFunctions, self).setUp() + self._spawn_processes() + + def tearDown(self): + super(TestDistributedNNFunctions, self).tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def op_timeout_sec(self): + return 1 + + @property + def world_size(self): + return 2 + + @requires_gloo() + @skip_if_not_multigpu + def test_broadcast(self): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') + device = torch.device(f"cuda:{self.rank}") + x = torch.ones(5, 5, device=device) + self.rank + x.requires_grad = True + y = torch.distributed.nn.broadcast(x, 1) + self.assertEqual(y, 1 + torch.ones(5, 5)) + z = y.sin().sum() + z.backward() + # We can't check the gradient of communications numerically so we have to do some calculations + if self.rank == 1: + self.assertEqual(x.grad, 2 * torch.cos(x)) + elif self.rank == 0: + self.assertEqual(x.grad, torch.zeros(5, 5, device=device)) + + @requires_gloo() + @skip_if_not_multigpu + def test_gather(self): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') + device = torch.device(f"cuda:{self.rank}") + x = torch.ones(5, 5, device=device) + self.rank + x.requires_grad = True + tensors = torch.distributed.nn.gather(x, 1) + if self.rank == 1: + for i, t in enumerate(tensors): + self.assertEqual(t, torch.ones(5, 5, device=device) + i) + elif self.rank == 0: + for i, t in enumerate(tensors): + zeros = torch.zeros(5, 5, device=device) + self.assertEqual(t, zeros) + y = torch.sum(torch.stack(tensors), axis=0) + z = y.sin().sum() + z.backward() + + # Test gradient + x_s = 3 * torch.ones(5, 5, device=device) + self.assertEqual(x.grad, x_s.cos()) + + @requires_gloo() + @skip_if_not_multigpu + def test_scatter(self): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') + device = torch.device(f"cuda:{self.rank}") + x0 = torch.ones(5, 5, device=device) + x1 = torch.ones(5, 5, device=device) + 1 + x0.requires_grad = True + x1.requires_grad = True + + y = torch.distributed.nn.scatter([x0, x1], 1) + if self.rank == 1: + self.assertEqual(y, 1 + torch.ones(5, 5, device=device)) + elif self.rank == 0: + self.assertEqual(y, torch.ones(5, 5, device=device)) + z = y.sin().sum() + z.backward() + + # Test gradient + if self.rank == 1: + x0_s = torch.ones(5, 5, device=device).cos() + x1_s = (2 * torch.ones(5, 5, device=device)).cos() + self.assertEqual(x0.grad, x0_s) + self.assertEqual(x1.grad, x1_s) + if self.rank == 0: + self.assertEqual(x0.grad, torch.zeros(5, 5, device=device)) + + @requires_gloo() + @skip_if_not_multigpu + def test_reduce(self): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') + device = torch.device(f"cuda:{self.rank}") + x = torch.ones(5, 5, device=device) + self.rank + x.requires_grad = True + y = torch.distributed.nn.reduce(x, 1, op=c10d.ReduceOp.SUM) + + if self.rank == 1: + self.assertEqual(y, 3 * torch.ones(5, 5, device=device)) + + z = y.sin().sum() + z.backward() + # Gradients are broadcasted to both ranks + x_g = (3 * torch.ones(5, 5, device=device)).cos() + self.assertEqual(x.grad, x_g) + + @requires_gloo() + @skip_if_not_multigpu + def test_allreduce(self): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') + device = torch.device(f"cuda:{self.rank}") + x = torch.ones(5, 5, device=device) + self.rank + x.requires_grad = True + y = torch.distributed.nn.all_reduce(x, op=c10d.ReduceOp.SUM) + + self.assertEqual(y, 3 * torch.ones(5, 5, device=device)) + + z = y.sin().sum() + z.backward() + x_g = 2 * (3 * torch.ones(5, 5, device=device)).cos() + self.assertEqual(x.grad, x_g) + + @requires_gloo() + @skip_if_not_multigpu + def test_all_gather(self): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') + device = torch.device(f"cuda:{self.rank}") + x = torch.ones(5, 5, device=device) + self.rank + x.requires_grad = True + tensors = torch.distributed.nn.all_gather(x) + for i, t in enumerate(tensors): + self.assertEqual(t, torch.ones(5, 5, device=device) + i) + y = torch.sum(torch.stack(tensors), axis=0) + z = y.sin().sum() + z.backward() + + x_s = 2 * (3 * torch.ones(5, 5, device=device)).cos() + self.assertEqual(x.grad, x_s) + + @requires_gloo() + @skip_if_not_multigpu + def test_all_to_all(self): + store = c10d.FileStore(self.file_name, self.world_size) + # This is required because these functions calls directly to the .dist and needs + # the world to be initialized + c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='gloo') + device = torch.device(f"cuda:{self.rank}") + x0 = torch.ones(5, 5, device=device) + 2 * self.rank + x1 = torch.ones(5, 5, device=device) + 2 * self.rank + x0.requires_grad = True + x1.requires_grad = True + tensors = torch.distributed.nn.all_to_all([x0, x1]) + for i, t in enumerate(tensors): + self.assertEqual(t, torch.ones(5, 5, device=device) + 2 * i) + y = torch.sum(torch.stack(tensors), axis=0) + z = y.sin().sum() + z.backward() + x_s = (4 * torch.ones(5, 5, device=device)).cos() + self.assertEqual(x0.grad, x_s) + self.assertEqual(x1.grad, x_s) + + if __name__ == '__main__': run_tests() diff --git a/test/distributions/test_constraints.py b/test/distributions/test_constraints.py index ffff932dfa37..12d777cce280 100644 --- a/test/distributions/test_constraints.py +++ b/test/distributions/test_constraints.py @@ -68,7 +68,7 @@ def test_biject_to(constraint_fn, args, is_cuda): assert torch.allclose(x, x2), "Error in biject_to({}) inverse".format(constraint) j = t.log_abs_det_jacobian(x, y) - assert j.shape == x.shape[:x.dim() - t.input_event_dim] + assert j.shape == x.shape[:x.dim() - t.domain.event_dim] @pytest.mark.parametrize('constraint_fn, args', [(c[0], c[1:]) for c in CONSTRAINTS]) diff --git a/test/distributions/test_distributions.py b/test/distributions/test_distributions.py index 0c84ff0e7058..e8940021bdd0 100644 --- a/test/distributions/test_distributions.py +++ b/test/distributions/test_distributions.py @@ -878,6 +878,22 @@ def test_has_examples(self): self.assertIn(Dist, distributions_with_examples, "Please add {} to the EXAMPLES list in test_distributions.py".format(Dist.__name__)) + def test_support_attributes(self): + for Dist, params in EXAMPLES: + for param in params: + d = Dist(**param) + event_dim = len(d.event_shape) + self.assertEqual(d.support.event_dim, event_dim) + try: + self.assertEqual(Dist.support.event_dim, event_dim) + except NotImplementedError: + pass + is_discrete = d.support.is_discrete + try: + self.assertEqual(Dist.support.is_discrete, is_discrete) + except NotImplementedError: + pass + def test_distribution_expand(self): shapes = [torch.Size(), torch.Size((2,)), torch.Size((2, 1))] for Dist, params in EXAMPLES: @@ -1620,8 +1636,8 @@ def test_logisticnormal(self): self.assertEqual(LogisticNormal(mean, std).sample((7,)).size(), (7, 5, 6)) self.assertEqual(LogisticNormal(mean_1d, std_1d).sample((1,)).size(), (1, 2)) self.assertEqual(LogisticNormal(mean_1d, std_1d).sample().size(), (2,)) - self.assertEqual(LogisticNormal(0.2, .6).sample((1,)).size(), (2,)) - self.assertEqual(LogisticNormal(-0.7, 50.0).sample((1,)).size(), (2,)) + self.assertEqual(LogisticNormal(0.2, .6).sample().size(), (2,)) + self.assertEqual(LogisticNormal(-0.7, 50.0).sample().size(), (2,)) # sample check for extreme value of mean, std set_rng_seed(1) @@ -3832,6 +3848,16 @@ def test_kl_shape(self): 'Actual {}'.format(kl.shape), ])) + def test_kl_transformed(self): + # Regression test for https://github.com/pytorch/pytorch/issues/34859 + scale = torch.ones(2, 3) + loc = torch.zeros(2, 3) + normal = Normal(loc=loc, scale=scale) + diag_normal = Independent(normal, reinterpreted_batch_ndims=1) + trans_dist = TransformedDistribution(diag_normal, AffineTransform(loc=0., scale=2.)) + self.assertEqual(kl_divergence(diag_normal, diag_normal).shape, (2,)) + self.assertEqual(kl_divergence(trans_dist, trans_dist).shape, (2,)) + def test_entropy_monte_carlo(self): set_rng_seed(0) # see Note [Randomized statistical tests] for Dist, params in EXAMPLES: diff --git a/test/distributions/test_transforms.py b/test/distributions/test_transforms.py index b5e9144f0bd8..8dbbc5eb2b9f 100644 --- a/test/distributions/test_transforms.py +++ b/test/distributions/test_transforms.py @@ -4,10 +4,10 @@ import torch from torch.autograd.functional import jacobian -from torch.distributions import Dirichlet, Normal, TransformedDistribution, constraints +from torch.distributions import Dirichlet, Independent, Normal, TransformedDistribution, constraints from torch.distributions.transforms import (AbsTransform, AffineTransform, ComposeTransform, - CorrCholeskyTransform, ExpTransform, - LowerCholeskyTransform, PowerTransform, + CorrCholeskyTransform, ExpTransform, IndependentTransform, + LowerCholeskyTransform, PowerTransform, ReshapeTransform, SigmoidTransform, TanhTransform, SoftmaxTransform, StickBreakingTransform, identity_transform, Transform, _InverseTransform) @@ -22,6 +22,8 @@ def get_transforms(cache_size): cache_size=cache_size), PowerTransform(exponent=torch.tensor(5.).normal_(), cache_size=cache_size), + PowerTransform(exponent=torch.tensor(5.).normal_(), + cache_size=cache_size), SigmoidTransform(cache_size=cache_size), TanhTransform(cache_size=cache_size), AffineTransform(0, 1, cache_size=cache_size), @@ -57,6 +59,12 @@ def get_transforms(cache_size): torch.randn(4, 5), cache_size=cache_size), ]), + ReshapeTransform((4, 5), (2, 5, 2)), + IndependentTransform( + AffineTransform(torch.randn(5), + torch.randn(5), + cache_size=cache_size), + 1), ] transforms += [t.inv for t in transforms] return transforms @@ -92,7 +100,16 @@ def transform_id(x): def generate_data(transform): torch.manual_seed(1) + while isinstance(transform, IndependentTransform): + transform = transform.base_transform + if isinstance(transform, ReshapeTransform): + return torch.randn(transform.in_shape) + if isinstance(transform.inv, ReshapeTransform): + return torch.randn(transform.inv.out_shape) domain = transform.domain + while (isinstance(domain, constraints.independent) and + domain is not constraints.real_vector): + domain = domain.base_constraint codomain = transform.codomain x = torch.empty(4, 5) if domain is constraints.lower_cholesky or codomain is constraints.lower_cholesky: @@ -170,6 +187,7 @@ def test_forward_inverse(transform, test_cached): y = transform(x) except NotImplementedError: pytest.skip('Not implemented.') + assert y.shape == transform.forward_shape(x.shape) if test_cached: x2 = transform.inv(y) # should be implemented at least by caching else: @@ -177,6 +195,7 @@ def test_forward_inverse(transform, test_cached): x2 = transform.inv(y.clone()) # bypass cache except NotImplementedError: pytest.skip('Not implemented.') + assert x2.shape == transform.inverse_shape(y.shape) y2 = transform(x2) if transform.bijective: # verify function inverse @@ -316,25 +335,29 @@ def test_jacobian(transform): except NotImplementedError: pytest.skip('Not implemented.') # Test shape - target_shape = x.shape[:x.dim() - transform.input_event_dim] + target_shape = x.shape[:x.dim() - transform.domain.event_dim] assert actual.shape == target_shape # Expand if required transform = reshape_transform(transform, x.shape) ndims = len(x.shape) - event_dim = ndims - transform.input_event_dim + event_dim = ndims - transform.domain.event_dim x_ = x.view((-1,) + x.shape[event_dim:]) n = x_.shape[0] # Reshape to squash batch dims to a single batch dim transform = reshape_transform(transform, x_.shape) - # 1. Transforms with 0 off-diagonal elements - if transform.input_event_dim == 0: + # 1. Transforms with unit jacobian + if isinstance(transform, ReshapeTransform) or isinstance(transform.inv, ReshapeTransform): + expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim]) + expected = x.new_zeros(x.shape[x.dim() - transform.domain.event_dim]) + # 2. Transforms with 0 off-diagonal elements + elif transform.domain.event_dim == 0: jac = jacobian(transform, x_) # assert off-diagonal elements are zero assert torch.allclose(jac, jac.diagonal().diag_embed()) expected = jac.diagonal().abs().log().reshape(x.shape) - # 2. Transforms with non-0 off-diagonal elements + # 3. Transforms with non-0 off-diagonal elements else: if isinstance(transform, CorrCholeskyTransform): jac = jacobian(lambda x: tril_matrix_to_vec(transform(x), diag=-1), x_) @@ -361,5 +384,88 @@ def test_jacobian(transform): assert torch.allclose(actual, expected, atol=1e-5) +@pytest.mark.parametrize("event_dims", + [(0,), (1,), (2, 3), (0, 1, 2), (1, 2, 0), (2, 0, 1)], + ids=str) +def test_compose_affine(event_dims): + transforms = [AffineTransform(torch.zeros((1,) * e), 1, event_dim=e) for e in event_dims] + transform = ComposeTransform(transforms) + assert transform.codomain.event_dim == max(event_dims) + assert transform.domain.event_dim == max(event_dims) + + base_dist = Normal(0, 1) + if transform.domain.event_dim: + base_dist = base_dist.expand((1,) * transform.domain.event_dim) + dist = TransformedDistribution(base_dist, transform.parts) + assert dist.support.event_dim == max(event_dims) + + base_dist = Dirichlet(torch.ones(5)) + if transform.domain.event_dim > 1: + base_dist = base_dist.expand((1,) * (transform.domain.event_dim - 1)) + dist = TransformedDistribution(base_dist, transforms) + assert dist.support.event_dim == max(1, max(event_dims)) + + +@pytest.mark.parametrize("batch_shape", [(), (6,), (5, 4)], ids=str) +def test_compose_reshape(batch_shape): + transforms = [ReshapeTransform((), ()), + ReshapeTransform((2,), (1, 2)), + ReshapeTransform((3, 1, 2), (6,)), + ReshapeTransform((6,), (2, 3))] + transform = ComposeTransform(transforms) + assert transform.codomain.event_dim == 2 + assert transform.domain.event_dim == 2 + data = torch.randn(batch_shape + (3, 2)) + assert transform(data).shape == batch_shape + (2, 3) + + dist = TransformedDistribution(Normal(data, 1), transforms) + assert dist.batch_shape == batch_shape + assert dist.event_shape == (2, 3) + assert dist.support.event_dim == 2 + + +@pytest.mark.parametrize("sample_shape", [(), (7,)], ids=str) +@pytest.mark.parametrize("transform_dim", [0, 1, 2]) +@pytest.mark.parametrize("base_batch_dim", [0, 1, 2]) +@pytest.mark.parametrize("base_event_dim", [0, 1, 2]) +@pytest.mark.parametrize("num_transforms", [0, 1, 2, 3]) +def test_transformed_distribution(base_batch_dim, base_event_dim, transform_dim, + num_transforms, sample_shape): + shape = torch.Size([2, 3, 4, 5]) + base_dist = Normal(0, 1) + base_dist = base_dist.expand(shape[4 - base_batch_dim - base_event_dim:]) + if base_event_dim: + base_dist = Independent(base_dist, base_event_dim) + transforms = [AffineTransform(torch.zeros(shape[4 - transform_dim:]), 1), + ReshapeTransform((4, 5), (20,)), + ReshapeTransform((3, 20), (6, 10))] + transforms = transforms[:num_transforms] + transform = ComposeTransform(transforms) + + # Check validation in .__init__(). + if base_batch_dim + base_event_dim < transform.domain.event_dim: + with pytest.raises(ValueError): + TransformedDistribution(base_dist, transforms) + return + d = TransformedDistribution(base_dist, transforms) + + # Check sampling is sufficiently expanded. + x = d.sample(sample_shape) + assert x.shape == sample_shape + d.batch_shape + d.event_shape + num_unique = len(set(x.reshape(-1).tolist())) + assert num_unique >= 0.9 * x.numel() + + # Check log_prob shape on full samples. + log_prob = d.log_prob(x) + assert log_prob.shape == sample_shape + d.batch_shape + + # Check log_prob shape on partial samples. + y = x + while y.dim() > len(d.event_shape): + y = y[0] + log_prob = d.log_prob(y) + assert log_prob.shape == d.batch_shape + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/test/jit/test_module_interface.py b/test/jit/test_module_interface.py index fb4ecc8c3c73..59055e996dfa 100644 --- a/test/jit/test_module_interface.py +++ b/test/jit/test_module_interface.py @@ -7,7 +7,7 @@ import os import sys from torch import Tensor -from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.jit_utils import JitTestCase, make_global # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -63,7 +63,6 @@ def forward(self, input: Tensor) -> Tensor: torch.jit.script(TestNotModuleInterfaceCall()) def test_module_interface(self): - global OneTwoModule, OneTwoClass @torch.jit.interface class OneTwoModule(nn.Module): def one(self, x: Tensor, y: Tensor) -> Tensor: @@ -107,6 +106,7 @@ def forward(self, x: Tensor) -> Tensor: def forward2(self, x: Tensor) -> Tensor: return self.two(self.one(x, x)) + 1 + make_global(OneTwoModule, OneTwoClass) def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor): return mod_list[0].forward(x) + mod_list[1].forward(x) @@ -155,7 +155,6 @@ def forward(self, input): self.checkModule(TestModule(), (input,)) def test_module_interface_subtype(self): - global OneTwoModule @torch.jit.interface class OneTwoModule(nn.Module): def one(self, x: Tensor, y: Tensor) -> Tensor: @@ -167,6 +166,7 @@ def two(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor: pass + make_global(OneTwoModule) @torch.jit.script def as_module_interface(x: OneTwoModule) -> OneTwoModule: return x @@ -200,22 +200,22 @@ def forward(self, x: Tensor) -> Tensor: as_module_interface(scripted_wrong_mod) # Check that interface implementations can be contravariant in argument types and covariant in return type. - global TensorToAny @torch.jit.interface class TensorToAny(nn.Module): def forward(self, input: torch.Tensor) -> Any: pass + make_global(TensorToAny) @torch.jit.script def as_tensor_to_any(x: TensorToAny) -> TensorToAny: return x - global AnyToAny @torch.jit.interface class AnyToAny(nn.Module): def forward(self, input: Any) -> Any: pass + make_global(AnyToAny) @torch.jit.script def as_any_to_any(x: AnyToAny) -> AnyToAny: return x diff --git a/test/jit/test_python_bindings.py b/test/jit/test_python_bindings.py new file mode 100644 index 000000000000..090efad55edd --- /dev/null +++ b/test/jit/test_python_bindings.py @@ -0,0 +1,37 @@ +import torch +from torch.testing._internal.jit_utils import JitTestCase + +if __name__ == "__main__": + raise RuntimeError( + "This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TestPythonBindings\n\n" + "instead." + ) + + +class TestPythonBindings(JitTestCase): + def test_cu_get_functions(self): + @torch.jit.script + def test_get_python_cu_fn(x: torch.Tensor): + return 2 * x + + cu = torch.jit._state._python_cu + self.assertTrue( + "test_get_python_cu_fn" in (str(fn.name) for fn in cu.get_functions()) + ) + + def test_cu_create_function(self): + @torch.jit.script + def fn(x: torch.Tensor): + return 2 * x + + cu = torch._C.CompilationUnit() + cu.create_function("test_fn", fn.graph) + + inp = torch.randn(5) + + self.assertEqual(inp * 2, cu.find_function("test_fn")(inp)) + self.assertEqual(cu.find_function("doesnt_exist"), None) + self.assertEqual(inp * 2, cu.test_fn(inp)) + with self.assertRaises(AttributeError): + cu.doesnt_exist(inp) diff --git a/test/mobile/test_lite_script_module.py b/test/mobile/test_lite_script_module.py index 3549582dcfac..033ab6a15011 100644 --- a/test/mobile/test_lite_script_module.py +++ b/test/mobile/test_lite_script_module.py @@ -1,4 +1,3 @@ -import unittest import torch import torch.utils.bundled_inputs from torch.utils.mobile_optimizer import * @@ -7,8 +6,9 @@ from collections import namedtuple from torch.jit.mobile import _load_for_lite_interpreter +from torch.testing._internal.common_utils import TestCase, run_tests -class TestLiteScriptModule(unittest.TestCase): +class TestLiteScriptModule(TestCase): def test_load_mobile_module(self): class MyTestModule(torch.nn.Module): @@ -287,4 +287,4 @@ def forward(self): script_module._save_to_buffer_for_lite_interpreter() if __name__ == '__main__': - unittest.main() + run_tests() diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py index 2a94e7cc43f1..bc68d722e342 100644 --- a/test/quantization/test_quantized_op.py +++ b/test/quantization/test_quantized_op.py @@ -2190,6 +2190,12 @@ def test_empty_batch(self): X = torch.ones((0, 2, 4, 4), dtype=torch.float32) qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point, dtype=torch.quint8) + + # upsample_nearest2d + qY = torch.nn.functional.upsample_nearest(qX, scale_factor=2) + np.testing.assert_equal(qY.size(), (0, 2, 8, 8), + "Quantized upsample_nearsest2d with batch size 0 failed.") + # relu qY = torch.nn.functional.relu(qX) np.testing.assert_equal(qY.size(), qX.size(), diff --git a/test/quantization/test_workflow_module.py b/test/quantization/test_workflow_module.py index 8a70ae149c29..3ce7b020e38a 100644 --- a/test/quantization/test_workflow_module.py +++ b/test/quantization/test_workflow_module.py @@ -411,7 +411,8 @@ def test_state_dict_respects_device_affinity(self): MovingAveragePerChannelMinMaxObserver, # TODO: enable this (separate PR) # HistogramObserver, - PlaceholderObserver, RecordingObserver, NoopObserver]) + PlaceholderObserver, RecordingObserver, NoopObserver, + FakeQuantize]) for device_source, device_target, obs_cls in test_cases: # calibrated source model diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index 06f3df93a99e..0ee99e7b5f68 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -130,7 +130,7 @@ def _check_cuobjdump_output(expected_values, is_ptx=False): err, output)) actual_arches = sorted(re.findall(r'sm_\d\d', output)) - expected_arches = ['sm_' + xx for xx in expected_values] + expected_arches = sorted(['sm_' + xx for xx in expected_values]) self.assertEqual(actual_arches, expected_arches, msg="Flags: {}, Actual: {}, Expected: {}\n" "Stderr: {}\nOutput: {}".format( @@ -180,11 +180,12 @@ def test_jit_cuda_archflags(self): # - Architecture names # - With/without '+PTX' - capability = torch.cuda.get_device_capability() + n = torch.cuda.device_count() + capabilities = {torch.cuda.get_device_capability(i) for i in range(n)} # expected values is length-2 tuple: (list of ELF, list of PTX) # note: there should not be more than one PTX value archflags = { - '': (['{}{}'.format(capability[0], capability[1])], None), + '': (['{}{}'.format(capability[0], capability[1]) for capability in capabilities], None), "Maxwell+Tegra;6.1": (['53', '61'], None), "Pascal 3.5": (['35', '60', '61'], None), "Volta": (['70'], ['70']), diff --git a/test/test_dispatch.py b/test/test_dispatch.py index 60ce16192318..80b9f9adeac1 100644 --- a/test/test_dispatch.py +++ b/test/test_dispatch.py @@ -1,5 +1,6 @@ import torch._C as C from torch.testing._internal.common_utils import TestCase, run_tests +from torch._python_dispatcher import PythonDispatcher from collections import namedtuple import itertools @@ -753,5 +754,137 @@ def test_overwrite_math(self): ''' ) +class TestPythonDispatcher(TestCase): + def test_basic(self): + dispatcher = PythonDispatcher() + dispatcher.register(["CPU", "XLA", "Math"]) + self.assertExpectedInline( + dispatcher.dispatchTable(), + '''\ + +Computed Dispatch Table +key kernel +--------------------------- +CPU fn_CPU [kernel] +XLA fn_XLA [kernel] +QuantizedCPU fn_Math [math kernel] +AutogradOther fn_Math [math kernel] +AutogradCPU fallthrough [backend fallback] +AutogradXLA fallthrough [backend fallback] +''' + ) + + def test_math_autogradcpu(self): + dispatcher = PythonDispatcher() + dispatcher.register(["CPU", "XLA", "Math", "AutogradCPU"]) + self.assertExpectedInline( + dispatcher.dispatchTable(), + '''\ + +Computed Dispatch Table +key kernel +--------------------------- +CPU fn_CPU [kernel] +XLA fn_XLA [kernel] +QuantizedCPU fn_Math [math kernel] +AutogradOther fn_Math [math kernel] +AutogradCPU fn_AutogradCPU [kernel] +AutogradXLA fallthrough [backend fallback] +''' + ) + self.assertExpectedInline( + dispatcher.registrations(), + '''\ + +Registered Kernels +key kernel +--------------------------- +CPU fn_CPU +XLA fn_XLA +AutogradCPU fn_AutogradCPU +Math[alias] fn_Math +''' + ) + + def test_defaultbackend_autogradcpu(self): + dispatcher = PythonDispatcher() + dispatcher.register(["CPU", "XLA", "DefaultBackend", "AutogradCPU"]) + self.assertExpectedInline( + dispatcher.dispatchTable(), + '''\ + +Computed Dispatch Table +key kernel +--------------------------- +CPU fn_CPU [kernel] +XLA fn_XLA [kernel] +QuantizedCPU fn_DefaultBackend [default backend kernel] +AutogradOther fallthrough [backend fallback] +AutogradCPU fn_AutogradCPU [kernel] +AutogradXLA fallthrough [backend fallback] +''' + ) + + self.assertExpectedInline( + dispatcher.registrations(), + '''\ + +Registered Kernels +key kernel +--------------------------- +CPU fn_CPU +XLA fn_XLA +AutogradCPU fn_AutogradCPU +DefaultBackend[alias] fn_DefaultBackend +''' + ) + + def test_autogradother(self): + dispatcher = PythonDispatcher() + dispatcher.register(["CPU", "QuantizedCPU", "Math"]) + self.assertExpectedInline( + dispatcher.dispatchTable(), + '''\ + +Computed Dispatch Table +key kernel +--------------------------- +CPU fn_CPU [kernel] +XLA fn_Math [math kernel] +QuantizedCPU fn_QuantizedCPU [kernel] +AutogradOther ambiguous_autogradother [ambiguous autogradother] +AutogradCPU fallthrough [backend fallback] +AutogradXLA fn_Math [math kernel] +''' + ) + + self.assertExpectedInline( + dispatcher.registrations(), + '''\ + +Registered Kernels +key kernel +--------------------------- +CPU fn_CPU +QuantizedCPU fn_QuantizedCPU +Math[alias] fn_Math +''' + ) + + def test_duplicate_registrations(self): + dispatcher = PythonDispatcher() + + with self.assertRaisesRegex(RuntimeError, r"Overriden is not allowed"): + dispatcher.register(["CPU", "CPU"]) + + def test_defaultbackend_math(self): + dispatcher = PythonDispatcher() + + with self.assertRaisesRegex( + RuntimeError, + r"Registration to both Math and DefaultBackend is not allowed"): + dispatcher.register(["DefaultBackend", "Math"]) + + if __name__ == '__main__': run_tests() diff --git a/test/test_jit.py b/test/test_jit.py index b0924dd09148..4d37cd0a3ef9 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -23,6 +23,7 @@ from jit.test_peephole import TestPeephole # noqa: F401 from jit.test_save_load import TestSaveLoad # noqa: F401 from jit.test_module_containers import TestModuleContainers # noqa: F401 +from jit.test_python_bindings import TestPythonBindings # noqa: F401 from jit.test_python_ir import TestPythonIr # noqa: F401 from jit.test_functional_blocks import TestFunctionalBlocks # noqa: F401 from jit.test_remove_mutation import TestRemoveMutation # noqa: F401 diff --git a/test/test_jit_py3.py b/test/test_jit_py3.py index e8694fd91aab..ad92ee0f281f 100644 --- a/test/test_jit_py3.py +++ b/test/test_jit_py3.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple from torch.testing._internal.common_utils import run_tests -from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.jit_utils import JitTestCase, make_global from torch.testing import FileCheck from torch import jit from textwrap import dedent @@ -727,7 +727,6 @@ def forward(self) -> int: def test_export_opnames_interface(self): - global OneTwoModule @torch.jit.interface class OneTwoModule(nn.Module): @@ -760,6 +759,8 @@ def two(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: return self.two(self.one(x, x)) + make_global(OneTwoModule) + class M(nn.Module): sub : OneTwoModule diff --git a/test/test_license.py b/test/test_license.py new file mode 100644 index 000000000000..cf04362d5e5e --- /dev/null +++ b/test/test_license.py @@ -0,0 +1,30 @@ +import io +import unittest + +from torch.testing._internal.common_utils import TestCase, run_tests + + +try: + from third_party.build_bundled import create_bundled +except ImportError: + create_bundled = None + +license_file = 'third_party/LICENSES_BUNDLED.txt' + +class TestLicense(TestCase): + + @unittest.skipIf(not create_bundled, "can only be run in a source tree") + def test_license_in_wheel(self): + current = io.StringIO() + create_bundled('third_party', current) + with open(license_file) as fid: + src_tree = fid.read() + if not src_tree == current.getvalue(): + raise AssertionError( + f'the contents of "{license_file}" do not ' + 'match the current state of the third_party files. Use ' + '"python third_party/build_bundled.py" to regenerate it') + + +if __name__ == '__main__': + run_tests() diff --git a/test/test_linalg.py b/test/test_linalg.py index 25fa0384a037..a115f1908641 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -24,7 +24,7 @@ skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, onlyOnCPUAndCUDA, dtypesIfCUDA, onlyCUDA) from torch.testing import floating_and_complex_types, floating_types, all_types -from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32 +from torch.testing._internal.common_cuda import SM53OrLater, tf32_on_and_off, CUDA11OrLater, CUDA9 from torch.autograd import gradcheck, gradgradcheck # Protects against includes accidentally setting the default dtype @@ -35,9 +35,6 @@ if TEST_SCIPY: import scipy -# TODO: make this common and import it -AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() - # See #49409, we should remove these if we end up with a global gradcheck setting gradcheck = partial(gradcheck, check_batched_grad=True) gradgradcheck = partial(gradgradcheck, check_batched_grad=True) @@ -247,7 +244,6 @@ def run_test(shape): # NOTE: old_cholesky* tests were moved here from test_torch.py and test_autograd.py @slowTest - @skipCUDAIf(True, "See issue #26789.") @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.double) @@ -864,7 +860,7 @@ def test_kron_errors_and_warnings(self, device, dtype): # dtypes should match out = torch.empty_like(a).to(torch.int) - with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + with self.assertRaisesRegex(RuntimeError, "can't be cast to the desired output type"): torch.kron(a, b, out=out) # This test confirms that torch.linalg.norm's dtype argument works @@ -996,6 +992,7 @@ def run_test_case(input, p, dim, keepdim): # their matrix norm results match @skipCUDAIfNoMagma @dtypes(torch.float, torch.double) + @precisionOverride({torch.float32: 2e-5}) def test_norm_matrix(self, device, dtype): def run_test_case(input, p, dim, keepdim): result = torch.linalg.norm(input, ord, dim, keepdim) @@ -1187,6 +1184,7 @@ def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex): @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.cfloat, torch.cdouble) + @precisionOverride({torch.cfloat: 2e-4}) def test_norm_complex(self, device, dtype): def gen_error_message(input_size, ord, keepdim, dim=None): return "complex norm failed for input size %s, ord=%s, keepdim=%s, dim=%s" % ( @@ -1229,6 +1227,7 @@ def gen_error_message(input_size, ord, keepdim, dim=None): # Test that linal.norm gives the same result as numpy when inputs # contain extreme values (inf, -inf, nan) + @skipCUDAIf(True, r"GPU Test is blocking torch.svd https://github.com/pytorch/pytorch/pull/48436") @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") @unittest.skipIf(IS_MACOS, "Skipped on MacOS!") @skipCUDAIfNoMagma @@ -1375,7 +1374,7 @@ def test_norm_fastpaths(self, device): @skipCPUIfNoLapack @skipCUDAIfNoMagma - @dtypes(torch.double, torch.float) + @dtypes(*floating_and_complex_types()) def test_eig_basic(self, device, dtype): a = torch.tensor([[1.96, 0.00, 0.00, 0.00, 0.00], [-6.49, 3.80, 0.00, 0.00, 0.00], @@ -1396,12 +1395,26 @@ def test_eig_basic(self, device, dtype): # # compare with numpy np_e, np_v = np.linalg.eig(a.cpu().numpy()) - # np_e.shape == (n, 2), where each column contain the real and - # imaginary parts of the result - self.assertEqual(ee[:, 0], np_e) # real part - self.assertEqual(ee[:, 1], torch.zeros(ee.shape[0], dtype=dtype)) # imaginary part + if dtype.is_complex: + self.assertEqual(ee, np_e) + else: + # np_e.shape == (n, 2), where each column contain the real and + # imaginary parts of the result + self.assertEqual(ee[:, 0], np_e) # real part + self.assertEqual(ee[:, 1], torch.zeros(ee.shape[0], dtype=dtype)) # imaginary part self.assertEqual(vv, np_v) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.complex64, torch.complex128) + def test_eig_backward_complex(self, device, dtype): + # torch.eig's backward is not supported yet for complex types. We + # should kill this test once it's implemented. + a = torch.tensor([[1., 2], [3, 4]], device=device, dtype=dtype, requires_grad=True) + with self.assertRaisesRegex(RuntimeError, + "eig does not support automatic differentiation for outputs with complex dtype"): + e, v = torch.eig(a, True) + @skipCPUIfNoLapack @skipCUDAIfNoMagma @dtypes(torch.double, torch.float) @@ -1594,7 +1607,7 @@ def gen_error_message(input_size, p, keepdim, dim=None): expected = np.linalg.norm(xn, p, keepdims=keepdim) msg = gen_error_message(x.size(), p, keepdim) self.assertEqual(res.shape, expected.shape, msg=msg) - self.assertEqual(res, expected, msg=msg) + self.assertEqual(res, expected, msg=msg, rtol=1.3e-6, atol=3e-4) # Ensure torch.norm with p='fro' and p=2 give the same results for mutually supported input combinations @dtypes(torch.float) @@ -1732,7 +1745,7 @@ def test_nuclear_norm_exceptions_old(self, device): self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) # ~~~ tests for torch.svd ~~~ - @skipCUDAIfNoMagma + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(torch.double) def test_svd(self, device, dtype): @@ -1791,7 +1804,7 @@ def run_test(dims, some, compute_uv): for dims, some, compute_uv in product(shapes, [True, False], [True, False]): run_test(dims, some, compute_uv) - @skipCUDAIfNoMagma + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(torch.float) def test_svd_no_singularvectors(self, device, dtype): @@ -1801,7 +1814,7 @@ def test_svd_no_singularvectors(self, device, dtype): u, s_actual, v = torch.svd(a, compute_uv=False) self.assertEqual(s_expect, s_actual, msg="Singular values don't match") - @skipCUDAIfNoMagma + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(torch.double) def test_svd_lowrank(self, device, dtype): @@ -1870,17 +1883,18 @@ def run_subtest(actual_rank, matrix_size, batches, device, svd_lowrank, **option actual_rank, size, batches = 2, (17, 4), () run_subtest(actual_rank, size, batches, device, jitted) - @onlyCPU + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(torch.cfloat) def test_svd_complex(self, device, dtype): + # this test verifies that torch.svd really returns V and not V.conj() + # see: https://github.com/pytorch/pytorch/issues/45821 t = torch.randn((10, 10), dtype=dtype, device=device) U, S, V = torch.svd(t, some=False) - # note: from the math point of view, it is weird that we need to use - # V.T instead of V.T.conj(): torch.svd has a buggy behavior for - # complex numbers and it's deprecated. You should use torch.linalg.svd - # instead. - t2 = U @ torch.diag(S).type(dtype) @ V.T + # verify that t ≈ t2 + # t2 = U @ diag(S) @ Vá´´ + # Vá´´ is the conjugate transpose of V + t2 = U @ torch.diag(S).type(dtype) @ V.conj().T self.assertEqual(t, t2) def _test_svd_helper(self, shape, some, col_maj, device, dtype): @@ -1901,44 +1915,44 @@ def _test_svd_helper(self, shape, some, col_maj, device, dtype): for x, y in zip(cpu_result, device_result): self.assertEqual(x[..., :m].abs(), y[..., :m].abs(), atol=1e-5, rtol=0) - @skipCUDAIfNoMagma + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(*floating_and_complex_types()) def test_svd_square(self, device, dtype): self._test_svd_helper((10, 10), True, False, device, dtype) - @skipCUDAIfNoMagma + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(*floating_types()) def test_svd_square_col_maj(self, device, dtype): self._test_svd_helper((10, 10), True, True, device, dtype) - @skipCUDAIfNoMagma + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(*floating_types()) def test_svd_tall_some(self, device, dtype): self._test_svd_helper((20, 5), True, False, device, dtype) - @skipCUDAIfNoMagma + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(*floating_types()) def test_svd_tall_all(self, device, dtype): self._test_svd_helper((20, 5), False, False, device, dtype) - @skipCUDAIfNoMagma + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(*floating_types()) def test_svd_tall_some_col_maj(self, device, dtype): self._test_svd_helper((5, 20), True, True, device, dtype) - @skipCUDAIfNoMagma + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(*floating_types()) def test_svd_tall_all_col_maj(self, device, dtype): self._test_svd_helper((5, 20), False, True, device, dtype) # ~~~ tests for torch.linalg.svd ~~~ - @skipCUDAIfNoMagma + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) def test_linalg_svd_compute_uv(self, device, dtype): @@ -1952,7 +1966,10 @@ def test_linalg_svd_compute_uv(self, device, dtype): # check linalg.svd vs numpy expected = np.linalg.svd(np_t, full_matrices, compute_uv=True) actual = torch.linalg.svd(t, full_matrices, compute_uv=True) - self.assertEqual(actual, expected) + # sign/phase of the singular vectors is not unique and therefore absolute values are compared + self.assertEqual(abs(actual[0]), abs(expected[0])) + self.assertEqual(actual[1], expected[1]) + self.assertEqual(abs(actual[2]), abs(expected[2])) # check linalg.svd vs linalg.svd(out=...) out = (torch.empty_like(actual[0]), torch.empty_like(actual[1]), @@ -1961,7 +1978,7 @@ def test_linalg_svd_compute_uv(self, device, dtype): self.assertEqual(actual, out) self.assertEqual(actual, out2) - @skipCUDAIfNoMagma + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) def test_linalg_svd_no_compute_uv(self, device, dtype): @@ -1990,7 +2007,7 @@ def is_empty(x): assert USV.V is out[2] self.assertEqual(USV.S, np_s) - @skipCUDAIfNoMagma + @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @onlyCUDA @dtypes(torch.float) @@ -2060,7 +2077,6 @@ def test_cholesky_solve_batched_non_contiguous(self, device, dtype): self.assertEqual(x, x_exp) @slowTest - @skipCUDAIf(True, "See https://github.com/pytorch/pytorch/issues/48996") @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) @@ -2177,12 +2193,9 @@ def run_test(torch_inverse, matrix, batches, n): for torch_inverse in [torch.inverse, torch.linalg.inv]: for batches, n in itertools.product( - [[], [0], [1], [4], [2, 3]], + [[], [0], [1], [2], [4], [2, 3]], [0, 5, 64] ): - # large batch size and large matrix size will be tested in test_inverse_many_batches (slow test) - if batches and batches[0] == 32 and n == 256: - continue matrices = random_fullrank_matrix_distinct_singular_value(n, *batches, dtype=dtype).to(device) run_test(torch_inverse, matrices, batches, n) @@ -3826,12 +3839,12 @@ def call_torch_fn(*args, **kwargs): self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape]) @skipCUDAIfRocm - @dtypesIfCUDA(*(torch.float, torch.double, torch.cfloat, torch.cdouble) + - # This test is disabled on CUDA 9, due to: - # See: https://github.com/pytorch/pytorch/issues/31006 - ((torch.half,) if torch.version.cuda and not torch.version.cuda.startswith('9.') else ())) + @dtypesIfCUDA(torch.cfloat, torch.cdouble, + *torch.testing.get_all_fp_dtypes(include_half=not CUDA9, include_bfloat16=(CUDA11OrLater and SM53OrLater))) @dtypes(*(set(torch.testing.get_all_dtypes()) - {torch.half, torch.bool})) def test_blas_alpha_beta_empty(self, device, dtype): + # This test is disabled on CUDA 9 due to: + # See: https://github.com/pytorch/pytorch/issues/31006 if dtype is torch.bfloat16 and self.device_type == 'xla': # TODO (@zasdfgbnm): this causes the following error on test # TestTorchDeviceTypeXLA.test_blas_alpha_beta_empty_xla_bfloat16: @@ -3916,14 +3929,64 @@ def test_renorm_ps(self, device): @onlyCPU @skipCPUIfNoLapack - def test_orgqr_errors(self, device): + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_orgqr(self, device, dtype): + def generate_reflectors_and_tau(A): + """ + This function uses numpy.linalg.qr with mode "raw" to extract output of LAPACK's geqrf. + There is torch.geqrf function but it doesn't work with complex-valued input. + """ + if A.numel() > 0: + A_cpu = A.cpu() + flattened_batch_shape = [-1, *A_cpu.shape[-2:]] + reflectors = torch.empty_like(A_cpu).view(*flattened_batch_shape) + tau_shape = [*A_cpu.shape[:-2], A_cpu.shape[-1]] + tau = torch.empty(tau_shape, dtype=dtype).view(-1, A_cpu.shape[-1]) + for A_i, reflectors_i, tau_i in zip(A_cpu.contiguous().view(*flattened_batch_shape), reflectors, tau): + reflectors_tmp, tau_i[:] = map(torch.from_numpy, np.linalg.qr(A_i, mode='raw')) + reflectors_i[:] = reflectors_tmp.T + reflectors = reflectors.view(*A_cpu.shape) + tau = tau.view(tau_shape) + return reflectors.to(A.device), tau.to(A.device) + + reflectors = torch.empty_like(A) + tau = torch.empty(*A.shape[:-2], A.shape[-1], dtype=dtype, device=device) + return reflectors, tau + + def run_test(shape): + A = torch.randn(*shape, dtype=dtype, device=device) + reflectors, tau = generate_reflectors_and_tau(A) + expected, _ = torch.linalg.qr(A) + actual = torch.orgqr(reflectors, tau) + # torch.linalg.qr does not work correctly for zero batch dimension tensors + # see https://github.com/pytorch/pytorch/issues/50576 + if (A.numel() > 0): + self.assertEqual(expected, actual) + else: + self.assertTrue(actual.shape == shape) + + out = torch.empty_like(A) + ans = torch.orgqr(reflectors, tau, out=out) + self.assertEqual(ans, out) + if (A.numel() > 0): + self.assertEqual(expected, out) + + shapes = [(0, 0), (5, 0), # Empty matrix + (5, 5), (5, 3), # Single matrix + (0, 0, 0), (0, 5, 5), (0, 5, 3), # Zero batch dimension tensors + (2, 5, 5), (2, 5, 3), # 3-dim tensors + (2, 1, 5, 5), (2, 1, 5, 3)] # 4-dim tensors + for shape in shapes: + run_test(shape) + + @onlyCPU + @skipCPUIfNoLapack + def test_orgqr_errors_and_warnings(self, device): test_cases = [ # input1 size, input2 size, error regex - ((10,), (2,), r"'input' should be 2 dimensional"), - ((10, 6), (20,), r"input.size\(1\) must be greater than or equal to input2.size\(0\)"), - ((6, 10), (5,), r"input.size\(0\) must be greater than or equal to input.size\(1\)"), - ((0, 0), (0,), r"'input' should not be empty"), - ((2, 2), (2, 0,), r"'tau' should not be empty") + ((10,), (2,), r"input must have at least 2 dimensions"), + ((10, 6), (20,), r"input.shape\[-1\] must be greater than or equal to tau.shape\[-1\]"), + ((6, 10), (5,), r"input.shape\[-2\] must be greater than or equal to input.shape\[-1\]"), ] for a_size, tau_size, error_regex in test_cases: a = torch.rand(*a_size, device=device) @@ -3931,6 +3994,40 @@ def test_orgqr_errors(self, device): with self.assertRaisesRegex(RuntimeError, error_regex): torch.orgqr(a, tau) + # if out tensor with wrong shape is passed a warning is given + reflectors = torch.randn(3, 3, device=device) + tau = torch.randn(3, device=device) + out = torch.empty(2, 3, device=device) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.orgqr(reflectors, tau, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(reflectors).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match the expected dtype"): + torch.orgqr(reflectors, tau, out=out) + + with self.assertRaisesRegex(RuntimeError, "tau dtype Int does not match input dtype"): + torch.orgqr(reflectors, tau.to(torch.int)) + + # TODO: enable the following tests when orgqr is implemented for CUDA + if torch.cuda.is_available(): + with self.assertRaisesRegex(RuntimeError, "the operator doesn't exist for this backend"): + # device of out and input should match + wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' + out = torch.empty_like(reflectors).to(wrong_device) + # with self.assertRaisesRegex(RuntimeError, "Expected result and input to be on the same device"): + torch.orgqr(reflectors, tau, out=out) + + # device of tau and input should match + wrong_device = 'cpu' if self.device_type != 'cpu' else 'cuda' + tau = tau.to(wrong_device) + # with self.assertRaisesRegex(RuntimeError, "Expected input and tau to be on the same device"): + torch.orgqr(reflectors, tau) + @precisionOverride({torch.complex64: 5e-6}) @skipCUDAIfNoMagma @skipCPUIfNoLapack @@ -4367,8 +4464,8 @@ def _test_addmm_addmv(self, f, t, m, v, *, alpha=None, beta=None, transpose_out= @precisionOverride({torch.bfloat16: 1e-0, torch.half: 5e-4, torch.float: 1e-4, torch.double: 1e-8, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), - *([torch.float32, torch.float64, torch.bfloat16] - if TEST_WITH_ROCM else torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))) + *torch.testing.get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)), + include_half=(not TEST_WITH_ROCM))) @dtypes(torch.bfloat16, torch.float, torch.double, torch.cfloat, torch.cdouble) def test_addmv(self, device, dtype): # have to use torch.randn(...).to(bfloat16) instead of @@ -4402,8 +4499,7 @@ def test_addmv(self, device, dtype): for m, v in itertools.product(ms, vs): self._test_addmm_addmv(torch.addmv, t, m, v, beta=0) - @dtypesIfCUDA(*([torch.half, torch.float, torch.double] - + ([torch.bfloat16] if TEST_WITH_ROCM else []))) + @dtypesIfCUDA(*torch.testing.get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)))) @dtypes(torch.float, torch.double) def test_addmv_rowmajor_colmajor_incx_incy_lda(self, device, dtype): # tests (o, s)*(s). o is output size, s is summed size. @@ -4434,7 +4530,8 @@ def _test(row_major, incx, incy, lda_tail): @precisionOverride({torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, torch.half: 1e-1, torch.cfloat: 1e-4, torch.cdouble: 1e-8}) - @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)) + @dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), + *torch.testing.get_all_fp_dtypes(include_bfloat16=(TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater)))) @dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes()) @tf32_on_and_off(0.05) def test_addmm(self, device, dtype): @@ -4609,19 +4706,25 @@ def test_strided_mm_bmm(self, device, dtype): @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) @tf32_on_and_off(0.05) def test_bmm(self, device, dtype): + if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater: + # cuBLAS does not guarantee BFloat16 support on SM < 53. + # So on PyTorch, we consider BFloat16 support on SM < 53 as + # undefined bahavior + return + num_batches = 10 M, N, O = 23, 8, 12 numpy_dtype = dtype if dtype != torch.bfloat16 else torch.float32 - if self.device_type == 'cpu': - is_supported = True - elif self.device_type == 'cuda': - is_supported = True if dtype != torch.bfloat16 else AMPERE_OR_ROCM + is_supported = True + if dtype == torch.bfloat16 and self.device_type == 'cuda': + is_supported = TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater) if not is_supported: b1 = torch.randn(num_batches, M, N, device=device).to(dtype) b2 = torch.randn(num_batches, N, O, device=device).to(dtype) - self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.bmm(b1, b2)) + self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", + lambda: torch.bmm(b1, b2)) return def invert_perm(p): @@ -4784,21 +4887,28 @@ def _test_addbmm_baddbmm(self, func, b1, b2, ref, out_tensor): @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) @tf32_on_and_off(0.05) def test_addbmm(self, device, dtype): + if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater: + # cuBLAS does not guarantee BFloat16 support on SM < 53. + # So on PyTorch, we consider BFloat16 support on SM < 53 as + # undefined bahavior + return + num_batches = 2 M, N, O = 2, 3, 4 - if self.device_type == 'cpu': - is_supported = True - if dtype == torch.bfloat16: + is_supported = True + if dtype == torch.bfloat16: + if self.device_type == 'cpu': self.precision = 1 # 43 vs 43.75 - else: - is_supported = (dtype != torch.bfloat16 or AMPERE_OR_ROCM) + else: + is_supported = TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater) if not is_supported: b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) t = make_tensor((M, O), device, dtype, low=-1, high=1) - self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.addbmm(t, b1, b2)) + self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", + lambda: torch.addbmm(t, b1, b2)) return def invert_perm(p): @@ -4850,19 +4960,25 @@ def generate_tensor(): @dtypes(*torch.testing.get_all_fp_dtypes(), *torch.testing.get_all_complex_dtypes()) @tf32_on_and_off(0.05) def test_baddbmm(self, device, dtype): + if self.device_type == 'cuda' and dtype is torch.bfloat16 and CUDA11OrLater and not SM53OrLater: + # cuBLAS does not guarantee BFloat16 support on SM < 53. + # So on PyTorch, we consider BFloat16 support on SM < 53 as + # undefined bahavior + return + num_batches = 10 M, N, O = 12, 8, 5 - if self.device_type == 'cpu': - is_supported = True - elif self.device_type == 'cuda': - is_supported = True if dtype != torch.bfloat16 else AMPERE_OR_ROCM + is_supported = True + if dtype == torch.bfloat16 and self.device_type == 'cuda': + is_supported = TEST_WITH_ROCM or (CUDA11OrLater and SM53OrLater) if not is_supported: b1 = make_tensor((num_batches, M, N), device, dtype, low=-1, high=1) b2 = make_tensor((num_batches, N, O), device, dtype, low=-1, high=1) t = make_tensor((num_batches, M, O), device, dtype, low=-1, high=1) - self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|Ampere", lambda: torch.baddbmm(t, b1, b2)) + self.assertRaisesRegex(RuntimeError, "type|Type|not implemented|CUBLAS_STATUS_NOT_SUPPORTED", + lambda: torch.baddbmm(t, b1, b2)) return def invert_perm(p): diff --git a/test/test_nn.py b/test/test_nn.py index f8baa307daae..60e4db8930ef 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -8588,6 +8588,18 @@ def test_upsamplingNearest2d(self): self.assertEqual(torch.ones(1, 1, 4, 4).contiguous(memory_format=memory_format), out_t.data) self.assertEqual(torch.ones(1, 1, 4, 4, dtype=torch.uint8).contiguous(memory_format=memory_format), out_uint8_t.data) + # test forward when input's height is not same as width + m = nn.Upsample(size=(4, 2), mode='nearest') + in_t = torch.ones(1, 1, 2, 1).contiguous(memory_format=memory_format) + with warnings.catch_warnings(record=True) as w: + out_t = m(in_t) + self.assertEqual(torch.ones(1, 1, 4, 2).contiguous(memory_format=memory_format), out_t.data) + + # test backward when input's height is not same as width + input = torch.ones(1, 1, 2, 1, requires_grad=True).contiguous(memory_format=memory_format) + gradcheck(lambda x: F.interpolate(x, size=(4, 2), mode='nearest'), [input]) + gradgradcheck(lambda x: F.interpolate(x, size=(4, 2), mode='nearest'), [input]) + input = torch.randn(1, 1, 2, 2, requires_grad=True).contiguous(memory_format=memory_format) self.assertEqual( F.interpolate(input, 4, mode='nearest'), diff --git a/test/test_ops.py b/test/test_ops.py index 26a3ee69f95a..bd82aca3820a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -43,7 +43,6 @@ def test_unsupported_dtypes(self, device, dtype, op): # Verifies that ops have their supported dtypes # registered correctly by testing that each claimed supported dtype # does NOT throw a runtime error - @skipCUDAIfRocm @onlyOnCPUAndCUDA @ops(op_db, dtypes=OpDTypes.supported) def test_supported_dtypes(self, device, dtype, op): diff --git a/test/test_tensorexpr_pybind.py b/test/test_tensorexpr_pybind.py index d8db2fef89ff..71c51bf019c7 100644 --- a/test/test_tensorexpr_pybind.py +++ b/test/test_tensorexpr_pybind.py @@ -1,6 +1,6 @@ import torch -import unittest +from torch.testing._internal.common_utils import run_tests from torch.testing._internal.jit_utils import JitTestCase class kernel_arena_scope(object): @@ -37,4 +37,4 @@ def compute(i): torch.testing.assert_allclose(tA + tB, tC) if __name__ == '__main__': - unittest.main() + run_tests() diff --git a/test/test_testing.py b/test/test_testing.py index 4ff215233fe2..280a312914ee 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -1,11 +1,13 @@ import torch import math +from pathlib import PurePosixPath from torch.testing._internal.common_utils import \ (TestCase, make_tensor, run_tests, slowTest) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, onlyCUDA, onlyOnCPUAndCUDA, dtypes) +from torch.testing._internal import mypy_wrapper # For testing TestCase methods and torch.testing functions class TestTesting(TestCase): @@ -487,5 +489,61 @@ def test_trivial_passing_test_case_on_cpu_cuda(self, device): instantiate_device_type_tests(TestTesting, globals()) + +class TestMypyWrapper(TestCase): + def test_glob(self): + # can match individual files + self.assertTrue(mypy_wrapper.glob( + pattern='test/test_torch.py', + filename=PurePosixPath('test/test_torch.py'), + )) + self.assertFalse(mypy_wrapper.glob( + pattern='test/test_torch.py', + filename=PurePosixPath('test/test_testing.py'), + )) + + # dir matters + self.assertFalse(mypy_wrapper.glob( + pattern='tools/codegen/utils.py', + filename=PurePosixPath('torch/nn/modules.py'), + )) + self.assertTrue(mypy_wrapper.glob( + pattern='setup.py', + filename=PurePosixPath('setup.py'), + )) + self.assertFalse(mypy_wrapper.glob( + pattern='setup.py', + filename=PurePosixPath('foo/setup.py'), + )) + self.assertTrue(mypy_wrapper.glob( + pattern='foo/setup.py', + filename=PurePosixPath('foo/setup.py'), + )) + + # can match dirs + self.assertTrue(mypy_wrapper.glob( + pattern='torch', + filename=PurePosixPath('torch/random.py'), + )) + self.assertTrue(mypy_wrapper.glob( + pattern='torch', + filename=PurePosixPath('torch/nn/cpp.py'), + )) + self.assertFalse(mypy_wrapper.glob( + pattern='torch', + filename=PurePosixPath('tools/fast_nvcc/fast_nvcc.py'), + )) + + # can match wildcards + self.assertTrue(mypy_wrapper.glob( + pattern='tools/autograd/*.py', + filename=PurePosixPath('tools/autograd/gen_autograd.py'), + )) + self.assertFalse(mypy_wrapper.glob( + pattern='tools/autograd/*.py', + filename=PurePosixPath('tools/autograd/deprecated.yaml'), + )) + + if __name__ == '__main__': run_tests() diff --git a/test/test_torch.py b/test/test_torch.py index 9c208994a014..6bfdf3b6e8b8 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6799,7 +6799,6 @@ def inner(self, device, dtype): 1e-5, 1e-5, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), ('topk', 'dim_desc_sort', _small_3d_unique, lambda t, d: [2, 1, True, True], 1e-5, 1e-5, 1e-5, torch.testing.get_all_dtypes(include_complex=False, include_bool=False), _cpu_types, False), - ('trace', '', _medium_2d, lambda t, d: [], 1e-3, 1e-5, 1e-5, _types, _cpu_types, False), ('tril', '', _medium_2d, lambda t, d: [],), ('tril', 'zero_stride', _medium_2d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, _cpu_types, False), ('tril', 'positive', _medium_2d, lambda t, d: [2], ), diff --git a/test/test_type_hints.py b/test/test_type_hints.py index cb3ab3bce796..97fa53cb4c83 100644 --- a/test/test_type_hints.py +++ b/test/test_type_hints.py @@ -1,5 +1,6 @@ import unittest from torch.testing._internal.common_utils import TestCase, run_tests, set_cwd +from torch.testing._internal.mypy_wrapper import config_files import tempfile import torch import doctest @@ -7,7 +8,7 @@ import inspect try: - import mypy.api # type: ignore + import mypy.api HAVE_MYPY = True except ImportError: HAVE_MYPY = False @@ -129,55 +130,45 @@ def test_doc_examples(self): @unittest.skipIf(not HAVE_MYPY, "need mypy") def test_run_mypy(self): """ - Runs mypy over all files specified in mypy.ini - Note that mypy.ini is not shipped in an installed version of PyTorch, - so this test will only run mypy in a development setup or in CI. + Runs mypy over all files specified in our mypy configs + Note that our mypy configs are not shipped in an installed + version of PyTorch, so this test will only run mypy in a + development setup or in CI. """ def is_torch_mypyini(path_to_file): with open(path_to_file, 'r') as f: first_line = f.readline() - if first_line.startswith('# This is the PyTorch MyPy config file'): + name = os.path.basename(path_to_file) + if first_line.startswith(f'# This is the PyTorch {name} file'): return True return False - test_dir = os.path.dirname(os.path.realpath(__file__)) - repo_rootdir = os.path.join(test_dir, '..') - mypy_inifile = os.path.join(repo_rootdir, 'mypy.ini') - if not (os.path.exists(mypy_inifile) and is_torch_mypyini(mypy_inifile)): - self.skipTest("Can't find PyTorch MyPy config file") - - import numpy - if numpy.__version__ == '1.20.0.dev0+7af1024': - self.skipTest("Typeannotations in numpy-1.20.0-dev are broken") - - # TODO: Would be better not to chdir here, this affects the entire - # process! - with set_cwd(repo_rootdir): - (stdout, stderr, result) = mypy.api.run([]) - - if result != 0: - self.fail(f"mypy failed: {stdout} {stderr}") - - @unittest.skipIf(not HAVE_MYPY, "need mypy") - def test_run_mypy_strict(self): - """ - Runs mypy over all files specified in mypy-strict.ini - """ - test_dir = os.path.dirname(os.path.realpath(__file__)) - repo_rootdir = os.path.join(test_dir, '..') - mypy_inifile = os.path.join(repo_rootdir, 'mypy-strict.ini') - if not os.path.exists(mypy_inifile): - self.skipTest("Can't find PyTorch MyPy strict config file") - - with set_cwd(repo_rootdir): - (stdout, stderr, result) = mypy.api.run([ - '--config', mypy_inifile, - ]) - - if result != 0: - self.fail(f"mypy failed: {stdout} {stderr}") + # to add more configs, edit the implementation of the + # config_files function rather than editing this test or adding + # more tests to this suite + for ini in config_files(): + with self.subTest(msg=ini): + test_dir = os.path.dirname(os.path.realpath(__file__)) + repo_rootdir = os.path.join(test_dir, '..') + mypy_inifile = os.path.join(repo_rootdir, ini) + if not (os.path.exists(mypy_inifile) and is_torch_mypyini(mypy_inifile)): + self.skipTest("Can't find PyTorch MyPy config file") + + import numpy + if numpy.__version__.startswith('1.20.0.dev0'): + self.skipTest("Typeannotations in numpy-1.20.0-dev are broken") + + # TODO: Would be better not to chdir here, this affects + # the entire process! + with set_cwd(repo_rootdir): + (stdout, stderr, result) = mypy.api.run([ + '--config', mypy_inifile, + ]) + + if result != 0: + self.fail(f"mypy failed: {stdout} {stderr}") if __name__ == '__main__': run_tests() diff --git a/third_party/LICENSES_BUNDLED.txt b/third_party/LICENSES_BUNDLED.txt new file mode 100644 index 000000000000..c1c9a1783964 --- /dev/null +++ b/third_party/LICENSES_BUNDLED.txt @@ -0,0 +1,259 @@ +The Pytorch repository and source distributions bundle several libraries that are +compatibly licensed. We list these here. + +Name: FP16 +License: MIT +Files: third_party/FP16 + For details, see third_party/FP16/LICENSE + +Name: FXdiv +License: MIT +Files: third_party/FXdiv + For details, see third_party/FXdiv/LICENSE + +Name: NNPACK +License: BSD-2-Clause +Files: third_party/NNPACK + For details, see third_party/NNPACK/LICENSE + +Name: QNNPACK +License: BSD-3-Clause +Files: third_party/QNNPACK + For details, see third_party/QNNPACK/LICENSE + +Name: XNNPACK +License: BSD-3-Clause +Files: third_party/XNNPACK + For details, see third_party/XNNPACK/LICENSE + +Name: benchmark +License: Apache-2.0 +Files: third_party/benchmark, + third_party/protobuf/third_party/benchmark, + third_party/onnx-tensorrt/third_party/onnx/third_party/benchmark, + third_party/onnx/third_party/benchmark + For details, see third_party/benchmark/LICENSE, + third_party/protobuf/third_party/benchmark/LICENSE, + third_party/onnx-tensorrt/third_party/onnx/third_party/benchmark/LICENSE, + third_party/onnx/third_party/benchmark/LICENSE + +Name: clog +License: BSD-2-Clause +Files: third_party/cpuinfo/deps/clog, + third_party/fbgemm/third_party/cpuinfo/deps/clog, + third_party/QNNPACK/deps/clog + For details, see third_party/cpuinfo/deps/clog/LICENSE, + third_party/fbgemm/third_party/cpuinfo/deps/clog/LICENSE, + third_party/QNNPACK/deps/clog/LICENSE + +Name: cpuinfo +License: BSD-2-Clause +Files: third_party/cpuinfo, + third_party/fbgemm/third_party/cpuinfo + For details, see third_party/cpuinfo/LICENSE, + third_party/fbgemm/third_party/cpuinfo/LICENSE + +Name: eigen +License: BSD-3-Clause +Files: third_party/eigen + For details, see third_party/eigen/COPYING.BSD + +Name: enum +License: BSD-3-Clause +Files: third_party/python-enum/enum + For details, see third_party/python-enum/enum/LICENSE + +Name: fbgemm +License: BSD-3-Clause +Files: third_party/fbgemm + For details, see third_party/fbgemm/LICENSE + +Name: fmt +License: MIT with exception +Files: third_party/kineto/libkineto/third_party/fmt, + third_party/fmt + For details, see third_party/kineto/libkineto/third_party/fmt/LICENSE.rst, + third_party/fmt/LICENSE.rst + +Name: foxi +License: MIT +Files: third_party/foxi + For details, see third_party/foxi/LICENSE + +Name: gemmlowp +License: Apache-2.0 +Files: third_party/gemmlowp/gemmlowp + For details, see third_party/gemmlowp/gemmlowp/LICENSE + +Name: generator +License: Apache-2.0 +Files: third_party/kineto/libkineto/third_party/googletest/googlemock/scripts/generator, + third_party/googletest/googlemock/scripts/generator, + third_party/fbgemm/third_party/googletest/googlemock/scripts/generator, + third_party/protobuf/third_party/googletest/googlemock/scripts/generator, + third_party/tensorpipe/third_party/googletest/googlemock/scripts/generator + For details, see third_party/kineto/libkineto/third_party/googletest/googlemock/scripts/generator/LICENSE, + third_party/googletest/googlemock/scripts/generator/LICENSE, + third_party/fbgemm/third_party/googletest/googlemock/scripts/generator/LICENSE, + third_party/protobuf/third_party/googletest/googlemock/scripts/generator/LICENSE, + third_party/tensorpipe/third_party/googletest/googlemock/scripts/generator/LICENSE + +Name: gloo +License: BSD-3-Clause +Files: third_party/gloo + For details, see third_party/gloo/LICENSE + +Name: googlemock +License: BSD-3-Clause +Files: third_party/kineto/libkineto/third_party/googletest/googlemock, + third_party/googletest/googlemock, + third_party/fbgemm/third_party/googletest/googlemock, + third_party/protobuf/third_party/googletest/googlemock, + third_party/tensorpipe/third_party/googletest/googlemock + For details, see third_party/kineto/libkineto/third_party/googletest/googlemock/LICENSE, + third_party/googletest/googlemock/LICENSE, + third_party/fbgemm/third_party/googletest/googlemock/LICENSE, + third_party/protobuf/third_party/googletest/googlemock/LICENSE, + third_party/tensorpipe/third_party/googletest/googlemock/LICENSE + +Name: googletest +License: BSD-3-Clause +Files: third_party/kineto/libkineto/third_party/googletest, + third_party/kineto/libkineto/third_party/googletest/googletest, + third_party/googletest, + third_party/googletest/googletest, + third_party/fbgemm/third_party/googletest, + third_party/fbgemm/third_party/googletest/googletest, + third_party/protobuf/third_party/googletest, + third_party/protobuf/third_party/googletest/googletest, + third_party/tensorpipe/third_party/googletest, + third_party/tensorpipe/third_party/googletest/googletest + For details, see third_party/kineto/libkineto/third_party/googletest/LICENSE, + third_party/kineto/libkineto/third_party/googletest/googletest/LICENSE, + third_party/googletest/LICENSE, + third_party/googletest/googletest/LICENSE, + third_party/fbgemm/third_party/googletest/LICENSE, + third_party/fbgemm/third_party/googletest/googletest/LICENSE, + third_party/protobuf/third_party/googletest/LICENSE, + third_party/protobuf/third_party/googletest/googletest/LICENSE, + third_party/tensorpipe/third_party/googletest/LICENSE, + third_party/tensorpipe/third_party/googletest/googletest/LICENSE + +Name: gtest +License: BSD-3-Clause +Files: third_party/ideep/mkl-dnn/tests/gtests/gtest + For details, see third_party/ideep/mkl-dnn/tests/gtests/gtest/LICENSE + +Name: ideep +License: MIT +Files: third_party/ideep + For details, see third_party/ideep/LICENSE + +Name: ios-cmake +License: BSD-3-Clause +Files: third_party/ios-cmake + For details, see third_party/ios-cmake/LICENSE + +Name: kineto +License: BSD-3-Clause +Files: third_party/kineto + For details, see third_party/kineto/LICENSE + +Name: libnop +License: Apache-2.0 +Files: third_party/tensorpipe/third_party/libnop + For details, see third_party/tensorpipe/third_party/libnop/LICENSE + +Name: libuv +License: MIT +Files: third_party/tensorpipe/third_party/libuv + For details, see third_party/tensorpipe/third_party/libuv/LICENSE + +Name: miniz-2.0.8 +License: MIT +Files: third_party/miniz-2.0.8 + For details, see third_party/miniz-2.0.8/LICENSE + +Name: mkl-dnn +License: Apache-2.0 +Files: third_party/ideep/mkl-dnn + For details, see third_party/ideep/mkl-dnn/LICENSE + +Name: nccl +License: BSD-3-Clause +Files: third_party/nccl/nccl + For details, see third_party/nccl/nccl/LICENSE.txt + +Name: neon2sse +License: BSD-Source-Code +Files: third_party/neon2sse + For details, see third_party/neon2sse/LICENSE + +Name: onnx +License: MIT +Files: third_party/onnx-tensorrt/third_party/onnx, + third_party/onnx + For details, see third_party/onnx-tensorrt/third_party/onnx/LICENSE, + third_party/onnx/LICENSE + +Name: onnx-tensorrt +License: MIT +Files: third_party/onnx-tensorrt + For details, see third_party/onnx-tensorrt/LICENSE + +Name: protobuf +License: BSD-3-Clause +Files: third_party/protobuf + For details, see third_party/protobuf/LICENSE + +Name: psimd +License: MIT +Files: third_party/psimd + For details, see third_party/psimd/LICENSE + +Name: pthreadpool +License: BSD-2-Clause +Files: third_party/pthreadpool + For details, see third_party/pthreadpool/LICENSE + +Name: pybind11 +License: BSD-3-Clause +Files: third_party/pybind11, + third_party/onnx-tensorrt/third_party/onnx/third_party/pybind11, + third_party/onnx/third_party/pybind11, + third_party/tensorpipe/third_party/pybind11 + For details, see third_party/pybind11/LICENSE, + third_party/onnx-tensorrt/third_party/onnx/third_party/pybind11/LICENSE, + third_party/onnx/third_party/pybind11/LICENSE, + third_party/tensorpipe/third_party/pybind11/LICENSE + +Name: python-peachpy +License: BSD-2-Clause +Files: third_party/python-peachpy + For details, see third_party/python-peachpy/LICENSE.rst + +Name: python-six +License: MIT +Files: third_party/python-six + For details, see third_party/python-six/LICENSE + +Name: sleef +License: BSL-1.0 +Files: third_party/sleef + For details, see third_party/sleef/LICENSE.txt + +Name: tbb +License: Apache-2.0 +Files: third_party/tbb + For details, see third_party/tbb/LICENSE + +Name: tensorpipe +License: BSD-3-Clause +Files: third_party/tensorpipe + For details, see third_party/tensorpipe/LICENSE.txt + +Name: zstd +License: BSD-3-Clause +Files: third_party/zstd + For details, see third_party/zstd/LICENSE + diff --git a/third_party/build_bundled.py b/third_party/build_bundled.py new file mode 100644 index 000000000000..0777450e9e91 --- /dev/null +++ b/third_party/build_bundled.py @@ -0,0 +1,158 @@ +import os + + +mydir = os.path.dirname(__file__) +licenses = {'LICENSE', 'LICENSE.txt', 'LICENSE.rst', 'COPYING.BSD'} + + +def collect_license(current): + collected = {} + for root, dirs, files in os.walk(current): + license = list(licenses & set(files)) + if license: + name = root.split('/')[-1] + license_file = os.path.join(root, license[0]) + try: + ident = identify_license(license_file) + except ValueError: + raise ValueError('could not identify license file ' + f'for {root}') from None + val = { + 'Name': name, + 'Files': [root], + 'License': ident, + 'License_file': [license_file], + } + if name in collected: + # Only add it if the license is different + if collected[name]['License'] == ident: + collected[name]['Files'].append(root) + collected[name]['License_file'].append(license_file) + else: + collected[name + f' ({root})'] = val + else: + collected[name] = val + return collected + + +def create_bundled(d, outstream): + """Write the information to an open outstream""" + collected = collect_license(d) + sorted_keys = sorted(collected.keys()) + outstream.write('The Pytorch repository and source distributions bundle ' + 'several libraries that are \n') + outstream.write('compatibly licensed. We list these here.\n\n') + for k in sorted_keys: + c = collected[k] + files = ',\n '.join(c['Files']) + license_file = ',\n '.join(c['License_file']) + outstream.write(f"Name: {c['Name']}\n") + outstream.write(f"License: {c['License']}\n") + outstream.write(f"Files: {files}\n") + outstream.write(' For details, see ') + outstream.write(license_file) + outstream.write('\n\n') + + +def identify_license(f, exception=''): + """ + Read f and try to identify the license type + This is __very__ rough and probably not legally binding, it is specific for + this repo. + """ + def squeeze(t): + """Remove 'n and ' ', normalize quotes + """ + t = t.replace('\n', '').replace(' ', '') + t = t.replace('``', '"').replace("''", '"') + return t + + with open(f) as fid: + txt = fid.read() + if not exception and 'exception' in txt: + license = identify_license(f, 'exception') + return license + ' with exception' + txt = squeeze(txt) + if 'ApacheLicense' in txt: + # Hmm, do we need to check the text? + return 'Apache-2.0' + elif 'MITLicense' in txt: + # Hmm, do we need to check the text? + return 'MIT' + elif 'BSD-3-ClauseLicense' in txt: + # Hmm, do we need to check the text? + return 'BSD-3-Clause' + elif 'BSD3-ClauseLicense' in txt: + # Hmm, do we need to check the text? + return 'BSD-3-Clause' + elif 'BoostSoftwareLicense-Version1.0' in txt: + # Hmm, do we need to check the text? + return 'BSL-1.0' + elif all([squeeze(m) in txt.lower() for m in bsd3_txt]): + return 'BSD-3-Clause' + elif all([squeeze(m) in txt.lower() for m in bsd3_v1_txt]): + return 'BSD-3-Clause' + elif all([squeeze(m) in txt.lower() for m in bsd2_txt]): + return 'BSD-2-Clause' + elif all([squeeze(m) in txt.lower() for m in bsd3_src_txt]): + return 'BSD-Source-Code' + elif all([squeeze(m) in txt.lower() for m in mit_txt]): + return 'MIT' + else: + raise ValueError('unknown license') + +mit_txt = ['permission is hereby granted, free of charge, to any person ' + 'obtaining a copy of this software and associated documentation ' + 'files (the "software"), to deal in the software without ' + 'restriction, including without limitation the rights to use, copy, ' + 'modify, merge, publish, distribute, sublicense, and/or sell copies ' + 'of the software, and to permit persons to whom the software is ' + 'furnished to do so, subject to the following conditions:', + + 'the above copyright notice and this permission notice shall be ' + 'included in all copies or substantial portions of the software.', + + 'the software is provided "as is", without warranty of any kind, ' + 'express or implied, including but not limited to the warranties of ' + 'merchantability, fitness for a particular purpose and ' + 'noninfringement. in no event shall the authors or copyright holders ' + 'be liable for any claim, damages or other liability, whether in an ' + 'action of contract, tort or otherwise, arising from, out of or in ' + 'connection with the software or the use or other dealings in the ' + 'software.' + ] + +bsd3_txt = ['redistribution and use in source and binary forms, with or without ' + 'modification, are permitted provided that the following conditions ' + 'are met:', + + 'redistributions of source code', + + 'redistributions in binary form', + + 'neither the name', + + 'this software is provided by the copyright holders and ' + 'contributors "as is" and any express or implied warranties, ' + 'including, but not limited to, the implied warranties of ' + 'merchantability and fitness for a particular purpose are disclaimed.', + ] + +# BSD2 is BSD3 without the "neither the name..." clause +bsd2_txt = bsd3_txt[:3] + bsd3_txt[4:] + +# This BSD3 variant leaves "and contributors" out of the last clause of BSD-3, +# which is still valid BSD-3 +v1 = bsd3_txt[4].replace('and contributors', '') +bsd3_v1_txt = bsd3_txt[:3] + [v1] + +# This source variant of BSD-3 leaves the "redistributions in binary form" out +# which is https://spdx.org/licenses/BSD-Source-Code.html +bsd3_src_txt = bsd3_txt[:2] + bsd3_txt[4:] + + +if __name__ == '__main__': + third_party = os.path.join(mydir) + fname = os.path.join(third_party, 'LICENSES_BUNDLED.txt') + with open(fname, 'w') as fid: + create_bundled(third_party, fid) diff --git a/third_party/pybind11 b/third_party/pybind11 index a1cb7c23d3b4..8e5d3d234ef3 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit a1cb7c23d3b47a2bca24b136e5222e497de9575a +Subproject commit 8e5d3d234ef3bbd9efdbba865d1e606d4c5e97bb diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 69e8792fbcb3..c43035b36a6b 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -192,7 +192,7 @@ def is_cuda_dispatch_key(dk: str) -> bool: # Structured kernel generation is only supported for certain key types; # otherwise use old-style def is_structured_dispatch_key(dk: str) -> bool: - return dk in {'CUDA', 'CPU'} + return dk in STRUCTURED_DISPATCH_KEYS # Generates RegisterSchema.cpp. Depending on the selector, either # all schemas are registered, or only some are (in the case of @@ -373,10 +373,10 @@ def gen_structured(self, g: StructuredNativeFunctions) -> List[str]: assert self.dispatch_key not in g.out.dispatch, \ "Do not explicitly specify Meta dispatch key on structured " \ "functions, they will be automatically generated for you" - elif self.dispatch_key not in g.out.dispatch: - return [] elif not is_structured_dispatch_key(self.dispatch_key): return list(mapMaybe(self.gen_unstructured, g.functions())) + elif self.dispatch_key not in g.out.dispatch: + return [] # Inner helper function to close over g # TODO: This function has a lot of similarity with gen_unstructured. If diff --git a/tools/codegen/model.py b/tools/codegen/model.py index 7ecf25801861..53bc0120271f 100644 --- a/tools/codegen/model.py +++ b/tools/codegen/model.py @@ -51,6 +51,8 @@ class UseC10Dispatcher(Enum): full = 0 hacky_wrapper_for_legacy_signatures = 1 +STRUCTURED_DISPATCH_KEYS = {'CUDA', 'CPU'} + # The basic input to the code generation is native_functions.yaml. # The name "native", BTW, comes from the distinction between native # functions and legacy TH functions. The legacy TH functions are gone, @@ -236,7 +238,7 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction': assert isinstance(v, str), e for k in ks.split(","): dispatch[k.strip()] = v - else: + elif not structured and structured_delegate is None: from tools.codegen.api import cpp dispatch['Math'] = cpp.name(func) @@ -305,6 +307,13 @@ def __post_init__(self) -> None: if self.structured or self.structured_delegate: assert self.use_c10_dispatcher is UseC10Dispatcher.full, \ "Structured kernels MUST be use_c10_dispatcher: full; port your argument order" + if self.structured_inherits is not None: + assert self.structured, "structured_inherits must also imply structured: True" + if self.structured_delegate is not None: + for k in STRUCTURED_DISPATCH_KEYS: + assert k not in self.dispatch, \ + f"if structured_delegate, then must not have {k} in dispatch dictionary " \ + "(it is delegated!)" SchemaKind = Enum('SchemaKind', ('functional', 'inplace', 'out')) diff --git a/tools/fast_nvcc/fast_nvcc.py b/tools/fast_nvcc/fast_nvcc.py index fc14191ffd14..7f126e84f7fb 100755 --- a/tools/fast_nvcc/fast_nvcc.py +++ b/tools/fast_nvcc/fast_nvcc.py @@ -337,7 +337,10 @@ async def run_command(command, *, env, deps, gather_data, i, save): Run the command with the given env after waiting for deps. """ for task in deps: - await task + dep_result = await task + # abort if a previous step failed + if 'exit_code' not in dep_result or dep_result['exit_code'] != 0: + return {} if gather_data: t1 = time.monotonic() proc = await asyncio.create_subprocess_shell( @@ -368,7 +371,7 @@ async def run_command(command, *, env, deps, gather_data, i, save): return results -async def run_graph(*, env, commands, graph, gather_data, save): +async def run_graph(*, env, commands, graph, gather_data=False, save=None): """ Return outputs/errors (and optionally time/file info) from commands. """ @@ -391,8 +394,8 @@ def print_command_outputs(command_results): Print captured stdout and stderr from commands. """ for result in command_results: - sys.stdout.write(result['stdout'].decode('ascii')) - sys.stderr.write(result['stderr'].decode('ascii')) + sys.stdout.write(result.get('stdout', b'').decode('ascii')) + sys.stderr.write(result.get('stderr', b'').decode('ascii')) def write_log_csv(command_parts, command_results, *, filename): @@ -401,15 +404,15 @@ def write_log_csv(command_parts, command_results, *, filename): """ tmp_files = [] for result in command_results: - tmp_files.extend(result['files'].keys()) + tmp_files.extend(result.get('files', {}).keys()) with open(filename, 'w', newline='') as csvfile: fieldnames = ['command', 'seconds'] + list(dict.fromkeys(tmp_files)) writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for i, result in enumerate(command_results): command = f'{i} {os.path.basename(command_parts[i][0])}' - row = {'command': command, 'seconds': result['time']} - writer.writerow({**row, **result['files']}) + row = {'command': command, 'seconds': result.get('time', 0)} + writer.writerow({**row, **result.get('files', {})}) def exit_code(results): @@ -417,7 +420,7 @@ def exit_code(results): Aggregate individual exit codes into a single code. """ for result in results: - code = result['exit_code'] + code = result.get('exit_code', 0) if code != 0: return code return 0 diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 7a37ab134dcb..8398b18b040e 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -125,7 +125,7 @@ def should_bind_method(python_func: PythonSignatureNativeFunctionPair) -> bool: 'radd', 'rsub', 'rmul', 'rtruediv', 'rfloordiv', 'rpow', # reverse arithmetic 'and', 'or', 'xor', # logic 'iadd', 'iand', 'idiv', 'ilshift', 'imul', - 'ior', 'irshift', 'isub', 'ixor', # inplace ops + 'ior', 'irshift', 'isub', 'ixor', 'ifloordiv', 'imod', # inplace ops ) comparison_ops = ('eq', 'ne', 'ge', 'gt', 'lt', 'le') unary_ops = ('neg', 'abs', 'invert') @@ -324,6 +324,32 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) - 'is_grad_enabled': ['def is_grad_enabled() -> _bool: ...'], 'nonzero': ['def nonzero(input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: ...', 'def nonzero(input: Tensor, *, as_tuple: bool=...) -> Tensor: ...'], + 'binary_cross_entropy_with_logits': ['def binary_cross_entropy_with_logits(input: Tensor, target: Tensor, ' + 'weight: Optional[Tensor] = None, size_average: Optional[bool] = None, ' + 'reduce: Optional[bool] = None, reduction: str = ..., ' + 'pos_weight: Optional[Tensor] = None) -> Tensor: ...'], + 'cosine_embedding_loss': ['def cosine_embedding_loss(input1: Tensor, input2: Tensor, ' + 'target: Tensor, margin: float = ..., size_average: Optional[bool] = ..., ' + 'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'], + 'ctc_loss': ['def ctc_loss(log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor,' + ' blank: int = ..., reduction: str = ..., zero_infinity: bool = ...) -> Tensor: ...'], + 'hinge_embedding_loss': ['def hinge_embedding_loss(input: Tensor, target: Tensor, margin: float = ...,' + ' size_average: Optional[bool] = ..., reduce: Optional[bool] = ..., ' + 'reduction: str = ...) -> Tensor: ...'], + 'kl_div': ['def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., ' + 'reduce: Optional[bool] = ..., reduction: str = ..., log_target: bool = ...) -> Tensor: ...'], + 'margin_ranking_loss': ['def margin_ranking_loss(input1: Tensor, input2: Tensor, target: Tensor,' + ' margin: float = ..., size_average: Optional[bool] = ..., ' + ' reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'], + 'triplet_margin_loss': ['def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, ' + 'margin: float = ..., p: float = ..., eps: float = ..., swap: bool = ..., ' + 'size_average: Optional[bool] = ..., ' + 'reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...'], + 'dsmm': ['def dsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'], + 'hsmm': ['def hsmm(input: Tensor, mat2: Tensor) -> Tensor: ...'], + 'saddmm': ['def saddmm(input: Tensor, mat1: Tensor, mat2: Tensor, *, beta: Number=1, ' + 'alpha: Number=1, out: Optional[Tensor]=None) -> Tensor: ...'], + 'spmm': ['def spmm(input: Tensor, mat2: Tensor) -> Tensor: ...'], }) for binop in ['mul', 'div', 'true_divide', 'floor_divide']: unsorted_function_hints[binop].append( @@ -382,10 +408,12 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) - 'def __init__(self, size: _size, *, {}) -> None: ...'.format(DEVICE_PARAM), ], 'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."], + '_make_subclass': ["def _make_subclass(cls, data: Tensor, require_grad: _bool = False) -> Tensor: ..."], # clamp has no default values in the Declarations 'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf," " *, out: Optional[Tensor]=None) -> Tensor: ..."], 'clamp_': ["def clamp_(self, min: _float=-inf, max: _float=inf) -> Tensor: ..."], + '__get__': ["def __get__(self, instance, owner=None) -> Tensor: ..."], '__getitem__': ["def __getitem__(self, {}) -> Tensor: ...".format(INDICES)], '__setitem__': ["def __setitem__(self, {}, val: Union[Tensor, Number])" " -> None: ...".format(INDICES)], @@ -402,13 +430,17 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) - 'numpy': ['def numpy(self) -> Any: ...'], 'apply_': ['def apply_(self, callable: Callable) -> Tensor: ...'], 'map_': ['def map_(self, tensor: Tensor, callable: Callable) -> Tensor: ...'], + 'map2_': ['def map2_(self, x: Tensor, y: Tensor, callable: Callable) -> Tensor: ...'], 'storage': ['def storage(self) -> Storage: ...'], + 'storage_type': ['def storage_type(self) -> Storage: ...'], 'type': ['def type(self, dtype: None=None, non_blocking: _bool=False) -> str: ...', 'def type(self, dtype: Union[str, _dtype], non_blocking: _bool=False) -> Tensor: ...', ], 'get_device': ['def get_device(self) -> _int: ...'], 'contiguous': ['def contiguous(self, memory_format=torch.contiguous_format) -> Tensor: ...'], + 'has_names': ['def has_names(self) -> _bool: ...'], 'is_contiguous': ['def is_contiguous(self, memory_format=torch.contiguous_format) -> _bool: ...'], + '_is_view': ['def _is_view(self) -> _bool: ...'], 'is_cuda': ['is_cuda: _bool'], 'is_leaf': ['is_leaf: _bool'], 'is_sparse': ['is_sparse: _bool'], diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 1f4d7a070d53..1d28f733c415 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -27,6 +27,8 @@ class device: type: str # THPDevice_type index: _int # THPDevice_index + def __get__(self, instance, owner=None) -> device: ... + # THPDevice_pynew @overload def __init__(self, device: Union[_device, _int, str]) -> None: ... @@ -249,6 +251,9 @@ def _jit_is_script_object(obj: Any) -> _bool: ... def _last_executed_optimized_graph() -> Graph: ... def parse_type_comment(comment: str) -> Decl: ... def merge_type_from_type_comment(decl: Decl, type_annotation_decl: Decl, is_method: _bool) -> Decl: ... +def parse_ir(input: str) -> Graph: ... +def parse_schema(schema: str) -> FunctionSchema: ... +def get_device(input: Tensor) -> _int: ... def _resolve_type_from_object(obj: Any, range: SourceRange, rcb: ResolutionCallback) -> JitType: ... def _create_module_with_type(ty: JitType) -> ScriptModule: ... def _run_emit_module_hook(m: ScriptModule): ... @@ -415,10 +420,13 @@ class ErrorReport: def call_stack() -> str: ... class CompilationUnit: - def __init__(self) -> None: ... + def __init__(self, lang: str=..., _frames_up: _int=...) -> None: ... def find_function(self, name: str) -> ScriptFunction: ... - def define(self, script: str, rcb: ResolutionCallback): ... + def __getattr__(self, name: str) -> ScriptFunction: ... + def define(self, script: str, rcb: ResolutionCallback=..., _frames_up: _int=...): ... def get_interface(self, name: str) -> InterfaceType: ... + def get_functions(self) -> List[ScriptFunction]: ... + def create_function(self, name: str, graph: Graph, shouldMangle: _bool=...) -> ScriptFunction: ... class ScriptModule: def setattr(self, name: str, value: Any): ... @@ -429,6 +437,7 @@ class ScriptFunction: def __call__(self, *args, **kwargs) -> Tensor: ... def save(self, filename: str, _extra_files: Dict[str, bytes]) -> None: ... def save_to_buffer(self, _extra_files: Dict[str, bytes]) -> bytes: ... + @property def graph(self) -> Graph: ... def inlined_graph(self) -> Graph: ... def schema(self) -> FunctionSchema: ... @@ -502,6 +511,7 @@ def _get_qengine() -> _int: ... # THPModule_qEngine def _set_qengine(qegine: _int) -> None: ... # THPModule_setQEngine def _supported_qengines() -> List[_int]: ... # THPModule_supportedQEngines def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK +def _is_torch_function_enabled() -> _bool: ... # THPModule_isEnabledTorchFunction def _has_torch_function(args: Iterable[Any]) -> _bool: ... # THPModule_has_torch_function def _has_torch_function_unary(Any) -> _bool: ... # THPModule_has_torch_function_unary def _has_torch_function_variadic(*args: Any) -> _bool: ... # THPModule_has_torch_function_variadic @@ -509,6 +519,7 @@ def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_n def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython def _demangle(str) -> str: ... # c10::demangle +def _disabled_torch_function_impl(func: Callable, types: Iterable[Type], args: Tuple, kwargs: Dict) -> Any: ... # THPModule_disable_torch_function # Defined in `valgrind.h` and `callgrind.h` respecitively. def _valgrind_supported_platform() -> _bool: ... # NVALGRIND @@ -640,9 +651,12 @@ class _TensorBase(object): imag: Tensor T: Tensor ndim: _int + output_nr: _int _version: _int _base: Optional[Tensor] + _cdata: _int grad_fn: Any + _grad_fn: Any _grad: Optional[Tensor] _backward_hooks: Optional[Dict[_int, Callable[[Tensor], Optional[Tensor]]]] ${tensor_method_hints} diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 5ac2c0a8315d..83089072883b 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -15,6 +15,14 @@ class BuiltinCommHookType(Enum): def _register_comm_hook(reducer: Reducer, state: Any, comm_hook: Any): ... def _register_builtin_comm_hook(reducer: Reducer, comm_hook_type: BuiltinCommHookType): ... +def _get_ddp_logging_data(reducer: Reducer): ... +def _set_construction_logging_data( + reducer: Reducer, + module_name: str, + device_ids: List[int], + output_device: int, + broadcast_buffers: bool): ... + class _GradBucket: def __init__(self, tensors: List[Tensor]): ... def get_tensors(self) -> List[Tensor]: ... diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 4ccffc2c8362..be287d0a9a3b 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -194,6 +194,9 @@ def get_closure(fn): # functions will be defined at a global scope like MyGlobalClass. In cases # where they are not, it is possible to work around issues by declaring the # values global in the function. +# In Python 3.9 declaring class as global will make it invisible to +# `inspect.getsource`, see https://bugs.python.org/issue42666 . +# This could be worked around by manualy adding it to `global()` dictionary. diff --git a/torch/_python_dispatcher.py b/torch/_python_dispatcher.py new file mode 100644 index 000000000000..9b4a60338596 --- /dev/null +++ b/torch/_python_dispatcher.py @@ -0,0 +1,156 @@ +import re +import torch._C as C + + +""" +PythonDispatcher class is a thin python-binding to C++ dispatcher and it +is designed to show how dispatcher precompute works. In particular, +it shows for a certain op `foo`, what the computed dispatch table looks +like after user register their kernels to certains dispatch keys. + +In the real C++ dispatcher we support many dispatch keys for different +functionalities. For simplicity PythonDispatcher only supports dispatch +keys for a single example of each use case. These use cases are listed below: + +- CPU/AutogradCPU: represents in-tree backends which we usually have dedicated inference & + autograd kernel in pytorch core library. + E.g. CPU, CUDA +- QuantizedCPU/AutogradOther: represents in-tree backends which we usually have backend specific + inference kernels, but they share the same autograd kernel specified in AutogradOther. + E.g. QuantizedCPU, QuantizedCUDA +- XLA/AutogradXLA: represents out-of-tree backends which we don't have either inference or autograd + kernel defined in pytorch core library. Backend owner is responsible for registering both + inference & autograd kernels in their extensions(e.g. torch-xla) for the operators they support. + E.g. XLA, XPU, MLC +- DefaultBackend: alias key mapped to inference kernels of all backends like CPU, CUDA, XLA etc. + Kernels registered to this key MUST work for inference for all backends. +- Autograd: alias key mapped to autograd of all backends like AutogradCPU, AutogradXLA, AutogradOther. + Kernels registered to this key MUST work for autograd for all backends. +- Math: alias key Math = DefaultBackend + Autograd + Kernels registered to this key MUST work for both inference + autograd for all backends. + +Note we only allow registrations to alias keys inside pytorch core library. E.g you shouldn't register +a Math or DefaultBackend kernel from torch-xla extension, instead you should upstream the kernel into +pytorch/pytorch repo so that it's available for all backends and continuously tested even without the extension. + +Usage: + dispatcher = PythonDispatcher() + dispatcher.register(["CPU", "XLA", "Math"]) + print(dispatcher.dispatchTable()) # This tells you exactly which kernel is used for certain backend. + # For more debugging information + # print(dispatcher.keys()) + # print(dispatcher.registrations()) + # print(dispatcher.rawRegistrations()) + # print(dispatcher.rawDispatchTable()) +PythonDispatcher calls C++ dispatcher under the hood for to precompute dispatch table. +This file only provides the simplified API for developers, revelant test code is located in +test/test_dispatch.py +""" +class PythonDispatcher: + namespace = "__test__" + name = "foo" + runtime_keys = [ + "CPU", "AutogradCPU", + "QuantizedCPU", "AutogradOther", + "XLA", "AutogradXLA", + ] + alias_keys = [ + "DefaultBackend", + "Autograd", + "Math", + ] + supported_keys = runtime_keys + alias_keys + + def __init__(self): + C._dispatch_check_invariants(self.name) # type: ignore[attr-defined] + self.ref = C._dispatch_library("FRAGMENT", self.namespace, "") # type: ignore[attr-defined] + self.ref.def_("foo(Tensor x) -> Tensor") + + """ + Returns a list of dispatch keys supported by PythonDispatcher. + You can register kernels to these keys. + """ + def keys(self): + return self.supported_keys + + """ + Register kernels to the target dispatchKeys. + dispatchKeys(list[str]): a list of dispatch keys that you want to register + your own kernel. Note that you don't need to write the kernel yourself in + this PythonDispatcher.E.g. for CPU key, a kernel(e.g fn_CPU for CPU) is + automatically generated and registered. + """ + def register(self, dispatchKeys): + # Overriden is not supported and triggers a warning in C++ dispatcher. + if len(set(dispatchKeys)) != len(dispatchKeys): + raise RuntimeError(f"Overriden is not allowed but found duplicates in {dispatchKeys}.") + # We currently forbid this in codegen instead of C++ dispatcher. + if 'Math' in dispatchKeys and 'DefaultBackend' in dispatchKeys: + raise RuntimeError("Registration to both Math and DefaultBackend is not allowed.") + for key in dispatchKeys: + if key not in self.supported_keys: + raise RuntimeError(f"{key} is not supported, please select a dispatch key in {self.supported_keys}.") + self.ref.impl_t_t("foo", dispatch=key, debug="fn_" + key) + + """ + Helper function to format (key, kernel). + """ + def _format_line(self, key, kernel): + return "{:<15} {}\n".format(key, kernel) + + """ + Helper function to print a table header. + """ + def _format_header(self, header): + s = f""" +{header} +""" + s += self._format_line("key", "kernel") + s += "---------------------------\n" + return s + + """ + Returns raw output of all registration info for debugging only. + Use registrations() for a simplified version. + """ + def rawRegistrations(self): + return C._dispatch_dump("{}::{}".format(self.namespace, self.name)) # type: ignore[attr-defined] + + """ + Returns raw output of computed dispatch table for debugging only. + Use dispatchTable() for a simplified version. + """ + def rawDispatchTable(self): + return C._dispatch_dump_table("{}::{}".format(self.namespace, self.name)) # type: ignore[attr-defined] + + """ + Returns a table(str) including all the registrations from users. + Note this includes registrations to both runtime keys and alias keys. + """ + def registrations(self): + output = self._format_header("Registered Kernels") + state = self.rawRegistrations() + state_entries = state.split('\n') + for line in state_entries: + first = line.split(":")[0] + if any(first.startswith(k) for k in self.supported_keys): + kernel = line.split("::")[0].split(" ")[1] + output += self._format_line(first, kernel) + return output + + """ + Returns the computed dispatch table(str). Note this only include + runtime keys, registrations to alias keys have been decoded to their + mapped runtime keys. + """ + def dispatchTable(self): + output = self._format_header("Computed Dispatch Table") + table = self.rawDispatchTable() + table_entries = table.split('\n') + regex = re.compile(r"registered at .*FallbackKernel\.cpp.*(\[)") + for line in table_entries: + k = line.split(":")[0] + if k in self.runtime_keys: + entry = regex.sub('[', line) + output += self._format_line(k, entry.split(": ")[1]) + return output diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index b787e94de663..93b0d10393f0 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -8312,53 +8312,52 @@ def merge_dicts(*dicts): svd(input, some=True, compute_uv=True, *, out=None) -> (Tensor, Tensor, Tensor) Computes the singular value decomposition of either a matrix or batch of -matrices :attr:`input`." The singular value decomposition is represented as a -namedtuple ``(U, S, V)``, such that :math:`input = U \mathbin{@} diag(S) \times -V^T`, where :math:`V^T` is the transpose of ``V``. If :attr:`input` is a batch -of tensors, then ``U``, ``S``, and ``V`` are also batched with the same batch -dimensions as :attr:`input`. +matrices :attr:`input`. The singular value decomposition is represented as a +namedtuple (`U,S,V`), such that +:attr:`input` = `U` diag(`S`) `Vá´´`, +where `Vá´´` is the transpose of `V` for the real-valued inputs, +or the conjugate transpose of `V` for the complex-valued inputs. +If :attr:`input` is a batch of tensors, then `U`, `S`, and `V` are also +batched with the same batch dimensions as :attr:`input`. If :attr:`some` is ``True`` (default), the method returns the reduced singular value decomposition i.e., if the last two dimensions of :attr:`input` are -``m`` and ``n``, then the returned `U` and `V` matrices will contain only -:math:`min(n, m)` orthonormal columns. +`m` and `n`, then the returned `U` and `V` matrices will contain only +min(`n, m`) orthonormal columns. If :attr:`compute_uv` is ``False``, the returned `U` and `V` will be -zero-filled matrices of shape :math:`(m \times m)` and :math:`(n \times n)` +zero-filled matrices of shape `(m × m)` and `(n × n)` respectively, and the same device as :attr:`input`. The :attr:`some` -argument has no effect when :attr:`compute_uv` is False. +argument has no effect when :attr:`compute_uv` is ``False``. -The dtypes of ``U`` and ``V`` are the same as :attr:`input`'s. ``S`` will +Supports input of float, double, cfloat and cdouble data types. +The dtypes of `U` and `V` are the same as :attr:`input`'s. `S` will always be real-valued, even if :attr:`input` is complex. -.. warning:: ``torch.svd`` is deprecated. Please use ``torch.linalg.`` - :func:`~torch.linalg.svd` instead, which is similar to NumPy's +.. warning:: :func:`torch.svd` is deprecated. Please use + :func:`torch.linalg.svd` instead, which is similar to NumPy's ``numpy.linalg.svd``. -.. note:: **Differences with** ``torch.linalg.`` :func:`~torch.linalg.svd`: +.. note:: Differences with :func:`torch.linalg.svd`: - * :attr:`some` is the opposite of ``torch.linalg.`` - :func:`~torch.linalg.svd`'s :attr:`full_matricies`. Note that + * :attr:`some` is the opposite of + :func:`torch.linalg.svd`'s :attr:`full_matricies`. Note that default value for both is ``True``, so the default behavior is effectively the opposite. - * it returns ``V``, whereas ``torch.linalg.`` - :func:`~torch.linalg.svd` returns ``Vh``. The result is that - when using ``svd`` you need to manually transpose - ``V`` in order to reconstruct the original matrix. + * :func:`torch.svd` returns `V`, whereas :func:`torch.linalg.svd` returns `Vá´´`. - * If :attr:`compute_uv=False`, it returns zero-filled tensors for - ``U`` and ``Vh``, whereas :meth:`~torch.linalg.svd` returns + * If :attr:`compute_uv=False`, :func:`torch.svd` returns zero-filled tensors for + ``U`` and ``Vh``, whereas :func:`torch.linalg.svd` returns empty tensors. -Supports real-valued and complex-valued input. - .. note:: The singular values are returned in descending order. If :attr:`input` is a batch of matrices, then the singular values of each matrix in the batch is returned in descending order. .. note:: The implementation of SVD on CPU uses the LAPACK routine `?gesdd` (a divide-and-conquer - algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine - `gesdd` as well. + algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the cuSOLVER routines + `gesvdj` and `gesvdjBatched` on CUDA 10.1.243 and later, and uses the MAGMA routine `gesdd` + on earlier versions of CUDA. .. note:: The returned matrix `U` will be transposed, i.e. with strides :code:`U.contiguous().transpose(-2, -1).stride()`. @@ -8372,15 +8371,19 @@ def merge_dicts(*dicts): .. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. - .. note:: With the complex-valued input the backward operation works correctly only for gauge invariant loss functions. Please look at `Gauge problem in AD`_ for more details. +.. note:: Since `U` and `V` of an SVD is not unique, each vector can be multiplied by + an arbitrary phase factor :math:`e^{i \phi}` while the SVD result is still correct. + Different platforms, like Numpy, or inputs on different device types, may produce different + `U` and `V` tensors. + Args: - input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more - batch dimensions consisting of :math:`m \times n` matrices. + input (Tensor): the input tensor of size `(*, m, n)` where `*` is zero or more + batch dimensions consisting of `(m × n)` matrices. some (bool, optional): controls whether to compute the reduced or full decomposition, and - consequently the shape of returned ``U`` and ``V``. Defaults to True. + consequently the shape of returned `U` and `V`. Defaults to True. compute_uv (bool, optional): option whether to compute `U` and `V` or not. Defaults to True. Keyword args: diff --git a/torch/autograd/functional.py b/torch/autograd/functional.py index 70961cef9744..02fd0c47de40 100644 --- a/torch/autograd/functional.py +++ b/torch/autograd/functional.py @@ -214,8 +214,11 @@ def vjp(func, inputs, v=None, create_graph=False, strict=False): Defaults to ``False``. Returns: - vjp (tuple of Tensors or Tensor): result of the dot product with - the same shape as the inputs. + output (tuple): tuple with: + func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` + + vjp (tuple of Tensors or Tensor): result of the dot product with + the same shape as the inputs. Example: @@ -298,8 +301,11 @@ def jvp(func, inputs, v=None, create_graph=False, strict=False): Defaults to ``False``. Returns: - jvp (tuple of Tensors or Tensor): result of the dot product with - the same shape as the output. + output (tuple): tuple with: + func_output (tuple of Tensors or Tensor): output of ``func(inputs)`` + + jvp (tuple of Tensors or Tensor): result of the dot product with + the same shape as the output. Example: diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 27dd4ccce649..38248ca93d4e 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1949,21 +1949,16 @@ Tensor svd_backward(const std::vector &grads, const T auto gsigma = grads[1]; auto u = raw_u; - // Currently torch.svd for complex dtypes returns the conjugate of V, - // while the backward formula is derived with just V (without the conjugation) - // therefore here we need to conjugate the V output of SVD and grads[2]. - // Once https://github.com/pytorch/pytorch/issues/45821 is resolved - // extra .conj(), that are marked below in the code, shall be removed. - auto v = raw_v.conj(); // TODO: remove .conj() + auto v = raw_v; auto gu = grads[0]; - auto gv = grads[2].conj(); // TODO: remove .conj() + auto gv = grads[2]; if (!some) { // We ignore the free subspace here because possible base vectors cancel // each other, e.g., both -v and +v are valid base for a dimension. // Don't assume behavior of any particular implementation of svd. u = raw_u.narrow(-1, 0, k); - v = raw_v.narrow(-1, 0, k).conj(); // TODO: remove .conj() + v = raw_v.narrow(-1, 0, k); if (gu.defined()) { gu = gu.narrow(-1, 0, k); } diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index a9c7d709466e..2b19536c6baf 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -48,6 +48,10 @@ Engine& PythonEngine::get_python_engine() { return engine; } +#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 9 +#define IS_PYTHON_3_9_PLUS +#endif + void PythonEngine::thread_init(int device, const std::shared_ptr& ready_queue, bool should_increment) { // Increment thread usage count before acquiring the GIL if (should_increment) { @@ -56,7 +60,11 @@ void PythonEngine::thread_init(int device, const std::shared_ptr& re // Create a PyThreadState, but release the GIL. This lets pybind11::gil_scoped_acquire calls // inside thread_main acquire the GIL without having to create a new // PyThreadState each time. +#ifdef IS_PYTHON_3_9_PLUS + auto gil = std::make_unique(); +#else pybind11::gil_scoped_acquire gil; +#endif pybind11::gil_scoped_release no_gil; Engine::thread_init(device, ready_queue, false); @@ -64,6 +72,15 @@ void PythonEngine::thread_init(int device, const std::shared_ptr& re // Decrement the count during shutdown if we incremented earlier. decrement_non_reentrant_thread_count(); } + +#ifdef IS_PYTHON_3_9_PLUS + // Do not call PyEval_RestoreThread, PyThreadState_[Clear|DeleteCurrent] if runtime is finalizing + if (_Py_IsFinalizing()) { + no_gil.disarm(); + // TODO: call disarm rather than leak gil_scoped_acquired once PyThreadState_Clear can safely be called from finalize + gil.release(); + } +#endif } void PythonEngine::thread_on_exception( diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index fc798e537b2f..9502b3236874 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -180,7 +180,29 @@ PyObject* c10d_init(PyObject* _unused, PyObject* noargs) { "_register_builtin_comm_hook", &_register_builtin_comm_hook, py::arg("reducer"), - py::arg("comm_hook_type")); + py::arg("comm_hook_type")) + .def( + "_set_construction_logging_data", + []( + ::c10d::Reducer& reducer, + const std::string& module_name, + const std::vector& device_ids, + int output_device, + bool broadcast_buffers) -> void { + reducer.set_construction_logging_data( + module_name, device_ids, output_device, broadcast_buffers); + }, + py::arg("reducer"), + py::arg("module_name"), + py::arg("device_ids"), + py::arg("output_device"), + py::arg("broadcast_buffers")) + .def( + "_get_ddp_logging_data", + [](::c10d::Reducer& reducer) -> c10::DDPLoggingData { + return reducer.get_ddp_logging_data(); + }, + py::arg("reducer")); shared_ptr_class_<::c10d::GradBucket>(module, "_GradBucket") .def( @@ -1159,6 +1181,18 @@ that adds a prefix to each key inserted to the store. Note that ``fut.done()`` returns only whether the operation has been enqueued on the GPU. )"); + py::class_(module, "DDPLoggingData") + .def(py::init<>()) + .def_readwrite("world_size", &c10::DDPLoggingData::world_size) + .def_readwrite("rank", &c10::DDPLoggingData::rank) + .def_readwrite("module_name", &c10::DDPLoggingData::module_name) + .def_readwrite("device_ids", &c10::DDPLoggingData::device_ids) + .def_readwrite("output_device", &c10::DDPLoggingData::output_device) + .def_readwrite("broadcast_buffers", &c10::DDPLoggingData::broadcast_buffers) + .def_readwrite("bucket_cap_mb", &c10::DDPLoggingData::bucket_cap_mb) + .def_readwrite("find_unused_parameters", &c10::DDPLoggingData::find_unused_parameters) + .def_readwrite("gradient_as_bucket_view", &c10::DDPLoggingData::gradient_as_bucket_view); + module.def( "_compute_bucket_assignment_by_size", &::c10d::compute_bucket_assignment_by_size, @@ -1668,7 +1702,6 @@ static const auto DistributedC10dFrontendTorchBind = .def( "get_name_of_process_group", &::c10d::DistributedC10d::getNameOfProcessGroup); - } // namespace // c10d methods on torch._C diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 52294c38ad11..fe4278a115db 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -230,6 +230,7 @@ void parseMethods( class BytecodeDeserializer final { public: explicit BytecodeDeserializer(std::unique_ptr reader); + mobile::Module deserialize(c10::optional device); mobile::Module deserialize( c10::optional device, ExtraFilesMap& extra_files); @@ -274,6 +275,12 @@ mobile::Module BytecodeDeserializer::deserialize( std::string(static_cast(meta_ptr.get()), meta_size); } } + return deserialize(device); +} + +mobile::Module BytecodeDeserializer::deserialize( + c10::optional device) { + device_ = device; auto mcu = std::make_shared(); // bvals can have 2 possible formats: diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 30689401102b..9ce62442f1e1 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -264,7 +264,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, c10::optional N) { case TypeKind::AnyEnumType: break; case TypeKind::ComplexDoubleType: - AT_ASSERT(false); + AT_ASSERT(false); case TypeKind::EnumType: EnumTypePtr enum_type = type->expect(); py::object py_obj = py::reinterpret_borrow(obj); diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index f7dd766d5da7..356c91f2a03b 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -51,7 +51,7 @@ namespace jit { using ::c10::Argument; using ::c10::FunctionSchema; -using ResolutionCallback = std::function; +using ResolutionCallback = std::function; using FunctionDefaults = std::unordered_map; using ClassMethodDefaults = std::unordered_map; @@ -710,6 +710,22 @@ void extra_files_to_python(const ExtraFilesMap& m, const py::dict& pydict) { } } +void pyCompilationUnitDefine( + CompilationUnit& cu, + const std::string& src, + const ResolutionCallback* rcb, + const uint32_t _frames_up) { + if (rcb && *rcb) { + cu.define(c10::nullopt, src, pythonResolver(*rcb), nullptr); + } else { + py::object py_default_rcb = + py::module::import("torch._jit_internal") + .attr("createResolutionCallbackFromFrame")(_frames_up); + auto default_rcb = py_default_rcb.cast(); + cu.define(c10::nullopt, src, pythonResolver(default_rcb), nullptr); + } +} + void initJitScriptBindings(PyObject* module) { auto m = py::handle(module).cast(); @@ -1114,21 +1130,72 @@ void initJitScriptBindings(PyObject* module) { py::class_>( m, "CompilationUnit") - .def(py::init<>()) + .def( + py::init([](const std::string& lang, const uint32_t _frames_up) { + auto cu = std::make_shared(); + if (lang.size() > 0) { + pyCompilationUnitDefine(*cu, lang, nullptr, _frames_up); + } + return cu; + }), + py::arg("lang") = "", + py::arg("_frames_up") = 0) + .def( "find_function", [](std::shared_ptr self, const std::string& name) { - auto& fn = self->get_function(QualifiedName(name)); - return StrongFunctionPtr(std::move(self), &fn); + auto fn = self->find_function(QualifiedName(name)); + if (fn) { + return c10::optional( + StrongFunctionPtr(std::move(self), fn)); + } else { + return c10::optional(c10::nullopt); + } + }) + .def( + "__getattr__", + [](std::shared_ptr self, const std::string& name) { + auto fn = self->find_function(QualifiedName(name)); + if (fn) { + return StrongFunctionPtr(std::move(self), fn); + } else { + throw AttributeError( + "'CompilationUnit' has no attribute '%s'", name.c_str()); + } + }) + .def( + "get_functions", + [](const std::shared_ptr& self) { + auto raw_functions = self->get_functions(); + std::vector functions; + functions.reserve(raw_functions.size()); + for (auto fn : raw_functions) { + if (fn) { + functions.emplace_back(self, fn); + } + } + return functions; }) .def("set_optimized", &CompilationUnit::set_optimized) .def( "define", - [](CompilationUnit& cu, - const std::string& src, - const ResolutionCallback& rcb) { - cu.define(c10::nullopt, src, pythonResolver(rcb), nullptr); - }) + pyCompilationUnitDefine, + py::arg("src"), + py::arg("rcb") = nullptr, + py::arg("_frames_up") = 0) + .def( + "create_function", + [](std::shared_ptr& self, + const std::string& qualified_name, + std::shared_ptr graph, + bool should_mangle) { + Function* fn = self->create_function( + qualified_name, std::move(graph), should_mangle); + return StrongFunctionPtr(std::move(self), fn); + }, + py::arg("qualified_name"), + py::arg("graph"), + py::arg("should_mangle") = false) .def( "get_interface", [](const std::shared_ptr& self, diff --git a/torch/csrc/jit/runtime/static/ops.cpp b/torch/csrc/jit/runtime/static/ops.cpp index 42d7b99da320..fdd9fbb31fbc 100644 --- a/torch/csrc/jit/runtime/static/ops.cpp +++ b/torch/csrc/jit/runtime/static/ops.cpp @@ -281,8 +281,8 @@ REGISTER_OPERATOR_FUNCTOR_OPT( true, [](Node* n) -> SROperator { return [](ProcessedNode* p_node) { - auto weight = p_node->Input(0).toTensor(); - auto indices = p_node->Input(1).toTensor(); + auto& weight = p_node->Input(0).toTensor(); + auto& indices = p_node->Input(1).toTensor(); auto offsets = p_node->Input(2).toOptional(); auto pruned_weights = p_node->Input(5).toBool(); auto per_sample_weights = p_node->Input(6).toOptional(); @@ -293,7 +293,7 @@ REGISTER_OPERATOR_FUNCTOR_OPT( p_node->Output(0) = at::empty({0}, weight.options().dtype(at::kFloat)); } - auto out_t = p_node->Output(0).toTensor(); + auto& out_t = p_node->Output(0).toTensor(); fastResizeToZero(out_t); return at::native::embedding_bag_byte_rowwise_offsets_out( out_t, @@ -312,13 +312,13 @@ REGISTER_OPERATOR_FUNCTOR_OPT( // The out variant takes precedence over native REGISTER_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator { return [](ProcessedNode* p_node) { - auto self = p_node->Input(0).toTensor(); // self + auto& self = p_node->Input(0).toTensor(); // self auto dim = p_node->Input(1).toInt(); // dim int64_t start = 0; if (p_node->Input(2).isScalar()) { start = p_node->Input(2).toInt(); } else { - auto t = p_node->Input(2).toTensor(); + auto& t = p_node->Input(2).toTensor(); start = t.item(); } auto length = p_node->Input(3).toInt(); // length @@ -326,7 +326,7 @@ REGISTER_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROperator { if (p_node->Output(0).isNone()) { p_node->Output(0) = create_empty_from(self); } - auto output = p_node->Output(0).toTensor(); + auto& output = p_node->Output(0).toTensor(); output.resize_({0}); at::native::narrow_copy_dense_cpu_out(self, dim, start, length, output); }; diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index e7cf55f7d194..f363fe73f1e9 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -80,7 +80,8 @@ void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { case AnyEnumType::Kind: // no op, there is nothing to tag break; - // TODO(@anjali411): Implement serialization/deserialization for complex numbers + // TODO(@anjali411): Implement serialization/deserialization for complex + // numbers case ComplexDoubleType::Kind: case EnumType::Kind: // TODO(gmagogsfm): Implement serialization/deserialization of Enum. diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp index c922797b50dc..d52227d7c04a 100644 --- a/torch/csrc/jit/tensorexpr/expr.cpp +++ b/torch/csrc/jit/tensorexpr/expr.cpp @@ -212,14 +212,12 @@ ExprHandle fast_log(const ExprHandle& v) { t = mlaf(t, x2, 0.666666686534881591796875); t = mlaf(t, x2, 2.0); x = x * t + FloatImm::make(0.693147180559945286226764) * e; - x = IfThenElse::make( - v < FloatImm::make(0), - FloatImm::make(std::numeric_limits::quiet_NaN()), - x); - x = IfThenElse::make( - v == FloatImm::make(0), - FloatImm::make(-std::numeric_limits::infinity()), - x); + + auto zero = FloatImm::make(0); + auto nan = FloatImm::make(std::numeric_limits::quiet_NaN()); + auto neg_inf = FloatImm::make(-std::numeric_limits::infinity()); + x = CompareSelect::make(v, zero, nan, x, kLT); + x = CompareSelect::make(v, zero, neg_inf, x, kEQ); return x; } diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp index 043f320ed3f5..fb93851974f6 100644 --- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp +++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp @@ -1,6 +1,9 @@ #include #include #include +#ifdef USE_CUDA +#include +#endif #include #include #include @@ -427,6 +430,14 @@ void initTensorExprBindings(PyObject* module) { "simplify", &tensorexpr::LoopNest::simplify, py::return_value_policy::reference) + .def( + "set_GPU_block_index", + &tensorexpr::LoopNest::setGPUBlockIndex, + py::return_value_policy::reference) + .def( + "set_GPU_thread_index", + &tensorexpr::LoopNest::setGPUThreadIndex, + py::return_value_policy::reference) .def( "__str__", [](const tensorexpr::LoopNest& self) { @@ -479,6 +490,12 @@ void initTensorExprBindings(PyObject* module) { cg = new tensorexpr::LLVMCodeGen(stmt, args); #else throw std::runtime_error("PyTorch not compiled with LLVM support!"); +#endif + } else if (name == "cuda") { +#ifdef USE_CUDA + cg = new tensorexpr::CudaCodeGen(stmt, args); +#else + throw std::runtime_error("PyTorch not compiled with CUDA support!"); #endif } else { cg = new tensorexpr::SimpleIREvaluator(stmt, args); diff --git a/torch/csrc/utils/tensor_flatten.h b/torch/csrc/utils/tensor_flatten.h index 3a6bc16a9eb4..cb54bbb53a75 100644 --- a/torch/csrc/utils/tensor_flatten.h +++ b/torch/csrc/utils/tensor_flatten.h @@ -20,8 +20,16 @@ inline std::vector unflatten_dense_tensors(const at::Tensor& flat, a size_t offset = 0; for (const auto & tensor : tensors) { auto numel = tensor.numel(); - outputs.push_back(flat.narrow(0, offset, numel).view(tensor.sizes())); - offset += numel; + // If unflatten an empty tensor, create a new empty tensor using + // flat tensor Options. + // This can avoid the unflattened empty tensor to share the same storage + // with other unflatten tensors. + if (numel == 0) { + outputs.push_back(at::empty({0}, flat.options())); + } else { + outputs.push_back(flat.narrow(0, offset, numel).view(tensor.sizes())); + offset += numel; + } } return outputs; } diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 1335fe9d1d6d..7fe880c83900 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -34,6 +34,8 @@ def is_available(): _broadcast_coalesced, _compute_bucket_assignment_by_size, _test_python_store, + _set_construction_logging_data, + _get_ddp_logging_data ) if sys.platform != 'win32': from torch._C._distributed_c10d import ( diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index bf9dbe8ad652..32ef2f7c35c2 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -20,7 +20,7 @@ def _orthogonalize(matrix, epsilon=1e-8): # This epsilon is not needed if the input matrix covers the gradients of at least one entire layer in the neural network. if epsilon == 0: # Note that col ** 2 can underflow/overflow if we use FP16. - # May need to consder multiplying a scaling factor and divding it later, or using bfloat16 isntead. + # May need to consider multiplying a scaling factor and dividing it later, or using bfloat16 instead. col /= torch.sqrt(torch.sum(col ** 2)) else: col /= torch.sqrt(torch.sum(col ** 2)) + epsilon @@ -33,32 +33,57 @@ def _orthogonalize(matrix, epsilon=1e-8): class PowerSGDState(object): __slots__ = [ "process_group", + # The two fields below are the configs that usually need to be tuned by the user. "matrix_approximation_rank", + "start_powerSGD_iter", + # The two fields below are the configs that usually need to be turned on for performance. "use_error_feedback", "warm_start", + # The fields below are not configs. "rng", "error_dict", "p_memory_dict", "q_memory_dict", + "iter", ] def __init__( self, process_group, matrix_approximation_rank=1, + start_powerSGD_iter=10, use_error_feedback=True, warm_start=True, random_seed=0, ): self.process_group = process_group - # The low rank for matrix approximation. - # Typically only 1 or 2 is used. See https://arxiv.org/pdf/1905.13727.pdf. + # The low rank for matrix approximation controls the size of compressed low-rank tensors, + # which determines the computation ratio. + # Typically only a small value 1-4 is used. + # For some NLP tasks (as shown in Appendix D of the original paper + # https://arxiv.org/pdf/1905.13727.pdf, the rank value has been increased to 32. + # A high rank value will increase the computation costs of compression exponentially. + # A good choice depends on how much extra computation can be hidden by the dominating communication costs. self.matrix_approximation_rank = matrix_approximation_rank + # This defers PowerSGD compression util step 'start_powerSGD_iter', + # and vanilla allreduce runs before step 'start_powerSGD_iter'. + # This hybrid scheme of vanilla allreduce + PowerSGD can have two advantages: + # 1) It turns out that PowerSGD may lead to a non-trivial accuracy loss, + # even if the matrix approximation rank is increased to a large value. + # To mitigate the accuracy loss, a simple yet effective way is mixing vanilla allreduce + # (or a more convervative compression such as FP16 compression) with PowerSGD. + # 2) There is an internal optimization of rebuilding buckets process in DDP, + # in order to save the memory space. + # This step takes place after the first iteration. + # However, this means that the shape of input bucketized tensors is subject to change, + # which will complicate the implementations of error feedback and warm-up. + # Running vanilla allreduce in the first few iterations can avoid this complexity. + self.start_powerSGD_iter = start_powerSGD_iter # Error feedback is usually crucial for both for convergence and generalization, # because PowerSGD is a biased compressor, # i.e., compressing and decompressing a random gradient does not yield the original in expectation. # This mechanism requires a temporary copy of the input gradients, - # so it increases the peak memory consumption by the size of gradient tensor. + # so it increases the peak memory consumption by the size of the gradient tensor. # However, if the target matrices are known to be exactly low-ranked (instead of just low stable rank), # sometimes it is possible to converge to the optima without error feedback. # See: http://proceedings.mlr.press/v54/yurtsever17a/yurtsever17a.pdf @@ -80,6 +105,29 @@ def __init__( self.error_dict = {} self.p_memory_dict = {} self.q_memory_dict = {} + # Iteration/step in the training loop. + self.iter = 0 + + logging.info( + "PowerSGD config: matrix_approximation_rank = {}; " + "start_powerSGD_iter = {}; use_error_feedback = {}; warm_start = {}.".format( + self.matrix_approximation_rank, + self.start_powerSGD_iter, + self.use_error_feedback, + self.warm_start, + ) + ) + + def maybe_increase_iter(self, bucket): + # Since bucket 0 is the last bucket to allreduce in an iteration. + # Only increase `iter` when bucket 0 is processed. + if bucket.get_index() == 0: + self.iter += 1 + + if self.iter == self.start_powerSGD_iter: + logging.info( + "Starting to apply PowerSGD after {} iterations.".format(self.iter) + ) def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: @@ -93,20 +141,20 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: 2) Handles rank-1 tensors by allreducing them without compression: 2.1) Allocate contiguous memory for those rank-1 tensors, and allreduces all the rank-1 tensors as a batch, without compression; - 2.2) Copies the indvidual rank-1 tensors from the contiguous memory back to the input tensor. + 2.2) Copies the individual rank-1 tensors from the contiguous memory back to the input tensor. 3) Handles high-rank tensors by PowerSGD compression: 3.1) For each high-rank tensor M, creates two low-rank tensors P and Q for decomposing M, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; 3.2) Computes each P in Ps, which is equal to MQ; 3.3) Allreduces Ps as a batch; - 3.4) Orthogonizes each P in Ps; + 3.4) Orthogonalizes each P in Ps; 3.5) Computes each Q in Qs, which is approximately equal to M^TP; 3.6) Allreduces Qs as a batch; 3.7) Computes each M among all the high-rank tensors, which is approximately equal to PQ^T. TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration -- one left multiplication and one right multiplication. - For warm start, can take one such step at a time, and alternate between them. + For warm-start, can take one such step at a time, and alternate between them. Args: state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. @@ -118,7 +166,7 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: Future handler of the communication, which updates the gradients in place. Example:: - state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) + state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10) >>> ddp_model.register_comm_hook(state, powerSGD_hook) """ process_group = state.process_group @@ -127,6 +175,20 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: # The input tensor is a flattened 1D tensor. input_tensor = bucket.get_tensors()[0] + + # Run vanilla allreduce in the first `start_powerSGD_iter` iterations. + if state.iter < state.start_powerSGD_iter: + fut = dist.all_reduce( + input_tensor, group=group_to_use, async_op=True + ).get_future() + + def div_callback(fut): + return [fut.value()[0].div_(world_size)] + + state.maybe_increase_iter(bucket) + return fut.then(div_callback) + + # Apply PowerSGD after `start_powerSGD_iter` iterations. device = input_tensor.device dtype = input_tensor.dtype @@ -243,7 +305,7 @@ def powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: with torch.random.fork_rng(devices=[]): # Fork this RNG to avoid changing the seed globally and affecting the random sampling anywhere else in the training. # The seed makes sure that the initial random values are the same across all the DDP replicas. - # Such seed should differ at every step. + # This seed should differ at every step. # Since it is very slow to fork RNG state across all the CUDA devices, # only fork on CPU and then move the generated tensor to the CUDA device (by overwriting q). torch.manual_seed(state.rng.randint(1_000_000_000)) @@ -275,7 +337,7 @@ def unpack_rank1_tensors_and_allreduce_ps(fut): tensor.copy_(rank1_tensors_memory[idx : idx + tensor.shape[0]]) idx += tensor.shape[0] - # Since these Ps will be orthogonized later, no need to divide them by world size. + # Since these Ps will be orthogonalized later, no need to divide them by world size. return [ dist.all_reduce( state.p_memory_dict[bucket_index], group=group_to_use, async_op=True @@ -317,6 +379,8 @@ def decompress(fut): state.p_memory_dict.clear() state.q_memory_dict.clear() + state.maybe_increase_iter(bucket) + return [input_tensor] return ( @@ -339,7 +403,7 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; 2) Computes P, which is equal to MQ; 3) Allreduces P; - 4) Orthogonizes P; + 4) Orthogonalizes P; 5) Computes Q, which is approximately equal to M^TP; 6) Allreduces Q; 7) Computes M, which is approximately equal to PQ^T. @@ -347,7 +411,7 @@ def batched_powerSGD_hook(state: PowerSGDState, bucket) -> torch.futures.Future: TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration -- one left multiplication and one right multiplication. - For warm start, can take one such step at a time, and alternate between them. + For warm-start, can take one such step at a time, and alternate between them. Args: state (PowerSGDState): State information to configure the compression rate and support error feedback, warm start, etc. @@ -430,7 +494,7 @@ def create_low_rank_tensor(fill_random_values, rng): # Fork this RNG to avoid changing the seed globally and affecting the random sampling # anywhere else in the training. # The seed makes sure that the initial random values are the same across all the DDP replicas. - # Such seed should differ at every step. + # This seed should differ at every step. # Since it is very slow to fork RNG state across all the CUDA devices, # only fork on CPU and then move the generated tensor to the CUDA device. torch.manual_seed(rng.randint(1_000_000_000)) diff --git a/torch/distributed/nn/__init__.py b/torch/distributed/nn/__init__.py index c2fa8d773bb6..06af28d20868 100644 --- a/torch/distributed/nn/__init__.py +++ b/torch/distributed/nn/__init__.py @@ -1 +1,2 @@ from .api.remote_module import RemoteModule +from .functional import * diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py new file mode 100644 index 000000000000..feb69df4984c --- /dev/null +++ b/torch/distributed/nn/functional.py @@ -0,0 +1,263 @@ +import torch +from torch.autograd import Function +import torch.distributed as dist + + +def broadcast(tensor, src, group=dist.group.WORLD): + """ + Broadcasts the tensor to the whole group. + + ``tensor`` must have the same number of elements in all processes + participating in the collective. + + Arguments: + tensor (Tensor): Data to be sent if ``src`` is the rank of current + process. + src (int): Source rank. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Received tensor from the broadcast op. + + """ + return _Broadcast.apply(src, group, tensor) + + +def gather(tensor, dst=0, group=dist.group.WORLD): + """ + Gathers a list of tensors in a single process. + + Arguments: + tensor (Tensor): Input tensor. + dst (int, optional): Destination rank (default is 0). + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple[Tensor]: List of appropriately-sized tensors with the gathered data. + """ + return _Gather.apply(dst, group, tensor) + + +def scatter(tensors, src=0, group=dist.group.WORLD): + """ + Scatters a list of tensors to all processes in a group. + + Each process will receive exactly one tensor and store its data in the + ``tensor`` argument. + + Arguments: + tensors (list[Tensor]): List of tensors to scatter on the source rank. + Receivers must pass ``None`. + src (int, optional): Source rank (default is 0). + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output tensor from the scatter operation. + + """ + return _Scatter.apply(src, group, *tensors) + + +def reduce(tensor, dst, op=dist.ReduceOp.SUM, group=dist.group.WORLD): + """ + Reduces the tensor data across all machines. + + Only the process with rank ``dst`` is going to receive the final result. + + Arguments: + tensor (Tensor): Input of the collective. + dst (int): Destination rank. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective. + + """ + return _Reduce.apply(dst, op, group, tensor) + + +def all_gather(tensor, group=dist.group.WORLD): + """ + Gathers tensors from the whole group in a list. + + Arguments: + tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple[Tensor]): Output of the collective. + + """ + return _AllGather.apply(group, tensor) + + +def all_to_all(tensors, group=dist.group.WORLD): + """ + Each process scatters list of input tensors to all processes in a group and + return gathered list of tensors in output list. + + Arguments: + tensors (list[Tensor]): List of tensors to scatter one per rank. + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple[Tensor]): Output of the collective. + + """ + return _AlltoAll.apply(group, *tensors) + + +def all_reduce(tensor, op=dist.ReduceOp.SUM, group=dist.group.WORLD): + """ + Reduces the tensor data across all machines in such a way that all get + the final result. + + After the call the returned tensor is going to be bitwise + identical in all processes. + + Arguments: + tensor (Tensor): Input of the collective. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective + + """ + return _AllReduce.apply(op, group, tensor) + + +class _Broadcast(Function): + @staticmethod + def forward(ctx, src, group, tensor): + ctx.src = src + ctx.group = group + ctx.rank = dist.get_rank() + # torch.distributed makes all the calls in place + # we allocate new tensors to avoid this + tensor = tensor.clone() + dist.broadcast(tensor, src, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + gx = _Reduce.apply(ctx.src, dist.ReduceOp.SUM, ctx.group, grad_output) + if ctx.src != ctx.rank: + gx.zero_() + return (None, None, gx) + + +class _Gather(Function): + @staticmethod + def forward(ctx, dst, group, tensor): + ctx.dst = dst + ctx.group = group + # Need to create a list of tensors here to do the + # aggregation, get it from the group size + # tensor should be correctly sized for the method + # gathering + tensor_list = [ + torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group)) + ] + if dist.get_rank(group=group) == dst: + dist.gather(tensor, tensor_list, dst, group=group) + else: + dist.gather(tensor, None, dst, group=group) + return tuple(tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),) + + +class _Scatter(Function): + @staticmethod + def forward(ctx, src, group, *tensors): + ctx.src = src + ctx.group = group + assert all(t.size() == tensors[0].size() for t in tensors) + output = torch.zeros_like(tensors[0]) + if dist.get_rank(group=group) == src: + dist.scatter(output, list(tensors), src, group=group) + else: + dist.scatter(output, None, src, group=group) + return output + + @staticmethod + def backward(ctx, grad_output): + return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output) + + +class _Reduce(Function): + @staticmethod + def forward(ctx, src, op, group, tensor): + ctx.src = src + ctx.group = group + tensor = tensor.clone() + dist.reduce(tensor, src, op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),) + + +class _AllGather(Function): + @staticmethod + def forward(ctx, group, tensor): + ctx.group = group + out_tensor_list = [ + torch.empty_like(tensor) for i in range(dist.get_world_size(group=group)) + ] + dist.all_gather(out_tensor_list, tensor, group=group) + return tuple(out_tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + gxs = _AlltoAll.apply(ctx.group, *grad_outputs) + gx = torch.sum(torch.stack(gxs), dim=0) + return (None, gx) + + +class _AlltoAll(Function): + @staticmethod + def forward(ctx, group, *tensors): + ctx.group = group + out_tensor_list = [ + torch.empty_like(tensors[i]) for i in range(dist.get_world_size(group=group)) + ] + reqs = [None] * dist.get_world_size(group=group) + my_rank = dist.get_rank(group=group) + # Implement it on means of scatter/gather, send/recv async operations have issues + if dist.get_backend(group=group) is dist.Backend.GLOO: + for i in range(dist.get_world_size(group=group)): + to_send = None + if i == my_rank: + to_send = list(tensors) + dist.scatter(out_tensor_list[i], to_send, i, group=group) + else: + dist.all_to_all(out_tensor_list, list(tensors), group=group) + return tuple(out_tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None,) + _AlltoAll.apply(ctx.group, *grad_outputs) + + +class _AllReduce(Function): + @staticmethod + def forward(ctx, op, group, tensor): + ctx.group = group + ctx.op = op + tensor = tensor.clone() + dist.all_reduce(tensor, op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),) diff --git a/torch/distributed/optim/functional_adadelta.py b/torch/distributed/optim/functional_adadelta.py new file mode 100644 index 000000000000..f8cee2494e9b --- /dev/null +++ b/torch/distributed/optim/functional_adadelta.py @@ -0,0 +1,82 @@ +from typing import List, Dict, Optional +import torch +import torch.optim.functional as F + +from torch import Tensor + +# Define a TorchScript compatible Functional Adadelta Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdadelta(object): + def __init__( + self, + params: List[Tensor], + lr: float = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + ): + self.defaults = { + "lr": lr, + "rho": rho, + "eps": eps, + "weight_decay": weight_decay, + } + + if len(params) == 0: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group['params'] + grads = [] + square_avgs = [] + acc_deltas = [] + lr = self.defaults['lr'] + rho = self.defaults['rho'] + eps = self.defaults['eps'] + weight_decay = self.defaults['weight_decay'] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + for param, gradient in zip(params, gradients): + if gradient is not None: + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state['step'] = torch.tensor(0.0) + state['square_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) + state['acc_delta'] = torch.zeros_like(param, memory_format=torch.preserve_format) + + state = self.state[param] + square_avgs.append(state['square_avg']) + acc_deltas.append(state['acc_delta']) + + with torch.no_grad(): + F.adadelta(params, + grads, + square_avgs, + acc_deltas, + lr, + rho, + eps, + weight_decay) diff --git a/torch/distributed/optim/functional_adam.py b/torch/distributed/optim/functional_adam.py new file mode 100644 index 000000000000..0df226adde55 --- /dev/null +++ b/torch/distributed/optim/functional_adam.py @@ -0,0 +1,113 @@ +from typing import List, Dict, Optional, Tuple +import torch +import torch.optim.functional as F + +from torch import Tensor + +# Define a TorchScript compatible Functional Adam Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdam(object): + def __init__( + self, + params: List[Tensor], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + amsgrad: bool = False + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.amsgrad = amsgrad + self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + + if len(params) == 0: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group['params'] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: List[int] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + for param, gradient in zip(self.param_group['params'], gradients): + if gradient is not None: + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state['step'] = torch.tensor(0.0) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + + state = self.state[param] + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + if self.amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step'].item()) + + with torch.no_grad(): + F.adam(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + self.amsgrad, + self.defaults['beta1'], + self.defaults['beta2'], + self.defaults['lr'], + self.defaults['weight_decay'], + self.defaults['eps']) diff --git a/torch/distributed/optim/functional_adamw.py b/torch/distributed/optim/functional_adamw.py new file mode 100644 index 000000000000..fe736167cfc4 --- /dev/null +++ b/torch/distributed/optim/functional_adamw.py @@ -0,0 +1,113 @@ +from typing import List, Dict, Optional, Tuple +import torch +import torch.optim.functional as F + +from torch import Tensor + +# Define a TorchScript compatible Functional AdamW Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdamW(object): + def __init__( + self, + params: List[Tensor], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.amsgrad = amsgrad + self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + + if len(params) == 0: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group['params'] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: List[int] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + for param, gradient in zip(self.param_group['params'], gradients): + if gradient is not None: + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state['step'] = torch.tensor(0.0) + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(param, memory_format=torch.preserve_format) + + state = self.state[param] + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + + if self.amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step'].item()) + + with torch.no_grad(): + F.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + self.amsgrad, + self.defaults['beta1'], + self.defaults['beta2'], + self.defaults['lr'], + self.defaults['weight_decay'], + self.defaults['eps']) diff --git a/torch/distributed/optim/functional_rmsprop.py b/torch/distributed/optim/functional_rmsprop.py new file mode 100644 index 000000000000..600f165e3d1e --- /dev/null +++ b/torch/distributed/optim/functional_rmsprop.py @@ -0,0 +1,99 @@ +from typing import List, Dict, Optional +import torch +import torch.optim.functional as F + +from torch import Tensor + +# Define a TorchScript compatible Functional RMSprop Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalRMSprop(object): + def __init__( + self, + params: List[Tensor], + lr: float = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0.0, + momentum: float = 0.0, + centered: bool = False + ): + self.defaults = { + "lr": lr, + "alpha": alpha, + "eps": eps, + "weight_decay": weight_decay, + "momentum": momentum, + } + self.centered = centered + + if len(params) == 0: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group['params'] + grads = [] + square_avgs = [] + grad_avgs = [] + momentum_buffer_list = [] + lr = self.defaults['lr'] + alpha = self.defaults['alpha'] + eps = self.defaults['eps'] + momentum = self.defaults['momentum'] + weight_decay = self.defaults['weight_decay'] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + for param, gradient in zip(params, gradients): + if gradient is not None: + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state['step'] = torch.tensor(0.0) + state['square_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) + if momentum > 0: + state['momentum_buffer'] = torch.zeros_like(param, memory_format=torch.preserve_format) + if self.centered: + state['grad_avg'] = torch.zeros_like(param, memory_format=torch.preserve_format) + + state = self.state[param] + square_avgs.append(state['square_avg']) + if momentum > 0: + momentum_buffer_list.append(state['momentum_buffer']) + if self.centered: + grad_avgs.append(state['grad_avg']) + + state['step'] += 1 + + with torch.no_grad(): + F.rmsprop(params, + grads, + square_avgs, + grad_avgs, + momentum_buffer_list, + lr, + alpha, + eps, + weight_decay, + momentum, + self.centered) diff --git a/torch/distributed/optim/functional_sgd.py b/torch/distributed/optim/functional_sgd.py new file mode 100644 index 000000000000..39045f29a2c6 --- /dev/null +++ b/torch/distributed/optim/functional_sgd.py @@ -0,0 +1,87 @@ +from typing import List, Optional, Dict +import torch +import torch.optim.functional as F + +from torch import Tensor + +# Define a TorchScript compatible Functional SGD Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalSGD(object): + def __init__( + self, + params: List[Tensor], + lr: float = 1e-2, + momentum: float = 0.0, + dampening: float = 0.0, + weight_decay: float = 0.0, + nesterov: bool = False + ): + self.defaults = { + "lr": lr, + "momentum": momentum, + "dampening": dampening, + "weight_decay": weight_decay, + } + self.nesterov = nesterov + self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) + + if len(params) == 0: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step(self, gradients: List[Optional[Tensor]]): + params = self.param_group['params'] + grads = [] + momentum_buffer_list: List[Optional[Tensor]] = [] + lr = self.defaults['lr'] + weight_decay = self.defaults['weight_decay'] + momentum = self.defaults['momentum'] + dampening = self.defaults['dampening'] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + for param, gradient in zip(params, gradients): + if gradient is not None: + grads.append(gradient) + + if param not in self.state: + self.state[param] = {} + + state = self.state[param] + if 'momentum_buffer' not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state['momentum_buffer']) + + with torch.no_grad(): + F.sgd(params, + grads, + momentum_buffer_list, + weight_decay, + momentum, + lr, + dampening, + self.nesterov) + + # update momentum_buffers in state + for i, p in enumerate(params): + state = self.state[p] + momentum_buffer = momentum_buffer_list[i] + if momentum_buffer is not None: + state['momentum_buffer'] = momentum_buffer diff --git a/torch/distributed/optim/optimizer.py b/torch/distributed/optim/optimizer.py index c7f8e3236776..a68e9fc9c113 100644 --- a/torch/distributed/optim/optimizer.py +++ b/torch/distributed/optim/optimizer.py @@ -1,4 +1,5 @@ from typing import List, Optional +import logging import torch.distributed.rpc as rpc import torch.optim as optim @@ -7,12 +8,18 @@ from torch import Tensor from torch.distributed.rpc import RRef from .functional_adagrad import _FunctionalAdagrad +from .functional_adam import _FunctionalAdam +from .functional_adamw import _FunctionalAdamW +from .functional_sgd import _FunctionalSGD +from .functional_adadelta import _FunctionalAdadelta +from .functional_rmsprop import _FunctionalRMSprop import torch.distributed.autograd as dist_autograd from collections import defaultdict from threading import Lock +logger = logging.getLogger(__name__) # XXX: we define a _ScriptModuleOptimizer here to explicitly # compile the FunctionalOptimizer class into TorchScript @@ -181,6 +188,11 @@ class DistributedOptimizer: # functional optimizer to user and still provide the same API. functional_optim_map = { optim.Adagrad: _FunctionalAdagrad, + optim.Adam: _FunctionalAdam, + optim.AdamW: _FunctionalAdamW, + optim.SGD: _FunctionalSGD, + optim.Adadelta: _FunctionalAdadelta, + optim.RMSprop: _FunctionalRMSprop, } def __init__(self, optimizer_class, params_rref, *args, **kwargs): @@ -188,12 +200,22 @@ def __init__(self, optimizer_class, params_rref, *args, **kwargs): for param in params_rref: per_worker_params_rref[param.owner()].append(param) - optim_ctor = DistributedOptimizer.functional_optim_map.get(optimizer_class, optimizer_class) + if optimizer_class in DistributedOptimizer.functional_optim_map and jit._state._enabled: + optim_ctor = DistributedOptimizer.functional_optim_map.get(optimizer_class) + else: + optim_ctor = optimizer_class self.is_functional_optim = (optim_ctor != optimizer_class) if self.is_functional_optim: optimizer_new_func = _new_script_local_optimizer else: + logger.warn( + f"Creating the optimizer {optimizer_class} without TorchScript support, " + "this might result in slow computation time in multithreading environment" + "(i.e. Distributed Model Parallel training on CPU) due to the Python's " + "Global Interpreter Lock (GIL). Please file an issue if you need this " + "optimizer in TorchScript. " + ) optimizer_new_func = _new_local_optimizer remote_optim_futs = [] diff --git a/torch/distributed/pipeline/sync/skip/skippable.py b/torch/distributed/pipeline/sync/skip/skippable.py index e0b0dae584a2..6f6bcd7b5614 100644 --- a/torch/distributed/pipeline/sync/skip/skippable.py +++ b/torch/distributed/pipeline/sync/skip/skippable.py @@ -242,7 +242,7 @@ def skippable( """The decorator to define a :class:`nn.Module ` with skip connections. Decorated modules are called "skippable". This functionality works perfectly fine even when the module is not wrapped by - :class:`~torchpipe.Pipe`. + :class:`~torch.distributed.pipeline.sync.Pipe`. Each skip tensor is managed by its name. Before manipulating skip tensors, a skippable module must statically declare the names for skip tensors by @@ -282,23 +282,10 @@ def forward(self, input): return input + carol Every skip tensor must be associated with exactly one pair of `stash` and - `pop`. :class:`~torchpipe.Pipe` checks this restriction automatically - when wrapping a module. You can also check the restriction by - :func:`~torchpipe.skip.verify_skippables` without - :class:`~torchpipe.Pipe`. - - .. note:: - - :func:`@skippable ` changes the type of the wrapped class. - But currently (mypy v0.740), mypy could not understand class decorators - yet (`#3135 `_). - - There are two workarounds: - - 1. Naively ignore type errors by ``# type: ignore``. - 2. Use ``skippable()()`` as a function instead of a decorator. - - .. seealso:: :ref:`Long Skip Connections` + `pop`. :class:`~torch.distributed.pipeline.sync.Pipe` checks this + restriction automatically when wrapping a module. You can also check the + restriction by :func:`verify_skippables` + without :class:`~torch.distributed.pipeline.sync.Pipe`. """ stashable_names = frozenset(stash) diff --git a/torch/distributions/binomial.py b/torch/distributions/binomial.py index caafcfacb166..42e2741fd93a 100644 --- a/torch/distributions/binomial.py +++ b/torch/distributions/binomial.py @@ -67,7 +67,7 @@ def expand(self, batch_shape, _instance=None): def _new(self, *args, **kwargs): return self._param.new(*args, **kwargs) - @constraints.dependent_property(is_discrete=True) + @constraints.dependent_property(is_discrete=True, event_dim=0) def support(self): return constraints.integer_interval(0, self.total_count) diff --git a/torch/distributions/categorical.py b/torch/distributions/categorical.py index ec22a7e4f802..d5a5bca9de80 100644 --- a/torch/distributions/categorical.py +++ b/torch/distributions/categorical.py @@ -76,7 +76,7 @@ def expand(self, batch_shape, _instance=None): def _new(self, *args, **kwargs): return self._param.new(*args, **kwargs) - @constraints.dependent_property(is_discrete=True) + @constraints.dependent_property(is_discrete=True, event_dim=0) def support(self): return constraints.integer_interval(0, self._num_events - 1) diff --git a/torch/distributions/constraint_registry.py b/torch/distributions/constraint_registry.py index 63fd4b8bf9ce..cbe987e72c79 100644 --- a/torch/distributions/constraint_registry.py +++ b/torch/distributions/constraint_registry.py @@ -160,12 +160,16 @@ def _transform_to_real(constraint): @biject_to.register(constraints.independent) def _biject_to_independent(constraint): - return biject_to(constraint.base_constraint) + base_transform = biject_to(constraint.base_constraint) + return transforms.IndependentTransform( + base_transform, constraint.reinterpreted_batch_ndims) @transform_to.register(constraints.independent) def _transform_to_independent(constraint): - return transform_to(constraint.base_constraint) + base_transform = transform_to(constraint.base_constraint) + return transforms.IndependentTransform( + base_transform, constraint.reinterpreted_batch_ndims) @biject_to.register(constraints.positive) diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 9a8c9bcebdfa..fb05d17ac271 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -64,13 +64,20 @@ class Constraint(object): A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized. + + Attributes: + is_discrete (bool): Whether constrained space is discrete. + Defaults to False. + event_dim (int): Number of rightmost dimensions that together define + an event. The :meth:`check` method will remove this many dimensions + when computing validity. """ is_discrete = False # Default to continuous. event_dim = 0 # Default to univariate. def check(self, value): """ - Returns a byte tensor of `sample_shape + batch_shape` indicating + Returns a byte tensor of ``sample_shape + batch_shape`` indicating whether each event in value satisfies this constraint. """ raise NotImplementedError @@ -83,22 +90,42 @@ class _Dependent(Constraint): """ Placeholder for variables whose support depends on other variables. These variables obey no simple coordinate-wise constraints. - """ - def __init__(self, *, is_discrete=False, event_dim=0): - self.is_discrete = is_discrete - self.event_dim = event_dim + + Args: + is_discrete (bool): Optional value of ``.is_discrete`` in case this + can be computed statically. If not provided, access to the + ``.is_discrete`` attribute will raise a NotImplementedError. + event_dim (int): Optional value of ``.event_dim`` in case this + can be computed statically. If not provided, access to the + ``.event_dim`` attribute will raise a NotImplementedError. + """ + def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): + self._is_discrete = is_discrete + self._event_dim = event_dim super().__init__() - def __call__(self, *, is_discrete=None, event_dim=None): + @property + def is_discrete(self): + if self._is_discrete is NotImplemented: + raise NotImplementedError(".is_discrete cannot be determined statically") + return self._is_discrete + + @property + def event_dim(self): + if self._event_dim is NotImplemented: + raise NotImplementedError(".event_dim cannot be determined statically") + return self._event_dim + + def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): """ Support for syntax to customize static attributes:: constraints.dependent(is_discrete=True, event_dim=1) """ - if is_discrete is None: - is_discrete = self.is_discrete - if event_dim is None: - event_dim = self.event_dim + if is_discrete is NotImplemented: + is_discrete = self._is_discrete + if event_dim is NotImplemented: + event_dim = self._event_dim return _Dependent(is_discrete=is_discrete, event_dim=event_dim) def check(self, x): @@ -120,14 +147,23 @@ class Uniform(Distribution): def __init__(self, low, high): self.low = low self.high = high - @constraints.dependent_property + @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return constraints.interval(self.low, self.high) + + Args: + fn (callable): The function to be decorated. + is_discrete (bool): Optional value of ``.is_discrete`` in case this + can be computed statically. If not provided, access to the + ``.is_discrete`` attribute will raise a NotImplementedError. + event_dim (int): Optional value of ``.event_dim`` in case this + can be computed statically. If not provided, access to the + ``.event_dim`` attribute will raise a NotImplementedError. """ - def __init__(self, fn=None, *, is_discrete=False, event_dim=0): - self.is_discrete = is_discrete - self.event_dim = event_dim + def __init__(self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented): super().__init__(fn) + self._is_discrete = is_discrete + self._event_dim = event_dim def __call__(self, fn): """ @@ -137,7 +173,7 @@ def __call__(self, fn): def support(self): ... """ - return _DependentProperty(fn, is_discrete=self.is_discrete, event_dim=self.event_dim) + return _DependentProperty(fn, is_discrete=self._is_discrete, event_dim=self._event_dim) class _IndependentConstraint(Constraint): @@ -171,6 +207,10 @@ def check(self, value): result = result.all(-1) return result + def __repr__(self): + return "{}({}, {})".format(self.__class__.__name__[1:], repr(self.base_constraint), + self.reinterpreted_batch_ndims) + class _Boolean(Constraint): """ diff --git a/torch/distributions/independent.py b/torch/distributions/independent.py index 0776ca6f67a7..61c7b7a03697 100644 --- a/torch/distributions/independent.py +++ b/torch/distributions/independent.py @@ -68,8 +68,10 @@ def has_enumerate_support(self): @constraints.dependent_property def support(self): - return constraints.independent(self.base_dist.support, - self.reinterpreted_batch_ndims) + result = self.base_dist.support + if self.reinterpreted_batch_ndims: + result = constraints.independent(result, self.reinterpreted_batch_ndims) + return result @property def mean(self): diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index ba7ba73d6063..b1da942b8e61 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -437,10 +437,7 @@ def _kl_transformed_transformed(p, q): raise NotImplementedError if p.event_shape != q.event_shape: raise NotImplementedError - # extra_event_dim = len(p.event_shape) - len(p.base_dist.event_shape) - extra_event_dim = len(p.event_shape) - base_kl_divergence = kl_divergence(p.base_dist, q.base_dist) - return _sum_rightmost(base_kl_divergence, extra_event_dim) + return kl_divergence(p.base_dist, q.base_dist) @register_kl(Uniform, Uniform) diff --git a/torch/distributions/logistic_normal.py b/torch/distributions/logistic_normal.py index 4fb259e20f3f..9cddf3c05290 100644 --- a/torch/distributions/logistic_normal.py +++ b/torch/distributions/logistic_normal.py @@ -1,4 +1,3 @@ -import torch from torch.distributions import constraints from torch.distributions.normal import Normal from torch.distributions.transformed_distribution import TransformedDistribution @@ -33,11 +32,11 @@ class LogisticNormal(TransformedDistribution): def __init__(self, loc, scale, validate_args=None): base_dist = Normal(loc, scale) + if not base_dist.batch_shape: + base_dist = base_dist.expand([1]) super(LogisticNormal, self).__init__(base_dist, StickBreakingTransform(), validate_args=validate_args) - # Adjust event shape since StickBreakingTransform adds 1 dimension - self._event_shape = torch.Size([s + 1 for s in self._event_shape]) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(LogisticNormal, _instance) @@ -45,8 +44,8 @@ def expand(self, batch_shape, _instance=None): @property def loc(self): - return self.base_dist.loc + return self.base_dist.base_dist.loc @property def scale(self): - return self.base_dist.scale + return self.base_dist.base_dist.scale diff --git a/torch/distributions/pareto.py b/torch/distributions/pareto.py index 4ac0d9b1f706..cd6571fa04bd 100644 --- a/torch/distributions/pareto.py +++ b/torch/distributions/pareto.py @@ -45,7 +45,7 @@ def variance(self): a = self.alpha.clamp(min=2) return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2)) - @constraints.dependent_property + @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return constraints.greater_than(self.scale) diff --git a/torch/distributions/relaxed_categorical.py b/torch/distributions/relaxed_categorical.py index 8e9a810cbe2b..5b3bb24b626e 100644 --- a/torch/distributions/relaxed_categorical.py +++ b/torch/distributions/relaxed_categorical.py @@ -30,8 +30,8 @@ class ExpRelaxedCategorical(Distribution): (Jang et al, 2017) """ arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real} - support = constraints.real + 'logits': constraints.real_vector} + support = constraints.real_vector # The true support is actually a submanifold of this. has_rsample = True def __init__(self, temperature, probs=None, logits=None, validate_args=None): @@ -104,7 +104,7 @@ class RelaxedOneHotCategorical(TransformedDistribution): logits (Tensor): the log probability of each event. """ arg_constraints = {'probs': constraints.simplex, - 'logits': constraints.real} + 'logits': constraints.real_vector} support = constraints.simplex has_rsample = True diff --git a/torch/distributions/transformed_distribution.py b/torch/distributions/transformed_distribution.py index d6bb4de75c6b..29f38dffa85a 100644 --- a/torch/distributions/transformed_distribution.py +++ b/torch/distributions/transformed_distribution.py @@ -1,7 +1,8 @@ import torch from torch.distributions import constraints from torch.distributions.distribution import Distribution -from torch.distributions.transforms import Transform +from torch.distributions.independent import Independent +from torch.distributions.transforms import ComposeTransform, Transform from torch.distributions.utils import _sum_rightmost from typing import Dict @@ -42,7 +43,6 @@ class TransformedDistribution(Distribution): arg_constraints: Dict[str, constraints.Constraint] = {} def __init__(self, base_distribution, transforms, validate_args=None): - self.base_dist = base_distribution if isinstance(transforms, Transform): self.transforms = [transforms, ] elif isinstance(transforms, list): @@ -51,25 +51,54 @@ def __init__(self, base_distribution, transforms, validate_args=None): self.transforms = transforms else: raise ValueError("transforms must be a Transform or list, but was {}".format(transforms)) - shape = self.base_dist.batch_shape + self.base_dist.event_shape - event_dim = max([len(self.base_dist.event_shape)] + [t.event_dim for t in self.transforms]) - batch_shape = shape[:len(shape) - event_dim] - event_shape = shape[len(shape) - event_dim:] + + # Reshape base_distribution according to transforms. + base_shape = base_distribution.batch_shape + base_distribution.event_shape + base_event_dim = len(base_distribution.event_shape) + transform = ComposeTransform(self.transforms) + domain_event_dim = transform.domain.event_dim + if len(base_shape) < domain_event_dim: + raise ValueError("base_distribution needs to have shape with size at least {}, but got {}." + .format(domain_event_dim, base_shape)) + shape = transform.forward_shape(base_shape) + expanded_base_shape = transform.inverse_shape(shape) + if base_shape != expanded_base_shape: + base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim] + base_distribution = base_distribution.expand(base_batch_shape) + reinterpreted_batch_ndims = domain_event_dim - base_event_dim + if reinterpreted_batch_ndims > 0: + base_distribution = Independent(base_distribution, reinterpreted_batch_ndims) + self.base_dist = base_distribution + + # Compute shapes. + event_dim = transform.codomain.event_dim + max(base_event_dim - domain_event_dim, 0) + assert len(shape) >= event_dim + cut = len(shape) - event_dim + batch_shape = shape[:cut] + event_shape = shape[cut:] super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(TransformedDistribution, _instance) batch_shape = torch.Size(batch_shape) - base_dist_batch_shape = batch_shape + self.base_dist.batch_shape[len(self.batch_shape):] - new.base_dist = self.base_dist.expand(base_dist_batch_shape) + shape = batch_shape + self.event_shape + for t in reversed(self.transforms): + shape = t.inverse_shape(shape) + base_batch_shape = shape[:len(shape) - len(self.base_dist.event_shape)] + new.base_dist = self.base_dist.expand(base_batch_shape) new.transforms = self.transforms super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new - @constraints.dependent_property + @constraints.dependent_property(is_discrete=False) def support(self): - return self.transforms[-1].codomain if self.transforms else self.base_dist.support + if not self.transforms: + return self.base_dist.support + support = self.transforms[-1].codomain + if len(self.event_shape) > support.event_dim: + support = constraints.independent(support, len(self.event_shape) - support.event_dim) + return support @property def has_rsample(self): @@ -110,8 +139,9 @@ def log_prob(self, value): y = value for transform in reversed(self.transforms): x = transform.inv(y) + event_dim += transform.domain.event_dim - transform.codomain.event_dim log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y), - event_dim - transform.event_dim) + event_dim - transform.domain.event_dim) y = x log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y), diff --git a/torch/distributions/transforms.py b/torch/distributions/transforms.py index 4181db799b28..8a1c989f5ac7 100644 --- a/torch/distributions/transforms.py +++ b/torch/distributions/transforms.py @@ -1,6 +1,9 @@ +import functools import math import numbers +import operator import weakref +from typing import List import torch import torch.nn.functional as F @@ -10,7 +13,6 @@ vec_to_tril_matrix) from torch.nn.functional import pad from torch.nn.functional import softplus -from typing import List __all__ = [ 'AbsTransform', @@ -19,8 +21,10 @@ 'ComposeTransform', 'CorrCholeskyTransform', 'ExpTransform', + 'IndependentTransform', 'LowerCholeskyTransform', 'PowerTransform', + 'ReshapeTransform', 'SigmoidTransform', 'TanhTransform', 'SoftmaxTransform', @@ -74,14 +78,10 @@ class Transform(object): sign (int or Tensor): For bijective univariate transforms, this should be +1 or -1 depending on whether transform is monotone increasing or decreasing. - event_dim (int): Number of dimensions that are correlated together in - the transform ``event_shape``. This should be 0 for pointwise - transforms, 1 for transforms that act jointly on vectors, 2 for - transforms that act jointly on matrices, etc. """ bijective = False + domain: constraints.Constraint codomain: constraints.Constraint - event_dim = 0 def __init__(self, cache_size=0): self._cache_size = cache_size @@ -95,12 +95,10 @@ def __init__(self, cache_size=0): super(Transform, self).__init__() @property - def input_event_dim(self): - return self.event_dim - - @property - def output_event_dim(self): - return self.event_dim + def event_dim(self): + if self.domain.event_dim == self.codomain.event_dim: + return self.domain.event_dim + raise ValueError("Please use either .domain.event_dim or .codomain.event_dim") @property def inv(self): @@ -185,36 +183,40 @@ def log_abs_det_jacobian(self, x, y): def __repr__(self): return self.__class__.__name__ + '()' + def forward_shape(self, shape): + """ + Infers the shape of the forward computation, given the input shape. + Defaults to preserving shape. + """ + return shape + + def inverse_shape(self, shape): + """ + Infers the shapes of the inverse computation, given the output shape. + Defaults to preserving shape. + """ + return shape + class _InverseTransform(Transform): """ Inverts a single :class:`Transform`. This class is private; please instead use the ``Transform.inv`` property. """ - def __init__(self, transform): + def __init__(self, transform: Transform): super(_InverseTransform, self).__init__(cache_size=transform._cache_size) - self._inv = transform + self._inv: Transform = transform - @constraints.dependent_property + @constraints.dependent_property(is_discrete=False) def domain(self): assert self._inv is not None return self._inv.codomain - @constraints.dependent_property + @constraints.dependent_property(is_discrete=False) def codomain(self): assert self._inv is not None return self._inv.domain - @property - def input_event_dim(self): - assert self._inv is not None - return self._inv.output_event_dim - - @property - def output_event_dim(self): - assert self._inv is not None - return self._inv.input_event_dim - @property def bijective(self): assert self._inv is not None @@ -225,11 +227,6 @@ def sign(self): assert self._inv is not None return self._inv.sign - @property - def event_dim(self): - assert self._inv is not None - return self._inv.event_dim - @property def inv(self): return self._inv @@ -244,6 +241,9 @@ def __eq__(self, other): assert self._inv is not None return self._inv == other._inv + def __repr__(self): + return f"{self.__class__.__name__}({repr(self._inv)})" + def __call__(self, x): assert self._inv is not None return self._inv._inv_call(x) @@ -252,6 +252,12 @@ def log_abs_det_jacobian(self, x, y): assert self._inv is not None return -self._inv.log_abs_det_jacobian(y, x) + def forward_shape(self, shape): + return self._inv.inverse_shape(shape) + + def inverse_shape(self, shape): + return self._inv.forward_shape(shape) + class ComposeTransform(Transform): """ @@ -263,7 +269,7 @@ class ComposeTransform(Transform): cache_size (int): Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported. """ - def __init__(self, parts, cache_size=0): + def __init__(self, parts: List[Transform], cache_size=0): if cache_size: parts = [part.with_cache(cache_size) for part in parts] super(ComposeTransform, self).__init__(cache_size=cache_size) @@ -274,17 +280,35 @@ def __eq__(self, other): return False return self.parts == other.parts - @constraints.dependent_property + @constraints.dependent_property(is_discrete=False) def domain(self): if not self.parts: return constraints.real - return self.parts[0].domain - - @constraints.dependent_property + domain = self.parts[0].domain + # Adjust event_dim to be maximum among all parts. + event_dim = self.parts[-1].codomain.event_dim + for part in reversed(self.parts): + event_dim += part.domain.event_dim - part.codomain.event_dim + event_dim = max(event_dim, part.domain.event_dim) + assert event_dim >= domain.event_dim + if event_dim > domain.event_dim: + domain = constraints.independent(domain, event_dim - domain.event_dim) + return domain + + @constraints.dependent_property(is_discrete=False) def codomain(self): if not self.parts: return constraints.real - return self.parts[-1].codomain + codomain = self.parts[-1].codomain + # Adjust event_dim to be maximum among all parts. + event_dim = self.parts[0].domain.event_dim + for part in self.parts: + event_dim += part.codomain.event_dim - part.domain.event_dim + event_dim = max(event_dim, part.codomain.event_dim) + assert event_dim >= codomain.event_dim + if event_dim > codomain.event_dim: + codomain = constraints.independent(codomain, event_dim - codomain.event_dim) + return codomain @lazy_property def bijective(self): @@ -297,10 +321,6 @@ def sign(self): sign = sign * p.sign return sign - @lazy_property - def event_dim(self): - return max(p.event_dim for p in self.parts) if self.parts else 0 - @property def inv(self): inv = None @@ -325,16 +345,30 @@ def __call__(self, x): def log_abs_det_jacobian(self, x, y): if not self.parts: return torch.zeros_like(x) - result = 0 + + # Compute intermediates. This will be free if parts[:-1] are all cached. + xs = [x] for part in self.parts[:-1]: - y_tmp = part(x) - result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y_tmp), - self.event_dim - part.event_dim) - x = y_tmp - part = self.parts[-1] - result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y), - self.event_dim - part.event_dim) - return result + xs.append(part(xs[-1])) + xs.append(y) + + terms = [] + event_dim = self.domain.event_dim + for part, x, y in zip(self.parts, xs[:-1], xs[1:]): + terms.append(_sum_rightmost(part.log_abs_det_jacobian(x, y), + event_dim - part.domain.event_dim)) + event_dim += part.codomain.event_dim - part.domain.event_dim + return functools.reduce(operator.add, terms) + + def forward_shape(self, shape): + for part in self.parts: + shape = part.forward_shape(shape) + return shape + + def inverse_shape(self, shape): + for part in reversed(self.parts): + shape = part.inverse_shape(shape) + return shape def __repr__(self): fmt_string = self.__class__.__name__ + '(\n ' @@ -346,6 +380,136 @@ def __repr__(self): identity_transform = ComposeTransform([]) +class IndependentTransform(Transform): + """ + Wrapper around another transform to treat + ``reinterpreted_batch_ndims``-many extra of the right most dimensions as + dependent. This has no effect on the forward or backward transforms, but + does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions + in :meth:`log_abs_det_jacobian`. + + Args: + base_transform (:class:`Transform`): A base transform. + reinterpreted_batch_ndims (int): The number of extra rightmost + dimensions to treat as dependent. + """ + def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0): + super().__init__(cache_size=cache_size) + self.base_transform = base_transform.with_cache(cache_size) + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return IndependentTransform(self.base_transform, + self.reinterpreted_batch_ndims, + cache_size=cache_size) + + @constraints.dependent_property(is_discrete=False) + def domain(self): + return constraints.independent(self.base_transform.domain, + self.reinterpreted_batch_ndims) + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + return constraints.independent(self.base_transform.codomain, + self.reinterpreted_batch_ndims) + + @property + def bijective(self): + return self.base_transform.bijective + + @property + def sign(self): + return self.base_transform.sign + + def _call(self, x): + if x.dim() < self.domain.event_dim: + raise ValueError("Too few dimensions on input") + return self.base_transform(x) + + def _inverse(self, y): + if y.dim() < self.codomain.event_dim: + raise ValueError("Too few dimensions on input") + return self.base_transform.inv(y) + + def log_abs_det_jacobian(self, x, y): + result = self.base_transform.log_abs_det_jacobian(x, y) + result = _sum_rightmost(result, self.reinterpreted_batch_ndims) + return result + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})" + + def forward_shape(self, shape): + return self.base_transform.forward_shape(shape) + + def inverse_shape(self, shape): + return self.base_transform.inverse_shape(shape) + + +class ReshapeTransform(Transform): + """ + Unit Jacobian transform to reshape the rightmost part of a tensor. + + Note that ``in_shape`` and ``out_shape`` must have the same number of + elements, just as for :meth:`torch.Tensor.reshape`. + + Arguments: + in_shape (torch.Size): The input event shape. + out_shape (torch.Size): The output event shape. + """ + bijective = True + + def __init__(self, in_shape, out_shape, cache_size=0): + self.in_shape = torch.Size(in_shape) + self.out_shape = torch.Size(out_shape) + if self.in_shape.numel() != self.out_shape.numel(): + raise ValueError("in_shape, out_shape have different numbers of elements") + super().__init__(cache_size=cache_size) + + @constraints.dependent_property + def domain(self): + return constraints.independent(constraints.real, len(self.in_shape)) + + @constraints.dependent_property + def codomain(self): + return constraints.independent(constraints.real, len(self.out_shape)) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size) + + def _call(self, x): + batch_shape = x.shape[:x.dim() - len(self.in_shape)] + return x.reshape(batch_shape + self.out_shape) + + def _inverse(self, y): + batch_shape = y.shape[:y.dim() - len(self.out_shape)] + return y.reshape(batch_shape + self.in_shape) + + def log_abs_det_jacobian(self, x, y): + batch_shape = x.shape[:x.dim() - len(self.in_shape)] + return x.new_zeros(batch_shape) + + def forward_shape(self, shape): + if len(shape) < len(self.in_shape): + raise ValueError("Too few dimensions on input") + cut = len(shape) - len(self.in_shape) + if shape[cut:] != self.in_shape: + raise ValueError("Shape mismatch: expected {} but got {}".format(shape[cut:], self.in_shape)) + return shape[:cut] + self.out_shape + + def inverse_shape(self, shape): + if len(shape) < len(self.out_shape): + raise ValueError("Too few dimensions on input") + cut = len(shape) - len(self.out_shape) + if shape[cut:] != self.out_shape: + raise ValueError("Shape mismatch: expected {} but got {}".format(shape[cut:], self.out_shape)) + return shape[:cut] + self.in_shape + + class ExpTransform(Transform): r""" Transform via the mapping :math:`y = \exp(x)`. @@ -400,6 +564,12 @@ def _inverse(self, y): def log_abs_det_jacobian(self, x, y): return (self.exponent * y / x).abs().log() + def forward_shape(self, shape): + return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) + + def inverse_shape(self, shape): + return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) + def _clipped_sigmoid(x): finfo = torch.finfo(x.dtype) @@ -494,15 +664,29 @@ class AffineTransform(Transform): for univariate random variables, 1 for distributions over vectors, 2 for distributions over matrices, etc. """ - domain = constraints.real - codomain = constraints.real bijective = True def __init__(self, loc, scale, event_dim=0, cache_size=0): super(AffineTransform, self).__init__(cache_size=cache_size) self.loc = loc self.scale = scale - self.event_dim = event_dim + self._event_dim = event_dim + + @property + def event_dim(self): + return self._event_dim + + @constraints.dependent_property(is_discrete=False) + def domain(self): + if self.event_dim == 0: + return constraints.real + return constraints.independent(constraints.real, self.event_dim) + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + if self.event_dim == 0: + return constraints.real + return constraints.independent(constraints.real, self.event_dim) def with_cache(self, cache_size=1): if self._cache_size == cache_size: @@ -554,6 +738,16 @@ def log_abs_det_jacobian(self, x, y): shape = shape[:-self.event_dim] return result.expand(shape) + def forward_shape(self, shape): + return torch.broadcast_shapes(shape, + getattr(self.loc, "shape", ()), + getattr(self.scale, "shape", ())) + + def inverse_shape(self, shape): + return torch.broadcast_shapes(shape, + getattr(self.loc, "shape", ()), + getattr(self.scale, "shape", ())) + class CorrCholeskyTransform(Transform): r""" @@ -573,14 +767,8 @@ class :class:`StickBreakingTransform` to transform :math:`X_i` into a """ domain = constraints.real_vector codomain = constraints.corr_cholesky - input_event_dim = 1 - output_event_dim = 2 bijective = True - @property - def event_dim(self): - raise ValueError("Please use `.input_event_dim` or `.output_event_dim` instead.") - def _call(self, x): x = torch.tanh(x) eps = torch.finfo(x.dtype).eps @@ -622,6 +810,26 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.)).sum(dim=-1) return stick_breaking_logdet + tanh_logdet + def forward_shape(self, shape): + # Reshape from (..., N) to (..., D, D). + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + N = shape[-1] + D = round((0.25 + 2 * N) ** 0.5 + 0.5) + if D * (D - 1) // 2 != N: + raise ValueError("Input is not a flattend lower-diagonal number") + return shape[:-1] + (D, D) + + def inverse_shape(self, shape): + # Reshape from (..., D, D) to (..., N). + if len(shape) < 2: + raise ValueError("Too few dimensions on input") + if shape[-2] != shape[-1]: + raise ValueError("Input is not square") + D = shape[-1] + N = D * (D - 1) // 2 + return shape[:-2] + (N,) + class SoftmaxTransform(Transform): r""" @@ -632,9 +840,8 @@ class SoftmaxTransform(Transform): coordinate-wise (except for the final normalization), and thus is appropriate for coordinate-wise optimization algorithms. """ - domain = constraints.real + domain = constraints.real_vector codomain = constraints.simplex - event_dim = 1 def __eq__(self, other): return isinstance(other, SoftmaxTransform) @@ -648,6 +855,16 @@ def _inverse(self, y): probs = y return probs.log() + def forward_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape + + def inverse_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape + class StickBreakingTransform(Transform): """ @@ -662,10 +879,9 @@ class StickBreakingTransform(Transform): This is bijective and appropriate for use in HMC; however it mixes coordinates together and is less appropriate for optimization. """ - domain = constraints.real + domain = constraints.real_vector codomain = constraints.simplex bijective = True - event_dim = 1 def __eq__(self, other): return isinstance(other, StickBreakingTransform) @@ -694,6 +910,16 @@ def log_abs_det_jacobian(self, x, y): detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1) return detJ + def forward_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape[:-1] + (shape[-1] + 1,) + + def inverse_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape[:-1] + (shape[-1] - 1,) + class LowerCholeskyTransform(Transform): """ @@ -703,9 +929,8 @@ class LowerCholeskyTransform(Transform): This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization. """ - domain = constraints.real + domain = constraints.independent(constraints.real, 2) codomain = constraints.lower_cholesky - event_dim = 2 def __eq__(self, other): return isinstance(other, LowerCholeskyTransform) @@ -733,7 +958,6 @@ class CatTransform(Transform): """ def __init__(self, tseq, dim=0, lengths=None, cache_size=0): assert all(isinstance(t, Transform) for t in tseq) - self.event_dim = max(t.event_dim for t in tseq) if cache_size: tseq = [t.with_cache(cache_size) for t in tseq] super(CatTransform, self).__init__(cache_size=cache_size) @@ -744,6 +968,10 @@ def __init__(self, tseq, dim=0, lengths=None, cache_size=0): assert len(self.lengths) == len(self.transforms) self.dim = dim + @lazy_property + def event_dim(self): + return max(t.event_dim for t in self.transforms) + @lazy_property def length(self): return sum(self.lengths) diff --git a/torch/distributions/uniform.py b/torch/distributions/uniform.py index edaf5abf77a5..70a1b1023ac5 100644 --- a/torch/distributions/uniform.py +++ b/torch/distributions/uniform.py @@ -22,7 +22,8 @@ class Uniform(Distribution): high (float or Tensor): upper range (exclusive). """ # TODO allow (loc,scale) parameterization to allow independent constraints. - arg_constraints = {'low': constraints.dependent, 'high': constraints.dependent} + arg_constraints = {'low': constraints.dependent(is_discrete=False, event_dim=0), + 'high': constraints.dependent(is_discrete=False, event_dim=0)} has_rsample = True @property @@ -58,7 +59,7 @@ def expand(self, batch_shape, _instance=None): new._validate_args = self._validate_args return new - @constraints.dependent_property + @constraints.dependent_property(is_discrete=False, event_dim=0) def support(self): return constraints.interval(self.low, self.high) diff --git a/torch/jit/_script.py b/torch/jit/_script.py index bdf00e21c515..f9d6c33192f2 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -273,8 +273,8 @@ def __getattr__(self, attr): # which always throws an exception. class ScriptModule(with_metaclass(ScriptMeta, Module)): # type: ignore - """ - ``ScriptModule``s wrap a C++ ``torch::jit::Module``. ``ScriptModule``s + r""" + A wrapper around C++ ``torch::jit::Module``. ``ScriptModule``\s contain methods, attributes, parameters, and constants. These can be accessed the same as on a normal ``nn.Module``. """ @@ -1095,24 +1095,8 @@ def _recursive_compile_class(obj, loc): rcb = _jit_internal.createResolutionCallbackForClassMethods(obj) _compile_and_register_class(obj, rcb, _qual_name) - -class CompilationUnit(object): - def __init__(self, lang=None, _frames_up=0): - self._c = torch._C.CompilationUnit() - if lang is not None: - self.define(lang, _frames_up=_frames_up + 1) - - def define(self, lang, rcb=None, _frames_up=0): - if not rcb: - rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1) - self._c.define(lang, rcb) - - def __getattr__(self, attr): - r = self._c.find_function(attr) - if r is None: - raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr)) - return r - +CompilationUnit = torch._C.CompilationUnit +set_module(CompilationUnit, "torch.jit") def _unwrap_optional(x): assert x is not None, "Unwrapping null optional" diff --git a/torch/lib/c10d/reducer.cpp b/torch/lib/c10d/reducer.cpp index ad0497724cfa..d13c6c3e658a 100644 --- a/torch/lib/c10d/reducer.cpp +++ b/torch/lib/c10d/reducer.cpp @@ -48,7 +48,8 @@ Reducer::Reducer( has_rebuilt_bucket_(false), bucket_bytes_cap_(bucket_bytes_cap), divFactor_(kUnsetDivFactor), - comm_hook_(nullptr) { + comm_hook_(nullptr), + ddp_logging_data_(std::move(std::make_unique())) { C10_LOG_API_USAGE_ONCE("torch.distributed.ddp.reducer"); TORCH_CHECK(replicas_.size() >= 1, "Expected at least one model replica."); TORCH_CHECK(replicas_[0].size() >= 1, "Expected at least one parameter."); @@ -1465,6 +1466,27 @@ void Reducer::ensure_prior_reduction_finished() { } } +void Reducer::set_construction_logging_data( + const std::string& module_name, + const std::vector& device_ids, + int output_device, + bool broadcast_buffers +) { + ddp_logging_data_->module_name = module_name; + ddp_logging_data_->device_ids = device_ids; + ddp_logging_data_->output_device = output_device; + ddp_logging_data_->broadcast_buffers = broadcast_buffers; + ddp_logging_data_->world_size = process_group_->getSize(); + ddp_logging_data_->rank = process_group_->getRank(); + ddp_logging_data_->bucket_cap_mb = bucket_bytes_cap_ / (1024 * 1024); + ddp_logging_data_->find_unused_parameters = find_unused_parameters_; + ddp_logging_data_->gradient_as_bucket_view = gradient_as_bucket_view_; +} + +c10::DDPLoggingData Reducer::get_ddp_logging_data() { + return *ddp_logging_data_; +} + namespace { // Tensors may be coalesced into buckets. Buckets must contain tensors of diff --git a/torch/lib/c10d/reducer.hpp b/torch/lib/c10d/reducer.hpp index ada39844a9ca..ea06276a7955 100644 --- a/torch/lib/c10d/reducer.hpp +++ b/torch/lib/c10d/reducer.hpp @@ -105,6 +105,18 @@ class Reducer { // index has been used. std::vector get_local_used_maps_on_device() const; + // Set logging data that can be got during DistributedDataParallel + // construction time. + void set_construction_logging_data( + const std::string& module_name, + const std::vector& device_ids, + int output_device, + bool broadcast_buffers); + + // An Interface for users to get DDPLoggingData and log them + // in the applications. + c10::DDPLoggingData get_ddp_logging_data(); + protected: // Forward declaration. struct Bucket; @@ -358,6 +370,10 @@ class Reducer { private: // comm_hook_ is used to access the DDP communication hook if registered. std::unique_ptr comm_hook_; + + // ddp_logging_data_ is used to hold all the ddp related logging + // data fields. + std::unique_ptr ddp_logging_data_; }; // This is equivalent to take_tensors but returns indices into the diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 32b5844ede53..1839c86c78a0 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -528,8 +528,9 @@ then the singular values of each matrix in the batch is returned in descending order. .. note:: The implementation of SVD on CPU uses the LAPACK routine `?gesdd` (a divide-and-conquer - algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the MAGMA routine - `gesdd` as well. + algorithm) instead of `?gesvd` for speed. Analogously, the SVD on GPU uses the cuSOLVER routines + `gesvdj` and `gesvdjBatched` on CUDA 10.1.243 and later, and uses the MAGMA routine `gesdd` + on earlier versions of CUDA. .. note:: The returned matrix `U` will be transposed, i.e. with strides :code:`U.contiguous().transpose(-2, -1).stride()`. @@ -543,6 +544,10 @@ .. note:: The `S` tensor can only be used to compute gradients if :attr:`compute_uv` is True. +.. note:: Since `U` and `V` of an SVD is not unique, each vector can be multiplied by + an arbitrary phase factor :math:`e^{i \phi}` while the SVD result is still correct. + Different platforms, like Numpy, or inputs on different device types, may produce different + `U` and `V` tensors. Args: input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 90023b7a4346..0e909de42986 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -1020,7 +1020,7 @@ def reset_parameters(self) -> None: # has_uninitialized_params is defined in parent class and it is using a protocol on self if not self.has_uninitialized_params() and self.in_channels != 0: # type: ignore[misc] # "type:ignore[..]" is required because mypy thinks that "reset_parameters" is undefined - # super class. Turns out that it is defined in _ConvND which is inherited by any class + # in super class. Turns out that it is defined in _ConvND which is inherited by any class # that also inherits _LazyConvXdMixin super().reset_parameters() # type: ignore[misc] @@ -1031,6 +1031,7 @@ def initialize_parameters(self, input) -> None: # type: ignore[override] self.in_channels = input.shape[1] if self.in_channels % self.groups != 0: raise ValueError('in_channels must be divisible by groups') + assert isinstance(self.weight, UninitializedParameter) if self.transposed: self.weight.materialize(( self.in_channels, self.out_channels // self.groups, *self.kernel_size)) diff --git a/torch/nn/modules/linear.py b/torch/nn/modules/linear.py index ea2c3a8f453b..fb3ce98dbae3 100644 --- a/torch/nn/modules/linear.py +++ b/torch/nn/modules/linear.py @@ -223,6 +223,6 @@ def initialize_parameters(self, input) -> None: # type: ignore if self.has_uninitialized_params(): with torch.no_grad(): self.in_features = input.shape[-1] - self.weight.materialize((self.out_features, self.in_features)) + self.weight.materialize((self.out_features, self.in_features)) # type: ignore self.reset_parameters() # TODO: PartialLinear - maybe in sparse? diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index bb58105ead59..876309f4589f 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -606,6 +606,14 @@ def produces_sparse_gradient(module): self.find_unused_parameters, self.gradient_as_bucket_view) + # Set logging data that can be got during construction time. + dist._set_construction_logging_data( + self.reducer, + self.module.__class__.__name__, + [] if self.device_ids is None else self.device_ids, + -1 if self.output_device is None else self.output_device, + self.broadcast_buffers) + # passing a handle to torch.nn.SyncBatchNorm layer self._passing_sync_batchnorm_handle(self._module_copies) @@ -765,6 +773,9 @@ def train(self, mode=True): module.train(mode) return self + def get_ddp_logging_data(self): + return dist._get_ddp_logging_data(self.reducer) + # When running in join mode, schedules an allreduce to match the one in the # forward pass to determine the no. of currently active processes and whether # all processes have joined. diff --git a/torch/optim/functional.py b/torch/optim/functional.py index 956cd0df1b96..e1fb27f5b7c7 100644 --- a/torch/optim/functional.py +++ b/torch/optim/functional.py @@ -73,8 +73,6 @@ def adam(params: List[Tensor], exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step = state_steps[i] - if amsgrad: - max_exp_avg_sq = max_exp_avg_sqs[i] bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step @@ -87,9 +85,9 @@ def adam(params: List[Tensor], exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now - torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) # Use the max. for normalizing running avg. of gradient - denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) + denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps) else: denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) diff --git a/torch/overrides.py b/torch/overrides.py index cdf4f307442b..1a5ebfb9a133 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -25,7 +25,7 @@ import collections import functools import types -from typing import Dict, Set, List, Any, Callable, Iterable +from typing import Dict, Set, List, Any, Callable, Iterable, Type import torch from torch._C import ( @@ -50,7 +50,7 @@ def get_ignored_functions() -> Set[Callable]: Returns ------- - Tuple[Callable] + Set[Callable] A tuple of functions that are publicly available in the torch API but cannot be overridden with ``__torch_function__``. Mostly this is because none of the arguments of these functions are tensors or tensor-likes. @@ -102,7 +102,6 @@ def get_ignored_functions() -> Set[Callable]: torch.has_cuda, torch.has_cudnn, torch.has_lapack, - torch.cpp, torch.device, torch.dtype, torch.finfo, @@ -163,8 +162,8 @@ def get_ignored_functions() -> Set[Callable]: torch.triu_indices, torch.vander, torch.zeros, + torch._jit_internal.boolean_dispatch, torch.nn.functional.assert_int_or_pair, - torch.nn.functional.boolean_dispatch, torch.nn.functional.upsample, torch.nn.functional.upsample_bilinear, torch.nn.functional.upsample_nearest, @@ -175,6 +174,8 @@ def get_ignored_functions() -> Set[Callable]: torch.nn.functional.sigmoid, torch.nn.functional.hardsigmoid, torch.nn.functional.tanh, + has_torch_function, + handle_torch_function, torch.set_autocast_enabled, torch.is_autocast_enabled, torch.clear_autocast_cache, @@ -242,7 +243,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: # function signatures for native kernels that can be consumed by inspect. # See Issue #28233. Tensor = torch.Tensor - ret = { + ret: Dict[Callable, Callable] = { torch.abs: lambda input, out=None: -1, torch.absolute: lambda input, out=None: -1, torch.adaptive_avg_pool1d: lambda input, output_size: -1, @@ -356,7 +357,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.deg2rad: lambda input, out=None: -1, torch.dequantize: lambda input: -1, torch.det: lambda input: -1, - torch.linalg.det: lambda input: -1, # alias for torch.det + torch.linalg.det: lambda input: -1, # alias for torch.det # type: ignore[attr-defined] torch.detach: lambda input: -1, torch.diag: lambda input, diagonal=0, out=None: -1, torch.diag_embed: lambda input, diagonal=0, out=None: -1, @@ -517,7 +518,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.less: lambda input, other, out=None: -1, torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1, torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1, - torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, + torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, # type: ignore[attr-defined] # noqa: B950 torch.masked_fill: lambda input, mask, value: -1, torch.masked_scatter: lambda input, mask, source: -1, torch.masked_select: lambda input, mask, out=None: -1, @@ -837,6 +838,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1, torch.tril: lambda input, diagonal=0, out=None: -1, torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, + size_average=None, reduce=None, reduction='mean': -1), torch.triu: lambda input, diagonal=0, out=None: -1, torch.true_divide: lambda input, other: -1, @@ -1123,8 +1125,8 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]: https://numpy.org/neps/nep-0018-array-function-protocol.html """ # Runtime is O(num_arguments * num_unique_types) - overloaded_types = set() - overloaded_args = [] + overloaded_types: Set[Type] = set() + overloaded_args: List[Any] = [] for arg in relevant_args: arg_type = type(arg) # We only collect arguments if they have a unique type, which ensures @@ -1147,7 +1149,6 @@ def _get_overloaded_args(relevant_args: Iterable[Any]) -> List[Any]: else: overloaded_types = {arg_type} overloaded_args = [arg] - return overloaded_args diff --git a/torch/quantization/fake_quantize.py b/torch/quantization/fake_quantize.py index 460b1c277a93..46dba803a1ff 100644 --- a/torch/quantization/fake_quantize.py +++ b/torch/quantization/fake_quantize.py @@ -170,7 +170,23 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, key = prefix + name if key in state_dict: val = state_dict[key] - setattr(self, name, val) + # Custom handling to allow loading scale and zero_point + # of size N into uninitialized buffers of size 0. The + # buffers are resized here, and the values are copied in + # the default state_dict loading code of the parent. + if name == 'scale': + self.scale.resize_(val.shape) + else: + assert name == 'zero_point' + self.zero_point.resize_(val.shape) + # For torchscript module we need to update the attributes here since we do not + # call the `_load_from_state_dict` function defined module.py + if torch.jit.is_scripting(): + if name == 'scale': + self.scale.copy_(val) + else: + assert name == 'zero_point' + self.zero_point.copy_(val) elif strict: missing_keys.append(key) super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index e104646af414..06f15240e761 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -221,6 +221,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, # Supported combinations are: # quant_type | activation (compute_type) | weight # static quint8 qint8 + # tuple (activation_dtype, weight_dtype, compute_dtype) supported_dtypes = [ (torch.quint8, torch.qint8, None), @@ -229,6 +230,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, # TODO: debug option for conv module qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) + # leave the op unquantized if the dtype combination is not supported if dtypes not in supported_dtypes: warnings.warn( "dtype combination: {} is not " @@ -357,6 +359,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, ] qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) + # leave the op unquantized if the dtype combination is not supported if dtypes not in supported_dtypes: warnings.warn( "dtype combination: {} is not " @@ -525,6 +528,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, emb_node = node qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) + # leave the op unquantized if the dtype combination is not supported if dtypes not in supported_dtypes: warnings.warn( "dtype combination: {} is not " @@ -568,6 +572,7 @@ def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable, assert node.op == 'call_module' qconfig = quantizer.qconfig_map[node.name] dtypes = get_qconfig_dtypes(qconfig) + # leave the op unquantized if the dtype combination is not supported if dtypes not in supported_dtypes: warnings.warn( "dtype combination: {} is not " diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index f0e8c40602c0..af6b27f322ff 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -6,6 +6,7 @@ from torch.testing._internal.common_utils import TEST_NUMBA import inspect import contextlib +from distutils.version import LooseVersion TEST_CUDA = torch.cuda.is_available() @@ -15,6 +16,10 @@ TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)) TEST_CUDNN_VERSION = torch.backends.cudnn.version() if TEST_CUDNN else 0 +CUDA11OrLater = torch.version.cuda and LooseVersion(torch.version.cuda) >= "11.0.0" +CUDA9 = torch.version.cuda and torch.version.cuda.startswith('9.') +SM53OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3) + TEST_MAGMA = TEST_CUDA if TEST_CUDA: torch.ones(1).cuda() # has_magma shows up after cuda is initialized diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 7172da19a4b7..93bb04bae6f1 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -249,12 +249,25 @@ def instantiate_test_helper(cls, name, *, test, dtype, op): # op-specific decorators to the original test. # Test-sepcific decorators are applied to the original test, # however. - if op is not None and op.decorators is not None: + if op is not None: + active_decorators = [] + if op.should_skip(generic_cls.__name__, name, cls.device_type, dtype): + active_decorators.append(skipIf(True, "Skipped!")) + + if op.decorators is not None: + for decorator in op.decorators: + # Can't use isinstance as it would cause a circular import + if decorator.__class__.__name__ == 'DecorateInfo': + if decorator.is_active(generic_cls.__name__, name, cls.device_type, dtype): + active_decorators += decorator.decorators + else: + active_decorators.append(decorator) + @wraps(test) def test_wrapper(*args, **kwargs): return test(*args, **kwargs) - for decorator in op.decorators: + for decorator in active_decorators: test_wrapper = decorator(test_wrapper) test_fn = test_wrapper @@ -262,12 +275,8 @@ def test_wrapper(*args, **kwargs): test_fn = test # Constructs the test - @wraps(test) + @wraps(test_fn) def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op): - if op is not None and op.should_skip(generic_cls.__name__, name, - self.device_type, dtype): - self.skipTest("Skipped!") - device_arg: str = cls.get_primary_device() if hasattr(test_fn, 'num_required_devices'): device_arg = cls.get_all_devices() diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 39308bc58e4f..90aa1468bd4a 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8,45 +8,70 @@ import numpy as np from torch._six import inf, istuple from torch.autograd import Variable +import collections.abc from typing import List, Tuple, Dict, Any from torch.testing import \ (make_non_contiguous, _dispatch_dtypes, floating_types, floating_types_and, floating_and_complex_types, floating_and_complex_types_and, - all_types_and_complex_and, all_types_and) + all_types_and_complex_and, all_types_and, all_types_and_complex) from torch.testing._internal.common_device_type import \ - (skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl, skipCUDAIfRocm, - expectedAlertNondeterministic, precisionOverride) -from torch.testing._internal.common_cuda import tf32_is_not_fp32 + (skipIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl, + skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride) +from torch.testing._internal.common_cuda import CUDA11OrLater from torch.testing._internal.common_utils import \ (prod_single_zero, random_square_matrix_of_rank, random_symmetric_matrix, random_symmetric_psd_matrix, random_symmetric_pd_matrix, make_nonzero_det, random_fullrank_matrix_distinct_singular_value, set_rng_seed, TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY, - torch_to_numpy_dtype_dict, TEST_WITH_SLOW) + torch_to_numpy_dtype_dict, slowTest) from distutils.version import LooseVersion if TEST_SCIPY: import scipy.special -class SkipInfo(object): - """Describes which test, or type of tests, should be skipped when testing - an operator. Any test that matches all provided arguments will be skipped. - The skip will only be checked if the active_if argument is True.""" - __slots__ = ['cls_name', 'test_name', 'device_type', 'dtypes', 'active_if'] +class DecorateInfo(object): + """Describes which test, or type of tests, should be wrapped in the given + decorators when testing an operator. Any test that matches all provided + arguments will be decorated. The decorators will only be applied if the + active_if argument is True.""" - def __init__(self, cls_name=None, test_name=None, *, + __slots__ = ['decorators', 'cls_name', 'test_name', 'device_type', 'dtypes', 'active_if'] + + def __init__(self, decorators, cls_name=None, test_name=None, *, device_type=None, dtypes=None, active_if=True): + self.decorators = list(decorators) if isinstance(decorators, collections.abc.Sequence) else [decorators] self.cls_name = cls_name self.test_name = test_name self.device_type = device_type self.dtypes = dtypes self.active_if = active_if + def is_active(self, cls_name, test_name, device_type, dtype): + return ( + self.active_if and + (self.cls_name is None or self.cls_name == cls_name) and + (self.test_name is None or self.test_name == test_name) and + (self.device_type is None or self.device_type == device_type) and + (self.dtypes is None or dtype in self.dtypes) + ) + + +class SkipInfo(DecorateInfo): + """Describes which test, or type of tests, should be skipped when testing + an operator. Any test that matches all provided arguments will be skipped. + The skip will only be checked if the active_if argument is True.""" + + def __init__(self, cls_name=None, test_name=None, *, + device_type=None, dtypes=None, active_if=True): + super().__init__(decorators=skipIf(True, "Skipped!"), cls_name=cls_name, + test_name=test_name, device_type=device_type, dtypes=dtypes, + active_if=active_if) + class SampleInput(object): """Represents sample inputs to a function.""" @@ -204,18 +229,8 @@ def sample_inputs(self, device, dtype, requires_grad=False): # Returns True if the test should be skipped and False otherwise def should_skip(self, cls_name, test_name, device_type, dtype): - for si in self.skips: - if not si.active_if: - continue - - cls_name_match = si.cls_name is None or cls_name == si.cls_name - name_match = si.test_name is None or test_name == si.test_name - device_type_match = si.device_type is None or device_type == si.device_type - dtype_match = si.dtypes is None or dtype in si.dtypes - if cls_name_match and name_match and device_type_match and dtype_match: - return True - - return False + return any(si.is_active(cls_name, test_name, device_type, dtype) + for si in self.skips) def supported_dtypes(self, device_type): if device_type == 'cpu': @@ -489,6 +504,11 @@ def sample_inputs_xlogy(self, device, dtype, requires_grad): low=0, high=None, requires_grad=requires_grad))),) +def sample_inputs_trace(self, device, dtype, requires_grad): + return (SampleInput((make_tensor((S, S), device, dtype, + low=None, high=None, + requires_grad=requires_grad))),) + def sample_inputs_linalg_inv(op_info, device, dtype, requires_grad=False): """ This function generates always invertible input for torch.linalg.inv using @@ -675,21 +695,18 @@ def __init__(self, ref=None, # Reference implementation (probably in np.fft namespace) dtypes=floating_and_complex_types(), ndimensional: bool, # Whether dim argument can be a tuple - skips=None, decorators=None, **kwargs): - skips = skips if skips is not None else [] - - # gradgrad is quite slow - if not TEST_WITH_SLOW: - skips.append(SkipInfo('TestGradients', 'test_fn_gradgrad')) - - decorators = decorators if decorators is not None else [] - decorators += [skipCPUIfNoMkl, skipCUDAIfRocm] + decorators = list(decorators) if decorators is not None else [] + decorators += [ + skipCPUIfNoMkl, + skipCUDAIfRocm, + # gradgrad is quite slow + DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'), + ] super().__init__(name=name, dtypes=dtypes, - skips=skips, decorators=decorators, **kwargs) self.ref = ref if ref is not None else _getattr_qual(np, name) @@ -841,16 +858,26 @@ def _sample_inputs_svd(op_info, device, dtype, requires_grad=False, is_linalg_sv """ from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value - # svd and linalg.svd returns V and V.T, respectively. So we need to slice + # svd and linalg.svd returns V and V.conj().T, respectively. So we need to slice # along different dimensions when needed (this is used by # test_cases2:wide_all and wide_all_batched below) if is_linalg_svd: def slice_V(v): return v[..., :(S - 2), :] + + def uv_loss(usv): + u00 = usv[0][0, 0] + v00_conj = usv[2][0, 0] + return u00 * v00_conj else: def slice_V(v): return v[..., :, :(S - 2)] + def uv_loss(usv): + u00 = usv[0][0, 0] + v00_conj = usv[2][0, 0].conj() + return u00 * v00_conj + test_cases1 = ( # some=True (default) # loss functions for complex-valued svd have to be "gauge invariant", # i.e. loss functions shouldn't change when sigh of the singular vectors change. @@ -861,12 +888,10 @@ def slice_V(v): lambda usv: abs(usv[0])), # 'check_grad_u' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), lambda usv: abs(usv[2])), # 'check_grad_v' - # TODO: replace lambda usv: usv[0][0, 0] * usv[2][0, 0] with lambda usv: usv[0][0, 0] * usv[2][0, 0].conj() - # once https://github.com/pytorch/pytorch/issues/45821 is resolved # this test is important as it checks the additional term that is non-zero only for complex-valued inputs # and when the loss function depends both on 'u' and 'v' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device), - lambda usv: usv[0][0, 0] * usv[2][0, 0]), # 'check_grad_uv' + uv_loss), # 'check_grad_uv' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:(S - 2)], lambda usv: (abs(usv[0]), usv[1], abs(usv[2][..., :, :(S - 2)]))), # 'wide' (random_fullrank_matrix_distinct_singular_value(S, dtype=dtype).to(device)[:, :(S - 2)], @@ -1010,8 +1035,9 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): OpInfo('addmm', dtypes=floating_types(), dtypesIfCPU=all_types_and_complex_and(torch.float16, torch.bfloat16), + # BFloat16 support on CUDA requires CUDA 11 and SM53 dtypesIfCUDA=floating_types_and(torch.float16, torch.complex64, torch.complex128, - *[torch.bfloat16] if tf32_is_not_fp32() else []), + *[torch.bfloat16] if CUDA11OrLater else []), dtypesIfROCM=floating_types_and(torch.half), assert_autodiffed=True, autodiff_nonfusible_nodes=['aten::add', 'aten::mm'], @@ -1579,13 +1605,16 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): test_inplace_grad=False, supports_tensor_out=False, sample_inputs_func=sample_inputs_svd, - decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], - skips=( + decorators=[ + skipCUDAIfNoMagma, + skipCPUIfNoLapack, # gradgrad checks are slow - SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)), + DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'), + ], + skips=( # cuda gradchecks are very slow # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 - SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), + SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)), OpInfo('linalg.svd', op=torch.linalg.svd, aten_name='linalg_svd', @@ -1593,13 +1622,16 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): test_inplace_grad=False, supports_tensor_out=False, sample_inputs_func=sample_inputs_linalg_svd, - decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], - skips=( + decorators=[ + skipCUDAIfNoMagma, + skipCPUIfNoLapack, # gradgrad checks are slow - SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)), + DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'), + ], + skips=( # cuda gradchecks are very slow # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 - SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), + SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)), OpInfo('pinverse', op=torch.pinverse, dtypes=floating_and_complex_types(), @@ -1788,6 +1820,19 @@ def reference_sigmoid(x): supports_tensor_out=True, safe_casts_outputs=True, sample_inputs_func=sample_inputs_xlogy), + OpInfo('trace', + dtypes=all_types_and_complex(), + dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half), + test_inplace_grad=False, + supports_tensor_out=False, + # Reference: https://github.com/pytorch/pytorch/issues/50381 + test_complex_grad=False, + sample_inputs_func=sample_inputs_trace, + skips=( + SkipInfo('TestCommon', 'test_variant_consistency_jit', + dtypes=[torch.complex64, torch.complex128]), + SkipInfo('TestCommon', 'test_variant_consistency_eager', + dtypes=[torch.complex64, torch.complex128]))), ] op_db = op_db + op_db_scipy_reference @@ -2494,7 +2539,6 @@ def method_tests(): ('triu', (S, M, M), NO_ARGS, 'batched'), ('triu', (S, M, M), (2,), 'batched_idx'), ('triu', (3, 3, S, S), NO_ARGS, 'more_batched'), - ('trace', (M, M), NO_ARGS), ('cross', (S, 3), ((S, 3),)), ('cross', (S, 3, S), ((S, 3, S), 1), 'dim'), ('index_add', (S, S), (0, index_variable(2, S), (2, S)), 'dim', (), [0]), diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index 1430e6c67035..00c01e3bace8 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -5324,6 +5324,7 @@ def __init__(self, *args, **kwargs): self.test_cpu = kwargs.get('test_cpu', True) self.with_tf32 = kwargs.get('with_tf32', True) self.tf32_precision = kwargs.get('tf32_precision', 0.001) + self.check_batched_grad = kwargs.get('check_batched_grad', True) def __call__(self, test_case): module = self.constructor(*self.constructor_args) @@ -5356,10 +5357,10 @@ def apply_fn(input, target, *params): def apply_fn(input1, input2, target, *params): # type: ignore[misc] return module(input1, input2, target) - gradcheck(apply_fn, inputs) + gradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad) if self.check_gradgrad: - gradgradcheck(apply_fn, inputs) + gradgradcheck(apply_fn, inputs, check_batched_grad=self.check_batched_grad) def test_cuda(self, test_case, dtype, extra_args=None): def convert_dtype(obj, dtype, requires_grad=False): diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 162e43c9580f..11eb5c23b43a 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -17,6 +17,7 @@ import torch.cuda import torch.distributed as dist import torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook as powerSGD +from torch.distributed.algorithms.ddp_comm_hooks import default_hooks as default from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars import torch.nn as nn @@ -2842,43 +2843,91 @@ def test_DistributedDataParallel_non_default_stream(self): msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}", ) + def _test_ddp_hook_parity(self, state, hook): + rank = self.rank + m = torch.nn.Linear(1, 5) + try: + process_group = state.process_group + except AttributeError: + process_group = state + + net_with_hook = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(m).to(rank), device_ids=[rank], process_group=process_group + ) + net_with_hook.register_comm_hook(state=state, hook=hook) + net_without_hook = torch.nn.parallel.DistributedDataParallel( + copy.deepcopy(m).to(rank), device_ids=[rank], process_group=process_group + ) + for i in range(100): + # Clear gradients manually. + for g in [net_without_hook.module.weight.grad, net_with_hook.module.weight.grad]: + if g is not None: + g.requires_grad_(False) + g.zero_() + # Forward + BW + batch = torch.tensor([rank]).float().cuda(rank) + loss = net_without_hook(batch).sum() + loss.backward() + # For each worker, the gradient on the weight should be worker_rank. + grad = net_without_hook.module.weight.grad + avg = grad.clone() + expected_grad = sum(i for i in range(dist.get_world_size())) / dist.get_world_size() + loss_hook = net_with_hook(batch).sum() + loss_hook.backward() + grad_hook = net_with_hook.module.weight.grad + avg_hook = grad_hook.clone() + # Verify hook grad with expected. + # Cannot use exact match here due to a very small accuracy loss, + # e.g. 1e-05, for powerSGD hook case. + assert_func = self.assertEqual if hook == default.allreduce_hook else torch.testing.assert_allclose + assert_func( + avg_hook[0, 0], + expected_grad, + msg=f"Expected hook grad of {expected_grad} but got {avg_hook[0, 0]}" + ) + # Verify hook grad with vanilla allreduce + assert_func( + avg_hook[0, 0], + avg[0, 0], + msg=f"Expected hook grad to be close to allreduce {avg[0, 0]}, but got {avg_hook[0, 0]}" + ) + @unittest.skipIf( BACKEND != "nccl", "Only NCCL backend supports DDP communication hook", ) @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) @skip_if_rocm - def test_DistributedDataParallel_powerSGD_ddp_comm_hook(self): - rank = self.rank + def test_ddp_hook_parity_allreduce(self): + self._test_ddp_hook_parity(state=None, hook=default.allreduce_hook) + + @unittest.skipIf( + BACKEND != "nccl", + "Only NCCL backend supports DDP communication hook", + ) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + @skip_if_rocm + def test_ddp_hook_parity_allreduce_process_group(self): + # process_group is passed in to both DDP and comm. hook + rank_to_GPU = self._init_multigpu_helper() + gpus = [rank_to_GPU[int(r)][0] for r in range(dist.get_world_size())] + process_group = torch.distributed.new_group(gpus) + self._test_ddp_hook_parity(state=process_group, hook=default.allreduce_hook) + + @unittest.skipIf( + BACKEND != "nccl", + "Only NCCL backend supports DDP communication hook", + ) + @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"])) + @skip_if_rocm + def test_ddp_hook_parity_powerSGD(self): for warm_start in [True, False]: - net = torch.nn.parallel.DistributedDataParallel( - torch.nn.Linear(1, 5).to(rank), device_ids=[rank] + powersgd_state = powerSGD.PowerSGDState( + process_group=None, + matrix_approximation_rank=1, + warm_start=warm_start, ) - state = powerSGD.PowerSGDState( - # Use the default process group (dist.group.WORLD) instead of creating a new one. - process_group=None, matrix_approximation_rank=1, warm_start=warm_start - ) - net.register_comm_hook(state=state, hook=powerSGD.powerSGD_hook) - # NOTE: batched_powerSGD_hook cannot pass the following test, because it has a much lower accuracy. - # E.g., after the compression of batched_powerSGD_hook, a gradient of 0.5 can become 0.8335. - for i in range(1000): - # Clear gradients manually. - grad = net.module.weight.grad - if grad is not None: - grad.requires_grad_(False) - grad.zero_() - # Forward + BW - batch = torch.tensor([rank]).float().cuda(rank) - loss = net(batch).sum() - loss.backward() - # For each worker, the gradient on the weight should be worker_rank. - grad = net.module.weight.grad - world_size = int(os.environ["WORLD_SIZE"]) - expected_grad = sum(i for i in range(world_size)) / world_size - # Cannot use exact match here due to a very small accuracy loss, e.g., 1e-05. - torch.testing.assert_allclose( - grad[0, 0], expected_grad, - msg=f"Expected gradient of {expected_grad} but got {grad} on rank {self.rank}") + self._test_ddp_hook_parity(state=powersgd_state, hook=powerSGD.powerSGD_hook) @unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo', @@ -3152,6 +3201,25 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self): global_bs=global_bs, offset=bs_offset) + @unittest.skipIf( + BACKEND == "nccl", "nccl does not support DDP on CPU models" + ) + def test_ddp_logging_data(self): + model_DDP = copy.deepcopy(DDP_NET) + model_DDP = nn.parallel.DistributedDataParallel(model_DDP) + ddp_logging_data = model_DDP.get_ddp_logging_data() + self.assertEqual(ddp_logging_data.world_size, dist.get_world_size()) + self.assertEqual(ddp_logging_data.rank, dist.get_rank()) + self.assertEqual(ddp_logging_data.module_name, 'Net') + self.assertEqual(ddp_logging_data.device_ids, []) + # output_device is -1 in default if it is not set, e.g. + # output_device of CPU training is -1. + self.assertEqual(ddp_logging_data.output_device, -1) + self.assertEqual(ddp_logging_data.broadcast_buffers, True) + self.assertEqual(ddp_logging_data.bucket_cap_mb, 25) + self.assertEqual(ddp_logging_data.find_unused_parameters, False) + self.assertEqual(ddp_logging_data.gradient_as_bucket_view, False) + @skipIfNoTorchVision def test_SyncBatchNorm_process_group(self): # When adopting `convert_sync_batchnorm` to convert a `nn.modules`, diff --git a/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py b/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py index b111ff614608..a8f953539c54 100644 --- a/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py @@ -137,13 +137,12 @@ def test_dist_optim_exception_on_constructor(self): OptimizerFailingOnConstructor, [remote_param1, remote_param2] ) - @dist_init() - def test_dist_optim(self): + def _test_dist_optim_base(self, optim_cls, *args, **kwargs): # local version module1 = MyModule() module2 = MyModule() params = [module1.get_w(), module2.get_w()] - local_optim = optim.SGD(params, lr=0.05) + local_optim = optim_cls(params, *args, **kwargs) old_w1 = module1.w.clone().detach() old_w2 = module2.w.clone().detach() @@ -175,7 +174,7 @@ def test_dist_optim(self): self.assertEqual(old_w2, remote_param2.to_here()) dist_optim = DistributedOptimizer( - optim.SGD, [remote_param1, remote_param2], lr=0.05 + optim_cls, [remote_param1, remote_param2], *args, **kwargs ) with dist_autograd.context() as context_id: @@ -199,65 +198,12 @@ def test_dist_optim(self): self.assertEqual(new_w1, module1.get_w()) self.assertEqual(new_w2, module2.get_w()) - - @dist_init - def test_dist_optim_functional(self): - # local version - module1 = MyModule() - module2 = MyModule() - params = [module1.get_w(), module2.get_w()] - local_optim = optim.Adagrad(params, lr=0.05) - - old_w1 = module1.w.clone().detach() - old_w2 = module2.w.clone().detach() - - g_cpu = torch.Generator() - g_cpu.manual_seed(0) - t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) - t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) - output1 = module1.forward(t2) - output2 = module2.forward(output1) - loss = torch.add(output2, t1).sum() - - loss.backward() - local_optim.step() - - # distributed version - owner1 = "worker%d" % ((self.rank + 1) % self.world_size) - owner2 = "worker%d" % ((self.rank + 2) % self.world_size) - - remote_module1 = rpc.remote(owner1, MyModule) - remote_module2 = rpc.remote(owner2, MyModule) - remote_param1 = remote_method(MyModule.get_w, remote_module1) - remote_param2 = remote_method(MyModule.get_w, remote_module2) - - old_w1_remote = remote_param1.to_here() - - # sanity check: local and remote initial weights should match - self.assertEqual(old_w1, remote_param1.to_here()) - self.assertEqual(old_w2, remote_param2.to_here()) - - dist_optim = DistributedOptimizer( - optim.Adagrad, [remote_param1, remote_param2], lr=0.05 - ) - - with dist_autograd.context() as context_id: - g_cpu.manual_seed(0) - t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) - t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) - output1 = rpc_async_method(MyModule.forward, remote_module1, t2) - output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait()) - loss = torch.add(output2.wait(), t1) - - dist_autograd.backward(context_id, [loss.sum()]) - dist_optim.step(context_id) - - new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait() - new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait() - - # ensure optimizer changed weights - self.assertNotEqual(old_w1, new_w1) - self.assertNotEqual(old_w2, new_w2) - # ensure local equals remote - self.assertEqual(new_w1, module1.get_w()) - self.assertEqual(new_w2, module2.get_w()) + @dist_init() + def test_dist_optim(self): + self._test_dist_optim_base(optim.Adagrad, lr=0.05) + self._test_dist_optim_base(optim.Adam, lr=1e-2, amsgrad=True) + self._test_dist_optim_base(optim.AdamW, lr=0.05, amsgrad=True) + self._test_dist_optim_base(optim.SGD, lr=0.05) + self._test_dist_optim_base(optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True) + self._test_dist_optim_base(optim.Adadelta, rho=0.95) + self._test_dist_optim_base(optim.RMSprop, lr=0.05) diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index 657cb5143ccf..297b776a2c6d 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -686,3 +686,8 @@ def warmup_backward(f, *args): f.backward(retain_graph=True) return results + +# TODO: Remove me once https://bugs.python.org/issue42666 is resolved +def make_global(*args): + for arg in args: + setattr(sys.modules[arg.__module__], arg.__name__, arg) diff --git a/torch/testing/_internal/mypy_wrapper.py b/torch/testing/_internal/mypy_wrapper.py new file mode 100755 index 000000000000..45e46ac16e7b --- /dev/null +++ b/torch/testing/_internal/mypy_wrapper.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 + +""" +This module serves two purposes: + +- it holds the config_files function, which defines the set of subtests + for the test_run_mypy test in test/test_type_hints.py +- it can be run as a script (see the docstring of main below) and passed + the filename of any Python file in this repo, to typecheck that file + using only the subset of our mypy configs that apply to it + +Since editors (e.g. VS Code) can be configured to use this wrapper +script in lieu of mypy itself, the idea is that this can be used to get +inline mypy results while developing, and have at least some degree of +assurance that those inline results match up with what you would get +from running the TestTypeHints test suite in CI. + +See also these wiki pages: + +- https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch +- https://github.com/pytorch/pytorch/wiki/Lint-as-you-type +""" + +import fnmatch +import re +import sys +from configparser import ConfigParser +from itertools import chain +from pathlib import Path, PurePath, PurePosixPath +from typing import List, Set + +# don't import any files that live in the PyTorch repo, since this is +# meant to work as a standalone script + +try: + import mypy.api +except ImportError: + # let test/test_type_hints.py import this even if mypy is absent + pass + + +def config_files() -> Set[str]: + """ + Return a set of the names of all the PyTorch mypy config files. + """ + return { + 'mypy.ini', + 'mypy-strict.ini', + } + + +def glob(*, pattern: str, filename: PurePosixPath) -> bool: + """ + Return True iff the filename matches the (mypy ini) glob pattern. + """ + return any( + fnmatch.fnmatchcase(str(prefix), pattern) + for prefix in chain([filename], filename.parents) + ) + + +def in_files(*, ini: str, py: str) -> bool: + """ + Return True iff the py file is included in the ini file's "files". + """ + config = ConfigParser() + repo_root = Path.cwd() + filename = PurePosixPath(PurePath(py).relative_to(repo_root).as_posix()) + config.read(repo_root / ini) + return any( + glob(pattern=pattern, filename=filename) + for pattern in re.split(r',\s*', config['mypy']['files'].strip()) + ) + + +def main(args: List[str]) -> None: + """ + Run mypy on one Python file using the correct config file(s). + + This function assumes the following preconditions hold: + + - the cwd is set to the root of this cloned repo + - args is a valid list of CLI arguments that could be passed to mypy + - last element of args is an absolute path to a file to typecheck + - all the other args are config flags for mypy, rather than files + + These assumptions hold, for instance, when mypy is run automatically + by VS Code's Python extension, so in your clone of this repository, + you could modify your .vscode/settings.json to look like this: + + { + "python.linting.enabled": true, + "python.linting.mypyEnabled": true, + "python.linting.mypyPath": + "${workspaceFolder}/torch/testing/_internal/mypy_wrapper.py" + } + + More generally, this should work for any editor sets the cwd to the + repo root, runs mypy on one file at a time via its absolute path, + and allows you to set the path to the mypy executable. + """ + if not args: + sys.exit('The PyTorch mypy wrapper must be passed exactly one file.') + configs = [f for f in config_files() if in_files(ini=f, py=args[-1])] + mypy_results = [ + mypy.api.run( + # insert right before args[-1] to avoid being overridden + # by existing flags in args[:-1] + args[:-1] + [ + # uniform, in case some configs set these and some don't + '--show-error-codes', + '--show-column-numbers', + # don't special-case the last line + '--no-error-summary', + f'--config-file={config}', + args[-1], + ] + ) + for config in configs + ] + mypy_issues = list(dict.fromkeys( # remove duplicates, retain order + item + # assume stderr is empty + # https://github.com/python/mypy/issues/1051 + for stdout, _, _ in mypy_results + for item in stdout.splitlines() + )) + for issue in mypy_issues: + print(issue) + # assume all mypy exit codes are nonnegative + # https://github.com/python/mypy/issues/6003 + sys.exit(max( + [exit_code for _, _, exit_code in mypy_results], + default=0, + )) + + +if __name__ == '__main__': + main(sys.argv[1:])