diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index 48685938a146..b9eda365de0c 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -d72e90640ec8514e0369b5419d7f3b74a387b1d7 +deba056203d009fec6b58afb9fa211f6ee3328c8 diff --git a/.github/ci_commit_pins/xla.txt b/.github/ci_commit_pins/xla.txt index 957272e8578b..6e29f8ee3c31 100644 --- a/.github/ci_commit_pins/xla.txt +++ b/.github/ci_commit_pins/xla.txt @@ -1 +1 @@ -08121e41079319cd369f82f523f5a714a0563f9d +dd9b67ff0d6ba4da6a46ca1b22e35c98dbed0d77 diff --git a/aten/src/ATen/InferSize.h b/aten/src/ATen/InferSize.h index 594b87373a20..111c7eb8f5fc 100644 --- a/aten/src/ATen/InferSize.h +++ b/aten/src/ATen/InferSize.h @@ -80,7 +80,7 @@ inline at::SymDimVector infer_size_dv( c10::SymInt numel) { auto res = at::SymDimVector(shape); infer_size_impl( - shape, numel, res); + shape, std::move(numel), res); return res; } diff --git a/aten/src/ATen/core/Formatting.cpp b/aten/src/ATen/core/Formatting.cpp index 875b9ef3d042..4537adff5aa4 100644 --- a/aten/src/ATen/core/Formatting.cpp +++ b/aten/src/ATen/core/Formatting.cpp @@ -13,7 +13,7 @@ std::ostream& operator<<(std::ostream & out, Backend b) { return out << toString(b); } -std::ostream& operator<<(std::ostream & out, Scalar s) { +std::ostream& operator<<(std::ostream & out, const Scalar& s) { if (s.isFloatingPoint()) { return out << s.toDouble(); } @@ -35,7 +35,7 @@ std::ostream& operator<<(std::ostream & out, Scalar s) { throw std::logic_error("Unknown type in Scalar"); } -std::string toString(Scalar s) { +std::string toString(const Scalar& s) { std::stringstream out; out << s; return out.str(); diff --git a/aten/src/ATen/core/Formatting.h b/aten/src/ATen/core/Formatting.h index 6dcfc6c7b3cd..9dcd14e1902e 100644 --- a/aten/src/ATen/core/Formatting.h +++ b/aten/src/ATen/core/Formatting.h @@ -8,8 +8,8 @@ namespace c10 { TORCH_API std::ostream& operator<<(std::ostream& out, Backend b); -TORCH_API std::ostream& operator<<(std::ostream & out, Scalar s); -TORCH_API std::string toString(Scalar s); +TORCH_API std::ostream& operator<<(std::ostream & out, const Scalar& s); +TORCH_API std::string toString(const Scalar& s); } namespace at { diff --git a/aten/src/ATen/core/List_test.cpp b/aten/src/ATen/core/List_test.cpp index e16e26b6042e..f37f3c008493 100644 --- a/aten/src/ATen/core/List_test.cpp +++ b/aten/src/ATen/core/List_test.cpp @@ -1118,7 +1118,7 @@ TEST(ListTest, canAccessStringByReference) { List list({"one", "two"}); const auto& listRef = list; static_assert(std::is_same::value, - "const List acccess should be by const reference"); + "const List access should be by const reference"); std::string str = list[1]; const std::string& strRef = listRef[1]; EXPECT_EQ("two", str); @@ -1130,7 +1130,7 @@ TEST(ListTest, canAccessOptionalStringByReference) { const auto& listRef = list; static_assert( std::is_same>>::value, - "List> acccess should be by const reference"); + "List> access should be by const reference"); c10::optional str1 = list[1]; c10::optional str2 = list[2]; decltype(auto) strRef1 = listRef[1]; diff --git a/aten/src/ATen/core/PythonFallbackKernel.cpp b/aten/src/ATen/core/PythonFallbackKernel.cpp index e16874a83f96..2d8834afe59e 100644 --- a/aten/src/ATen/core/PythonFallbackKernel.cpp +++ b/aten/src/ATen/core/PythonFallbackKernel.cpp @@ -74,10 +74,13 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { (*interpreter)->dispatch(op, stack); return; } - } else if (ivalue.isTensorList() || (ivalue.isOptionalTensorList() && !ivalue.isNone())) { + } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) { // NB: use toListRef as it doesn't induce refcount bumps (toTensorListRef // is not a thing) for (const auto& nv : ivalue.toListRef()) { + if (nv.isNone()) { + continue; + } auto* interpreter = nv.unsafeToTensorImpl()->pyobj_interpreter(); if (interpreter) { (*interpreter)->dispatch(op, stack); diff --git a/aten/src/ATen/core/class_type.cpp b/aten/src/ATen/core/class_type.cpp index 9d7b38d4d67b..2478bde034bc 100644 --- a/aten/src/ATen/core/class_type.cpp +++ b/aten/src/ATen/core/class_type.cpp @@ -86,7 +86,7 @@ std::string ClassType::getForwardPreHookErrorMessage(int pre_hook_idx) const { std::string pre_hook_schema = pre_hook_name + "(self, input: Tuple[" + input_types + "])"; std::string return_string = - "This error occured while scripting the forward pre-hook '" + + "This error occurred while scripting the forward pre-hook '" + pre_hook_name + "' on module '" + name()->name() + "'. If you did not want to script this pre-hook remove it from the " "original NN module before scripting. Pre-hooks for module '" + @@ -111,7 +111,7 @@ std::string ClassType::getForwardHookErrorMessage(int hook_idx) const { std::string hook_schema = hook_name + "(self, input: Tuple[" + input_types + "], output: " + output_types + ")"; std::string return_string = - "This error occured while scripting the forward hook '" + "This error occurred while scripting the forward hook '" + hook_name + "' on module " + name()->name() + ". If you did not want to script this hook remove it from" + " the original NN module before scripting. This hook was" + diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_int.h b/aten/src/ATen/cpu/vec/vec256/vec256_int.h index 0cc36d590019..7737f4a0037c 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_int.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_int.h @@ -1133,6 +1133,201 @@ inline Vectorized Vectorized::le(const Vectorized& other return (*this <= other) & Vectorized(1); } +template +Vectorized inline shift_256_16(const Vectorized& a, const Vectorized& b) { + // No vector instruction for shifting int16_t, so emulating it instead. + + // Control masks for shuffle operation, treating 256 bits as an + // array of 16-bit elements, and considering pairs of neighboring + // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and + // M!=N) is set so that shuffle will move element with index M from + // input pair into element with index N in output pair, and element + // with index M in output pair will be set to all 0s. + __m256i ctl_0_1 = _mm256_set_epi8(29, 28, 0x80, 0x80, 25, 24, 0x80, 0x80, + 21, 20, 0x80, 0x80, 17, 16, 0x80, 0x80, + 13, 12, 0x80, 0x80, 9, 8, 0x80, 0x80, + 5, 4, 0x80, 0x80, 1, 0, 0x80, 0x80); + __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 31, 30, 0x80, 0x80, 27, 26, + 0x80, 0x80, 23, 22, 0x80, 0x80, 19, 18, + 0x80, 0x80, 15, 14, 0x80, 0x80, 11, 10, + 0x80, 0x80, 7, 6, 0x80, 0x80, 3, 2); + + // Masks for bitwise and operation, treating 256 bits as an array of + // 16-bit elements, and considering them in pairs of neighboring + // elements. A mask named "keep_M" (M in [0,1]) is set so that + // bitwise and will copy element with index M from input pair into + // element with the same index in output pair, while the other + // element in output pair will be set to all 0s. + __m256i keep_0 = _mm256_set1_epi32(0xFFFF); + __m256i keep_1 = _mm256_set1_epi32(0xFFFF0000); + + // Take each 16-bit element with idx%2==0 from input array to be + // shifted and extend it to 32 bits so that 0s are added to the + // right. Then, perform shifting on this 32-bit number. Upper 16 + // bits will be proper result of shifting original 16-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%2!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 32 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_1); + __m256i b0 = _mm256_and_si256(b, keep_0); + __m256i c0; + if (left_shift) + c0 = _mm256_sllv_epi32(a0, b0); + c0 = _mm256_shuffle_epi8(c0, ctl_1_0); + + // Peform shifting the same way for input array elements with + // idx%2==1. + __m256i a1 = _mm256_and_si256(a, keep_1); + __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); + __m256i c1; + if (left_shift) + c1 = _mm256_sllv_epi32(a1, b1); + c1 = _mm256_and_si256(c1, keep_1); + + // Merge partial results into the final result. + __m256i c = _mm256_or_si256(c0, c1); + + return c; +} + +template +Vectorized inline shift_256_8(const Vectorized& a, const Vectorized& b) { + // No vector instruction for shifting int8_t, so emulating it instead. + + // Control masks for shuffle operation, treating 256 bits as an + // array of 8-bit elements, and considering quadruples of + // neighboring elements. Specifially, a mask named "ctl_M_N" (M,N + // in [0,1,2,3], and M!=N) is set so that shuffle will move element + // with index M from input quadruple into element with index N in + // output quadruple, and other elements in output quadruple will be + // set to all 0s. + __m256i ctl_0_3 = _mm256_set_epi8(28, 0x80, 0x80, 0x80, 24, 0x80, 0x80, 0x80, + 20, 0x80, 0x80, 0x80, 16, 0x80, 0x80, 0x80, + 12, 0x80, 0x80, 0x80, 8, 0x80, 0x80, 0x80, + 4, 0x80, 0x80, 0x80, 0, 0x80, 0x80, 0x80); + __m256i ctl_1_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 29, 0x80, 0x80, 0x80, 25, + 0x80, 0x80, 0x80, 21, 0x80, 0x80, 0x80, 17, + 0x80, 0x80, 0x80, 13, 0x80, 0x80, 0x80, 9, + 0x80, 0x80, 0x80, 5, 0x80, 0x80, 0x80, 1); + __m256i ctl_1_3 = _mm256_set_epi8(29, 0x80, 0x80, 0x80, 25, 0x80, 0x80, 0x80, + 21, 0x80, 0x80, 0x80, 17, 0x80, 0x80, 0x80, + 13, 0x80, 0x80, 0x80, 9, 0x80, 0x80, 0x80, + 5, 0x80, 0x80, 0x80, 1, 0x80, 0x80, 0x80); + __m256i ctl_2_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 30, 0x80, 0x80, 0x80, 26, + 0x80, 0x80, 0x80, 22, 0x80, 0x80, 0x80, 18, + 0x80, 0x80, 0x80, 14, 0x80, 0x80, 0x80, 10, + 0x80, 0x80, 0x80, 6, 0x80, 0x80, 0x80, 2); + __m256i ctl_2_3 = _mm256_set_epi8(30, 0x80, 0x80, 0x80, 26, 0x80, 0x80, 0x80, + 22, 0x80, 0x80, 0x80, 18, 0x80, 0x80, 0x80, + 14, 0x80, 0x80, 0x80, 10, 0x80, 0x80, 0x80, + 6, 0x80, 0x80, 0x80, 2, 0x80, 0x80, 0x80); + __m256i ctl_3_0 = _mm256_set_epi8(0x80, 0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, + 0x80, 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, + 0x80, 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, + 0x80, 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3); + __m256i ctl_3_1 = _mm256_set_epi8(0x80, 0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, + 0x80, 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, + 0x80, 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, + 0x80, 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80); + __m256i ctl_3_2 = _mm256_set_epi8(0x80, 31, 0x80, 0x80, 0x80, 27, 0x80, 0x80, + 0x80, 23, 0x80, 0x80, 0x80, 19, 0x80, 0x80, + 0x80, 15, 0x80, 0x80, 0x80, 11, 0x80, 0x80, + 0x80, 7, 0x80, 0x80, 0x80, 3, 0x80, 0x80); + + // Masks for bitwise and operation, treating 256 bits as an array of + // 8-bit elements, and considering them in quadruples of neighboring + // elements. A mask named "keep_M" (M in [0,1,2,3]) is set so that + // bitwise and will copy element with index M from input quadruple + // into element with the same index in output quadruple, while the + // other elements in output quadruple will be set to all 0s. + __m256i keep_0 = _mm256_set1_epi32(0xFF); + __m256i keep_3 = _mm256_set1_epi32(0xFF000000); + + // Take each 8-bit element with idx%4==0 from input array to be + // shifted and extend it to 32 bits so that 0s are added to the + // right. Then, perform shifting on this 32-bit number. Upper 8 + // bits will be proper result of shifting original 8-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%4!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 32 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m256i a0 = _mm256_shuffle_epi8(a, ctl_0_3); + __m256i b0 = _mm256_and_si256(b, keep_0); + __m256i c0; + if (left_shift) + c0 = _mm256_sllv_epi32(a0, b0); + c0 = _mm256_shuffle_epi8(c0, ctl_3_0); + + // Peform shifting the same way for input array elements with + // idx%4==1. + __m256i a1 = _mm256_shuffle_epi8(a, ctl_1_3); + __m256i b1 = _mm256_shuffle_epi8(b, ctl_1_0); + __m256i c1; + if (left_shift) + c1 = _mm256_sllv_epi32(a1, b1); + c1 = _mm256_shuffle_epi8(c1, ctl_3_1); + + // Peform shifting the same way for input array elements with + // idx%4==2. + __m256i a2 = _mm256_shuffle_epi8(a, ctl_2_3); + __m256i b2 = _mm256_shuffle_epi8(b, ctl_2_0); + __m256i c2; + if (left_shift) + c2 = _mm256_sllv_epi32(a2, b2); + c2 = _mm256_shuffle_epi8(c2, ctl_3_2); + + // Peform shifting the same way for input array elements with + // idx%4==3. + __m256i a3 = _mm256_and_si256(a, keep_3); + __m256i b3 = _mm256_shuffle_epi8(b, ctl_3_0); + __m256i c3; + if (left_shift) + c3 = _mm256_sllv_epi32(a3, b3); + c3 = _mm256_and_si256(c3, keep_3); + + // Merge partial results into the final result. + __m256i c01 = _mm256_or_si256(c0, c1); + __m256i c23 = _mm256_or_si256(c2, c3); + __m256i c = _mm256_or_si256(c01, c23); + + return c; +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm256_sllv_epi64(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm256_sllv_epi32(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return shift_256_16(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return shift_256_8(a, b); +} + #endif }}} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_int.h b/aten/src/ATen/cpu/vec/vec512/vec512_int.h index c2cbc0b1d7f9..590c3254e379 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_int.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_int.h @@ -1163,6 +1163,99 @@ inline Vectorized Vectorized::le(const Vectorized& other return (*this <= other) & Vectorized(1); } +template +Vectorized inline shift_512_8(const Vectorized& a, const Vectorized& b) { + // No vector instruction for shifting int8_t, so emulating it instead. + + // Control masks for shuffle operation, treating 512 bits as an + // array of 8-bit elements, and considering pairs of neighboring + // elements. Specifially, a mask named "ctl_M_N" (M,N in [0,1], and + // M!=N) is set so that shuffle will move element with index M from + // input pair into element with index N in output pair, and element + // with index M in output pair will be set to all 0s. + __m512i ctl_0_1 = _mm512_set_epi8(62, 0x80, 60, 0x80, 58, 0x80, 56, 0x80, + 54, 0x80, 52, 0x80, 50, 0x80, 48, 0x80, + 46, 0x80, 44, 0x80, 42, 0x80, 40, 0x80, + 38, 0x80, 36, 0x80, 34, 0x80, 32, 0x80, + 30, 0x80, 28, 0x80, 26, 0x80, 24, 0x80, + 22, 0x80, 20, 0x80, 18, 0x80, 16, 0x80, + 14, 0x80, 12, 0x80, 10, 0x80, 8, 0x80, + 6, 0x80, 4, 0x80, 2, 0x80, 0, 0x80); + __m512i ctl_1_0 = _mm512_set_epi8(0x80, 63, 0x80, 61, 0x80, 59, 0x80, 57, + 0x80, 55, 0x80, 53, 0x80, 51, 0x80, 49, + 0x80, 47, 0x80, 45, 0x80, 43, 0x80, 41, + 0x80, 39, 0x80, 37, 0x80, 35, 0x80, 33, + 0x80, 31, 0x80, 29, 0x80, 27, 0x80, 25, + 0x80, 23, 0x80, 21, 0x80, 19, 0x80, 17, + 0x80, 15, 0x80, 13, 0x80, 11, 0x80, 9, + 0x80, 7, 0x80, 5, 0x80, 3, 0x80, 1); + + // Masks for bitwise and operation, treating 512 bits as an array of + // 8-bit elements, and considering them in pairs of neighboring + // elements. A mask named "keep_M" (M in [0,1]) is set so that + // bitwise and will copy element with index M from input pair into + // element with the same index in output pair, while the other + // element in output pair will be set to all 0s. + __m512i keep_0 = _mm512_set1_epi16(0xFF); + __m512i keep_1 = _mm512_set1_epi16(0xFF00); + + // Take each 8-bit element with idx%2==0 from input array to be + // shifted and extend it to 16 bits so that 0s are added to the + // right. Then, perform shifting on this 16-bit number. Upper 8 + // bits will be proper result of shifting original 8-bit number, so + // write them to result array, into the same position from which + // corresponding input element is taken. Also, make sure that + // result array elements with idx%2!=0 are set to all 0s. + // + // Note that number of bits to shift for is extended to 16 bits by + // adding 0s to the left. That means this number is not properly + // sign-extended for negative values. However, number of bits to + // shift is treated as an unsigned integer by respective shift + // intrinsics anyway so if negative then either with or without + // proper sign extension, it will be interpreted as a number greater + // than 32, and the shifting result will be the same. + __m512i a0 = _mm512_shuffle_epi8(a, ctl_0_1); + __m512i b0 = _mm512_and_si512(b, keep_0); + __m512i c0; + if (left_shift) + c0 = _mm512_sllv_epi16(a0, b0); + c0 = _mm512_shuffle_epi8(c0, ctl_1_0); + + // Peform shifting the same way for input array elements with + // idx%2==1. + __m512i a1 = _mm512_and_si512(a, keep_1); + __m512i b1 = _mm512_shuffle_epi8(b, ctl_1_0); + __m512i c1; + if (left_shift) + c1 = _mm512_sllv_epi16(a1, b1); + c1 = _mm512_and_si512(c1, keep_1); + + // Merge partial results into the final result. + __m512i c = _mm512_or_si512(c0, c1); + + return c; +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm512_sllv_epi64(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm512_sllv_epi32(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return _mm512_sllv_epi16(a, b); +} + +template <> +Vectorized inline operator<<(const Vectorized& a, const Vectorized& b) { + return shift_512_8(a, b); +} + #endif }}} diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index b9b3745e99d5..f045437ac368 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -799,6 +799,13 @@ inline Vectorized operator~(const Vectorized& a) { return a ^ ones; } +template Vectorized inline operator<<(const Vectorized &a, const Vectorized &b) { + Vectorized c; + for (int i = 0; i != Vectorized::size(); i++) { + c[i] = a[i] << b[i]; + } + return c; +} template inline Vectorized& operator += (Vectorized& a, const Vectorized& b) { @@ -826,6 +833,12 @@ inline Vectorized& operator *= (Vectorized& a, const Vectorized& b) { return a; } +template +inline Vectorized& operator <<= (Vectorized& a, const Vectorized& b) { + a = a << b; + return a; +} + template inline Vectorized fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { return a * b + c; diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index b5e685dac65f..25e4c2b44fa9 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -82,7 +82,7 @@ void CUDAHooks::initCUDA() const { at::cuda::detail::init_p2p_access_cache(num_devices); #if AT_MAGMA_ENABLED() - TORCH_INTERNAL_ASSERT(magma_init_fn != nullptr, "Cannot initilaize magma, init routine not set"); + TORCH_INTERNAL_ASSERT(magma_init_fn != nullptr, "Cannot initialize magma, init routine not set"); magma_init_fn(); #endif } diff --git a/aten/src/ATen/cudnn/Descriptors.cpp b/aten/src/ATen/cudnn/Descriptors.cpp index f954bbf5623a..0e739a49bb33 100644 --- a/aten/src/ATen/cudnn/Descriptors.cpp +++ b/aten/src/ATen/cudnn/Descriptors.cpp @@ -164,7 +164,7 @@ void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_fo filter_format = CUDNN_TENSOR_NHWC; break; default: - TORCH_INTERNAL_ASSERT(false, "unsurpported memory_format for cuDNN filters"); + TORCH_INTERNAL_ASSERT(false, "unsupported memory_format for cuDNN filters"); } set(getDataType(t), (int) dim, size, filter_format); } diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp index 5eecbedd93e7..fc51e9d74409 100644 --- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp @@ -928,6 +928,11 @@ Tensor index_copy_decomp( return at::scatter(self, dim, index_, source); ; } +// Note [Fix vmap slice_scatter] +// registers a decomposition for `slice_scatter` that calls into `slice.src` +// *_scatter operators have some special semantics though, that we can't easily +// through a decomposition: slice_scatter's output needs to have the same +// size, size, strides and storage_offset as the input. Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src, int64_t dim, c10::optional start, c10::optional end, int64_t step) diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index a44f39c5bb2e..c6b82426d3bf 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -220,6 +220,18 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking) return at::metal::metal_copy_(self, src); } + // Exit early if self and src are views of the same data + const bool is_same_data = ( + self.is_alias_of(src) && + self.storage_offset() == src.storage_offset() && + self.strides().equals(src.strides()) && + self.sizes().equals(src.sizes()) && + self.scalar_type() == src.scalar_type() + ); + if (is_same_data) { + return self; + } + auto iter = TensorIteratorConfig() .add_output(self) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 8c5a6fc8f195..c21bc4b47531 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -877,7 +877,7 @@ std::vector> matrix_chain_order(TensorList tensors) { /** * @brief Recursively multiplies the tensors i...j using the given order * - * @param tensors matrices to multiply togther + * @param tensors matrices to multiply together * @param order optimal chain multiplication order from #matrix_chain_order * @param i index of first tensor to be multiplied * @param j index of last tensor to be multiplied diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 0acc3506cf51..e08e17af4d08 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -1053,13 +1053,13 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const optional ho if (onesided) { if (n_fft / 2 + 1 != fft_size) { std::ostringstream ss; - REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft / 2 + 1 when onsided=True, but got " << fft_size; + REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft / 2 + 1 when onesided=True, but got " << fft_size; AT_ERROR(ss.str()); } } else { if (n_fft != fft_size) { std::ostringstream ss; - REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft when onsided=False, but got " << fft_size; + REPR(ss) << ": expected the frequency dimension (3rd to the last) of the input tensor to match n_fft when onesided=False, but got " << fft_size; AT_ERROR(ss.str()); } } diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index deb9b949aa5d..c44f3a921afc 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -204,10 +205,11 @@ #include #endif +#include #include #include +#include #include -#include namespace at { namespace meta { @@ -416,7 +418,7 @@ Tensor& set_storage_meta__symint(Tensor& result, Storage storage, c10::SymInt st const auto itemsize = result.dtype().itemsize(); c10::SymInt size_bytes = at::detail::computeStorageNbytes( size, stride, itemsize, storage_offset); - storage.set_nbytes(size_bytes); + storage.set_nbytes(std::move(size_bytes)); } return result; } @@ -1196,6 +1198,8 @@ Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t return self.narrow(dim, start, length).clone(at::MemoryFormat::Contiguous); } +// Should just use narrow_copy_out, but this API is used internally at Meta: +// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561 Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){ auto output = at::empty_like(self); return narrow_copy_dense_cpu_out(self, dim, start, length, output); @@ -1205,9 +1209,10 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_ int64_t allDim = self.dim(); int64_t end = start+length; TORCH_CHECK(allDim > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(length >= 0, "narrow(): length must be non-negative."); TORCH_CHECK(dim >= 0 && dim < allDim, "Dimension ", dim, " out of range. Expecting 0 <= dim < ", allDim, "."); - TORCH_CHECK(start >= 0 && length >= 0 && end <= self.size(dim), + TORCH_CHECK(start >= 0 && end <= self.size(dim), "Invalid range to narrow. range(start, start+length) must be a subset of range(0, ", self.size(dim), ").") Tensor indices = self._indices(); int64_t sparse_dim = self.sparse_dim(); @@ -1235,6 +1240,8 @@ Tensor narrow_copy_sparse(const Tensor& self, int64_t dim, int64_t start, int64_ return newTensor._coalesced_(self.is_coalesced()); } +// Should just use narrow_copy_out, but this API is used internally at Meta: +// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561 Tensor& narrow_copy_dense_cpu_out( const Tensor& self, int64_t dim, int64_t start, int64_t length, Tensor& output ) { @@ -1318,22 +1325,24 @@ Tensor& narrow_copy_dense_cpu_out( Tensor narrow(const Tensor& self, int64_t dim, int64_t start, int64_t length) { TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(length >= 0, "narrow(): length must be non-negative."); auto cur_size = self.size(dim); if (start != cur_size) { // start being the end is valid, but not a valid dim specification. start = maybe_wrap_dim(start, cur_size); } - TORCH_CHECK(length >= 0 && start <= cur_size - length, + TORCH_CHECK(start <= cur_size - length, "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ")."); return at::slice(self, dim, start, start + length, 1); } Tensor narrow_symint(const Tensor& self, int64_t dim, SymInt start, SymInt length) { TORCH_CHECK(self.dim() > 0, "narrow() cannot be applied to a 0-dim tensor."); + TORCH_CHECK(length >= 0, "narrow(): length must be non-negative."); auto cur_size = self.sym_size(dim); if (start != cur_size) { // start being the end is valid, but not a valid dim specification. start = maybe_wrap_dim(start, cur_size); } - TORCH_CHECK(length >= 0 && start <= cur_size - length, + TORCH_CHECK(start <= cur_size - length, "start (", start, ") + length (", length, ") exceeds dimension size (", cur_size, ")."); return at::slice_symint(self, dim, start, start + length, 1); } @@ -1565,7 +1574,7 @@ Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) { // // We need to do the checks here instead of in `native_functions.yaml` // to preserve backwards compatibility. - if (!self.is_xla() && !self.is_lazy() && !self.is_ipu()) { + if (!self.is_xla() && !self.is_lazy() && !self.is_ipu() && !at::isTensorSubclassLike(self)) { return self._reshape_alias_symint(shape, stride.value()); } else { return self.view_symint(shape); @@ -1581,7 +1590,7 @@ Tensor _reshape_copy_symint(const Tensor& self, c10::SymIntArrayRef proposed_sha c10::SymDimVector shape = infer_size_dv(proposed_shape, self.sym_numel()); if (self.is_mkldnn()) { - TORCH_CHECK(0, "_reshape_copy not implemented for mkldnn tesnors"); + TORCH_CHECK(0, "_reshape_copy not implemented for mkldnn tensors"); } if (self.is_contiguous()) { diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp index c5fa0210cd58..d367dbe01103 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp @@ -209,7 +209,7 @@ PackedLinearWeightQnnp::PackedLinearWeightQnnp( std::get(serialized); TORCH_CHECK( serialization_version <= SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION, - "Attemped to deserialize sparse qlinear packed params with an ", + "Attempted to deserialize sparse qlinear packed params with an ", "incompatible serialization version (", serialization_version, " > ", diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp index a430e8185451..64cab80790a9 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp @@ -45,7 +45,7 @@ at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl( const auto cols_input = static_cast(input.size(input.dim() - 1)); TORCH_CHECK( cols_input == input_channels_, - "quantized_sparse_lienar: Input tensor's last and weight tensor's" + "quantized_sparse_linear: Input tensor's last and weight tensor's" " second dimension must match."); // On empty input, no output data will be generated, diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index a5dde16024ab..c2497a6949f1 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -314,10 +314,13 @@ void bitwise_xor_kernel(TensorIteratorBase& iter) { void lshift_kernel(TensorIteratorBase& iter) { AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cpu", [&]() { - cpu_kernel(iter, - [](scalar_t a, scalar_t b) -> scalar_t { - return static_cast>(a) << b; - }); + cpu_kernel_vec(iter, + [](scalar_t a, scalar_t b) -> scalar_t { + return static_cast>(a) << b; + }, + [](Vectorized a, Vectorized b) { + return a << b; + }); }); } diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index c08c5d26b63c..426243392b6f 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -70,7 +70,7 @@ Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_see c10::optional device, c10::optional pin_memory) { // See [Note: hacky wrapper removal for TensorOptions] - TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); + TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); AT_ERROR("_cudnn_init_dropout_state: ATen not compiled with cuDNN support"); } diff --git a/aten/src/ATen/native/mkldnn/Prelu.cpp b/aten/src/ATen/native/mkldnn/Prelu.cpp index acc78211d83c..dc7d239da7b6 100644 --- a/aten/src/ATen/native/mkldnn/Prelu.cpp +++ b/aten/src/ATen/native/mkldnn/Prelu.cpp @@ -17,7 +17,7 @@ std::tuple mkldnn_prelu_backward(const Tensor& grad_output, cons }} -#else // AT_MKLDNN_EBABLED +#else // AT_MKLDNN_ENABLED #include #include @@ -76,4 +76,4 @@ std::tuple mkldnn_prelu_backward(const Tensor& grad_output, cons } }} -#endif // AT_MKLDNN_EBABLED +#endif // AT_MKLDNN_ENABLED diff --git a/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp b/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp index 8444f9ca615b..58a7036bdd7e 100644 --- a/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp @@ -36,10 +36,10 @@ namespace { inline void check_inputs(const Tensor& qa, const Tensor& qb) { TORCH_CHECK( qa.qscheme() == kPerTensorAffine, - "Only per tensor quantization is suported in Add."); + "Only per tensor quantization is supported in Add."); TORCH_CHECK( qa.qscheme() == qb.qscheme(), - "Both inputs to Add must have the same quantization shceme."); + "Both inputs to Add must have the same quantization scheme."); TORCH_CHECK( qa.scalar_type() == qb.scalar_type(), "Add operands should have same data type."); diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 2cd7cd81b903..b6fa57b9e3ed 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -130,7 +130,7 @@ at::SmallVector MakeDeConvOutputShape( ", output padding: ", output_padding[idx], ", dilation: ", dilation[idx]) TORCH_CHECK(output_shape[idx + 2] < kReasonableMaxDim, - "Output dimension is beyound reasonable maximum for ", idx, + "Output dimension is beyond reasonable maximum for ", idx, " axis;" " kernel: ", kernel[idx], ", stride: ", stride[idx], diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index 2250e84ad7a6..9d2f1a96c31b 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -1,4 +1,5 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include #include @@ -444,7 +445,7 @@ c10::intrusive_ptr> PackedConvWeightsOnednn< exp_wgt.init(w_desc); exp_wgt.set_scale(wgt_scales); // Also for feed_from() exp_wgt.feed_from(wgt, transpose); // expect wgt to be in [OC IC KH KW] format - ideep::tensor * packed_weight_p = new ideep::tensor(exp_wgt); + ideep::tensor * packed_weight_p = new ideep::tensor(std::move(exp_wgt)); packed_weight_p->set_scale(wgt_scales); packed_weight_p->set_zero_point(wgt_zero_points); std::unique_ptr weight_ptr(packed_weight_p); diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index dda600e9b41c..36523bbd1b9b 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -23,6 +23,7 @@ #include #include +#include #include int register_linear_params(); @@ -249,7 +250,7 @@ c10::intrusive_ptr PackedLinearWeightsOnednn::prepack( dnnl::memory::data_type::u8); ideep::tensor exp_wgt(w_desc); exp_wgt.feed_from(wgt); - ideep::tensor * packed_weight_p = new ideep::tensor(exp_wgt); + ideep::tensor * packed_weight_p = new ideep::tensor(std::move(exp_wgt)); packed_weight_p->set_scale(wgt_scales); packed_weight_p->set_zero_point(wgt_zero_points); std::unique_ptr weight_ptr(packed_weight_p); diff --git a/aten/src/ATen/native/quantized/cpu/qmatmul.cpp b/aten/src/ATen/native/quantized/cpu/qmatmul.cpp index c1e5041a5734..4da714e0bcf0 100644 --- a/aten/src/ATen/native/quantized/cpu/qmatmul.cpp +++ b/aten/src/ATen/native/quantized/cpu/qmatmul.cpp @@ -21,7 +21,7 @@ inline void check_inputs(const Tensor& qa, const Tensor& qb) { "MatMul operands should have same data type."); TORCH_CHECK( qa.qscheme() == kPerTensorAffine || qa.qscheme() == kPerTensorSymmetric, - "Only per-tensor quantization is suported in Matmul."); + "Only per-tensor quantization is supported in Matmul."); TORCH_CHECK( qa.qscheme() == qb.qscheme(), "Both inputs to Matmul must have the same quantization scheme."); @@ -45,7 +45,7 @@ Tensor qmatmul( " and ", b_num_dims, " provided)"); TORCH_CHECK( num_dims >= 2, - "Quantized Matmul currently only suports operands which are at least 2-dimensional. (", + "Quantized Matmul currently only supports operands which are at least 2-dimensional. (", num_dims, " provided)"); const int64_t m = qa.size(num_dims - 2); diff --git a/aten/src/ATen/native/quantized/cpu/qmul.cpp b/aten/src/ATen/native/quantized/cpu/qmul.cpp index 35d2139c6c14..aa6ad0e724f5 100644 --- a/aten/src/ATen/native/quantized/cpu/qmul.cpp +++ b/aten/src/ATen/native/quantized/cpu/qmul.cpp @@ -40,7 +40,7 @@ inline void check_inputs(const Tensor& qa, const Tensor& qb) { TORCH_CHECK(qa.scalar_type() == qb.scalar_type(), "Mul operands should have same data type."); TORCH_CHECK(qa.qscheme() == qb.qscheme(), - "Both inputs to Mul must have the same quantization shceme."); + "Both inputs to Mul must have the same quantization scheme."); } // Note: out is assumed to be the same size as self and other. @@ -314,7 +314,7 @@ class QMulScalarTensor final { static Tensor run(Tensor qa, Tensor b) { TORCH_CHECK(qa.qscheme() == kPerTensorAffine || qa.qscheme() == kPerTensorSymmetric, - "Only per tensor quantization is suported in Mul."); + "Only per tensor quantization is supported in Mul."); auto qc = at::empty_like(qa, qa.suggest_memory_format()); return _mul_scalar_out(qc, qa, b.item()); } diff --git a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp index d9abd8bcfc79..fbb46b4b0174 100644 --- a/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp +++ b/aten/src/ATen/native/quantized/cudnn/BinaryOps.cpp @@ -71,10 +71,10 @@ std::unordered_map> PackedConvWeightCudnn< int64_t groups, bool transpose) { // TODO: need to check out to implement groups for conv operator in Conv.cpp - TORCH_CHECK(groups == 1, "Quantized cudnn conv2d is currenty limited to groups = 1; received groups =", groups); + TORCH_CHECK(groups == 1, "Quantized cudnn conv2d is currently limited to groups = 1; received groups =", groups); TORCH_CHECK(weight.qscheme() == c10::kPerTensorAffine, "Unsupported qscheme: ", toString(weight.qscheme())); TORCH_CHECK( kSpatialDim == 2, // 1D is packed as 2d, hence we don't need other checks diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index 2bcbe00a8720..ef205c5673ae 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -129,7 +129,7 @@ void _validate_sparse_compressed_tensor_args_worker(const Tensor& compressed_ind // 3.1 TORCH_CHECK( static_cast(size.size()) == batch_ndim + base_ndim + dense_ndim, - "tensor dimensionality must be sum of batch, base, and dense dimensionalites (=", + "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=", batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size()); // For CSR/CSC formats, we define blocksize=(1, 1) so that checking @@ -380,7 +380,7 @@ DimVector _estimate_sparse_compressed_tensor_size( } TORCH_CHECK( static_cast(size.size()) == batch_ndim + base_ndim + dense_ndim, - "tensor dimensionality must be sum of batch, base, and dense dimensionalites (=", + "tensor dimensionality must be sum of batch, base, and dense dimensionalities (=", batch_ndim, " + ", base_ndim, " + ", dense_ndim, ") but got ", size.size()); return size; } @@ -559,13 +559,13 @@ Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocki "torch.copy_: expected shapes of self and src to match along dimension ", self_compressed_dim, " for ", self.layout(), " layout but the corresponding dimensions of self and src are ", - self_compressed_dims, " and ", src_compressed_dims, ", respecitvely."); + self_compressed_dims, " and ", src_compressed_dims, ", respectively."); } else { TORCH_CHECK(self_compressed_dims == src_compressed_dims, "torch.copy_: expected shapes of self and src to match along dimensions ", self_compressed_dim, " and ", src_compressed_dim, ", respectively, for ", self.layout(), " layout but the corresponding dimensions of self and src are ", - self_compressed_dims, " and ", src_compressed_dims, ", respecitvely."); + self_compressed_dims, " and ", src_compressed_dims, ", respectively."); } AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_", [&]{}, @@ -576,7 +576,7 @@ Tensor& copy_sparse_compressed_(Tensor& self, const Tensor& src, bool non_blocki auto src_blocksize = DimVector(src_values.sizes().slice(src_values.dim()-2, 2)); TORCH_CHECK(self_blocksize == src_blocksize, "torch.copy_: copy of sparse compressed tensors having different block sizes is not supported.", - " self and src block sizes are ", self_blocksize, " and ", src_blocksize, ", respectivly."); + " self and src block sizes are ", self_blocksize, " and ", src_blocksize, ", respectively."); }); AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "copy_sparse_compressed_", [&]{ diff --git a/aten/src/ATen/native/vulkan/api/Adapter.cpp b/aten/src/ATen/native/vulkan/api/Adapter.cpp index 311648b6894e..176236611c1d 100644 --- a/aten/src/ATen/native/vulkan/api/Adapter.cpp +++ b/aten/src/ATen/native/vulkan/api/Adapter.cpp @@ -195,7 +195,7 @@ std::string get_device_type_str(const VkPhysicalDeviceType type) { case VK_PHYSICAL_DEVICE_TYPE_CPU: return "CPU"; default: - return "UNKOWN"; + return "UNKNOWN"; } } diff --git a/aten/src/ATen/native/vulkan/ops/Clone.cpp b/aten/src/ATen/native/vulkan/ops/Clone.cpp index de353a10cb93..2601d785ddb5 100644 --- a/aten/src/ATen/native/vulkan/ops/Clone.cpp +++ b/aten/src/ATen/native/vulkan/ops/Clone.cpp @@ -21,7 +21,7 @@ Tensor clone( TORCH_CHECK( (c10::MemoryFormat::Preserve == memory_format) || (c10::MemoryFormat::Contiguous == memory_format), - "Vulkan supports Preserve and Contiguous memory foramts"); + "Vulkan supports Preserve and Contiguous memory formats"); Tensor self; if (memory_format == MemoryFormat::Preserve) { diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 758f4396b5b1..198877e0313d 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -156,6 +156,7 @@ "hrnet_w18", # accuracy "lcnet_0500", # accuracy "levit_128", # levit_128 + "poolformer_m36", "rexnet_100", # accuracy "swin_base_patch4_window7_224", "twins_pcpvt_base", # time out diff --git a/benchmarks/dynamo/dist_util.py b/benchmarks/dynamo/dist_util.py index 9e2f086ca8b7..d0267cbca307 100644 --- a/benchmarks/dynamo/dist_util.py +++ b/benchmarks/dynamo/dist_util.py @@ -13,13 +13,16 @@ CheckpointImpl, ) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy try: from .torchbench import setup_torchbench_cwd except ImportError: from torchbench import setup_torchbench_cwd +from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead +from transformers.models.t5.modeling_t5 import T5Block + def setup(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" @@ -122,26 +125,22 @@ def check_fn(submodule): ) -# from transformers.models.t5.modeling_t5 import T5Block - MODEL_FSDP_WRAP = { - ToyModel: (MyModule,) - # TODO T5: (T5Block,) + "toy_model": (MyModule,), + "hf_Bert": (BertLayer, BertLMPredictionHead), + "hf_T5": (T5Block,), } -def apply_fsdp(model, use_checkpointing=False, use_wrap_policy=True): - blocks = MODEL_FSDP_WRAP[model.__class__] - +def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True): wrap_policy = None + blocks = MODEL_FSDP_WRAP[ + "toy_model" if model.__class__ is ToyModel else args.torchbench_model + ] if use_wrap_policy: - # transformer policy is really a generic policy that wraps modules of specified classes - wrap_policy = functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls=blocks - ) + wrap_policy = ModuleWrapPolicy(blocks) - model = FSDP(model, auto_wrap_policy=wrap_policy) + model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True) if use_checkpointing: fsdp_checkpointing_base(model, blocks) - return model diff --git a/benchmarks/dynamo/distributed.py b/benchmarks/dynamo/distributed.py index c2db15563348..32e3b544d87d 100644 --- a/benchmarks/dynamo/distributed.py +++ b/benchmarks/dynamo/distributed.py @@ -50,6 +50,7 @@ def move_tensor(maybe_tensor): if args.fsdp: model = apply_fsdp( + args, model, use_checkpointing=args.fsdp_checkpoint, use_wrap_policy=args.fsdp_wrap, @@ -160,7 +161,9 @@ def experiment(fn, key, world_size, results): ) args = parser.parse_args() - model_name = "ToyModel" if args.toy_model else args.torchbench_model + model_name = args.torchbench_model + if args.toy_model: + model_name = "ToyModel" model, inputs = get_model(args) fn = partial(run_model, args, model, inputs) diff --git a/benchmarks/dynamo/test.py b/benchmarks/dynamo/test.py index 317e8e4ea50e..438218462030 100644 --- a/benchmarks/dynamo/test.py +++ b/benchmarks/dynamo/test.py @@ -5,8 +5,17 @@ from .torchbench import setup_torchbench_cwd, TorchBenchmarkRunner +try: + # fbcode only + from aiplatform.utils.sanitizer_status import is_asan_or_tsan +except ImportError: + + def is_asan_or_tsan(): + return False + class TestDynamoBenchmark(unittest.TestCase): + @unittest.skipIf(is_asan_or_tsan(), "ASAN/TSAN not supported") def test_benchmark_infra_runs(self) -> None: """ Basic smoke test that TorchBench runs. diff --git a/build_variables.bzl b/build_variables.bzl index e476341b9ac0..473ed1c1de1b 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1489,3 +1489,33 @@ aten_cuda_with_sort_by_key_source_list = [ aten_cuda_cu_with_sort_by_key_source_list = [ "aten/src/ATen/native/cuda/Unique.cu", ] + +# Followings are source code for xnnpack delegate + +xnnpack_delegate_serializer_header = [ + "torch/csrc/jit/backends/xnnpack/serialization/serializer.h", +] + +xnnpack_delegate_serializer_source_list = [ + "torch/csrc/jit/backends/xnnpack/serialization/serializer.cpp", +] + +xnnpack_delegate_core_source_list = [ + "torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp", +] + +xnnpack_delegate_core_header = [ + "torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.h", + "torch/csrc/jit/backends/xnnpack/executor/xnn_executor.h", +] + +xnnpack_backend_header = [ + "torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h", +] + xnnpack_delegate_core_header + +xnnpack_backend_source_list = [ + "torch/csrc/jit/backends/xnnpack/compiler/xnn_compiler.cpp", + "torch/csrc/jit/backends/xnnpack/xnnpack_backend_lib.cpp", + "torch/csrc/jit/backends/xnnpack/xnnpack_backend_preprocess.cpp", + "torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp", +] + xnnpack_delegate_core_source_list diff --git a/c10/core/Storage.h b/c10/core/Storage.h index a89a0039fdfe..09c5920b5649 100644 --- a/c10/core/Storage.h +++ b/c10/core/Storage.h @@ -76,7 +76,7 @@ struct C10_API Storage { } void set_nbytes(c10::SymInt size_bytes) const { - storage_impl_.get()->set_nbytes(size_bytes); + storage_impl_.get()->set_nbytes(std::move(size_bytes)); } bool resizable() const { diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index bbf080384253..1d80daed871a 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -112,7 +112,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { } void set_nbytes(c10::SymInt size_bytes) { - size_bytes_ = size_bytes; + size_bytes_ = std::move(size_bytes); } bool resizable() const { diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index d2f3aafaad8b..fcec452821d7 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -85,9 +85,6 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual SymNode clone() { TORCH_CHECK(false, "NYI"); }; - virtual SymNode sym_int() { - TORCH_CHECK(false, "NYI"); - } virtual SymNode sym_float() { TORCH_CHECK(false, "NYI"); } diff --git a/c10/core/WrapDimMinimal.cpp b/c10/core/WrapDimMinimal.cpp index 6703f0638901..2375dc3ac5cf 100644 --- a/c10/core/WrapDimMinimal.cpp +++ b/c10/core/WrapDimMinimal.cpp @@ -14,7 +14,8 @@ T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) { "Dimension specified as ", dim, " but tensor has no dimensions"); - return c10::maybe_wrap_dim(dim, /*dim_post_expr=*/1, /*wrap_scalar=*/false); + return c10::maybe_wrap_dim( + std::move(dim), /*dim_post_expr=*/1, /*wrap_scalar=*/false); } T min = dim_post_expr * -1; diff --git a/c10/core/WrapDimMinimal.h b/c10/core/WrapDimMinimal.h index 0f5949f65082..dda01fbe18f0 100644 --- a/c10/core/WrapDimMinimal.h +++ b/c10/core/WrapDimMinimal.h @@ -38,7 +38,7 @@ inline c10::SymInt maybe_wrap_dim( c10::SymInt dim, c10::SymInt dim_post_expr, bool wrap_scalar = true) { - return _maybe_wrap_dim(dim, dim_post_expr, wrap_scalar); + return _maybe_wrap_dim(std::move(dim), std::move(dim_post_expr), wrap_scalar); } } // namespace c10 diff --git a/caffe2/perfkernels/batch_box_cox_avx2.cc b/caffe2/perfkernels/batch_box_cox_avx2.cc index 8b93293646db..6171b5bfd032 100644 --- a/caffe2/perfkernels/batch_box_cox_avx2.cc +++ b/caffe2/perfkernels/batch_box_cox_avx2.cc @@ -1,3 +1,4 @@ +#include #ifdef CAFFE2_PERF_USE_MKL #include #include @@ -5,30 +6,68 @@ #include "vectorizer.h" -#ifndef VECTORIZED_KERNEL +// Enable compiler vectorized version only if numerical consistency is not +// required between dev and opt versions - disabled for now +#ifndef FAST_VECTORIZED_KERNEL #define CPU_CAPABILITY_AVX2 #include namespace at::vec { +// Implements the vectorized version of std::max() operation, +// which DOESNOT propagates NaN for second argument template Vectorized max(const Vectorized& a, const Vectorized& b); -// Implements the vectorized version of std::max() operation, -// which DOESNOT propagates NaN for second argument template <> Vectorized max(const Vectorized& a, const Vectorized& b) { // std::max(NaN, nonNan) -> NaN return _mm256_max_pd(b, a); } - template <> Vectorized max(const Vectorized& a, const Vectorized& b) { // std::max(NaN, nonNan) -> NaN return _mm256_max_ps(b, a); } +// Implements recieprocal method based on newton-rapson method +// 1. user RCP approximiation +// 2. update with RCP = RCP * (2 - X * RCP) +template +Vectorized fast_recieprocal(const Vectorized& b); +template +scalar_t fast_recieprocal(scalar_t b); + +template<> +Vectorized fast_recieprocal(const Vectorized& b) { + auto minus2 = _mm256_set1_ps(-2.f); + auto rcp = _mm256_rcp_ps(b); + rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2)); + rcp = _mm256_mul_ps(rcp, _mm256_fnmsub_ps(rcp, b, minus2)); + return rcp; +} + +template <> +float fast_recieprocal(float b) { + auto minus2 = _mm_set_ss(-2.f); + auto b_reg = _mm_set_ss(b); + auto rcp = _mm_rcp_ss(b_reg); + rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); + rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); + return _mm_cvtss_f32(rcp); +} + +template<> +Vectorized fast_recieprocal(const Vectorized& b) { + return b.reciprocal(); +} + +template <> +double fast_recieprocal(double b) { + return 1./b; +} + } #endif @@ -45,14 +84,6 @@ template void PackV(const int N, const T* a, const int* ia, T* y); template void UnpackV(const int N, const T* a, T* y, const int* iy); -template -void Pow(const int N, const T* a, const T* b, T* y); -template -void Add(const int N, const T* a, const T* b, T* y); -template -void Div(const int N, const T* a, const T* b, T* y); -template -void Ln(const int N, const T* a, T* y); #define DELEGATE_PACKV_FUNCTION(T, OriginalFunc) \ template <> \ @@ -72,29 +103,7 @@ DELEGATE_UNPACKV_FUNCTION(float, vsUnpackV) DELEGATE_UNPACKV_FUNCTION(double, vdUnpackV) #undef DELEGATE_UNPACKV_FUNCTION -#define DELEGATE_SIMPLE_BINARY_FUNCTION(T, Funcname, OriginalFunc) \ - template <> \ - void Funcname(const int N, const T* a, const T* b, T* y) { \ - OriginalFunc(N, a, b, y); \ - } -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Pow, vsPow) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Pow, vdPow) -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Add, vsAdd) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Add, vdAdd) -DELEGATE_SIMPLE_BINARY_FUNCTION(float, Div, vsDiv) -DELEGATE_SIMPLE_BINARY_FUNCTION(double, Div, vdDiv) -#undef DELEGATE_SIMPLE_BINARY_FUNCTION - -#define DELEGATE_SIMPLE_UNARY_FUNCTION(T, Funcname, OriginalFunc) \ - template <> \ - void Funcname(const int N, const T* a, T* y) { \ - OriginalFunc(N, a, y); \ - } -DELEGATE_SIMPLE_UNARY_FUNCTION(float, Ln, vsLn) -DELEGATE_SIMPLE_UNARY_FUNCTION(double, Ln, vdLn) -#undef DELEGATE_SIMPLE_UNARY_FUNCTION - -#ifndef VECTORIZED_KERNEL +#ifndef FAST_VECTORIZED_KERNEL template void box_cox_zero_lambda( size_t D, @@ -140,7 +149,7 @@ void box_cox_nonzero_lambda( auto sum = data + lambda2; auto max = at::vec::max(sum, k_eps_vec); auto lambda1 = Vec::loadu(lambda1_ptr + j); - auto lambda_over_1 = lambda1.reciprocal(); + auto lambda_over_1 = at::vec::fast_recieprocal(lambda1); auto pow = max.pow(lambda1); auto res = at::vec::fmsub(pow, lambda_over_1, lambda_over_1); res.store(out + j); @@ -148,7 +157,7 @@ void box_cox_nonzero_lambda( for ( ;j < D; ++j) { auto sum = data_ptr[j] + lambda2_ptr[j]; auto max = std::max(sum, k_eps); - auto lambda_over_1 = 1 / lambda1_ptr[j]; + auto lambda_over_1 = at::vec::fast_recieprocal(lambda1_ptr[j]); auto pow = std::pow(max, lambda1_ptr[j]); out[j] = pow * lambda_over_1 - lambda_over_1; } @@ -181,12 +190,16 @@ void box_cox_nonzero_lambda( FAST_MATH auto sum = data_ptr[j] + lambda2_ptr[j]; auto max = std::max(sum, k_eps); - auto lambda_over_1 = 1 / lambda1_ptr[j]; - auto pow = std::pow(max, lambda1_ptr[j]); + auto lamda1 = lambda1_ptr[j]; + auto lambda_over_1 = 1 / lamda1; + if constexpr (std::is_same::value) { + lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); + lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); + } + auto pow = std::pow(max, lamda1); out[j] = pow * lambda_over_1 - lambda_over_1; } } - #endif template diff --git a/docs/source/data.rst b/docs/source/data.rst index de2d44920f57..b44096d10196 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -441,9 +441,6 @@ Example:: .. autoclass:: torch.utils.data.distributed.DistributedSampler -.. This module is experimental and should be private, adding it here for now -.. py:module:: torch.utils.data.communication - .. These modules are documented as part of torch/data listing them here for .. now until we have a clearer fix .. py:module:: torch.utils.data.datapipes diff --git a/functorch/.flake8 b/functorch/.flake8 deleted file mode 100644 index a6d73773e3b5..000000000000 --- a/functorch/.flake8 +++ /dev/null @@ -1,20 +0,0 @@ -[flake8] -select = B,C,E,F,P,T4,W,B9 -max-line-length = 120 -# C408 ignored because we like the dict keyword argument syntax -# E501 is not flexible enough, we're using B950 instead -ignore = - E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, - # shebang has extra meaning in fbcode lints, so I think it's not worth trying - # to line this up with executable bit - EXE001, - # these ignores are from flake8-bugbear; please fix! - B007,B008, - # these ignores are from flake8-comprehensions; please fix! - C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415 -exclude = - ./.git, - ./benchmarks, - ./docs, - ./examples, - ./notebooks diff --git a/functorch/_src/compilers.py b/functorch/_src/compilers.py index 3f52fede57eb..55de63e5c344 100644 --- a/functorch/_src/compilers.py +++ b/functorch/_src/compilers.py @@ -19,6 +19,8 @@ draw_graph, min_cut_rematerialization_partition, ) +import torch.utils._pytree as pytree + # These canonicalizations are needed here (and not decompositions), as the ops @@ -113,6 +115,34 @@ def nop(fx_g: fx.GraphModule, _) -> Callable: """ return fx_g +class DebugInterpreter(fx.Interpreter): + def run_node(self, n): + # TODO: This will fail once we start caching in AOTAutograd + # again, because we need to remap SymInts to their new values + # in the presence of dynamism + r = super().run_node(n) + if 'val' in n.meta: + n_vals, n_spec = pytree.tree_flatten(n.meta['val']) + r_vals, r_spec = pytree.tree_flatten(r) + assert n_spec == r_spec, f"{n_spec} != {r_spec}" + assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" + for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): + if not isinstance(rv, torch.Tensor): + continue + assert nv.size() == rv.size(), f"output {i}: {nv.size()} != {rv.size()}" + assert nv.dtype == rv.dtype, f"output {i}: {nv.dtype} != {rv.dtype}" + assert torch._prims_common.check_significant_strides(nv, rv), f"output {i}: {nv.stride()} != {rv.stride()}" + return r + + +@make_boxed_compiler +def debug_nop(fx_g: fx.GraphModule, _) -> Callable: + """ + Returns a (slow) interpreter over the FX graph module that also checks + various debugging properties (e.g., that tracing strides matched real + strides.) + """ + return DebugInterpreter(fx_g).run @make_boxed_compiler def simple_ts_compile(fx_g, _): diff --git a/functorch/packaging/build_wheel.sh b/functorch/packaging/build_wheel.sh deleted file mode 100644 index 074e7dde7714..000000000000 --- a/functorch/packaging/build_wheel.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash -set -ex - -script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" -. "$script_dir/pkg_helpers.bash" - -export BUILD_TYPE=wheel -setup_env 0.2.0 -setup_wheel_python -pip_install numpy pyyaml future ninja -pip_install --upgrade setuptools -setup_pip_pytorch_version -python setup.py clean - -if [[ "$OSTYPE" == "msys" ]]; then - "$script_dir/windows/internal/vc_env_helper.bat" python setup.py bdist_wheel -else - python setup.py bdist_wheel -fi diff --git a/functorch/packaging/pkg_helpers.bash b/functorch/packaging/pkg_helpers.bash deleted file mode 100644 index 329891a07216..000000000000 --- a/functorch/packaging/pkg_helpers.bash +++ /dev/null @@ -1,414 +0,0 @@ -# A set of useful bash functions for common functionality we need to do in -# many build scripts - - -# Setup CUDA environment variables, based on CU_VERSION -# -# Inputs: -# CU_VERSION (cpu, cu92, cu100) -# NO_CUDA_PACKAGE (bool) -# BUILD_TYPE (conda, wheel) -# -# Outputs: -# VERSION_SUFFIX (e.g., "") -# PYTORCH_VERSION_SUFFIX (e.g., +cpu) -# WHEEL_DIR (e.g., cu100/) -# CUDA_HOME (e.g., /usr/local/cuda-9.2, respected by torch.utils.cpp_extension) -# FORCE_CUDA (respected by torchvision setup.py) -# NVCC_FLAGS (respected by torchvision setup.py) -# -# Precondition: CUDA versions are installed in their conventional locations in -# /usr/local/cuda-* -# -# NOTE: Why VERSION_SUFFIX versus PYTORCH_VERSION_SUFFIX? If you're building -# a package with CUDA on a platform we support CUDA on, VERSION_SUFFIX == -# PYTORCH_VERSION_SUFFIX and everyone is happy. However, if you are building a -# package with only CPU bits (e.g., torchaudio), then VERSION_SUFFIX is always -# empty, but PYTORCH_VERSION_SUFFIX is +cpu (because that's how you get a CPU -# version of a Python package. But that doesn't apply if you're on OS X, -# since the default CU_VERSION on OS X is cpu. -setup_cuda() { - - # First, compute version suffixes. By default, assume no version suffixes - export VERSION_SUFFIX="" - export PYTORCH_VERSION_SUFFIX="" - export WHEEL_DIR="" - # Wheel builds need suffixes (but not if they're on OS X, which never has suffix) - if [[ "$BUILD_TYPE" == "wheel" ]] && [[ "$(uname)" != Darwin ]]; then - export PYTORCH_VERSION_SUFFIX="+$CU_VERSION" - # Match the suffix scheme of pytorch, unless this package does not have - # CUDA builds (in which case, use default) - if [[ -z "$NO_CUDA_PACKAGE" ]]; then - export VERSION_SUFFIX="$PYTORCH_VERSION_SUFFIX" - export WHEEL_DIR="$CU_VERSION/" - fi - fi - - # Now work out the CUDA settings - case "$CU_VERSION" in - cu115) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.5" - else - export CUDA_HOME=/usr/local/cuda-11.5/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu113) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.3" - else - export CUDA_HOME=/usr/local/cuda-11.3/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu112) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.2" - else - export CUDA_HOME=/usr/local/cuda-11.2/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu111) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.1" - else - export CUDA_HOME=/usr/local/cuda-11.1/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" - ;; - cu110) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.0" - else - export CUDA_HOME=/usr/local/cuda-11.0/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0" - ;; - cu102) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.2" - else - export CUDA_HOME=/usr/local/cuda-10.2/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" - ;; - cu101) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.1" - else - export CUDA_HOME=/usr/local/cuda-10.1/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" - ;; - cu100) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v10.0" - else - export CUDA_HOME=/usr/local/cuda-10.0/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" - ;; - cu92) - if [[ "$OSTYPE" == "msys" ]]; then - export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v9.2" - else - export CUDA_HOME=/usr/local/cuda-9.2/ - fi - export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0" - ;; - cpu) - ;; - rocm*) - export FORCE_CUDA=1 - ;; - *) - echo "Unrecognized CU_VERSION=$CU_VERSION" - exit 1 - ;; - esac - if [[ -n "$CUDA_HOME" ]]; then - # Adds nvcc binary to the search path so that CMake's `find_package(CUDA)` will pick the right one - export PATH="$CUDA_HOME/bin:$PATH" - export FORCE_CUDA=1 - fi -} - -# Populate build version if necessary, and add version suffix -# -# Inputs: -# BUILD_VERSION (e.g., 0.2.0 or empty) -# VERSION_SUFFIX (e.g., +cpu) -# -# Outputs: -# BUILD_VERSION (e.g., 0.2.0.dev20190807+cpu) -# -# Fill BUILD_VERSION if it doesn't exist already with a nightly string -# Usage: setup_build_version 0.2.0 -setup_build_version() { - if [[ -z "$BUILD_VERSION" ]]; then - export BUILD_VERSION="$1.dev$(date "+%Y%m%d")$VERSION_SUFFIX" - else - export BUILD_VERSION="$BUILD_VERSION$VERSION_SUFFIX" - fi - - # Set build version based on tag if on tag - if [[ -n "${CIRCLE_TAG}" ]]; then - # Strip tag - export BUILD_VERSION="$(echo "${CIRCLE_TAG}" | sed -e 's/^v//' -e 's/-.*$//')${VERSION_SUFFIX}" - fi -} - -# Set some useful variables for OS X, if applicable -setup_macos() { - if [[ "$(uname)" == Darwin ]]; then - export MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ - fi -} - - -# Top-level entry point for things every package will need to do -# -# Usage: setup_env 0.2.0 -setup_env() { - setup_cuda - setup_build_version "$1" - setup_macos -} - -# Function to retry functions that sometimes timeout or have flaky failures -retry () { - $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) -} - -# Inputs: -# PYTHON_VERSION (3.7, 3.8, 3.9) -# UNICODE_ABI (bool) -# -# Outputs: -# PATH modified to put correct Python version in PATH -# -# Precondition: If Linux, you are in a soumith/manylinux-cuda* Docker image -setup_wheel_python() { - if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then - eval "$(conda shell.bash hook)" - conda env remove -n "env$PYTHON_VERSION" || true - conda create ${CONDA_CHANNEL_FLAGS} -yn "env$PYTHON_VERSION" python="$PYTHON_VERSION" - conda activate "env$PYTHON_VERSION" - # Install libpng from Anaconda (defaults) - conda install ${CONDA_CHANNEL_FLAGS} libpng "jpeg<=9b" -y - else - # Install native CentOS libJPEG, freetype and GnuTLS - yum install -y libjpeg-turbo-devel freetype gnutls - case "$PYTHON_VERSION" in - 3.7) python_abi=cp37-cp37m ;; - 3.8) python_abi=cp38-cp38 ;; - 3.9) python_abi=cp39-cp39 ;; - 3.10) python_abi=cp310-cp310 ;; - *) - echo "Unrecognized PYTHON_VERSION=$PYTHON_VERSION" - exit 1 - ;; - esac - # Download all the dependencies required to compile image and video_reader - # extensions - - mkdir -p ext_libraries - pushd ext_libraries - popd - export PATH="/opt/python/$python_abi/bin:$(pwd)/ext_libraries/bin:$PATH" - fi -} - -# Install with pip a bit more robustly than the default -pip_install() { - retry pip install --progress-bar off "$@" -} - -# Install torch with pip, respecting PYTORCH_VERSION, and record the installed -# version into PYTORCH_VERSION, if applicable -setup_pip_pytorch_version() { - if [[ -z "$PYTORCH_VERSION" ]]; then - # Install latest prerelease version of torch, per our nightlies, consistent - # with the requested cuda version - pip_install --pre torch -f "https://download.pytorch.org/whl/nightly/${WHEEL_DIR}torch_nightly.html" - if [[ "$CUDA_VERSION" == "cpu" ]]; then - # CUDA and CPU are ABI compatible on the CPU-only parts, so strip - # in this case - export PYTORCH_VERSION="$(pip show torch | grep ^Version: | sed 's/Version: *//' | sed 's/+.\+//')" - else - export PYTORCH_VERSION="$(pip show torch | grep ^Version: | sed 's/Version: *//')" - fi - else - pip_install "torch==$PYTORCH_VERSION$PYTORCH_VERSION_SUFFIX" \ - -f "https://download.pytorch.org/whl/${CU_VERSION}/torch_stable.html" \ - -f "https://download.pytorch.org/whl/${UPLOAD_CHANNEL}/${CU_VERSION}/torch_${UPLOAD_CHANNEL}.html" - fi -} - -# Fill PYTORCH_VERSION with the latest conda nightly version, and -# CONDA_CHANNEL_FLAGS with appropriate flags to retrieve these versions -# -# You MUST have populated PYTORCH_VERSION_SUFFIX before hand. -setup_conda_pytorch_constraint() { - if [[ -z "$PYTORCH_VERSION" ]]; then - export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch-nightly -c pytorch" - export PYTORCH_VERSION="$(conda search --json 'pytorch[channel=pytorch-nightly]' | \ - python -c "import os, sys, json, re; cuver = os.environ.get('CU_VERSION'); \ - cuver_1 = cuver.replace('cu', 'cuda') if cuver != 'cpu' else cuver; \ - cuver_2 = (cuver[:-1] + '.' + cuver[-1]).replace('cu', 'cuda') if cuver != 'cpu' else cuver; \ - print(re.sub(r'\\+.*$', '', \ - [x['version'] for x in json.load(sys.stdin)['pytorch'] \ - if (x['platform'] == 'darwin' or cuver_1 in x['fn'] or cuver_2 in x['fn']) \ - and 'py' + os.environ['PYTHON_VERSION'] in x['fn']][-1]))")" - if [[ -z "$PYTORCH_VERSION" ]]; then - echo "PyTorch version auto detection failed" - echo "No package found for CU_VERSION=$CU_VERSION and PYTHON_VERSION=$PYTHON_VERSION" - exit 1 - fi - else - export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch -c pytorch-${UPLOAD_CHANNEL}" - fi - if [[ "$CU_VERSION" == cpu ]]; then - export CONDA_PYTORCH_BUILD_CONSTRAINT="- pytorch==$PYTORCH_VERSION${PYTORCH_VERSION_SUFFIX}" - export CONDA_PYTORCH_CONSTRAINT="- pytorch==$PYTORCH_VERSION" - else - export CONDA_PYTORCH_BUILD_CONSTRAINT="- pytorch==${PYTORCH_VERSION}${PYTORCH_VERSION_SUFFIX}" - export CONDA_PYTORCH_CONSTRAINT="- pytorch==${PYTORCH_VERSION}${PYTORCH_VERSION_SUFFIX}" - fi - if [[ "$OSTYPE" == msys && "$CU_VERSION" == cu92 ]]; then - export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c defaults -c numba/label/dev" - fi -} - -# Translate CUDA_VERSION into CUDA_CUDATOOLKIT_CONSTRAINT -setup_conda_cudatoolkit_constraint() { - export CONDA_BUILD_VARIANT="cuda" - if [[ "$(uname)" == Darwin ]]; then - export CONDA_BUILD_VARIANT="cpu" - else - case "$CU_VERSION" in - cu115) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.5,<11.6 # [not osx]" - ;; - cu113) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.3,<11.4 # [not osx]" - ;; - cu112) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.2,<11.3 # [not osx]" - ;; - cu111) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.1,<11.2 # [not osx]" - ;; - cu110) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.0,<11.1 # [not osx]" - ;; - cu102) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.2,<10.3 # [not osx]" - ;; - cu101) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.1,<10.2 # [not osx]" - ;; - cu100) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=10.0,<10.1 # [not osx]" - ;; - cu92) - export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=9.2,<9.3 # [not osx]" - ;; - cpu) - export CONDA_CUDATOOLKIT_CONSTRAINT="" - export CONDA_BUILD_VARIANT="cpu" - ;; - *) - echo "Unrecognized CU_VERSION=$CU_VERSION" - exit 1 - ;; - esac - fi -} - -setup_conda_cudatoolkit_plain_constraint() { - export CONDA_BUILD_VARIANT="cuda" - export CMAKE_USE_CUDA=1 - if [[ "$(uname)" == Darwin ]]; then - export CONDA_BUILD_VARIANT="cpu" - export CMAKE_USE_CUDA=0 - else - case "$CU_VERSION" in - cu115) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.5" - ;; - cu113) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.3" - ;; - cu112) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.2" - ;; - cu111) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=11.1" - ;; - cu102) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.2" - ;; - cu101) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.1" - ;; - cu100) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=10.0" - ;; - cu92) - export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit=9.2" - ;; - cpu) - export CONDA_CUDATOOLKIT_CONSTRAINT="" - export CONDA_BUILD_VARIANT="cpu" - export CMAKE_USE_CUDA=0 - ;; - *) - echo "Unrecognized CU_VERSION=$CU_VERSION" - exit 1 - ;; - esac - fi -} - -# Build the proper compiler package before building the final package -setup_visual_studio_constraint() { - if [[ "$OSTYPE" == "msys" ]]; then - export VSTOOLCHAIN_PACKAGE=vs$VC_YEAR - conda build $CONDA_CHANNEL_FLAGS --no-anaconda-upload packaging/$VSTOOLCHAIN_PACKAGE - cp packaging/$VSTOOLCHAIN_PACKAGE/conda_build_config.yaml packaging/torchvision/conda_build_config.yaml - fi -} - -setup_junit_results_folder() { - if [[ "$CI" == "true" ]]; then - export CONDA_PYTORCH_BUILD_RESULTS_DIRECTORY="${SOURCE_ROOT_DIR}/build_results/results.xml" - fi -} - - -download_copy_ffmpeg() { - if [[ "$OSTYPE" == "msys" ]]; then - # conda install -yq ffmpeg=4.2 -c pytorch - # curl -L -q https://anaconda.org/pytorch/ffmpeg/4.3/download/win-64/ffmpeg-4.3-ha925a31_0.tar.bz2 --output ffmpeg-4.3-ha925a31_0.tar.bz2 - # bzip2 --decompress --stdout ffmpeg-4.3-ha925a31_0.tar.bz2 | tar -x --file=- - # cp Library/bin/*.dll ../torchvision - echo "FFmpeg is disabled currently on Windows" - else - if [[ "$(uname)" == Darwin ]]; then - conda install -yq ffmpeg=4.2 -c pytorch - conda install -yq wget - else - # pushd ext_libraries - # wget -q https://anaconda.org/pytorch/ffmpeg/4.2/download/linux-64/ffmpeg-4.2-hf484d3e_0.tar.bz2 - # tar -xjvf ffmpeg-4.2-hf484d3e_0.tar.bz2 - # rm -rf ffmpeg-4.2-hf484d3e_0.tar.bz2 - # ldconfig - # which ffmpeg - # popd - echo "FFmpeg is disabled currently on Linux" - fi - fi -} diff --git a/functorch/packaging/windows/internal/cuda_install.bat b/functorch/packaging/windows/internal/cuda_install.bat deleted file mode 100644 index 41960224ebae..000000000000 --- a/functorch/packaging/windows/internal/cuda_install.bat +++ /dev/null @@ -1,264 +0,0 @@ -@echo on - -if "%CU_VERSION%" == "cpu" ( - echo Skipping for CPU builds - exit /b 0 -) - -set SRC_DIR=%~dp0\.. - -if not exist "%SRC_DIR%\temp_build" mkdir "%SRC_DIR%\temp_build" - -rem in unit test workflow, we get CUDA_VERSION, for example 11.1 -if defined CUDA_VERSION ( - set CUDA_VER=%CUDA_VERSION:.=% -) else ( - set CUDA_VER=%CU_VERSION:cu=% -) - -set /a CUDA_VER=%CU_VERSION:cu=% -set CUDA_VER_MAJOR=%CUDA_VER:~0,-1% -set CUDA_VER_MINOR=%CUDA_VER:~-1,1% -set CUDA_VERSION_STR=%CUDA_VER_MAJOR%.%CUDA_VER_MINOR% - - -if %CUDA_VER% EQU 92 goto cuda92 -if %CUDA_VER% EQU 100 goto cuda100 -if %CUDA_VER% EQU 101 goto cuda101 -if %CUDA_VER% EQU 102 goto cuda102 -if %CUDA_VER% EQU 110 goto cuda110 -if %CUDA_VER% EQU 111 goto cuda111 -if %CUDA_VER% EQU 112 goto cuda112 -if %CUDA_VER% EQU 113 goto cuda113 -if %CUDA_VER% EQU 115 goto cuda115 - - -echo CUDA %CUDA_VERSION_STR% is not supported -exit /b 1 - -:cuda92 -if not exist "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_9.2.148_win10.exe --output "%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_9.2.148_win10.exe" - set "ARGS=nvcc_9.2 cuobjdump_9.2 nvprune_9.2 cupti_9.2 cublas_9.2 cublas_dev_9.2 cudart_9.2 cufft_9.2 cufft_dev_9.2 curand_9.2 curand_dev_9.2 cusolver_9.2 cusolver_dev_9.2 cusparse_9.2 cusparse_dev_9.2 nvgraph_9.2 nvgraph_dev_9.2 npp_9.2 npp_dev_9.2 nvrtc_9.2 nvrtc_dev_9.2 nvml_dev_9.2" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-9.2-windows10-x64-v7.2.1.38.zip --output "%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-9.2-windows10-x64-v7.2.1.38.zip" -) - -goto cuda_common - -:cuda100 - -if not exist "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cuda_10.0.130_411.31_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.0.130_411.31_win10.exe" - set "ARGS=nvcc_10.0 cuobjdump_10.0 nvprune_10.0 cupti_10.0 cublas_10.0 cublas_dev_10.0 cudart_10.0 cufft_10.0 cufft_dev_10.0 curand_10.0 curand_dev_10.0 cusolver_10.0 cusolver_dev_10.0 cusparse_10.0 cusparse_dev_10.0 nvgraph_10.0 nvgraph_dev_10.0 npp_10.0 npp_dev_10.0 nvrtc_10.0 nvrtc_dev_10.0 nvml_dev_10.0" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/win2016/cudnn-10.0-windows10-x64-v7.4.1.5.zip --output "%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.0-windows10-x64-v7.4.1.5.zip" -) - -goto cuda_common - -:cuda101 - -if not exist "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.1.243_426.00_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.1.243_426.00_win10.exe" - set "ARGS=nvcc_10.1 cuobjdump_10.1 nvprune_10.1 cupti_10.1 cublas_10.1 cublas_dev_10.1 cudart_10.1 cufft_10.1 cufft_dev_10.1 curand_10.1 curand_dev_10.1 cusolver_10.1 cusolver_dev_10.1 cusparse_10.1 cusparse_dev_10.1 nvgraph_10.1 nvgraph_dev_10.1 npp_10.1 npp_dev_10.1 nvjpeg_10.1 nvjpeg_dev_10.1 nvrtc_10.1 nvrtc_dev_10.1 nvml_dev_10.1" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.1-windows10-x64-v7.6.4.38.zip --output "%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.1-windows10-x64-v7.6.4.38.zip" -) - -goto cuda_common - -:cuda102 - -if not exist "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_10.2.89_441.22_win10.exe --output "%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_10.2.89_441.22_win10.exe" - set "ARGS=nvcc_10.2 cuobjdump_10.2 nvprune_10.2 cupti_10.2 cublas_10.2 cublas_dev_10.2 cudart_10.2 cufft_10.2 cufft_dev_10.2 curand_10.2 curand_dev_10.2 cusolver_10.2 cusolver_dev_10.2 cusparse_10.2 cusparse_dev_10.2 nvgraph_10.2 nvgraph_dev_10.2 npp_10.2 npp_dev_10.2 nvjpeg_10.2 nvjpeg_dev_10.2 nvrtc_10.2 nvrtc_dev_10.2 nvml_dev_10.2" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-10.2-windows10-x64-v7.6.5.32.zip --output "%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-10.2-windows10-x64-v7.6.5.32.zip" -) - -rem The below only for cu102, if it's used in other version, e.g. cu111, torch.cuda.is_availabe() would be False. -if not exist "%SRC_DIR%\temp_build\gpu_driver_dlls.7z" ( - curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "%SRC_DIR%\temp_build\gpu_driver_dlls.zip" - if errorlevel 1 exit /b 1 -) - -echo Installing GPU driver DLLs -7z x %SRC_DIR%\temp_build\gpu_driver_dlls.zip -aoa -o"C:\Windows\System32" - -goto cuda_common - -:cuda110 - -if not exist "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.0.2_451.48_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.0.2_451.48_win10.exe" - set "ARGS=nvcc_11.0 cuobjdump_11.0 nvprune_11.0 nvprof_11.0 cupti_11.0 cublas_11.0 cublas_dev_11.0 cudart_11.0 cufft_11.0 cufft_dev_11.0 curand_11.0 curand_dev_11.0 cusolver_11.0 cusolver_dev_11.0 cusparse_11.0 cusparse_dev_11.0 npp_11.0 npp_dev_11.0 nvjpeg_11.0 nvjpeg_dev_11.0 nvrtc_11.0 nvrtc_dev_11.0 nvml_dev_11.0" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.0-windows-x64-v8.0.4.30.zip --output "%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.0-windows-x64-v8.0.4.30.zip" -) - -goto cuda_common - -:cuda111 - -if not exist "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.1.1_456.81_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.1.1_456.81_win10.exe" - set "ARGS=nvcc_11.1 cuobjdump_11.1 nvprune_11.1 nvprof_11.1 cupti_11.1 cublas_11.1 cublas_dev_11.1 cudart_11.1 cufft_11.1 cufft_dev_11.1 curand_11.1 curand_dev_11.1 cusolver_11.1 cusolver_dev_11.1 cusparse_11.1 cusparse_dev_11.1 npp_11.1 npp_dev_11.1 nvjpeg_11.1 nvjpeg_dev_11.1 nvrtc_11.1 nvrtc_dev_11.1 nvml_dev_11.1" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cudnn-11.1-windows-x64-v8.0.5.39.zip --output "%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.1-windows-x64-v8.0.5.39.zip" -) - -goto cuda_common - -:cuda112 - -if not exist "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" ( - curl -k -L https://ossci-windows.s3.amazonaws.com/cuda_11.2.0_460.89_win10.exe --output "%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\cuda_11.2.0_460.89_win10.exe" - set "ARGS=nvcc_11.2 cuobjdump_11.2 nvprune_11.2 nvprof_11.2 cupti_11.2 cublas_11.2 cublas_dev_11.2 cudart_11.2 cufft_11.2 cufft_dev_11.2 curand_11.2 curand_dev_11.2 cusolver_11.2 cusolver_dev_11.2 cusparse_11.2 cusparse_dev_11.2 npp_11.2 npp_dev_11.2 nvjpeg_11.2 nvjpeg_dev_11.2 nvrtc_11.2 nvrtc_dev_11.2 nvml_dev_11.2" -) - -if not exist "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" ( - curl -k -L http://s3.amazonaws.com/ossci-windows/cudnn-11.2-windows-x64-v8.1.0.77.zip --output "%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\cudnn-11.2-windows-x64-v8.1.0.77.zip" -) - -goto cuda_common - -:cuda113 - -set CUDA_INSTALL_EXE=cuda_11.3.0_465.89_win10.exe -if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( - curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - set "ARGS=thrust_11.3 nvcc_11.3 cuobjdump_11.3 nvprune_11.3 nvprof_11.3 cupti_11.3 cublas_11.3 cublas_dev_11.3 cudart_11.3 cufft_11.3 cufft_dev_11.3 curand_11.3 curand_dev_11.3 cusolver_11.3 cusolver_dev_11.3 cusparse_11.3 cusparse_dev_11.3 npp_11.3 npp_dev_11.3 nvjpeg_11.3 nvjpeg_dev_11.3 nvrtc_11.3 nvrtc_dev_11.3 nvml_dev_11.3" - -) - -set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip -if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( - curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" -) - -goto cuda_common - -:cuda115 - -set CUDA_INSTALL_EXE=cuda_11.5.0_496.13_win10.exe -if not exist "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" ( - curl -k -L "https://ossci-windows.s3.amazonaws.com/%CUDA_INSTALL_EXE%" --output "%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - if errorlevel 1 exit /b 1 - set "CUDA_SETUP_FILE=%SRC_DIR%\temp_build\%CUDA_INSTALL_EXE%" - set "ARGS=thrust_11.5 nvcc_11.5 cuobjdump_11.5 nvprune_11.5 nvprof_11.5 cupti_11.5 cublas_11.5 cublas_dev_11.5 cudart_11.5 cufft_11.5 cufft_dev_11.5 curand_11.5 curand_dev_11.5 cusolver_11.5 cusolver_dev_11.5 cusparse_11.5 cusparse_dev_11.5 npp_11.5 npp_dev_11.5 nvrtc_11.5 nvrtc_dev_11.5 nvml_dev_11.5" -) - -set CUDNN_INSTALL_ZIP=cudnn-11.3-windows-x64-v8.2.0.53.zip -if not exist "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" ( - curl -k -L "http://s3.amazonaws.com/ossci-windows/%CUDNN_INSTALL_ZIP%" --output "%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" - if errorlevel 1 exit /b 1 - set "CUDNN_SETUP_FILE=%SRC_DIR%\temp_build\%CUDNN_INSTALL_ZIP%" -) - -goto cuda_common - -:cuda_common - -if not exist "%SRC_DIR%\temp_build\NvToolsExt.7z" ( - curl -k -L https://www.dropbox.com/s/9mcolalfdj4n979/NvToolsExt.7z?dl=1 --output "%SRC_DIR%\temp_build\NvToolsExt.7z" - if errorlevel 1 exit /b 1 -) - -echo Installing CUDA toolkit... -7z x %CUDA_SETUP_FILE% -o"%SRC_DIR%\temp_build\cuda" -pushd "%SRC_DIR%\temp_build\cuda" -sc config wuauserv start= disabled -sc stop wuauserv -sc query wuauserv - -start /wait setup.exe -s %ARGS% -loglevel:6 -log:"%cd%/cuda_install_logs" -echo %errorlevel% - -popd - -echo Installing VS integration... -rem It's for VS 2019 -if "%CUDA_VER_MAJOR%" == "10" ( - xcopy /Y "%SRC_DIR%\temp_build\cuda\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" -) -if "%CUDA_VER_MAJOR%" == "11" ( - xcopy /Y "%SRC_DIR%\temp_build\cuda\visual_studio_integration\CUDAVisualStudioIntegration\extras\visual_studio_integration\MSBuildExtensions\*.*" "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations" -) - -echo Installing NvToolsExt... -7z x %SRC_DIR%\temp_build\NvToolsExt.7z -o"%SRC_DIR%\temp_build\NvToolsExt" -mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" -mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" -mkdir "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" -xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\bin\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" -xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\include\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\include" -xcopy /Y "%SRC_DIR%\temp_build\NvToolsExt\lib\x64\*.*" "%ProgramFiles%\NVIDIA Corporation\NvToolsExt\lib\x64" - -echo Setting up environment... -set "PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin;%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\libnvvp;%PATH%" -set "CUDA_PATH=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" -set "CUDA_PATH_V%CUDA_VER_MAJOR%_%CUDA_VER_MINOR%=%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%" -set "NVTOOLSEXT_PATH=%ProgramFiles%\NVIDIA Corporation\NvToolsExt\bin\x64" - -if not exist "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin\nvcc.exe" ( - echo CUDA %CUDA_VERSION_STR% installed failed. - echo --------- RunDll32.exe.log - type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.RunDll32.exe.log" - echo --------- setup.exe.log ------- - type "%SRC_DIR%\temp_build\cuda\cuda_install_logs\LOG.setup.exe.log" - exit /b 1 -) - -echo Installing cuDNN... -7z x %CUDNN_SETUP_FILE% -o"%SRC_DIR%\temp_build\cudnn" -xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\bin\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\bin" -xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\lib\x64\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\lib\x64" -xcopy /Y "%SRC_DIR%\temp_build\cudnn\cuda\include\*.*" "%ProgramFiles%\NVIDIA GPU Computing Toolkit\CUDA\v%CUDA_VERSION_STR%\include" - -echo Cleaning temp files -rd /s /q "%SRC_DIR%\temp_build" || ver > nul diff --git a/functorch/packaging/windows/internal/driver_update.bat b/functorch/packaging/windows/internal/driver_update.bat deleted file mode 100644 index 00b43affc01c..000000000000 --- a/functorch/packaging/windows/internal/driver_update.bat +++ /dev/null @@ -1,25 +0,0 @@ -set "DRIVER_DOWNLOAD_LINK=https://ossci-windows.s3.amazonaws.com/461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe" -curl --retry 3 -kL %DRIVER_DOWNLOAD_LINK% --output 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -if errorlevel 1 exit /b 1 - -start /wait 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe -s -noreboot -if errorlevel 1 exit /b 1 - -del 461.09-data-center-tesla-desktop-winserver-2019-2016-international.exe || ver > NUL - -setlocal EnableDelayedExpansion -set NVIDIA_GPU_EXISTS=0 -for /F "delims=" %%i in ('wmic path win32_VideoController get name') do ( - set GPUS=%%i - if not "x!GPUS:NVIDIA=!" == "x!GPUS!" ( - SET NVIDIA_GPU_EXISTS=1 - goto gpu_check_end - ) -) -:gpu_check_end -endlocal & set NVIDIA_GPU_EXISTS=%NVIDIA_GPU_EXISTS% - -if "%NVIDIA_GPU_EXISTS%" == "0" ( - echo "CUDA Driver installation Failed" - exit /b 1 -) diff --git a/functorch/packaging/windows/internal/vc_env_helper.bat b/functorch/packaging/windows/internal/vc_env_helper.bat deleted file mode 100644 index e85a372f93d5..000000000000 --- a/functorch/packaging/windows/internal/vc_env_helper.bat +++ /dev/null @@ -1,43 +0,0 @@ -@echo on - -set VC_VERSION_LOWER=16 -set VC_VERSION_UPPER=17 -if "%VC_YEAR%" == "2017" ( - set VC_VERSION_LOWER=15 - set VC_VERSION_UPPER=16 -) - -for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( - if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( - set "VS15INSTALLDIR=%%i" - set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" - goto vswhere - ) -) - -:vswhere -if "%VSDEVCMD_ARGS%" == "" ( - call "%VS15VCVARSALL%" x64 || exit /b 1 -) else ( - call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 -) - -@echo on - -set DISTUTILS_USE_SDK=1 - -set args=%1 -shift -:start -if [%1] == [] goto done -set args=%args% %1 -shift -goto start - -:done -if "%args%" == "" ( - echo Usage: vc_env_helper.bat [command] [args] - echo e.g. vc_env_helper.bat cl /c test.cpp -) - -%args% || exit /b 1 diff --git a/functorch/packaging/windows/internal/vc_install_helper.sh b/functorch/packaging/windows/internal/vc_install_helper.sh deleted file mode 100644 index cdae18065b9f..000000000000 --- a/functorch/packaging/windows/internal/vc_install_helper.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -set -ex - -if [[ "$CU_VERSION" == "cu92" ]]; then - export VC_YEAR=2017 - export VSDEVCMD_ARGS="-vcvars_ver=14.13" - powershell packaging/windows/internal/vs2017_install.ps1 -elif [[ "$CU_VERSION" == "cu100" ]]; then - export VC_YEAR=2017 - export VSDEVCMD_ARGS="" - powershell packaging/windows/internal/vs2017_install.ps1 -else - export VC_YEAR=2019 - export VSDEVCMD_ARGS="" -fi diff --git a/test/distributed/_composable/test_fully_shard.py b/test/distributed/_composable/test_fully_shard.py index 27e0fb855fba..ba08deeafcdf 100644 --- a/test/distributed/_composable/test_fully_shard.py +++ b/test/distributed/_composable/test_fully_shard.py @@ -1,7 +1,6 @@ # Owner(s): ["oncall: distributed"] import copy -import functools import sys from typing import Any, Tuple @@ -12,7 +11,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp._common_utils import _is_fsdp_flattened from torch.distributed.fsdp._runtime_utils import _root_pre_forward -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest from torch.testing._internal.common_utils import ( @@ -62,10 +61,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return z @staticmethod - def auto_wrap_policy(): - return functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls={SubModel} - ) + def policy(): + return ModuleWrapPolicy({SubModel}) def get_input(self, device=torch.device) -> Tuple[Any, ...]: return (torch.randn((8, 5), device=device),) @@ -85,13 +82,13 @@ def test_auto_wrap_policy(self): local_model = Model(device=torch.device("cuda")) fsdp_wrapped_model = FSDP( copy.deepcopy(local_model), - auto_wrap_policy=Model.auto_wrap_policy(), + auto_wrap_policy=Model.policy(), use_orig_params=True, ) composable_module = copy.deepcopy(local_model) fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), ) # Check that the composable module has the same names as the local @@ -138,7 +135,7 @@ def test_device_id(self): assert param.device == cpu_device fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), device_id=self.rank, ) for param in composable_module.parameters(): @@ -157,12 +154,12 @@ def test_sync_module_states(self): param.zero_() fsdp_wrapped_model = FSDP( copy.deepcopy(local_model), - auto_wrap_policy=Model.auto_wrap_policy(), + auto_wrap_policy=Model.policy(), use_orig_params=True, ) fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), sync_module_states=True, ) for (composable_param, fsdp_wrapped_param) in zip( @@ -197,13 +194,13 @@ def _param_init_fn(module: nn.Module): composable_module = Model(device="meta") fsdp_wrapped_model = FSDP( Model(device="meta"), - auto_wrap_policy=Model.auto_wrap_policy(), + auto_wrap_policy=Model.policy(), param_init_fn=_param_init_fn, use_orig_params=True, ) fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), param_init_fn=_param_init_fn, ) for (composable_param, fsdp_wrapped_param) in zip( @@ -227,13 +224,13 @@ def test_training(self): local_model = Model(device=device) fsdp_wrapped_model = FSDP( copy.deepcopy(local_model), - auto_wrap_policy=Model.auto_wrap_policy(), + auto_wrap_policy=Model.policy(), use_orig_params=True, ) composable_module = copy.deepcopy(local_model) fully_shard( composable_module, - auto_wrap_policy=Model.auto_wrap_policy(), + policy=Model.policy(), ) del local_model # not needed anymore LR = 1e-2 diff --git a/test/distributed/fsdp/test_fsdp_clip_grad_norm.py b/test/distributed/fsdp/test_fsdp_clip_grad_norm.py index ddba50a9e456..1a742da889ac 100644 --- a/test/distributed/fsdp/test_fsdp_clip_grad_norm.py +++ b/test/distributed/fsdp/test_fsdp_clip_grad_norm.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import functools import itertools import sys from typing import Union @@ -8,11 +7,12 @@ import torch import torch.nn as nn from torch import distributed as dist +from torch.distributed.fsdp import ShardingStrategy from torch.distributed.fsdp.fully_sharded_data_parallel import ( CPUOffload, FullyShardedDataParallel as FSDP, ) -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu @@ -43,10 +43,6 @@ class TestClipGradNorm(FSDPTest): """Tests :meth:`FullyShardedDataParallel.clip_grad_norm_`.""" - @property - def world_size(self) -> int: - return 2 - @skip_if_lt_x_gpu(2) def test_non_root(self): """ @@ -81,6 +77,11 @@ def test_ddp_parity(self): { "max_norm": [1, 2.5], "norm_type": [1, 2, float("inf")], + "sharding_strategy": [ + ShardingStrategy.FULL_SHARD, + ShardingStrategy.NO_SHARD, + "mixed_strategy", + ], "use_orig_params": [False, True], "offload_params": [False, True], }, @@ -91,8 +92,9 @@ def _test_ddp_parity( self, max_norm: Union[float, int], norm_type: Union[float, int], - offload_params: bool, + sharding_strategy: Union[ShardingStrategy, str], use_orig_params: bool, + offload_params: bool, ): local_model = TransformerWithSharedParams.init( self.process_group, @@ -102,23 +104,52 @@ def _test_ddp_parity( ) ddp_model = DDP(local_model, device_ids=[self.rank]) fsdp_kwargs = { - "auto_wrap_policy": functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - TransformerEncoderLayer, - TransformerDecoderLayer, - }, - ), "cpu_offload": CPUOffload(offload_params=offload_params), "use_orig_params": use_orig_params, } - fsdp_model = TransformerWithSharedParams.init( - self.process_group, - FSDPInitMode.RECURSIVE, - CUDAInitMode.CUDA_BEFORE, - deterministic=True, - fsdp_kwargs=fsdp_kwargs, - ) + if sharding_strategy == "mixed_strategy": + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.NO_FSDP, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + ) + # Apply `NO_SHARD` to the encoder + fsdp_model.transformer.encoder = FSDP( + fsdp_model.transformer.encoder, + sharding_strategy=ShardingStrategy.NO_SHARD, + **fsdp_kwargs, + ) + # Apply `FULL_SHARD` to the decoder + fsdp_model.transformer.decoder = FSDP( + fsdp_model.transformer.decoder, + sharding_strategy=ShardingStrategy.FULL_SHARD, + **fsdp_kwargs, + ) + # TODO: FSDP's `clip_grad_norm_()` is not a static method, so we + # must make the root module an FSDP instance + fsdp_model = FSDP( + fsdp_model, sharding_strategy=ShardingStrategy.FULL_SHARD, **fsdp_kwargs + ) + else: + fsdp_kwargs.update( + { + "sharding_strategy": sharding_strategy, + "auto_wrap_policy": ModuleWrapPolicy( + { + TransformerEncoderLayer, + TransformerDecoderLayer, + } + ), + } + ) + fsdp_model = TransformerWithSharedParams.init( + self.process_group, + FSDPInitMode.RECURSIVE, + CUDAInitMode.CUDA_BEFORE, + deterministic=True, + fsdp_kwargs=fsdp_kwargs, + ) LR = 1e-2 ddp_optim = torch.optim.Adam(ddp_model.parameters(), lr=LR) fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR) @@ -127,7 +158,10 @@ def _test_ddp_parity( inp = ddp_model.module.get_input(device) for model in (ddp_model, fsdp_model): out = model(*inp) - loss = model.module.get_loss(inp, out) + if isinstance(model, (DDP, FSDP)): + loss = model.module.get_loss(inp, out) + else: + loss = model.get_loss(inp, out) loss.backward() # Multiply gradients by a large factor to ensure that gradients will diff --git a/test/distributed/fsdp/test_fsdp_misc.py b/test/distributed/fsdp/test_fsdp_misc.py index 79ed6da6240f..8c972f851563 100644 --- a/test/distributed/fsdp/test_fsdp_misc.py +++ b/test/distributed/fsdp/test_fsdp_misc.py @@ -15,7 +15,11 @@ FullyShardedDataParallel as FSDP, ShardingStrategy, ) -from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ( + always_wrap_policy, + ModuleWrapPolicy, + transformer_auto_wrap_policy, +) from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( @@ -211,10 +215,20 @@ def forward(self, x, y): def test_device_id_auto_wrap(self): """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all nested FSDP instances.""" - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, + self.run_subtests( + {"use_callable": [False, True]}, + self._test_device_id_auto_wrap, ) + + def _test_device_id_auto_wrap(self, use_callable: bool): + module_classes = {TransformerEncoderLayer, TransformerDecoderLayer} + if use_callable: + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=module_classes, + ) + else: + auto_wrap_policy = ModuleWrapPolicy(module_classes) fsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, "device_id": torch.cuda.current_device(), diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index ba51ae66ed1b..6fafc8e8fdf4 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -26,7 +26,7 @@ ) from torch.distributed.fsdp._shard_utils import _gather_state_dict from torch.distributed.fsdp._unshard_param_utils import FLAT_PARAM -from torch.distributed.fsdp.wrap import enable_wrap, transformer_auto_wrap_policy, wrap +from torch.distributed.fsdp.wrap import enable_wrap, ModuleWrapPolicy, wrap from torch.nn import Linear, Module, TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel import DistributedDataParallel from torch.optim import SGD @@ -350,9 +350,8 @@ def test_state_dict_with_manual_ac_wrapper( @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) def test_state_dict_with_shared_parameters(self, state_dict_type): - auto_wrap_policy = partial( - transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer} ) model_creator = partial( TransformerWithSharedParams.init, @@ -377,9 +376,8 @@ def test_state_dict_rank0_offload_save_load_flow(self, use_orig_params: bool): """Tests saving a model checkpoint only on rank 0 and loading it only on rank 0 with ``sync_module_states=True`` to emulate the workflow to avoid redundant CPU memory usage.""" - auto_wrap_policy = partial( - transformer_auto_wrap_policy, - transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer} ) fsdp_kwargs = { "auto_wrap_policy": auto_wrap_policy, diff --git a/test/distributed/fsdp/test_fsdp_use_orig_params.py b/test/distributed/fsdp/test_fsdp_use_orig_params.py index 24829ff408d9..0f5ffa564c2d 100644 --- a/test/distributed/fsdp/test_fsdp_use_orig_params.py +++ b/test/distributed/fsdp/test_fsdp_use_orig_params.py @@ -15,7 +15,7 @@ ShardingStrategy, ) from torch.distributed.fsdp._common_utils import clean_tensor_name -from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import skip_if_lt_x_gpu @@ -117,12 +117,11 @@ def _get_fsdp_transformer_and_optim( # combination with the parameter group construction, ensures different # hyperparameter settings within one `FlatParameter` fsdp_kwargs = { - "auto_wrap_policy": functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ + "auto_wrap_policy": ModuleWrapPolicy( + { TransformerEncoderLayer, TransformerDecoderLayer, - }, + } ), "use_orig_params": True, "sharding_strategy": sharding_strategy, diff --git a/test/distributed/fsdp/test_utils.py b/test/distributed/fsdp/test_utils.py index e797325ccbc9..37c52547e847 100644 --- a/test/distributed/fsdp/test_utils.py +++ b/test/distributed/fsdp/test_utils.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import functools import random import sys import unittest @@ -14,7 +13,7 @@ from torch import distributed as dist from torch.distributed.fsdp._utils import _apply_to_tensors from torch.distributed.fsdp._wrap_utils import _get_submodule_to_states -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.distributed.utils import _replace_by_prefix from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -173,9 +172,7 @@ def test_module_wrap_policy(self): # Compute the mapping from submodule to states according to a logical # module wrap policy module_classes = (nn.Sequential,) - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, transformer_layer_cls=set(module_classes) - ) + auto_wrap_policy = ModuleWrapPolicy(set(module_classes)) submodule_to_states = _get_submodule_to_states( model, auto_wrap_policy, set(), set() ) diff --git a/test/distributed/fsdp/test_wrap.py b/test/distributed/fsdp/test_wrap.py index cd0d11ba9b4b..e157f041ae1b 100644 --- a/test/distributed/fsdp/test_wrap.py +++ b/test/distributed/fsdp/test_wrap.py @@ -5,6 +5,7 @@ import tempfile import unittest from enum import auto, Enum +from typing import Callable, Union import torch import torch.nn as nn @@ -15,10 +16,12 @@ FullyShardedDataParallel as FSDP, ) from torch.distributed.fsdp.wrap import ( + _FSDPPolicy, _or_policy, _wrap_batchnorm_individually, always_wrap_policy, enable_wrap, + ModuleWrapPolicy, size_based_auto_wrap_policy, transformer_auto_wrap_policy, wrap, @@ -373,6 +376,19 @@ def test_transformer_auto_wrap_policy(self): transformer_auto_wrap_policy, transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer}, ) + self._test_transformer_wrapping(auto_wrap_policy) + + @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") + def test_module_wrap_policy(self): + """Tests the ``ModuleWrapPolicy``.""" + auto_wrap_policy = ModuleWrapPolicy( + {TransformerEncoderLayer, TransformerDecoderLayer} + ) + self._test_transformer_wrapping(auto_wrap_policy) + + def _test_transformer_wrapping( + self, auto_wrap_policy: Union[Callable, _FSDPPolicy] + ): fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index cf46f89b353c..77ee7487a0af 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1503,6 +1503,21 @@ def _test_collectives(self, backend): with self.subTest(collective=collective, args=args): self._call_collective_with_varying_tensors(backend, collective, *args) + def _test_allreduce_coalesced(self, backend): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend, + world_size=self.world_size, + rank=self.rank, + store=store, + ) + # TODO: this will be updated in the future to not be backend specific + device = "cuda" if backend == "nccl" else "cpu" + tensors = [torch.ones(10, 10, device=torch.device(device))] + dist.all_reduce_coalesced(tensors, dist.ReduceOp.SUM) + for tensor in tensors: + self.assertEqual(tensor, torch.ones(10, 10) * self.world_size) + class CompilerTest(MultiProcessTestCase): def setUp(self): super(CompilerTest, self).setUp() diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index e0c7c64f7b83..ba214a02696f 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -2363,6 +2363,10 @@ class GlooProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGro def test_collectives(self): self._test_collectives(backend="gloo") + @requires_gloo() + def test_allreduce_coalesced(self): + self._test_allreduce_coalesced(backend="gloo") + class CompilerTest(test_c10d_common.CompilerTest): @property diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 5d412dd3fb1b..c514ea4ab31f 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2953,6 +2953,28 @@ class NcclProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGro def test_collectives(self): self._test_collectives(backend="nccl") + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_allreduce_coalesced(self): + self._test_allreduce_coalesced(backend="nccl") + + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_allgather_base(self): + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + device = "cuda" + tensor = torch.ones(10, 10, device=torch.device(device)) + output_tensor = torch.zeros(10, 10, device=torch.device(device)) + dist.all_gather_into_tensor(output_tensor, tensor) + self.assertEqual(output_tensor, tensor) + + if __name__ == "__main__": assert ( not torch.cuda._initialized diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 3dd3c5de7725..21550a0120e4 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1,4 +1,6 @@ # Owner(s): ["module: dynamo"] +import copy +import functools import logging import os import random @@ -16,7 +18,9 @@ from torch._dynamo.utils import same from torch._dynamo.testing import collect_results from torch._inductor.utils import has_triton +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.testing._internal.common_distributed import ( MultiProcessTestCase, import_transformers_or_skip, @@ -175,6 +179,7 @@ def test_ddp_baseline_aot_eager_multiprocess(self): @skip_if_lt_x_gpu(2) @import_transformers_or_skip() + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @patch.object(config, "optimize_ddp", True) @patch.object(torch._inductor.config, "fallback_random", True) def test_hf_bert_ddp(self): @@ -199,6 +204,108 @@ def test_hf_bert_ddp(self): opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat) self.assertTrue(same(correct_results, opt_results)) + + @skip_if_lt_x_gpu(1) + # TODO(whc) delete aot_eager test, if inductor test lands stably + def test_fsdp_aot_eager(self): + with _per_rank_init(self.rank, self.world_size): + # Test with basic FSDP wrapping (outer wrap around whole model) + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # Test with recursive wrapping, nested FSDP around each Linear + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP( + m, + auto_wrap_policy=functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, ) + ), + use_orig_params=True + ) + fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @skip_if_lt_x_gpu(1) + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + def test_fsdp_inductor(self): + with _per_rank_init(self.rank, self.world_size): + # Test with basic FSDP wrapping (outer wrap around whole model) + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP(m, use_orig_params=True) + fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + # Test with recursive wrapping, nested FSDP around each Linear + m, inputs, correct_outputs = get_model(f"cuda:{self.rank}") + fsdp_m = FSDP( + m, + auto_wrap_policy=functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear, ) + ), + use_orig_params=True + ) + fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) + outputs = fsdp_m(inputs) + self.assertTrue(same(correct_outputs, outputs)) + + @import_transformers_or_skip() + @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") + # TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert + @patch.object(torch._inductor.config.triton, "cudagraphs", False) + @patch.object(torch._inductor.config, "fallback_random", True) + # TODO(voz): Flaky on CI failure, consistent failure on local master. + @unittest.skipIf(True, "Flaky on CI failure, consistent failure on local master") + def test_hf_bert_fsdp(self): + from transformers.models.bert.modeling_bert import BertLayer + + def apply_fsdp(model, wrap_policy): + model = FSDP( + copy.deepcopy(model), + auto_wrap_policy=wrap_policy, + use_orig_params=True + ) + return model + + with _per_rank_init(self.rank, self.world_size): + for (wrap_policy, test_instance) in ( + ( + None, + "FSDP without recursive wrapping" + ), + ( + functools.partial( + transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer, ) + ), + "FSDP with recursive wrapping BertLayer instances" + ) + ): + print(f"Running hf_bert test for {test_instance}") + model, inputs = get_hf_bert(self.rank) + reset_rng_state() + eager_model = apply_fsdp(model, wrap_policy) + correct_outputs = eager_model(**inputs) + correct_loss = correct_outputs.loss + correct_loss.backward() + + reset_rng_state() + opt_model = apply_fsdp(model, wrap_policy) + + opt_model = torch._dynamo.optimize("inductor")(opt_model) + opt_outputs = opt_model(**inputs) + opt_loss = opt_outputs.loss + opt_loss.backward() + + inputs_flat = [inputs[k] for k in inputs] + correct_results = collect_results(eager_model, correct_outputs.logits, correct_loss, inputs_flat) + opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat) + self.assertTrue(same(correct_results, opt_results)) + + @requires_nccl() class TestDistributed(torch._dynamo.test_case.TestCase): """ @@ -257,32 +364,6 @@ def test_ddp_baseline_inductor(self): outputs = ddp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) - # TODO(whc) move these tests to 'distributed' shard to get nccl, or see if it's available already in pytorch CI? - @unittest.skip( - "can't run with gloo (no support for _allgather_base) and nccl not available in CI" - ) - @patch.object(config, "optimize_ddp", False) - def test_fsdp_baseline_aot_eager(self): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - m, inputs, correct_outputs = self.get_model() - fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) - fsdp_m = torch._dynamo.optimize("aot_eager")(fsdp_m) - outputs = fsdp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - @unittest.skip("hangs/crashes with inductor currently") - @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") - @patch.object(config, "optimize_ddp", False) - def test_fsdp_baseline_inductor(self): - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - - m, inputs, correct_outputs = self.get_model() - fsdp_m = FSDP(m, device_id=self.device_ids[0] if self.device_ids else None) - fsdp_m = torch._dynamo.optimize("inductor")(fsdp_m) - outputs = fsdp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - @patch.object(config, "optimize_ddp", True) def test_graph_split(self): """ diff --git a/test/dynamo/test_aot_cudagraphs.py b/test/dynamo/test_aot_cudagraphs.py index cb1d2a0e601f..fdb7c88762b8 100644 --- a/test/dynamo/test_aot_cudagraphs.py +++ b/test/dynamo/test_aot_cudagraphs.py @@ -71,7 +71,6 @@ def fn(x, y): y = torch.randn(3, device="cuda") fn(x, y) - @patch("torch._dynamo.config.suppress_errors", True) @patch_all() def test_dtoh(self): def model(x, y): @@ -105,7 +104,6 @@ def fn(x, y): y = torch.randn((), device="cpu") fn(x, y) - @patch("torch._dynamo.config.suppress_errors", True) @patch("functorch._src.config.use_functionalize", True) @patch_all(ok=False) # input mutation not supported yet def test_mutate_input(self): @@ -145,7 +143,6 @@ def fn(x, y): y = torch.randn(1, device="cuda") fn(x, y) - @patch("torch._dynamo.config.suppress_errors", True) @patch_all() def test_factory(self): def model(y): diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index d82cc6925fe9..294ea9e54952 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -51,22 +51,6 @@ def make_dynamic_cls(cls): ) -# DynamicShapesReproTests -unittest.expectedFailure( - DynamicShapesReproTests.test_reformer_eval_dynamic_shapes - # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer -) - -unittest.expectedFailure( - DynamicShapesReproTests.test_reformer_train_dynamic_shapes - # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer -) - -unittest.expectedFailure( - DynamicShapesReproTests.test_issue175_dynamic_shapes - # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer -) - unittest.expectedFailure( DynamicShapesReproTests.test_do_paste_mask_dynamic_shapes # aten.min.dim - couldn't find symbolic meta function/decomposition @@ -77,97 +61,66 @@ def make_dynamic_cls(cls): # Could not infer dtype of torch._C.SymIntNode ) -unittest.expectedFailure( - DynamicShapesReproTests.test_ellipsis_dynamic_shapes - # Cannot call sizes() on tensor with symbolic sizes/strides -) - unittest.expectedFailure( DynamicShapesReproTests.test_hf_t5_forward_dynamic_shapes # Cannot call sizes() on tensor with symbolic sizes/strides ) +# DynamicShapesExportTests unittest.expectedFailure( - DynamicShapesReproTests.test_reformer_sorting_dynamic_shapes - # Unable to cast Python instance to C++ type -) - -unittest.expectedFailure( - DynamicShapesReproTests.test_guard_fail_tensor_bool_dynamic_shapes - # RuntimeError: aten.allclose.default - couldn't find symbolic meta function/decomposition + DynamicShapesExportTests.test_export_with_constant_list_nonzero_dynamic_shapes ) - -# DynamicShapesMiscTests unittest.expectedFailure( - DynamicShapesMiscTests.test_unsupported_fake_tensor_dynamic_shapes - # aten.quantize_per_tensor.default - couldn't find symbolic meta function/decomposition + DynamicShapesExportTests.test_export_with_constant_list_nonzero_free_function_dynamic_shapes ) unittest.expectedFailure( - DynamicShapesMiscTests.test_module_deepcopy_dynamic_shapes - # aten.squeeze_.dim - couldn't find symbolic meta function/decompositio + DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes ) - -# DynamicShapesUnspecTests unittest.expectedFailure( - DynamicShapesUnspecTests.test_unspec_float_precision_dynamic_shapes - # float() argument must be a string or a real number, not 'torch._C.SymIntNode' + DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes ) -# DynamicShapesNNModuleTests -unittest.expectedFailure( - DynamicShapesNNModuleTests.test_unsupportedmethod_dynamic_shapes - # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition -) - +# DynamicShapesSubGraphTests unittest.expectedFailure( - DynamicShapesNNModuleTests.test_unsupportedmodule_dynamic_shapes - # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition + DynamicShapesSubGraphTests.test_enumerate_not_break_graph_dynamic_shapes ) +unittest.expectedFailure(DynamicShapesSubGraphTests.test_restore_state_dynamic_shapes) +# DynamicShapesUnspecTests +# Missing decomp +# RuntimeError: Failed running call_function +# (*(FakeTensor(FakeTensor(..., device='meta', size=(5, 1, 28, 28)), cpu), +# FakeTensor(FakeTensor(..., device='meta', size=(1,)), cpu), +# FakeTensor(FakeTensor(..., device='meta', size=(1,)), cpu), +# FakeTensor(Parameter(FakeTensor(..., device='meta', size=(1,), +# requires_grad=True)), cpu), +# FakeTensor(Parameter(FakeTensor(..., device='meta', size=(1,), +# requires_grad=True)), cpu), False, 0.1, +# FakeTensor(FakeTensor(..., device='meta', size=()), cpu)), **{}): +# aten._local_scalar_dense.default +unittest.expectedFailure(test_unspec.UnspecReproTests.test_batch_norm_act_unspec) + +# SymIntArrayRef expected to contain only concrete integers unittest.expectedFailure( - DynamicShapesNNModuleTests.test_self_mutating1_dynamic_shapes - # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition + DynamicShapesUnspecTests.test_unspec_float_precision_dynamic_shapes ) +# DynamicShapesReproTests unittest.expectedFailure( - DynamicShapesNNModuleTests.test_call_fn_with_non_const_inputs_safe_dynamic_shapes - # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition + DynamicShapesReproTests.test_reformer_eval_dynamic_shapes + # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer ) - -# DynamicShapesExportTests -unittest.expectedFailure( - DynamicShapesExportTests.test_export_compare_optimize_with_make_fx_dynamic_shapes -) -unittest.expectedFailure( - DynamicShapesExportTests.test_export_with_constant_list_nonzero_dynamic_shapes -) -unittest.expectedFailure( - DynamicShapesExportTests.test_export_with_constant_list_nonzero_free_function_dynamic_shapes -) unittest.expectedFailure( - DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes -) -unittest.expectedFailure( - DynamicShapesExportTests.test_export_with_stack_trace_dynamic_shapes -) -unittest.expectedFailure( - DynamicShapesExportTests.test_zeroes_in_new_shape_scalar_out_dynamic_shapes -) -unittest.expectedFailure( - DynamicShapesExportTests.test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass_dynamic_shapes -) -unittest.expectedFailure( - DynamicShapesExportTests.test_zeroes_in_new_shape_scalar_out_permute_dynamic_shapes + DynamicShapesReproTests.test_reformer_sorting_dynamic_shapes + # Unable to cast Python instance to C++ type ) - -# DynamicShapesSubGraphTests unittest.expectedFailure( - DynamicShapesSubGraphTests.test_enumerate_not_break_graph_dynamic_shapes + DynamicShapesReproTests.test_reformer_train_dynamic_shapes + # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer ) -unittest.expectedFailure(DynamicShapesSubGraphTests.test_restore_state_dynamic_shapes) if __name__ == "__main__": diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index a157926422c8..21c0d2004bb9 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -71,6 +71,32 @@ def func(x): self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + @patch.object(torch._dynamo.config, "dynamic_shapes", True) + def test_export_shape_control_flow_1(self): + def func(x): + if x.shape[0] > 10: + return x.cos() + return x.sin() + + opt_func = torch._dynamo.optimize("eager")(func) + real_result = opt_func(torch.ones(6, 4)) + + torch._dynamo.reset() + + exported = torch._dynamo.export(func, torch.ones(6, 4)) + out_graph, out_guards = exported + + dynamo_result = out_graph(torch.ones(6, 4)) + + self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result)) + hit = False + for guard in out_guards: + if guard.name == "symbolic_shape_expression": + hit = True + self.assertTrue("x.size()[0] <= 10" in guard.code_list) + + self.assertTrue(hit) + def test_export_graph_bypass(self): inp = [ torch.tensor([0.1, 0.1]), diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 4df7153b8fb2..e270852fc526 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -579,8 +579,6 @@ def fn(count): self.assertEqual(cnts.frame_count, 0) self.assertEqual(cnts.op_count, 0) - # KeyError: '__name__' - @patch.object(torch._dynamo.config, "suppress_errors", True) def test_user_getattr1(self): class MyConfig(dict): def __getattr__(self, name): @@ -1146,6 +1144,7 @@ def fn(x): torch._dynamo.run()(fn2)(torch.randn(4)) self.assertEqual(cnts2.frame_count, 0) + @patch.object(torch._dynamo.config, "suppress_errors", True) def test_nested_disable_decorator(self): cnts = torch._dynamo.testing.CompileCounter() @@ -1618,6 +1617,7 @@ def fn(x, func): self.assertEqual(cnts.op_count, 1) @patch.object(torch._dynamo.config, "fake_tensor_propagation", True) + @patch.object(torch._dynamo.config, "suppress_errors", True) def test_unsupported_fake_tensor(self): def f(x): return torch.quantize_per_tensor(x, 0.1, 10, torch.quint8) diff --git a/test/dynamo/test_no_fake_tensors.py b/test/dynamo/test_no_fake_tensors.py index df511f1affd5..f7943c1d7ab9 100644 --- a/test/dynamo/test_no_fake_tensors.py +++ b/test/dynamo/test_no_fake_tensors.py @@ -1,6 +1,4 @@ # Owner(s): ["module: dynamo"] -import unittest - from torch._dynamo.testing import make_test_cls_with_patches try: @@ -25,9 +23,6 @@ def make_no_fake_cls(cls): NoFakeTensorsNNModuleTests = make_no_fake_cls(test_modules.NNModuleTests) NoFakeTensorsUnspecTests = make_no_fake_cls(test_unspec.UnspecTests) -unittest.expectedFailure( - NoFakeTensorsReproTests.test_guard_fail_tensor_bool_no_fake_tensors -) NoFakeTensorsReproTests.test_numpy_list_no_fake_tensors.__unittest_expecting_failure__ = ( False ) diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index 92b163b76d6d..2f204a7a1199 100644 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -1,7 +1,6 @@ # Owner(s): ["module: dynamo"] import inspect -import sys import unittest import torch @@ -126,7 +125,7 @@ def training_iter_fn(batch, model, optimizer): batch = {"x": input1, "y": input2} for _ in range(2): opt_training_iter_fn(batch, net, optimizer) - self.assertEqual(cnts.frame_count, (2 if sys.version_info < (3, 8) else 6)) + self.assertEqual(cnts.frame_count, 2) if __name__ == "__main__": diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 2103e075fffc..fd0fcf9e08bc 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -11,6 +11,8 @@ from typing import List from unittest.mock import patch +import functorch._src.config + import numpy as np import torch @@ -803,7 +805,6 @@ def test_do_paste_mask(self): ) self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3) - # Graph break because of dynamic slicing self.assertEqual( torch._dynamo.utils.counters["frames"]["total"], torch._dynamo.utils.counters["frames"]["ok"] + 1, @@ -961,7 +962,7 @@ def test_maml_item_capture(self): self.assertEqual(cnt.frame_count, ifdyn(3, 2)) # TODO(jansel): figure out why op count depends on imports - self.assertIn(cnt.op_count, (36, 35, 29, 28)) + self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27)) # see: https://github.com/pytorch/pytorch/issues/80067 @patch.object(torch._dynamo.config, "fake_tensor_propagation", False) @@ -980,7 +981,7 @@ def test_maml_no_item_capture(self): self.assertEqual(cnt.frame_count, ifdyn(5, 4)) # TODO(jansel): figure out why op count depends on imports - self.assertIn(cnt.op_count, (31, 36, 35, 29, 28)) + self.assertIn(cnt.op_count, (31, 36, 35, 34, 29, 28)) def test_hf_model_output(self): ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10)) @@ -1316,6 +1317,7 @@ def blah(self, x): self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 3) self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["total"], 3) + @patch.object(torch._dynamo.config, "suppress_errors", True) def test_guard_fail_tensor_bool(self): @torch._dynamo.skip def fn(): @@ -1402,8 +1404,17 @@ def fn(x): self.assertTrue(same(ref1, res1)) @unittest.skipIf(not HAS_REFS, "requires recent PT version") - @unittest.expectedFailure def test_primtorch(self): + @torch._dynamo.optimize("eager") + def fn(x): + torch._refs.abs(x) + + fn(torch.randn(3)) + + @unittest.skipIf(not HAS_REFS, "requires recent PT version") + @unittest.expectedFailure + # inline_call [('inline in skipfiles: bind ...python3.10/inspect.py', 1)] + def test_primtorch_no_graph_break(self): @torch._dynamo.optimize("eager", nopython=True) def fn(x): torch._refs.abs(x) @@ -1456,14 +1467,14 @@ def fn(x): fn(torch.randn(3)) - # AssertionError: ABCMeta + # Bug with storage meta - torch.BoolStorage is becoming torch.storage._LegacyStorageMeta @unittest.expectedFailure def test_isinstance_storage(self): @torch._dynamo.optimize("eager") def fn(x): f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40]) bools = torch.BoolStorage.from_buffer(f, "big") - self.assertTrue(isinstance(bools, torch.BoolStorage)) + assert isinstance(bools, torch.BoolStorage) return x fn(torch.randn(3)) @@ -1662,6 +1673,21 @@ def fn(x): opt_fn(x) self.assertEqual(cnt.frame_count, 1) + @patch.object(functorch._src.config, "use_dynamic_shapes", True) + def test_bigbird_unsqueeze_inplace(self): + def fn(reshape_2): + view_2 = reshape_2.clone() + view_2.unsqueeze_(2) + cat_11 = torch.cat([view_2], dim=2) + view_13 = cat_11.view((2, 12, 64, -1)) + return (view_13,) + + x = torch.randn(2, 12, 64, 64, requires_grad=True) + ref = fn(x) + opt_fn = torch._dynamo.optimize("aot_eager")(fn) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + # This doesn't work without fake tensors but I don't care @patch.object(torch._dynamo.config, "fake_tensor_propagation", True) def test_issue1466_size_aot_autograd(self): @@ -1792,6 +1818,111 @@ def fn(x): res = opt_fn(a) self.assertTrue(same(ref, res)) + def test_tokenization(self): + from collections import UserDict + + class BatchEncoding(UserDict): + """ + Copied from tokenization + """ + + def __init__( + self, + data, + ): + super().__init__(data) + + def __getattr__(self, item: str): + try: + return self.data[item] + except KeyError: + raise AttributeError + + def tokenization(x): + encoding = BatchEncoding({"key": x}) + return encoding["key"] + + opt_fn = torch._dynamo.optimize("eager")(tokenization) + x = torch.rand((1, 4)) + ref = tokenization(x) + res = opt_fn(x) + self.assertTrue(same(ref, res)) + + def test_modules(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc = torch.nn.Linear(4, 3) + + def forward(self, inp): + res = torch.zeros(3, 3) + for mod in self.modules(): + res += self.fc(inp) + return res + + mod = Foo() + args = (torch.ones(3, 4),) + cnt = torch._dynamo.testing.CompileCounter() + opt_mod = torch._dynamo.optimize(cnt, nopython=True)(mod) + self.assertTrue(same(mod(*args), opt_mod(*args))) + self.assertEqual(cnt.op_count, 5) + self.assertEqual(cnt.frame_count, 1) + + def test_for_loop_graph_break(self): + def inner(x): + return torch.sin(x) + + def fn(x): + for _ in range(100): + inner(x) + torch._dynamo.graph_break() + return x + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt)(fn) + x = torch.randn(4) + opt_fn(x) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 1) + + def test_for_loop_graph_break_before(self): + # Checks that the backedge is calculated correctly + def inner(x): + return torch.sin(x) + + def fn(x): + torch._dynamo.graph_break() + for _ in range(100): + inner(x) + return x + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt)(fn) + x = torch.randn(4) + opt_fn(x) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 100) + + def test_while_loop_graph_break(self): + # Repro of tacotron2 cache_size_recompilation + def inner(x): + return torch.sin(x) + + def fn(x): + i = 20 + while i > 10: + x = inner(x) + i -= 1 + torch._dynamo.graph_break() + return x + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(cnt)(fn) + x = torch.randn(4) + opt_fn(x) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index fd5396981b74..e46d79208de0 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -50,6 +50,8 @@ class UnspecTest(cls): UnspecReproTests = make_unspec_cls(test_repros.ReproTests) UnspecNNModuleTests = make_unspec_cls(test_modules.NNModuleTests) +unittest.expectedFailure(UnspecReproTests.test_batch_norm_act_unspec) + @patch.object(torch._dynamo.config, "specialize_int_float", False) class UnspecTests(torch._dynamo.test_case.TestCase): diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index f4782b8a595d..e31ac58039ec 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -973,6 +973,9 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('cholesky'), xfail('linalg.cholesky'), + # Given input size: (s0xs1x2). Calculated output size: ... + skip('max_pool2d_with_indices_backward'), + # Misc xfail('to_sparse'), xfail('corrcoef'), @@ -1095,8 +1098,6 @@ def assert_compiler(gm: torch.fx.GraphModule, _): xfail('masked_fill', ''), # could not find kernel xfail('masked.logaddexp', ''), # aten.logaddexp.default - couldn't find symbolic meta function/decomposi... xfail('masked.logsumexp', ''), # Cannot call sizes() on tensor with symbolic sizes/strides - # Seems flaky: https://github.com/pytorch/pytorch/issues/88883 - skip('masked.median', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked.prod', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_scatter', ''), # Cannot call sizes() on tensor with symbolic sizes/strides xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decompos... diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 26b64c5e70cc..ff69ed9df6e6 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -3130,13 +3130,16 @@ def normalize_devices(fx_g): return fx_g class TestFunctionalize(TestCase): - def _check_functionalize_correctness(self, f, inpt): + def _check_functionalize_correctness(self, f, inpt, *, skip_vmap=False): inpt1 = inpt.clone() inpt2 = inpt.clone() inpt3 = inpt.clone() expected_outputs = f(inpt1) - actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze() + if skip_vmap: + actual_outputs = functionalize(f)(inpt2) + else: + actual_outputs = vmap(functionalize(f))(inpt2.unsqueeze(0))[0].squeeze() # Right now the flavor of functionalize that also removes view ops # isn't being used with vmap # That's because {view}_copy ops don't have batching rules yet @@ -3206,7 +3209,8 @@ def f(x: torch.Tensor) -> torch.Tensor: z2, z3 = z1.split(2) z2.add_(tmp) return x - self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device)) + # See Note [Fix vmap slice_scatter] + self._check_functionalize_correctness(f, torch.zeros(4, 2, device=device), skip_vmap=True) # Ensure functionalize works with List[Optional[Tensor]] arguments. # See the fix / discussion at https://github.com/pytorch/pytorch/pull/76085 diff --git a/test/fx/test_common_passes.py b/test/fx/test_common_passes.py index 9c59abce4da6..407e707db879 100644 --- a/test/fx/test_common_passes.py +++ b/test/fx/test_common_passes.py @@ -73,10 +73,15 @@ def MutationMetadata(x): if torch.cuda.is_available(): Devices.append("cuda") + +def name_fn(common_pass, f, device): + """Names parameterized test cases.""" + return f'{type(common_pass()).__name__}_{f.__name__}_{device}' + @instantiate_parametrized_tests class TestCommonPass(TestCase): - @parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices)) + @parametrize("common_pass,f,device", itertools.product(Passes, Test_Cases, Devices), name_fn) def test_correctness(self, common_pass, f, device): inp = torch.randn(10, device=device) @@ -94,7 +99,7 @@ def test_correctness(self, common_pass, f, device): self.assertEqual(result, expected) - @parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices)) + @parametrize("common_pass,f,device", itertools.product(Passes, Factory_Test_Cases, Devices), name_fn) def test_correctness_factory(self, common_pass, f, device): inp = torch.randn(10, device=device) traced_m = make_fx(f)(inp, device) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index aea8013bdfac..ba1f9032d97f 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -20,7 +20,6 @@ from torch.fx.experimental.proxy_tensor import make_fx from torch.nn import functional as F from torch.testing._internal.common_utils import ( - IS_FBCODE, TEST_WITH_ASAN, TEST_WITH_ROCM, TestCase as TorchTestCase, @@ -41,7 +40,7 @@ from torch._inductor.compile_fx import compile_fx, complex_memory_overlap from torch._inductor.ir import IndexingDiv, ModularIndexing from torch._inductor.sizevars import SizeVarAllocator - from torch._inductor.utils import has_torchvision_roi_align, has_triton, timed + from torch._inductor.utils import has_torchvision_roi_align, timed # This will only pass on pytorch builds newer than roughly 5/15/2022 assert get_decompositions([torch.ops.aten.trace]) @@ -53,25 +52,10 @@ sys.exit(0) raise unittest.SkipTest("requires sympy/functorch/filelock") -HAS_CPU = False -try: - from subprocess import CalledProcessError - - from torch._inductor.codecache import CppCodeCache - - CppCodeCache.load("") - HAS_CPU = not IS_FBCODE -except ( - CalledProcessError, - OSError, - torch._inductor.exc.InvalidCxxCompiler, - torch._inductor.exc.CppCompileError, -): - pass +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA aten = torch.ops.aten -HAS_CUDA = has_triton() requires_cuda = functools.partial(unittest.skipIf, not HAS_CUDA, "requires cuda") torch._inductor.config.triton.autotune = False # too slow @@ -1400,6 +1384,7 @@ def test_conv2d_unary(self): [1, 3], [1, 2], [1, 4], + ["same", 0], test_memory_format, ) @@ -1409,6 +1394,7 @@ def test_conv2d_unary(self): kernel_size, dilation, groups, + padding, memory_format, ) in options: oC = 32 * groups @@ -1419,6 +1405,7 @@ def test_conv2d_unary(self): iC, oC, kernel_size=kernel_size, + padding=padding, dilation=dilation, groups=groups, bias=bias, @@ -1448,7 +1435,9 @@ def __init__( out_channels, dilation, groups, + padding, bias, + has_relu, **kwargs, ): super(M, self).__init__() @@ -1457,6 +1446,7 @@ def __init__( out_channels, dilation=dilation, groups=groups, + padding=padding, bias=bias, **kwargs, ) @@ -1466,40 +1456,54 @@ def __init__( out_channels, dilation=dilation, groups=groups, + padding=padding, bias=bias, **kwargs, ) ) self.binary_fn = binary_fn + self.relu = torch.nn.ReLU() if has_relu else torch.nn.Identity() def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(x) - return self.binary_fn(x1, x2) + return self.relu(self.binary_fn(x1, x2)) test_memory_format = [torch.contiguous_format, torch.channels_last] options = itertools.product( binary_list, [True, False], + [True, False], [1, 3], [1, 2], [1, 4], + ["same", 0], test_memory_format, ) for ( binary_fn, + has_relu, bias, kernel_size, dilation, groups, + padding, memory_format, ) in options: oC = 32 * groups iC = 3 * groups x_shape = (1, iC, 112, 112) mod = M( - binary_fn, iC, oC, dilation, groups, bias, kernel_size=kernel_size + binary_fn, + iC, + oC, + dilation, + groups, + padding, + bias, + has_relu, + kernel_size=kernel_size, ).eval() mod = mod.to(memory_format=memory_format) # TODO: add bf16 test @@ -4787,6 +4791,9 @@ def forward(self, x): for param in model_opt.parameters(): param.add_(1.0) + # Probably fails due to the symint math issue caught while adding + # max_pool2d_with_indices_backward + @unittest.skip("Accuracy failure, needs debugging") def test_accuracy_issue1(self): class Repro(torch.nn.Module): def __init__(self): diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 5cee29920b77..3880b87c082c 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -16,20 +16,22 @@ onlyNativeDeviceTypes, OpDTypes, ops, + skipCPUIf, + skipCUDAIf, ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import ( dtype_abbrs, run_tests, skipCUDAMemoryLeakCheckIf, + skipIfCrossRef, + skipIfTorchDynamo, suppress_warnings, - TEST_WITH_ROCM, TestCase, ) +from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA try: - from torch._inductor.utils import has_triton - try: from .test_torchinductor import check_model, check_model_cuda except ImportError: @@ -120,6 +122,7 @@ def process(device_type): inductor_skips["cpu"] = { "linalg.ldl_solve": {b8, f16, f32, f64, i32, i64}, # segfault + "linalg.ldl_factor": {f32, f64}, # flaky "__rdiv__": {b8, f16, f32, f64, i32, i64}, # flaky } @@ -169,6 +172,9 @@ def process(device_type): "argwhere": {b8, f16, f32, f64, i32, i64}, "bernoulli": {f32, f64}, "bincount": {i32, i64}, + "bucketize": {b8, f16, f32, f64, i32, i64}, + "cdouble": {b8, f16, f32, f64, i32, i64}, + "cfloat": {b8, f16, f32, f64, i32, i64}, "chalf": {b8, f16, f32, f64, i32, i64}, "cholesky": {f32, f64}, "combinations": {b8, f16, f32, f64, i32, i64}, @@ -209,11 +215,10 @@ def process(device_type): "linalg.lstsq.grad_oriented": {f32, f64}, "linalg.matrix_rank": {f32, f64}, "linalg.matrix_rank.hermitian": {f32, f64}, - "linalg.lu_solve": {f32, f64}, - "lu_solve": {f32, f64}, - "lu_unpack": {f32, f64}, + "linalg.pinv.singular": {f32, f64}, "logdet": {f32, f64}, "masked.norm": {f16}, + "masked.normalize": {f16}, "masked_fill": {f16}, "masked_scatter": {f16, f32, f64}, "masked_select": {b8, f16, f32, f64, i32, i64}, @@ -225,8 +230,8 @@ def process(device_type): "nan_to_num": {f16}, "nanquantile": {f32, f64}, "nn.functional.avg_pool1d": {i64}, - "nn.functional.avg_pool2d": {i64}, - "nn.functional.adaptive_avg_pool2d": {f16}, + "nn.functional.avg_pool2d": {i64, f64}, + "nn.functional.adaptive_avg_pool2d": {f16, f64}, "nn.functional.ctc_loss": {f32, f64}, "nn.functional.gaussian_nll_loss": {f32, f64}, "nn.functional.gelu": {f64}, @@ -243,6 +248,7 @@ def process(device_type): "quantile": {f32, f64}, "rand_like": {f16, f32, f64}, "randint_like": {f16, f32, f64, i32, i64}, + "randint": {f16, f32, f64, i32, i64}, "randn_like": {f16, f32, f64}, "repeat_interleave": {b8, f16, f32, f64, i32, i64}, "scatter_add": {f16}, @@ -279,6 +285,7 @@ def process(device_type): "baddbmm": {f16}, "bernoulli": {f16, f32, f64}, "bincount": {i32, i64}, + "bucketize": {b8, f16, f32, f64, i32, i64}, "chalf": {b8, f16, f32, f64, i32, i64}, "cholesky": {f32, f64}, "combinations": {b8, f16, f32, f64, i32, i64}, @@ -365,7 +372,6 @@ def process(device_type): "asin": {f16}, "cumprod": {f16}, "linalg.vector_norm": {f64, f64}, - "linalg.householder_product": {f32}, "kron": {f16}, "nanquantile": {f32, f64}, "native_batch_norm": {f16, f32, f64}, @@ -454,6 +460,10 @@ class TestInductorOpInfo(TestCase): @skipCUDAMemoryLeakCheckIf( True ) # inductor kernels failing this test intermittently + @skipCUDAIf(not HAS_CUDA, "Skipped! Triton not found") + @skipCPUIf(not HAS_CPU, "Skipped! Supported CPU compiler not found") + @skipIfTorchDynamo("Test uses dynamo already") + @skipIfCrossRef @_ops(op_db[START:END]) @patch("torch._dynamo.config.raise_on_unsafe_aot_autograd", True) def test_comprehensive(self, device, dtype, op): @@ -598,5 +608,4 @@ def fn(*args, **kwargs): instantiate_device_type_tests(TestInductorOpInfo, globals()) if __name__ == "__main__": - if has_triton() and not TEST_WITH_ROCM: - run_tests() + run_tests() diff --git a/test/onnx/internal/test_diagnostics.py b/test/onnx/internal/test_diagnostics.py index ea9a789e91c1..884b7cb1c388 100644 --- a/test/onnx/internal/test_diagnostics.py +++ b/test/onnx/internal/test_diagnostics.py @@ -19,7 +19,7 @@ def _assert_has_diagnostics( rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]], ): sarif_log = engine.sarif_log() - unseen_pairs = {(rule.id, level.value) for rule, level in rule_level_pairs} + unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs} actual_results = [] for run in sarif_log.runs: if run.results is None: diff --git a/test/test_autograd.py b/test/test_autograd.py index e08047860e42..33cf188af065 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -6776,6 +6776,20 @@ def inplace_double(x): # not leaf, not output test(lambda: (1 + torch.randn(5, requires_grad=True)), False) + def test_saved_variable_saved_original_inplace_detach(self): + # Detaching a tensor that is saved input raises + a = torch.tensor(1., requires_grad=True).clone() + b = a.sin() + a.detach_() + with self.assertRaisesRegex(RuntimeError, "Trying to use a saved tensor that has been detached"): + b.backward() + + # Detaching a tensor that is saved as output is OK + a = torch.tensor(1., requires_grad=True).clone() + b = a.exp() + a.detach_() + b.backward() + def test_saved_variable_packing_unpacking_did_not_save_original_with_hooks(self): # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks # The saved_original / did_not_save_original distinction corresponds to the `save_original` diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 270ca89764ed..6a7ff90527d3 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -20,19 +20,16 @@ ChainDataset, ConcatDataset, DataLoader, - DataLoader2, Dataset, IterableDataset, IterDataPipe, Subset, TensorDataset, - communication, _utils ) from torch.utils.data._utils import MP_STATUS_CHECK_INTERVAL from torch.utils.data.dataset import random_split from torch.utils.data.datapipes.iter import IterableWrapper -from torch.utils.data.datapipes.map import SequenceWrapper from torch._utils import ExceptionWrapper from torch.testing._internal.common_utils import (TestCase, run_tests, TEST_NUMPY, IS_WINDOWS, IS_CI, NO_MULTIPROCESSING_SPAWN, skipIfRocm, slowTest, @@ -2222,114 +2219,6 @@ def test_excessive_thread_creation_warning(self): r"excessive worker creation might get DataLoader running slow or even freeze"): dataloader = DataLoader(self.dataset, batch_size=2, num_workers=1000) -# Define a global function for testing purposes since local functions cannot be pickled -def identity(x): - return x - -@unittest.skipIf( - TEST_WITH_TSAN, - "Fails with TSAN with the following error: starting new threads after multi-threaded " - "fork is not supported. Dying (set die_after_fork=0 to override)") -class TestDataLoader2(TestCase): - @skipIfNoDill - def test_basics(self): - # TODO(VitalyFedyunin): This test will start breaking if we remove guaranteed order - # of traversing workers - dp = IterableWrapper(list(range(1000))).sharding_filter() - dl = DataLoader(dp, batch_size=3, collate_fn=identity, num_workers=2) - dl2 = DataLoader2(dp, batch_size=3, collate_fn=identity, num_workers=2) - dl2_threading = DataLoader2(dp, batch_size=3, collate_fn=identity, num_workers=2, parallelism_mode='thread') - self.assertEqual(list(dl), list(dl2)) - self.assertEqual(list(dl), list(dl2_threading)) - - class Sorter(IterDataPipe): - def __init__(self, datapipe): - self.datapipe = datapipe - - def __iter__(self): - return iter(sorted(self.datapipe)) - - def test_shuffle(self): - items = list(range(1000)) - dp = IterableWrapper(items).sharding_filter().shuffle() - - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=False) - self.assertEqual(items, list(dl)) - - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=True) - self.assertNotEqual(items, list(dl)) - self.assertEqual(items, sorted(list(dl))) - - dl = DataLoader2(dp, batch_size=None, num_workers=2, shuffle=True) - self.assertNotEqual(items, list(dl)) - self.assertEqual(items, sorted(list(dl))) - - dl = DataLoader2(self.Sorter(dp), batch_size=None, num_workers=2, shuffle=True) - self.assertEqual(list(dl), items) - - dl = DataLoader2(self.Sorter(dp), batch_size=None, num_workers=2, shuffle=True) - self.assertEqual(list(dl), items) - - -@unittest.skipIf( - TEST_WITH_TSAN, - "Fails with TSAN with the following error: starting new threads after multi-threaded " - "fork is not supported. Dying (set die_after_fork=0 to override)") -class TestDataLoader2_EventLoop(TestCase): - @skipIfNoDill - def test_basic_threading(self): - def clean_me(process, req_queue, res_queue): - req_queue.put(communication.messages.TerminateRequest()) - _ = res_queue.get() - process.join() - - it = list(range(100)) - numbers_dp = IterableWrapper(it) - (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline(numbers_dp) - - process.start() - local_datapipe = communication.iter.QueueWrapper( - communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) - - actual = list(local_datapipe) - clean_me(process, req_queue, res_queue) - - self.assertEqual(list(range(100)), actual) - - @skipIfNoDill - def test_basic_mapdatapipe_threading(self): - def clean_me(process, req_queue, res_queue): - req_queue.put(communication.messages.TerminateRequest()) - _ = res_queue.get() - process.join() - - input_len = 100 - it = list(range(input_len)) - numbers_dp = SequenceWrapper(it) - (process, req_queue, res_queue, _thread_local_datapipe) = communication.eventloop.SpawnThreadForDataPipeline( - numbers_dp) - - process.start() - - # Functional Test: Ensure that you can retrieve every element from the Queue and DataPipe - local_datapipe = communication.map.QueueWrapperForMap( - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) - actual = list(local_datapipe) - self.assertEqual([(x, x) for x in range(100)], actual) - - # Functional Test: raise Error when input - local_datapipe = communication.map.QueueWrapperForMap( - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) - with self.assertRaisesRegex(IndexError, "out of bound"): - local_datapipe[1000] - - # __len__ Test: Ensure that the correct length is returned - local_datapipe = communication.map.QueueWrapperForMap( - communication.protocol.MapDataPipeQueueProtocolClient(req_queue, res_queue)) - self.assertEqual(input_len, len(local_datapipe)) - - clean_me(process, req_queue, res_queue) - class IntegrationTestDataLoaderDataPipe(TestCase): r""" diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 0f1f49d2e6ea..3a8e31151bf3 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -20,7 +20,7 @@ from torch.utils._pytree import tree_map from torch.fx.experimental import symbolic_shapes from torch.fx.experimental.proxy_tensor import make_fx -from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int +from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int, to_node from torch.utils._python_dispatch import TorchDispatchMode from torch import SymInt @@ -478,9 +478,9 @@ def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): def get_sym_inp(inp): if isinstance(inp, int): - return torch.SymInt(seed_node.to_node(inp)) + return torch.SymInt(to_node(seed_node, inp)) else: - return torch.SymFloat(seed_node.to_node(inp)) + return torch.SymFloat(to_node(seed_node, inp)) def maybe_xfail(inp1, inp2): key = (fn, type(inp1).__name__, type(inp2).__name__) diff --git a/test/test_functionalization.py b/test/test_functionalization.py index c6c3d991771b..c5330664d1e8 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -147,17 +147,17 @@ def forward(self, a_1): sum_1 = torch.ops.aten.sum.default(relu) ones_like = torch.ops.aten.ones_like.default(sum_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False, memory_format = torch.preserve_format); sum_1 = None expand_copy = torch.ops.aten.expand_copy.default(ones_like, [16, 64, 128, 128]); ones_like = None - _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(expand_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1]); expand_copy = None - new_empty_strided = torch.ops.aten.new_empty_strided.default(_reshape_alias_copy, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) - view_copy_3 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128]) - view_copy_4 = torch.ops.aten.view_copy.default(_reshape_alias_copy, [16, 64, 128, 128]) - clone_1 = torch.ops.aten.clone.default(view_copy_4, memory_format = torch.contiguous_format); view_copy_4 = None + view_copy_3 = torch.ops.aten.view_copy.default(expand_copy, [1, 1024, 128, 128]); expand_copy = None + new_empty_strided = torch.ops.aten.new_empty_strided.default(view_copy_3, [1, 1024, 128, 128], [16777216, 16384, 128, 1]) + view_copy_4 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128]) + view_copy_5 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128]) + clone_1 = torch.ops.aten.clone.default(view_copy_5, memory_format = torch.contiguous_format); view_copy_5 = None threshold_backward = torch.ops.aten.threshold_backward.default(clone_1, relu, 0); clone_1 = relu = None - _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(_reshape_alias_copy, [16, 64, 128, 128], [1048576, 16384, 128, 1]); _reshape_alias_copy = None - detach_copy = torch.ops.aten.detach_copy.default(_reshape_alias_copy_1); _reshape_alias_copy_1 = None - view_copy_5 = torch.ops.aten.view_copy.default(threshold_backward, [1, 1024, 128, 128]); threshold_backward = None - _reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_5, [16, 64, 128, 128], [1048576, 16384, 128, 1]); view_copy_5 = None - detach_copy_1 = torch.ops.aten.detach_copy.default(_reshape_alias_copy_2); _reshape_alias_copy_2 = None + view_copy_6 = torch.ops.aten.view_copy.default(view_copy_3, [16, 64, 128, 128]); view_copy_3 = None + detach_copy = torch.ops.aten.detach_copy.default(view_copy_6); view_copy_6 = None + view_copy_7 = torch.ops.aten.view_copy.default(threshold_backward, [1, 1024, 128, 128]); threshold_backward = None + view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [16, 64, 128, 128]); view_copy_7 = None + detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_8); view_copy_8 = None return detach_copy_1 """) # noqa: B950 @@ -710,40 +710,40 @@ def forward(self, a_1): ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False) add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None view_copy = torch.ops.aten.view_copy.default(add, [8]) - _reshape_alias_copy = torch.ops.aten._reshape_alias_copy.default(view_copy, [2, 4], [4, 1]); view_copy = None - transpose_copy = torch.ops.aten.transpose_copy.int(_reshape_alias_copy, 1, 0) + view_copy_1 = torch.ops.aten.view_copy.default(view_copy, [2, 4]); view_copy = None + transpose_copy = torch.ops.aten.transpose_copy.int(view_copy_1, 1, 0) unsqueeze_copy = torch.ops.aten.unsqueeze_copy.default(transpose_copy, 0); transpose_copy = None squeeze_copy = torch.ops.aten.squeeze_copy.default(unsqueeze_copy); unsqueeze_copy = None split_copy = torch.ops.aten.split_copy.Tensor(squeeze_copy, 2); squeeze_copy = None getitem = split_copy[0] getitem_1 = split_copy[1]; split_copy = None add_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None - select_copy = torch.ops.aten.select_copy.int(_reshape_alias_copy, 0, 0); _reshape_alias_copy = None - _reshape_alias_copy_1 = torch.ops.aten._reshape_alias_copy.default(add_1, [4], [1]) - view_copy_1 = torch.ops.aten.view_copy.default(add, [8]); add = None - _reshape_alias_copy_2 = torch.ops.aten._reshape_alias_copy.default(view_copy_1, [2, 4], [4, 1]); view_copy_1 = None - transpose_copy_1 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_2, 1, 0); _reshape_alias_copy_2 = None + select_copy = torch.ops.aten.select_copy.int(view_copy_1, 0, 0); view_copy_1 = None + view_copy_2 = torch.ops.aten.view_copy.default(add_1, [4]) + view_copy_3 = torch.ops.aten.view_copy.default(add, [8]); add = None + view_copy_4 = torch.ops.aten.view_copy.default(view_copy_3, [2, 4]); view_copy_3 = None + transpose_copy_1 = torch.ops.aten.transpose_copy.int(view_copy_4, 1, 0); view_copy_4 = None unsqueeze_copy_1 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_1, 0); transpose_copy_1 = None squeeze_copy_1 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_1); unsqueeze_copy_1 = None slice_scatter = torch.ops.aten.slice_scatter.default(squeeze_copy_1, add_1, 0, 0, 2); squeeze_copy_1 = None unsqueeze_copy_2 = torch.ops.aten.unsqueeze_copy.default(slice_scatter, 0); slice_scatter = None squeeze_copy_2 = torch.ops.aten.squeeze_copy.dim(unsqueeze_copy_2, 0); unsqueeze_copy_2 = None transpose_copy_2 = torch.ops.aten.transpose_copy.int(squeeze_copy_2, 1, 0); squeeze_copy_2 = None - _reshape_alias_copy_3 = torch.ops.aten._reshape_alias_copy.default(transpose_copy_2, [8], [1]); transpose_copy_2 = None - view_copy_2 = torch.ops.aten.view_copy.default(_reshape_alias_copy_3, [4, 2]); _reshape_alias_copy_3 = None - view_copy_3 = torch.ops.aten.view_copy.default(view_copy_2, [8]) - _reshape_alias_copy_4 = torch.ops.aten._reshape_alias_copy.default(view_copy_3, [2, 4], [4, 1]); view_copy_3 = None - select_copy_1 = torch.ops.aten.select_copy.int(_reshape_alias_copy_4, 0, 0); _reshape_alias_copy_4 = None - view_copy_4 = torch.ops.aten.view_copy.default(view_copy_2, [8]); view_copy_2 = None - _reshape_alias_copy_5 = torch.ops.aten._reshape_alias_copy.default(view_copy_4, [2, 4], [4, 1]); view_copy_4 = None - transpose_copy_3 = torch.ops.aten.transpose_copy.int(_reshape_alias_copy_5, 1, 0); _reshape_alias_copy_5 = None + view_copy_5 = torch.ops.aten.view_copy.default(transpose_copy_2, [8]); transpose_copy_2 = None + view_copy_6 = torch.ops.aten.view_copy.default(view_copy_5, [4, 2]); view_copy_5 = None + view_copy_7 = torch.ops.aten.view_copy.default(view_copy_6, [8]) + view_copy_8 = torch.ops.aten.view_copy.default(view_copy_7, [2, 4]); view_copy_7 = None + select_copy_1 = torch.ops.aten.select_copy.int(view_copy_8, 0, 0); view_copy_8 = None + view_copy_9 = torch.ops.aten.view_copy.default(view_copy_6, [8]); view_copy_6 = None + view_copy_10 = torch.ops.aten.view_copy.default(view_copy_9, [2, 4]); view_copy_9 = None + transpose_copy_3 = torch.ops.aten.transpose_copy.int(view_copy_10, 1, 0); view_copy_10 = None unsqueeze_copy_3 = torch.ops.aten.unsqueeze_copy.default(transpose_copy_3, 0); transpose_copy_3 = None squeeze_copy_3 = torch.ops.aten.squeeze_copy.default(unsqueeze_copy_3); unsqueeze_copy_3 = None split_copy_1 = torch.ops.aten.split_copy.Tensor(squeeze_copy_3, 2); squeeze_copy_3 = None getitem_2 = split_copy_1[0] getitem_3 = split_copy_1[1]; split_copy_1 = None - _reshape_alias_copy_6 = torch.ops.aten._reshape_alias_copy.default(getitem_2, [4], [1]); getitem_2 = None - add_2 = torch.ops.aten.add.Tensor(select_copy_1, _reshape_alias_copy_6); select_copy_1 = _reshape_alias_copy_6 = None + view_copy_11 = torch.ops.aten.view_copy.default(getitem_2, [4]); getitem_2 = None + add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_11); select_copy_1 = view_copy_11 = None return add_1 """) # noqa: B950 @@ -756,30 +756,30 @@ def forward(self, a_1): ones = torch.ops.aten.ones.default([2, 2], device = device(type='cpu'), pin_memory = False) add = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None view = torch.ops.aten.view.default(add, [8]) - _reshape_alias = torch.ops.aten._reshape_alias.default(view, [2, 4], [4, 1]); view = None - transpose = torch.ops.aten.transpose.int(_reshape_alias, 1, 0) + view_1 = torch.ops.aten.view.default(view, [2, 4]); view = None + transpose = torch.ops.aten.transpose.int(view_1, 1, 0) unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0); transpose = None squeeze = torch.ops.aten.squeeze.default(unsqueeze); unsqueeze = None split = torch.ops.aten.split.Tensor(squeeze, 2); squeeze = None getitem = split[0] getitem_1 = split[1]; split = None add_1 = torch.ops.aten.add_.Tensor(getitem, ones); ones = None - select = torch.ops.aten.select.int(_reshape_alias, 0, 0); _reshape_alias = None + select = torch.ops.aten.select.int(view_1, 0, 0); view_1 = None clone = torch.ops.aten.clone.default(getitem, memory_format = torch.contiguous_format) _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [4]); clone = None - view_1 = torch.ops.aten.view.default(add, [8]); add = None - _reshape_alias_1 = torch.ops.aten._reshape_alias.default(view_1, [2, 4], [4, 1]); view_1 = None - transpose_1 = torch.ops.aten.transpose.int(_reshape_alias_1, 1, 0); _reshape_alias_1 = None + view_2 = torch.ops.aten.view.default(add, [8]); add = None + view_3 = torch.ops.aten.view.default(view_2, [2, 4]); view_2 = None + transpose_1 = torch.ops.aten.transpose.int(view_3, 1, 0); view_3 = None unsqueeze_1 = torch.ops.aten.unsqueeze.default(transpose_1, 0); transpose_1 = None squeeze_1 = torch.ops.aten.squeeze.default(unsqueeze_1); unsqueeze_1 = None unsqueeze_2 = torch.ops.aten.unsqueeze.default(squeeze_1, 0); squeeze_1 = None squeeze_2 = torch.ops.aten.squeeze.dim(unsqueeze_2, 0); unsqueeze_2 = None transpose_2 = torch.ops.aten.transpose.int(squeeze_2, 1, 0); squeeze_2 = None - _reshape_alias_2 = torch.ops.aten._reshape_alias.default(transpose_2, [8], [1]); transpose_2 = None - view_2 = torch.ops.aten.view.default(_reshape_alias_2, [4, 2]); _reshape_alias_2 = None - view_3 = torch.ops.aten.view.default(view_2, [8]); view_2 = None - _reshape_alias_3 = torch.ops.aten._reshape_alias.default(view_3, [2, 4], [4, 1]); view_3 = None - select_1 = torch.ops.aten.select.int(_reshape_alias_3, 0, 0); _reshape_alias_3 = None + view_4 = torch.ops.aten.view.default(transpose_2, [8]); transpose_2 = None + view_5 = torch.ops.aten.view.default(view_4, [4, 2]); view_4 = None + view_6 = torch.ops.aten.view.default(view_5, [8]); view_5 = None + view_7 = torch.ops.aten.view.default(view_6, [2, 4]); view_6 = None + select_1 = torch.ops.aten.select.int(view_7, 0, 0); view_7 = None add_2 = torch.ops.aten.add.Tensor(select_1, _unsafe_view); select_1 = _unsafe_view = None return getitem """) diff --git a/test/test_meta.py b/test/test_meta.py index ef25d184c842..ae248a90cffb 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -745,7 +745,6 @@ def run_meta_crossref( } meta_function_device_skips['cpu'] = { - torch.narrow_copy: {b8, bf16, c128, c32, c64, f16, f32, f64, i16, i32, i64, i8, u8}, torch.native_batch_norm: {f32, f64}, } diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 86beb651cb2d..42ecc3d376ab 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -721,7 +721,6 @@ def deco(cls): @xfail_inherited_tests([ "test_mode_tracing_factory_function", "test_make_fx_overloads", - "test_resnet18_backward_trace", "test_trace_subclasses", ]) class TestGenericProxyTensorSymbolic(TestGenericProxyTensor): @@ -1229,6 +1228,7 @@ def f(a, b, c, d, e): xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition + xfail('max_pool2d_with_indices_backward', ''), # (symint math failure) Given input size: (s0xs1x2). Calculated ... xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct... xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl... diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 380f85f568f7..33465217bbbc 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -390,6 +390,24 @@ def test_produce_real_type(self) -> None: $4 = torch._ops.aten.select.int($3, 1, 1) $5 = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)''') + def test_optional_tensor_list(self) -> None: + def weird(xs): + print("woof") + return torch.empty(()) + + my_lib = Library("my_lib", "DEF") + my_lib.define("weird(Tensor?[] self) -> Tensor") + my_lib.impl("weird", weird, "CPU") + with capture_logs() as logs: + x = LoggingTensor(torch.ones(2, 2)) + log_input("x", x) + torch.ops.my_lib.weird.default([None, x]) + + self.assertExpectedInline('\n'.join(logs), '''\ +$0 = input('x') +$1 = torch._ops.my_lib.weird.default([None, LoggingTensor(tensor([[1., 1.], + [1., 1.]]))])''') + def test_list_ret(self) -> None: # test all sequence types are permissible returns for list_type in (list, tuple): diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index cc5044da0bd5..d2e3c5fc3851 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -710,7 +710,7 @@ def _generate_invalid_input(self, layout, device): shape((2, 3)), 'compressed_indices must have dimensionality >= 1 but got 0') - yield ('compressed/plain_indices mismatch of dimensionalites', + yield ('compressed/plain_indices mismatch of dimensionalities', tensor([[0, 2, 4]]), tensor([0, 1, 0, 2]), values([1, 2, 3, 4]), @@ -718,14 +718,14 @@ def _generate_invalid_input(self, layout, device): 'compressed_indices and plain_indices dimensionalities must be equal but got 2 and 1, respectively') if layout in {torch.sparse_csr, torch.sparse_csc}: - yield ('indices and values mismatch of dimensionalites', + yield ('indices and values mismatch of dimensionalities', tensor([[0, 2, 4]]), tensor([[0, 1, 0, 2]]), values([1, 2, 3, 4]), shape((2, 3)), r'values must have dimensionality > sum of batch and block dimensionalities \(=1 \+ 0\) but got 1') else: - yield ('indices and values mismatch of dimensionalites', + yield ('indices and values mismatch of dimensionalities', tensor([[0, 2, 4]]), tensor([[0, 1, 0, 2]]), values([1, 2, 3, 4]), @@ -737,7 +737,7 @@ def _generate_invalid_input(self, layout, device): tensor([0, 1, 0, 2]), values([1, 2, 3, 4]), (2,), - r'tensor dimensionality must be sum of batch, base, and dense dimensionalites \(=0 \+ 2 \+ 0\) but got 1') + r'tensor dimensionality must be sum of batch, base, and dense dimensionalities \(=0 \+ 2 \+ 0\) but got 1') yield ('invalid batchsize', tensor([[0, 2, 4]]), diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 2d20da2a04f3..5833d7d7f2a4 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -972,11 +972,14 @@ class AggregationType(Enum): AVG = 1 class FileCheck(object): - # TODO (add more FileCheck signature) - def check_source_highlighted(self, highlight: str) -> 'FileCheck': ... def run(self, test_string: str) -> None: ... def check(self, test_string: str) -> 'FileCheck': ... def check_not(self, test_string: str) -> 'FileCheck': ... + def check_same(self, test_string: str) -> 'FileCheck': ... + def check_next(self, test_string: str) -> 'FileCheck': ... + def check_count(self, test_string: str, count: _int, exactly: _bool = False) -> 'FileCheck': ... + def check_dag(self, test_string: str) -> 'FileCheck': ... + def check_source_highlighted(self, test_string: str) -> 'FileCheck': ... ... # Defined in torch/csrc/jit/python/init.cpp diff --git a/torch/__init__.py b/torch/__init__.py index 19be59282cca..6def80d1dc59 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -204,8 +204,6 @@ class SymInt: """ def __init__(self, node): - from torch.fx.experimental.symbolic_shapes import SymNode - assert isinstance(node, SymNode) # This field MUST be named node; C++ binding code assumes that this # class has a field named node that stores SymNode self.node = node diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index fe63e0db007a..1a2d332e99fd 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -2261,9 +2261,7 @@ def matmul(tensor1, tensor2): t2_is_matrix = t2.dim() == 2 if t2_is_matrix: output_shape.append(t2.shape[1]) - # HACK: We need reshape with symint support - t1 = t1.contiguous() - t1_folded = t1.view(folded_dim1, sizes_1[-1]) + t1_folded = t1.reshape(folded_dim1, sizes_1[-1]) if t2_is_matrix: # FIXME This path always does an unnecessary copy when transpose == True as the returned # result from BLAS is already C-transposed @@ -2296,15 +2294,11 @@ def matmul(tensor1, tensor2): expand_batch_product = prod(expand_batch_portion) # HACK: We need reshape with symint support - tensor1_expanded = ( - tensor1.expand(tensor1_expand_size) - .contiguous() - .view(expand_batch_product, n, m1) + tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape( + expand_batch_product, n, m1 ) - tensor2_expanded = ( - tensor2.expand(tensor2_expand_size) - .contiguous() - .view(expand_batch_product, m2, p) + tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape( + expand_batch_product, m2, p ) output_shape = expand_batch_portion diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 2ba29981c366..e469ce02ebd6 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -14,6 +14,7 @@ from .variables.base import VariableTracker from .variables.nn_module import NNModuleVariable from .variables.tensor import ( + DynamicShapeVariable, TensorVariable, TensorWithTFOverrideVariable, UnspecializedNumpyVariable, @@ -95,6 +96,7 @@ def __call__(self, value, allow_cache=True): value, ( TensorVariable, + DynamicShapeVariable, TensorWithTFOverrideVariable, UnspecializedNumpyVariable, UnspecializedPythonVariable, diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index f1ce83727a19..c612fe3c167d 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -156,7 +156,11 @@ def has_tensor(obj): seen_ids[obj_id] = any([has_tensor(v) for v in obj]) return seen_ids[obj_id] elif istype(obj, dict): - seen_ids[obj_id] = any([has_tensor(v) for v in obj.values()]) + # Some packages like pytest can be updated during runtime. So, make a + # copy of values to avoid issues like "RuntimeError: dictionary + # changed size during iteration" + values = list(obj.values()) + seen_ids[obj_id] = any([has_tensor(v) for v in values]) return seen_ids[obj_id] elif istype(obj, (str, int, float, type(None), bool)): seen_ids[obj_id] = False @@ -164,9 +168,6 @@ def has_tensor(obj): elif is_namedtuple(obj): seen_ids[obj_id] = any([has_tensor(getattr(obj, v)) for v in obj._fields]) return seen_ids[obj_id] - elif not is_allowed(obj) and hasattr(obj, "__dict__") and len(obj.__dict__): - seen_ids[obj_id] = any([has_tensor(v) for v in obj.__dict__.values()]) - return seen_ids[obj_id] else: # if config.debug: # print( @@ -302,6 +303,7 @@ def _convert_frame_assert(frame: types.FrameType, cache_size: int): # setattr could be tricky to handle generally, # but also not likely useful to compile- skip the whole frame return None + # Check if the frame is generated by an exec builtin call # TODO - Running exec generated frame seems propagates f_globals to the # next frames. diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 9edd6f60560d..9cbcb93fcc5c 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -92,7 +92,7 @@ def __hash__(self): def sort_key(self): return ( - self.source.value, + self.source.value if self.source else -1, len(self.name), self.name, self.create_fn.__code__.co_firstlineno, @@ -101,13 +101,38 @@ def sort_key(self): def __lt__(self, other): return self.sort_key() < other.sort_key() + @staticmethod + def weakref_to_str(obj_weakref): + """ + This is a workaround of a Python weakref bug. + + `obj_weakref` is instance returned by `weakref.ref`, + `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g: + + class MyConfig(dict): + def __getattr__(self, x): + return self[x] + + obj = MyConfig(offset=5) + obj_weakref = weakref.ref(obj) + str(obj_weakref) # raise error: KeyError: '__name__' + """ + if isinstance(obj_weakref, weakref.ReferenceType): + obj = obj_weakref() + if obj is not None: + return f"" + else: + return f"" + else: + return str(obj_weakref) + def __str__(self): s = f""" - {self.source.name.lower()} {repr(self.name)} {self.create_fn.__name__} + {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.create_fn.__name__} {{ 'guard_types': {self.guard_types}, 'code': {self.code_list}, - 'obj_weakref': {self.obj_weakref} + 'obj_weakref': {self.weakref_to_str(self.obj_weakref)} 'guarded_class': {self.guarded_class_weakref} }} """ @@ -413,6 +438,13 @@ def GRAD_MODE(self, guard: Guard): code = "not ___is_grad_enabled()" self._produce_guard_code(guard, [code]) + # This is a bit of a crutch for export case for symbolic shape guards. + # SYMBOL_MATCH is only ever, and must only ever, be used for setting this value on + # the create_fn field for tracking guards in export. + @staticmethod + def SYMBOL_MATCH(): + pass + def TENSOR_MATCH(self, guard: Guard): if guard.is_nn_module(): self.ID_MATCH(guard) @@ -512,18 +544,22 @@ def tensor_ref_as_str(tensor_ref, id_to_name_map): return f"{id_to_name_map[tensor_ref.ref_id]}.{tensor_ref.kind}()[{tensor_ref.idx}]" return f"{id_to_name_map[tensor_ref.ref_id]}.{tensor_ref.kind}()" - def __init__(self, expr_to_tensor_ref, id_to_name_map): + def __init__( + self, expr_to_tensor_ref, id_to_name_map, shape_env, intermediary_symbols + ): super().__init__() self.expr_to_tensor_ref = expr_to_tensor_ref self.id_to_name_map = id_to_name_map + self.shape_env = shape_env + self.intermediary_symbols = intermediary_symbols def _print_Symbol(self, expr) -> str: - assert isinstance(expr, sympy.core.symbol.Symbol) + assert isinstance(expr, sympy.Symbol) if expr == 0: return "0" if expr == 1: return "1" - assert expr in self.expr_to_tensor_ref, f"Unknown expression {expr}" + assert expr in (self.expr_to_tensor_ref) or (expr in self.intermediary_symbols) refs = self.expr_to_tensor_ref[expr] if len(refs) == 0: return super()._print_Symbol(expr) @@ -574,7 +610,7 @@ def combine_scopes(left, right): if not config.guard_nn_modules and guard.is_nn_module(): continue guard.create(local_builder, global_builder) - self.check_fn = self.compile_check_fn(local_builder, global_builder) + self.check_fn = self.compile_check_fn(local_builder, global_builder, guards) self._seen_ids.clear() """ @@ -607,7 +643,12 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids return None expr_to_tensor_ref = {} - guard_printer = DynamoGuardPrinter(expr_to_tensor_ref, id_to_name_map) + guard_printer = DynamoGuardPrinter( + expr_to_tensor_ref, + id_to_name_map, + self.output_graph.shape_env, + self.output_graph.intermediary_symbols, + ) # tensor_check_names is the primary tensor association mechanism in dynamo. # All other guards installations are driven off of it, so these ones will too. @@ -624,7 +665,6 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids if obj_expr not in expr_to_tensor_ref: expr_to_tensor_ref[obj_expr] = {} expr_to_tensor_ref[obj_expr][tensor_ref] = "" - finished_expressions.append(f"isinstance({name}, torch.Tensor)") guard_expression = self.output_graph.shape_env.get_guard_expr() expr_as_str = guard_printer.doprint(guard_expression) @@ -643,7 +683,6 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids if len(equality_candidates) > 1: equality_expr = " == ".join(equality_candidates) - # breakpoint() finished_expressions.append(equality_expr) # Redundant with code_parts, but allows us to wrap it with parens nicely. @@ -653,7 +692,7 @@ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids expression = " and ".join(finished_expressions) return f"({expression})" - def compile_check_fn(self, local_builder, global_builder): + def compile_check_fn(self, local_builder, global_builder, guards_out): assert not (set(local_builder.argnames) & set(global_builder.argnames)) # see parallel handling of ".0" / "___implicit0" in _eval_frame.c args = [a for a in local_builder.scope.keys() if a == "___implicit0"] @@ -682,10 +721,6 @@ def compile_check_fn(self, local_builder, global_builder): symbolic_shape_expression = self._parse_symbolic_shape_expressions( tensor_check_names, tensor_check_ids ) - if symbolic_shape_expression: - code_parts.append(symbolic_shape_expression) - verbose_code_parts.append(symbolic_shape_expression) - tensor_check_examples = ( local_builder.tensor_check_examples + global_builder.tensor_check_examples @@ -700,6 +735,17 @@ def compile_check_fn(self, local_builder, global_builder): tensor_check_names + ["tensor_check_names=tensor_check_names"] ) verbose_code_parts.append(f"___check_tensors_verbose({verbose_args})") + if symbolic_shape_expression: + code_parts.append(symbolic_shape_expression) + verbose_code_parts.append(symbolic_shape_expression) + guards_out.add( + Guard( + name="symbolic_shape_expression", + source=None, + create_fn=GuardBuilder.SYMBOL_MATCH, + code_list=symbolic_shape_expression, + ) + ) def direct_equality(a, b): return a == b @@ -714,6 +760,8 @@ def direct_negation(a, b): ("___check_tensors", check_tensors_fn), ("___check_tensors_verbose", check_tensors_verbose_fn), ("tensor_check_names", tensor_check_names), + ("floor", math.floor), + ("ceiling", math.ceil), ("Eq", direct_equality), ("Ne", direct_negation), ("Mod", sympy.Mod), diff --git a/torch/_dynamo/optimizations/analysis.py b/torch/_dynamo/optimizations/analysis.py index b3f6ed79eb06..c4ed04ca8c39 100644 --- a/torch/_dynamo/optimizations/analysis.py +++ b/torch/_dynamo/optimizations/analysis.py @@ -15,7 +15,7 @@ if fake_tensors_available: from torch._subclasses import FakeTensorMode # noqa: F401 - from ..utils import deepcopy_to_fake_tensor, wrap_to_fake_tensor + from ..utils import deepcopy_to_fake_tensor class ShapeAliasingAndMutationProp(ShapeProp): @@ -122,9 +122,26 @@ def has_mutation(gm, example_inputs, inputs_only=False): # TODO - moco gives bad accuracy with Aliasing. gm is getting mutated in a bad way. if fake_tensors_available and config.fake_tensor_propagation: - with FakeTensorMode() as fake_mode: - pass - fake_wrapper = functools.partial(wrap_to_fake_tensor, fake_mode=fake_mode) + + def _wrap_to_fake_tensor(t, *, f_mode): + if type(t) in (torch.Tensor, torch.nn.Parameter): + static_shapes_ = config.dynamic_shapes is False + return fake_mode.from_tensor( + t, static_shapes=config.dynamic_shapes is not False + ) + else: + return t + + # Our analysis pass should use dynamic shape tensor inputs + # when dynamic shapes are enabled. + # We don't actually care about the guards that are created + # on those shapes though, so just create a fresh ShapeEnv here. + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + fake_mode = FakeTensorMode( + shape_env=ShapeEnv() if config.dynamic_shapes else None + ) + fake_wrapper = functools.partial(_wrap_to_fake_tensor, f_mode=fake_mode) example_inputs = tree_map(fake_wrapper, example_inputs) new_gm = deepcopy_to_fake_tensor(gm, fake_mode) with fake_mode.restore() if hasattr(fake_mode, "restore") else fake_mode: diff --git a/torch/_dynamo/optimizations/training.py b/torch/_dynamo/optimizations/training.py index 49f9a4397dd9..a56a74ad5aea 100644 --- a/torch/_dynamo/optimizations/training.py +++ b/torch/_dynamo/optimizations/training.py @@ -140,9 +140,13 @@ class AotNop(AotAutogradStrategy): """Useful for debugging purpose""" def candidate(self): + from functorch._src.compilers import debug_nop from functorch.compile import nop - return BACKENDS["aot_autograd"](self.gm, self.example_inputs, fw_compiler=nop) + DEBUG = False + return BACKENDS["aot_autograd"]( + self.gm, self.example_inputs, fw_compiler=debug_nop if DEBUG else nop + ) aot_eager = AotNop.compile_fn diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 9dd9a713a25c..ee5079581be7 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -6,7 +6,7 @@ import re import traceback from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch.nn from torch import fx @@ -15,7 +15,7 @@ from . import config, logging as torchdynamo_logging, variables from .bytecode_transformation import create_instruction, Instruction, unique_id from .codegen import PyCodegen -from .exc import BackendCompilerFailed, unimplemented +from .exc import BackendCompilerFailed from .guards import GuardBuilder from .mutation_guard import is_dynamic_nn_module from .side_effects import SideEffects @@ -27,9 +27,10 @@ fake_tensors_available, format_graph_tabular, ) -from .variables.builder import VariableBuilder +from .variables.builder import VariableBuilder, wrap_fx_proxy from .variables.nn_module import NNModuleVariable from .variables.tensor import ( + DynamicShapeVariable, TensorVariable, UnspecializedNumpyVariable, UnspecializedPythonVariable, @@ -93,7 +94,7 @@ def __init__( self.side_effects = SideEffects() self.code_options = dict(code_options) self.output_instructions = [] - # Node => computed real value (see TensorVariable.get_real_value) + # Node => computed real value (see utils.get_real_value) self.real_value_cache = {} # Not checkpointed @@ -107,6 +108,7 @@ def __init__( self.unspec_variable_map = {} self.shape_env = ShapeEnv() if config.dynamic_shapes else None self.tensor_id_to_sym_shape_ref = {} + self.intermediary_symbols = {} @property def output(self): @@ -194,43 +196,63 @@ def update_co_names(self, name): name, ) - def register_attr_or_module(self, mod: torch.nn.Module, *names, **options): - if is_dynamic_nn_module(mod): - return variables.UnspecializedNNModuleVariable(mod, **options) + def register_attr_or_module( + self, target: Union[torch.nn.Module, torch.Tensor, Any], *names, **options + ): + if is_dynamic_nn_module(target): + return variables.UnspecializedNNModuleVariable(target, **options) options = dict(options) options["guards"] = set(options.get("guards", [])) source: Source = options.get("source", None) - if isinstance(mod, torch.Tensor): + if isinstance(target, torch.Tensor): if source: options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH)) def wrap_name(module_key): - return TensorVariable.create( + return wrap_fx_proxy( self, self.create_proxy("get_attr", module_key, tuple(), {}), - example_value=mod, + example_value=target, **options, ) - elif isinstance(mod, torch.nn.Module): - assert isinstance(mod, torch.nn.Module) + elif isinstance(target, torch.nn.Module): + assert isinstance(target, torch.nn.Module) options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE)) def wrap_name(module_key): - return NNModuleVariable(type(mod), module_key, **options) + return NNModuleVariable(type(target), module_key, **options) + + elif isinstance(target, (torch.SymInt, torch.SymFloat)): + # HACKY CODE REGION BEGIN + # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS + # This ultimately gets written to self.nn_modules, which is unfortunate + # Attrs that are tenors and symints and such need to be migrated to have their + # own storage + # alas, this is like this for now + self.intermediary_symbols.update({target.get_pyobj().expr: None}) + + def wrap_name(module_key): + return DynamicShapeVariable.create( + self, + self.create_proxy("get_attr", module_key, tuple(), {}), + dyn_shape=target, + **options, + ) + # HACKY CODE REGION END else: def wrap_name(module_key): self.output.update_co_names(module_key) - self.root_globals[module_key] = mod + self.root_globals[module_key] = target return VariableBuilder(self, ConstantSource(source_name=module_key))( - mod + target ) for k, v in self.nn_modules.items(): - if v is mod: + if v is target: # it already exists return wrap_name(k) @@ -246,7 +268,7 @@ def wrap_name(module_key): base = name for i in itertools.count(): if name not in self.nn_modules: - self.nn_modules[name] = mod + self.nn_modules[name] = target return wrap_name(name) name = f"{base}_{i}" diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index e06f62a6bf62..d707bee930ee 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -55,7 +55,7 @@ istype, ) from .variables.base import MutableLocal, typestr, VariableTracker -from .variables.builder import VariableBuilder +from .variables.builder import VariableBuilder, wrap_fx_proxy from .variables.builtin import BuiltinVariable from .variables.constant import ConstantVariable from .variables.dicts import ConstDictVariable @@ -81,7 +81,7 @@ WithExitFunctionVariable, ) from .variables.nn_module import NNModuleVariable -from .variables.tensor import TensorVariable +from .variables.tensor import DynamicShapeVariable, TensorVariable from .variables.torch import TorchVariable from .variables.user_defined import UserDefinedVariable @@ -129,8 +129,17 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if truth_fn(value.as_python_constant()): push and self.push(value) self.jump(inst) - elif isinstance(value, TensorVariable) and self.should_compile_partial_graph(): + elif ( + isinstance(value, (TensorVariable)) and self.should_compile_partial_graph() + ): # compile a partial subgraph prefix then jump into user code + if self.has_backedge(): + msg = ( + "Skipping frame because there is a graph break in a for/while loop" + ) + log.debug(msg) + raise exc.SkipFrame(msg) + self.push(value) self.output.compile_subgraph( self, @@ -155,6 +164,11 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if truth_fn(len(value.unpack_var_sequence(self))): push and self.push(value) self.jump(inst) + elif isinstance(value, DynamicShapeVariable): + eval_result = value.evaluate_expr(self.output) + if truth_fn(eval_result): + push and self.push(value) + self.jump(inst) else: unimplemented(f"generic_jump {typestr(value)}") @@ -172,10 +186,15 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): reason = None try: return inner_fn(self, inst) - except Unsupported as exc: + except Unsupported as excp: + if self.has_backedge(): + msg = "Skipping frame because there is a graph break in a for/while loop" + log.debug(msg) + raise exc.SkipFrame(msg) + if not self.should_compile_partial_graph(): raise - user_stack = [self.frame_summary()] + list(reversed(exc.real_stack)) + user_stack = [self.frame_summary()] + list(reversed(excp.real_stack)) user_stack_formatted = "".join(traceback.format_list(user_stack)) frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) # torch._dynamo.explain() formats this a little nicer, and presents a slightly @@ -186,12 +205,12 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): and graph_break_dup_warning_checker.add(frame_loc) ): log.warning( - f"Graph break: {exc} from user code at {user_stack_formatted}" + f"Graph break: {excp} from user code at {user_stack_formatted}" ) - exc.remove_from_stats() - exc.add_to_stats("graph_break") - reason = GraphCompileReason(exc.msg, user_stack) + excp.remove_from_stats() + excp.add_to_stats("graph_break") + reason = GraphCompileReason(excp.msg, user_stack) self.restore_graphstate(state) self.output.compile_subgraph(self, reason=reason) self.popn(push - dis.stack_effect(inst.opcode, inst.arg)) @@ -230,6 +249,19 @@ def wrapper(self: "InstructionTranslatorBase", inst: Instruction): class InstructionTranslatorBase(object): + def has_backedge(self): + cur_offset = self.current_instruction.offset + for inst in self.instructions[self.instruction_pointer :]: + if inst.opname in ( + "JUMP_ABSOLUTE", + "POP_JUMP_IF_TRUE", + "POP_JUMP_IF_FALSE", + ): + jump_offset = inst.argval + if jump_offset < cur_offset: + return True + return False + def cell_and_freevars(self): if not hasattr(self, "_cell_and_freevars"): self._cell_and_freevars = tuple( @@ -700,6 +732,7 @@ def COMPARE_OP(self, inst): left, ( TensorVariable, + DynamicShapeVariable, NNModuleVariable, BaseListVariable, UserDefinedVariable, @@ -717,16 +750,6 @@ def COMPARE_OP(self, inst): supported_is_const[op](object(), right.value), **options ) ) - elif ( - isinstance(left, TensorVariable) or isinstance(right, TensorVariable) - ) and op in supported_tensors: - self.push( - TensorVariable.create( - self, - supported_tensors[op](left.as_proxy(), right.as_proxy()), - **options, - ) - ) elif ( left.is_python_constant() and right.is_python_constant() @@ -741,6 +764,28 @@ def COMPARE_OP(self, inst): **options, ) ) + elif ( + isinstance(left, TensorVariable) or isinstance(right, TensorVariable) + ) and op in supported_tensors: + self.push( + wrap_fx_proxy( + self, + supported_tensors[op](left.as_proxy(), right.as_proxy()), + **options, + ) + ) + elif ( + isinstance(left, DynamicShapeVariable) + or isinstance(right, DynamicShapeVariable) + ) and op in supported_tensors: + self.push( + DynamicShapeVariable.create( + self, + supported_tensors[op](left.as_proxy(), right.as_proxy()), + dyn_shape=None, + **options, + ) + ) elif op in ("in", "not in"): self.push(right.call_method(self, "__contains__", [left], {})) if op == "not in": @@ -1029,12 +1074,12 @@ def UNPACK_SEQUENCE(self, inst): elif isinstance(seq, TensorVariable): proxy = seq.as_proxy() for i in reversed(range(inst.argval)): - self.push(TensorVariable.create(self, proxy[i], **options)) + self.push(wrap_fx_proxy(self, proxy[i], **options)) elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable): # x, y = a.shape proxy = getattr(seq.obj.as_proxy(), seq.name) for i in reversed(range(inst.argval)): - self.push(TensorVariable.create(self, proxy[i], **options)) + self.push(wrap_fx_proxy(self, proxy[i], **options)) else: unimplemented(f"UNPACK_SEQUENCE {seq}") @@ -1109,7 +1154,8 @@ def FORMAT_VALUE(self, inst): fmt_spec = ConstantVariable("") value = self.pop() - + if isinstance(value, DynamicShapeVariable): + value = ConstantVariable(str(value.dyn_shape)) if (flags & 0x03) == 0x01: value = BuiltinVariable(str).call_function(self, [value], {}) elif (flags & 0x03) == 0x02: diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 067a80807374..0b87be7393b5 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -29,7 +29,9 @@ import torch from torch import fx +from torch._dispatch.python import enable_python_dispatcher from torch.nn.modules.lazy import LazyModuleMixin +from torch.utils._pytree import tree_map from . import config, logging as torchdynamo_logging @@ -679,10 +681,8 @@ def rename_implicit(v): UnsupportedFakeTensorException, ) - def make_fake_tensor(e, fake_mode, tx=None): - fake_tensor = fake_mode.from_tensor( - e, static_shapes=config.dynamic_shapes is False - ) + def make_fake_tensor(e, fake_mode, static_shapes=False, tx=None): + fake_tensor = fake_mode.from_tensor(e, static_shapes=static_shapes) if tx is not None: from torch._dynamo.guards import TensorReference @@ -728,13 +728,23 @@ def wrap_fake_exception(fn): def wrap_to_fake_tensor(e, fake_mode): if type(e) in (torch.Tensor, torch.nn.Parameter): - return wrap_fake_exception(lambda: make_fake_tensor(e, fake_mode)) + return wrap_fake_exception( + lambda: make_fake_tensor( + e, fake_mode, static_shapes=config.dynamic_shapes is False + ) + ) else: return e def wrap_to_fake_tensor_and_record(e, tx): if type(e) in (torch.Tensor, torch.nn.Parameter): - return wrap_fake_exception(lambda: make_fake_tensor(e, tx.fake_mode, tx)) + static_shapes = config.dynamic_shapes is False + if type(e) is torch.nn.Parameter: + # Always static for params + static_shapes = True + return wrap_fake_exception( + lambda: make_fake_tensor(e, tx.fake_mode, static_shapes, tx) + ) else: return e @@ -997,3 +1007,116 @@ def _get_debug_dir(root_dir): def get_debug_dir(): debug_root = config.debug_dir_root return _get_debug_dir(debug_root) + + +def get_fake_value(node, tx): + """ + Run the computation represented by `node` using fake tensors and return the result. + """ + from .exc import TorchRuntimeError, unimplemented, Unsupported + + op = node.op + fake_wrapper = functools.partial(wrap_to_fake_tensor_and_record, tx=tx) + + def visit(n: torch.fx.Node): + return n.meta["example_value"] + + args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit) + args = tree_map(fake_wrapper, args) + kwargs = tree_map(fake_wrapper, kwargs) + + nnmodule = None + if op == "call_module": + nnmodule = tx.output.nn_modules[node.target] + + if not is_lazy_module(nnmodule): + nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) + + if op == "call_module" and is_lazy_module(nnmodule): + assert nnmodule is not None + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it + nnmodule(*args, **kwargs) + try: + with tx.fake_mode, enable_python_dispatcher(): + return wrap_fake_exception( + lambda: run_node(tx.output, node, args, kwargs, nnmodule) + ) + except Unsupported: + raise + except RuntimeError as e: + if isinstance(e, torch._subclasses.fake_tensor.DataDependentOutputException): + if config.capture_scalar_outputs and node.target == "item": + return torch.zeros(size=(), dtype=args[0].dtype).item() + else: + unimplemented(f"data dependent operator: {e.func}") + elif isinstance(e, torch._subclasses.fake_tensor.DynamicOutputShapeException): + unimplemented(f"dynamic shape operator: {e.func}") + raise TorchRuntimeError() from e + + +def run_node(output_graph, node, args, kwargs, nnmodule): + """ + Runs a given node, with the given args and kwargs. + + Behavior is dicatated by a node's op. + + run_node is useful for extracting real values out of nodes. + See get_real_value for more info on common usage. + + Note: The output_graph arg is only used for 'get_attr' ops + Note: The nnmodule arg is only used for 'call_module' ops + + Nodes that are not call_function, call_method, call_module, or get_attr will + raise an AssertionError. + """ + op = node.op + try: + if op == "call_function": + return node.target(*args, **kwargs) + elif op == "call_method": + return getattr(args[0], node.target)(*args[1:], **kwargs) + elif op == "call_module": + assert nnmodule is not None + return nnmodule(*args, **kwargs) + elif op == "get_attr": + return output_graph.get_submodule(node.target) + except Exception as e: + raise RuntimeError( + f"Failed running {op} {node.target}(*{args}, **{kwargs}):\n{e}\n(scroll up for backtrace)" + ) from e + raise AssertionError(op) + + +def get_real_value(node, output_graph): + """ + Run the actual computation represented by `node` and return the result. + This will execute any dependent nodes in the graph as well. + """ + cache = output_graph.real_value_cache + if node in cache: + return cache[node] + + op = node.op + args, kwargs = torch.fx.node.map_arg( + (node.args, node.kwargs), + lambda n: get_real_value(n, output_graph), + ) + + if op == "call_module": + nn_module = output_graph.nn_modules[node.target] + if not is_lazy_module(nn_module): + nn_module = copy.deepcopy(nn_module) + else: + # In the case of a lazy module, we want to run + # the pre-hooks which initialize it + nn_module(*args, **kwargs) + else: + nn_module = None + + try: + real_value = run_node(output_graph, node, args, kwargs, nn_module) + cache[node] = real_value + except RuntimeError as e: + raise TorchRuntimeError() from e + return real_value diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 8c80557e3fd0..2305afc226ac 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -35,6 +35,7 @@ ) from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable from .tensor import ( + DynamicShapeVariable, FakeItemVariable, TensorVariable, UnspecializedNumpyVariable, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index d3c5140fa4a9..67e506b5b435 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -3,15 +3,19 @@ import enum import functools import inspect +import math +import numbers +import operator import re import types from abc import ABCMeta -from typing import Any, List +from typing import Any, List, Union import numpy as np from functorch.experimental.ops import PyOperator import torch +from torch.fx.immutable_collections import immutable_list from .. import config, mutation_guard, replay_record, skipfiles from ..allowed_functions import is_allowed, is_builtin_callable, is_numpy @@ -31,6 +35,10 @@ TupleIteratorGetItemSource, ) from ..utils import ( + clone_input, + fake_tensors_available, + get_fake_value, + get_real_value, getfile, global_key_name, is_namedtuple, @@ -38,11 +46,14 @@ istensor, istype, odict_values, + preserve_rng_state, tuple_iterator, tuple_iterator_getitem, tuple_iterator_len, + wrap_to_fake_tensor_and_record, ) -from .base import MutableLocal + +from .base import MutableLocal, typestr from .builtin import BuiltinVariable from .constant import ConstantVariable, EnumVariable from .dicts import ( @@ -57,6 +68,7 @@ ListVariable, NamedTupleVariable, RangeVariable, + SizeVariable, SliceVariable, TupleVariable, ) @@ -72,6 +84,7 @@ ) from .nn_module import UnspecializedNNModuleVariable from .tensor import ( + DynamicShapeVariable, TensorVariable, TensorWithTFOverrideVariable, UnspecializedNumpyVariable, @@ -86,6 +99,10 @@ from .user_defined import UserDefinedClassVariable, UserDefinedObjectVariable +class _missing: + pass + + @dataclasses.dataclass class GraphArg: source: Source @@ -187,6 +204,8 @@ def make_guards(self, *guards): def _wrap(self, value): make_guards = self.make_guards + if istype(value, (torch.SymInt, torch.SymFloat)): + return self.wrap_sym(value) if istensor(value): return self.wrap_tensor(value) elif istype(value, (tuple, list, odict_values)) or is_namedtuple(value): @@ -490,6 +509,26 @@ def tensor_should_specialize(self): ) ) + def wrap_sym(self, value: Union[torch.SymInt, torch.SymFloat]): + if not is_constant_source(self.get_source()): + self.tx.output.graphargs.append(GraphArg(self.get_source(), value, False)) + elif is_constant_source(self.get_source()): + return self.tx.output.register_attr_or_module( + value, + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), + source=None, + dyn_shape=value + # shape Guards live their own rich life via shape_env + ) + return DynamicShapeVariable.create( + tx=self.tx, + proxy=self.tx.output.create_graph_input( + re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value) + ), + dyn_shape=value + # shape Guards live their own rich life via shape_env + ) + def wrap_tensor(self, value: torch.Tensor): if self.get_source().guard_source().is_nn_module(): return self.tx.output.register_attr_or_module( @@ -514,7 +553,7 @@ def wrap_tensor(self, value: torch.Tensor): source=None, # Guards are added inside register_attr_or_module ) - tensor_variable = TensorVariable.create( + tensor_variable = wrap_fx_proxy( tx=self.tx, proxy=self.tx.output.create_graph_input( re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value) @@ -556,14 +595,16 @@ def wrap_unspecialized_primitive(self, value): ) if isinstance(value, np.number): - unspec_var = UnspecializedNumpyVariable.create( + unspec_var = wrap_fx_proxy_cls( + UnspecializedNumpyVariable, tx=self.tx, proxy=proxy, example_value=wrapped_value, **options, ) else: - unspec_var = UnspecializedPythonVariable.create( + unspec_var = wrap_fx_proxy_cls( + UnspecializedPythonVariable, tx=self.tx, proxy=proxy, example_value=wrapped_value, @@ -589,3 +630,190 @@ def _dataclasses_fields_lambda(obj): ) items.append(UserDefinedObjectVariable(field, source=source).add_options(obj)) return TupleVariable(items).add_options(obj) + + +def wrap_fx_proxy(tx, proxy, example_value=None, **options): + return wrap_fx_proxy_cls( + target_cls=TensorVariable, + tx=tx, + proxy=proxy, + example_value=example_value, + **options, + ) + + +# Note: Unfortunate split due to some gross classes existing that subclass TensorVariable +# Should be compositional instead +def wrap_fx_proxy_cls(target_cls, tx, proxy, example_value=None, **options): + if "guards" in options and options["guards"] is not None: + tx.output.guards.update(options["guards"]) + + assert "example_value" not in proxy.node.meta + if not config.dynamic_propagation: + if isinstance(example_value, torch.Tensor): + options.update(target_cls.specialize(example_value)) + return target_cls(proxy, **options) + + use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation + + initial_example_value = example_value + + def _clone_input(value): + if isinstance(value, torch.Tensor): + use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation + # tensor subclasses will not be converted to FakeTensors and need to be cloned + if not use_fake_tensors or not isinstance( + value, torch._subclasses.fake_tensor.FakeTensor + ): + # NB: ensure strides are preserved + value = clone_input(value) + + return value + + with preserve_rng_state(): + if example_value is None: + if use_fake_tensors: + example_value = get_fake_value(proxy.node, tx) + else: + example_value = get_real_value(proxy.node, tx.output) + + else: + proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value) + if use_fake_tensors: + fake_wrapper = functools.partial(wrap_to_fake_tensor_and_record, tx=tx) + example_value = fake_wrapper(example_value) + + if isinstance(example_value, torch.Tensor): + is_parameter = isinstance(example_value, torch.nn.Parameter) + should_specialize = options.pop("should_specialize", False) + if is_parameter or should_specialize: + specialized_value = initial_example_value + else: + specialized_value = None + + example_value = _clone_input(example_value) + proxy.node.meta["example_value"] = example_value + specialized_props = target_cls.specialize(example_value) + if use_fake_tensors and isinstance( + example_value, torch._subclasses.fake_tensor.FakeTensor + ): + specialized_props["class_type"] = ( + torch.nn.Parameter if is_parameter else torch.Tensor + ) + + specialized_props["specialized_value"] = specialized_value + + options.update(specialized_props) + return target_cls(proxy, **options) + elif ( + hasattr(proxy.node.target, "__name__") + and proxy.node.target.__name__ == "set_state" + and isinstance(proxy.node.target.__self__, torch._C.Generator) + or proxy.node.target == torch.random.set_rng_state + ): + from . import TorchVariable + + return TorchVariable(proxy.node.target) + elif ( + proxy.node.target == torch._C._DisableFuncTorch + or proxy.node.target == torch.cuda._is_in_bad_fork + ): + from . import UserDefinedObjectVariable + + return UserDefinedObjectVariable(example_value) + elif istype(example_value, (int, bool, float)) and config.dynamic_shapes: + proxy.node.meta["example_value"] = example_value + return DynamicShapeVariable.create(tx, proxy, example_value, **options) + elif istype(example_value, torch.Size) and config.dynamic_shapes: + proxy.node.meta["example_value"] = example_value + sizes = [] + for i, v in enumerate(example_value): + proxy_i = proxy[i] + sizes.append(DynamicShapeVariable.create(tx, proxy_i, v, **options)) + return SizeVariable(sizes, proxy, **options) + elif istype(example_value, int) and proxy.node.target in ( + torch.seed, + operator.mod, + # some mac builds are missing torch.distributed.get_rank() + getattr(torch.distributed, "get_rank", _missing), + getattr(torch.distributed, "get_world_size", _missing), + ): + if config.dynamic_shapes: + proxy.node.meta["example_value"] = example_value + return DynamicShapeVariable.create(tx, proxy, example_value, **options) + else: + return ConstantVariable(example_value, **options) + elif istype(example_value, torch.Size) and all( + [isinstance(x, int) for x in example_value] + ): + sizes = [ConstantVariable(x) for x in example_value] + return SizeVariable(sizes, **options) + elif isinstance(example_value, (tuple, list)): + unpacked = [] + for i, val in enumerate(example_value): + if val is None: + # nn.MultiheadAttention() can return None, see issue #175 + unpacked.append( + ConstantVariable(None, **options), + ) + else: + unpacked.append( + wrap_fx_proxy( + tx, + proxy.tracer.create_proxy( + "call_function", operator.getitem, (proxy, i), {} + ), + example_value=val, + **options, + ) + ) + if istype(example_value, tuple): + return TupleVariable(unpacked, **options) + elif istype(example_value, (list, immutable_list)): + return ListVariable(unpacked, mutable_local=MutableLocal(), **options) + else: + assert ( + example_value.__class__.__module__ == "torch.return_types" + or hasattr(example_value, "_fields") + ), ("namedtuple?") + return NamedTupleVariable(unpacked, example_value.__class__, **options) + elif example_value is None or proxy.node.target is torch.manual_seed: + return ConstantVariable(None, **options) + elif ( + isinstance(example_value, int) + and proxy.node.target is torch._utils._element_size + ): + proxy.node.meta["example_value"] = example_value + return ConstantVariable(example_value, **options) + elif ( + isinstance(example_value, numbers.Number) + and (proxy.node.target == "item" or proxy.node.target in {math.sqrt, math.pow}) + and config.capture_scalar_outputs + ): + if use_fake_tensors: + # item raw value should not be accessed + return wrap_fx_proxy_cls( + FakeItemVariable, + tx=tx, + proxy=proxy, + example_value=torch.tensor(example_value), + **options, + ) + else: + return wrap_fx_proxy_cls( + UnspecializedPythonVariable, + tx=tx, + proxy=proxy, + example_value=torch.tensor(example_value), + raw_value=None if use_fake_tensors else example_value, + need_unwrap=False, + **options, + ) + elif isinstance(example_value, (torch.SymInt, torch.SymFloat)): + proxy.node.meta["example_value"] = example_value + return DynamicShapeVariable(proxy, example_value, **options) + else: + raise AssertionError( + "torch.* op returned non-Tensor " + + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}" + ) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 5a88f375c9c2..904ed8a49f81 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -10,6 +10,7 @@ import numpy as np import torch +from torch.fx.experimental.symbolic_shapes import sym_float, sym_int from .. import config, variables from ..allowed_functions import is_allowed @@ -26,7 +27,7 @@ ) from .base import MutableLocal, VariableTracker from .dicts import ConstDictVariable -from .tensor import DynamicShapeVariable, FakeItemVariable +from .tensor import DynamicShapeVariable, FakeItemVariable, UnspecializedPythonVariable log = logging.getLogger(__name__) @@ -226,6 +227,7 @@ def unwrap_unspec_args_kwargs(args, kwargs): def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": + from .builder import wrap_fx_proxy, wrap_fx_proxy_cls constant_args = check_constant_args(args, kwargs) tensor_args = self.tensor_args(*args, **kwargs) @@ -234,7 +236,7 @@ def call_function( has_constant_handler = self.can_constant_fold_through() and ( constant_args or unspec_python_args ) - assert isinstance(args, list) + assert isinstance(args, (list, tuple)) assert isinstance(kwargs, dict) if ( @@ -274,7 +276,8 @@ def call_function( "call_function", fn, *proxy_args_kwargs(args, kwargs), current_tx=tx ) if any([isinstance(arg, FakeItemVariable) for arg in args]): - return variables.FakeItemVariable.create( + return wrap_fx_proxy_cls( + FakeItemVariable, tx, proxy, **options, @@ -282,7 +285,8 @@ def call_function( elif self.unspec_numpy_args(*args, **kwargs): _args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) raw_value = self.fn(*_args, **_kwargs) - return variables.UnspecializedNumpyVariable.create( + return wrap_fx_proxy_cls( + variables.UnspecializedNumpyVariable, tx, proxy, raw_value=raw_value, @@ -298,7 +302,8 @@ def call_function( if isinstance(x, variables.UnspecializedPythonVariable) ) - return variables.UnspecializedPythonVariable.create( + return wrap_fx_proxy_cls( + UnspecializedPythonVariable, tx, proxy, raw_value=raw_value, @@ -312,14 +317,27 @@ def call_function( args[0], variables.UnspecializedPythonVariable ): args[0] = args[0].convert_to_constant(tx) - return variables.TensorVariable.create(tx, proxy, **options) + return wrap_fx_proxy(tx, proxy, **options) except NotImplementedError: unimplemented(f"partial tensor op: {self} {args} {kwargs}") # Handle cases like int(torch.seed()) - if self.fn is int and isinstance(args[0], DynamicShapeVariable): - return args[0] + # Also handle sym_float to sym_int cases + if self.fn in (int, float) and isinstance(args[0], DynamicShapeVariable): + fn_ = sym_int if self.fn is int else sym_float + out = wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + fn_, + (args[0].as_proxy(),), + {}, + current_tx=tx, + ), + **options, + ) + return out handler = getattr(self, f"call_{self.fn.__name__}", None) if handler: @@ -353,7 +371,6 @@ def call_function( ), **options, ) - return super().call_function(tx, args, kwargs) def _call_min_max(self, tx, a, b): @@ -368,7 +385,9 @@ def _call_min_max(self, tx, a, b): # Dynamic input does not get resolved, rather, gets stored as call_function if isinstance(a, DynamicShapeVariable): - return variables.TensorVariable.create( + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -437,7 +456,13 @@ def _call_min_max(self, tx, a, b): return variables.ConstantVariable(max(a.value, b.value)) else: return variables.ConstantVariable(min(a.value, b.value)) + elif isinstance(a, DynamicShapeVariable) or isinstance(b, DynamicShapeVariable): + proxy = tx.output.create_proxy( + "call_function", self.fn, *proxy_args_kwargs([a, b], {}) + ) + return DynamicShapeVariable.create(tx, proxy, None) else: + unimplemented(f"unsupported min / max over args {str(a)}, {str(b)}") call_min = _call_min_max @@ -454,11 +479,48 @@ def call_range(self, tx, *args, **kwargs): **{k: v.value for k, v in kwargs.items()}, ), ) + elif self._dynamic_args(*args, **kwargs): + assert len(kwargs) == 0 + + def guard_if_dyn(arg): + if isinstance(arg, DynamicShapeVariable): + return arg.evaluate_expr(tx.output) + return arg + + args = [guard_if_dyn(arg) for arg in args] + value = self.fn(*args) + return variables.RangeVariable(value=value) + # None no-ops this handler and lets the driving function proceed + return None + + def _dynamic_args(self, *args, **kwargs): + return any([isinstance(x, DynamicShapeVariable) for x in args]) or any( + [isinstance(x, DynamicShapeVariable) for x in kwargs.values()] + ) def call_slice(self, tx, *args): return variables.SliceVariable(args) - def _call_iter_tuple_list(self, tx, obj=None): + def _dyn_proxy(self, tx, *args, **kwargs): + assert self._dynamic_args(*args, **kwargs) + from .builder import wrap_fx_proxy + + options = VariableTracker.propagate(self, args, kwargs.values()) + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_function", self.fn, *proxy_args_kwargs(args, kwargs) + ), + **options, + ) + + def call_mod(self, tx, *args, **kwargs): + if self._dynamic_args(*args, **kwargs): + return self._dyn_proxy(tx, *args, **kwargs) + + def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs): + if self._dynamic_args(*args, **kwargs): + return self._dyn_proxy(tx, *args, **kwargs) cls = variables.BaseListVariable.cls_for(self.fn) if obj is None: return cls( @@ -551,6 +613,7 @@ def call_getitem(self, tx, *args, **kwargs): def call_isinstance(self, tx, arg, isinstance_type): arg_type = arg.python_type() + isinstance_type = isinstance_type.as_python_constant() if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index d3366448e379..63eed37ccbec 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -13,6 +13,8 @@ class ConstantVariable(VariableTracker): def __init__(self, value, **kwargs): super(ConstantVariable, self).__init__(**kwargs) assert not isinstance(value, torch.Tensor) + assert not isinstance(value, torch.SymInt) + assert not isinstance(value, torch.SymFloat) self.value = value def as_proxy(self): @@ -70,6 +72,8 @@ def call_method( args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": + from .tensor import DynamicShapeVariable + options = VariableTracker.propagate(self, args, kwargs.values()) if istype(self.value, tuple): @@ -78,6 +82,20 @@ def call_method( items=self.unpack_var_sequence(tx), source=self.source, **options ).call_method(tx, name, args, kwargs) + if any([isinstance(x, DynamicShapeVariable) for x in args]): + # NOTE! DANGER! THIS ONLY WORKS FOR COMMUTATIVE OPS + # we are relying on add to have arg[0] be a DynamicShapeVariable + # because we are in ConstantVariable land + # This transforms + # constant + dynamic + # into + # dynamic + constant + # Which already has infra built for writing to the graph + if name == "__add__": + assert len(args) == 1 + return args[0].call_method(tx, name, [self], {}) + # Unfortunate constant + return super(ConstantVariable, self).call_method(tx, name, args, kwargs) try: const_args = [a.as_python_constant() for a in args] const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()} @@ -98,7 +116,19 @@ def has_arith_binop(num_ty): return ConstantVariable(method(*const_args, **const_kwargs), **options) elif has_arith_binop(int) or has_arith_binop(float): op = getattr(operator, name) - return ConstantVariable(op(self.value, const_args[0]), **options) + add_target = const_args[0] + if isinstance(add_target, (torch.SymInt, torch.SymFloat)): + from .tensor import DynamicShapeVariable + + # Addition between a non sym and sym makes a sym + # dyn_shape = tx.output.register_attr_or_module( + # add_target, f"sym_shape_{add_target}", source=None + # ) + proxy = tx.output.create_proxy( + "call_function", op, (self.value, add_target), {} + ) + return DynamicShapeVariable.create(tx, proxy, add_target, **options) + return ConstantVariable(op(self.value, add_target), **options) elif name == "__len__" and not (args or kwargs): return ConstantVariable(len(self.value), **options) elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant(): diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index f63283819f35..151619d0e4ab 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -7,7 +7,7 @@ from ..bytecode_transformation import create_instruction from ..exc import unimplemented from ..source import GetItemSource -from ..utils import namedtuple_fields +from ..utils import namedtuple_fields, proxy_args_kwargs from .base import MutableLocal, VariableTracker from .constant import ConstantVariable @@ -308,6 +308,58 @@ def reconstruct(self, codegen): ] return build_torch_size + def unpack_var_sequence(self, tx): + return [x.add_options(self) for x in self.items] + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + options = VariableTracker.propagate(self, args, kwargs.values()) + if name == "__getitem__": + assert not kwargs and len(args) == 1 + if config.dynamic_shapes: + out = self.get_item_dyn(tx, args[0]) + else: + out = self.getitem_const(args[0]) + return out + return super(SizeVariable, self).call_method(tx, name, args, kwargs) + + def get_item_dyn(self, tx, arg: VariableTracker): + from .tensor import DynamicShapeVariable + + index = arg.as_python_constant() + if isinstance(index, slice): + + def _dynamo_get_item_lambda(target, index): + return torch.Size.__getitem__(target, index) + + parent_proxy = self.as_proxy() + proxy = tx.output.create_proxy( + "call_function", + _dynamo_get_item_lambda, + *proxy_args_kwargs([self, arg], {}), + current_tx=tx, + ) + items = self.items[index] + + def _unpack_into_example(item): + if isinstance(item, DynamicShapeVariable): + return item.dyn_shape + return item.as_python_constant() + + # Mirror the indexing into example_value for downstream correctness + proxy.node.meta["example_value"] = parent_proxy.node.meta["example_value"][ + index + ] + return SizeVariable(items, proxy=proxy).add_options(arg, self) + else: + assert isinstance(index, int) + return self.items[index].add_options(arg, self) + class ShapeVariable(TupleVariable): """ @@ -349,13 +401,20 @@ def call_hasattr(self, tx, name: str) -> "VariableTracker": class SliceVariable(BaseListVariable): def __init__(self, items, **kwargs): + from .tensor import DynamicShapeVariable + + if any([isinstance(x, DynamicShapeVariable) for x in items]): + unimplemented("Dynamic slicing not supported") + + items_to_map = items start, stop, step = [variables.ConstantVariable(None)] * 3 - if len(items) == 1: - (stop,) = items - elif len(items) == 2: - start, stop = items - elif len(items) == 3: - start, stop, step = items + + if len(items_to_map) == 1: + (stop,) = items_to_map + elif len(items_to_map) == 2: + start, stop = items_to_map + elif len(items_to_map) == 3: + start, stop, step = items_to_map else: raise AssertionError() @@ -366,7 +425,7 @@ def __init__(self, items, **kwargs): # more complete support for breaking on data dependent operators. if not config.capture_scalar_outputs: for limit in (start, stop, step): - if isinstance(limit, variables.TensorVariable): + if isinstance(limit, (variables.TensorVariable, DynamicShapeVariable)): unimplemented("Dynamic slicing not supported") super().__init__([start, stop, step], **kwargs) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index da327122a6a7..5d7336cefeae 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -513,6 +513,7 @@ def reconstruct(self, codegen): def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": + from .builder import wrap_fx_proxy # This variable is True when it corresponds to user code such as # @@ -530,7 +531,7 @@ def call_function( if is_original_tensor_torch_function: # Instead of tracing inside torch.Tensor.__torch_function__, # record the `call_function` or `call_method` call into the graph. - from . import TensorVariable, TorchVariable + from . import TorchVariable original_torch_or_getattr_variable = args[0] new_args = args[2].items @@ -540,7 +541,7 @@ def call_function( # example tensor from going into the override. with torch._C.DisableTorchFunction(): if isinstance(args[0], TorchVariable): - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -551,7 +552,7 @@ def call_function( **options, ) elif isinstance(args[0], GetAttrVariable): - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_method", diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 6f7c2ff28737..848f022525d9 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -197,8 +197,9 @@ def record_nn_module_stack(): # The module type will change after it is called if is_lazy: self.module_type = mod.cls_to_become + from .builder import wrap_fx_proxy - return variables.TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_module", @@ -337,6 +338,8 @@ def named_embed(name, obj): ): result.append(named_embed(name, submod)) return ListIteratorVariable(result, mutable_local=MutableLocal(), **options) + elif name == "modules": + return wrap_values(module.named_modules()) elif name == "parameters": return wrap_values(module.named_parameters(**get_kwargs("recurse"))) elif name == "values": @@ -452,7 +455,9 @@ def make_attr(name): proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) - return variables.TensorVariable.create( + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_method", diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index e87b1d87bac9..8867f7e6cc93 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1,161 +1,28 @@ -import copy -import functools import itertools -import math -import numbers import operator from typing import Dict, List import torch.fx import torch.random -from ..utils import fake_tensors_available - -if fake_tensors_available: - from torch._subclasses import FakeTensor - from torch._subclasses.fake_tensor import ( - DataDependentOutputException, - DynamicOutputShapeException, - ) - from ..utils import deepcopy_to_fake_tensor, wrap_to_fake_tensor_and_record - -import torch.utils._python_dispatch as py_dispatch -from torch.fx.immutable_collections import immutable_list -from torch.utils._pytree import tree_map - from .. import config, variables -from ..exc import TorchRuntimeError, unimplemented, Unsupported +from ..exc import unimplemented from ..guards import GuardBuilder from ..source import AttrSource + from ..utils import ( - clone_input, - is_lazy_module, - istype, - preserve_rng_state, + fake_tensors_available, + get_fake_value, + get_real_value, product, proxy_args_kwargs, tensortype_to_dtype, ) -from .base import MutableLocal, typestr, VariableTracker +from .base import VariableTracker from .constant import ConstantVariable from .lists import ShapeVariable, SizeVariable -class _missing: - pass - - -def _run_node(output_graph, node, args, kwargs, nnmodule): - op = node.op - if op == "call_function": - return node.target(*args, **kwargs) - elif op == "call_method": - return getattr(args[0], node.target)(*args[1:], **kwargs) - elif op == "call_module": - assert nnmodule is not None - return nnmodule(*args, **kwargs) - elif op == "get_attr": - return output_graph.get_submodule(node.target) - raise AssertionError(op) - - -def _get_real_value(node, output_graph): - """ - Run the actual computation represented by `node` and return the result. - This will execute any dependent nodes in the graph as well. - """ - cache = output_graph.real_value_cache - if node in cache: - return cache[node] - - op = node.op - args, kwargs = torch.fx.node.map_arg( - (node.args, node.kwargs), - lambda n: _get_real_value(n, output_graph), - ) - - if op == "call_module": - nn_module = output_graph.nn_modules[node.target] - if not is_lazy_module(nn_module): - nn_module = copy.deepcopy(nn_module) - else: - # In the case of a lazy module, we want to run - # the pre-hooks which initialize it - nn_module(*args, **kwargs) - else: - nn_module = None - - try: - real_value = _run_node(output_graph, node, args, kwargs, nn_module) - cache[node] = real_value - except RuntimeError as e: - raise TorchRuntimeError() from e - return real_value - - -def _get_fake_value(node, tx): - """ - Run the computation represented by `node` using fake tensors and return the result. - """ - op = node.op - fake_wrapper = functools.partial(wrap_to_fake_tensor_and_record, tx=tx) - from ..utils import wrap_fake_exception - - def visit(n: torch.fx.Node): - return n.meta["example_value"] - - args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit) - args = tree_map(fake_wrapper, args) - kwargs = tree_map(fake_wrapper, kwargs) - - nnmodule = None - if op == "call_module": - nnmodule = tx.output.nn_modules[node.target] - - if not is_lazy_module(nnmodule): - nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) - - def context(): - if hasattr(py_dispatch, "enable_torch_dispatch_mode"): - return py_dispatch.enable_torch_dispatch_mode(tx.fake_mode) - else: - return tx.fake_mode - - if op == "call_module" and is_lazy_module(nnmodule): - assert nnmodule is not None - # In the case of a lazy module, we want to run - # the pre-hooks which initialize it - nnmodule(*args, **kwargs) - try: - with context(): - return wrap_fake_exception( - lambda: _run_node(tx.output, node, args, kwargs, nnmodule) - ) - except Unsupported: - raise - except RuntimeError as e: - if isinstance(e, DataDependentOutputException): - if config.capture_scalar_outputs and node.target == "item": - return torch.zeros(size=(), dtype=args[0].dtype).item() - else: - unimplemented(f"data dependent operator: {e.func}") - elif isinstance(e, DynamicOutputShapeException): - unimplemented(f"dynamic shape operator: {e.func}") - else: - raise TorchRuntimeError() from e - - -def _clone_input(value): - if isinstance(value, torch.Tensor): - use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation - # tensor subclasses will not be converted to FakeTensors and need to be cloned - if not use_fake_tensors or not isinstance(value, FakeTensor): - # NB: ensure strides are preserved - value = clone_input(value) - - return value - - class TensorVariable(VariableTracker): """A torch.Tensor input or an intermediate value in the FX graph""" @@ -178,173 +45,7 @@ def get_real_value(self): NOTE: this runs actual tensor computation and may be slow and memory-intensive. """ - return _get_real_value(self.proxy.node, self.proxy.tracer) - - @classmethod - def create(cls, tx, proxy, example_value=None, **options): - if "guards" in options and options["guards"] is not None: - tx.output.guards.update(options["guards"]) - - assert "example_value" not in proxy.node.meta - if not config.dynamic_propagation: - if isinstance(example_value, torch.Tensor): - options.update(cls.specialize(example_value)) - return cls(proxy, **options) - - use_fake_tensors = fake_tensors_available and config.fake_tensor_propagation - - initial_example_value = example_value - - with preserve_rng_state(): - if example_value is None: - if use_fake_tensors: - example_value = _get_fake_value(proxy.node, tx) - else: - example_value = _get_real_value(proxy.node, tx.output) - - else: - proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value) - if use_fake_tensors: - fake_wrapper = functools.partial( - wrap_to_fake_tensor_and_record, tx=tx - ) - example_value = fake_wrapper(example_value) - - if isinstance(example_value, torch.Tensor): - is_parameter = isinstance(example_value, torch.nn.Parameter) - should_specialize = options.pop("should_specialize", False) - if is_parameter or should_specialize: - specialized_value = initial_example_value - else: - specialized_value = None - - example_value = _clone_input(example_value) - proxy.node.meta["example_value"] = example_value - specialized_props = cls.specialize(example_value) - if use_fake_tensors and isinstance(example_value, FakeTensor): - specialized_props["class_type"] = ( - torch.nn.Parameter if is_parameter else torch.Tensor - ) - - specialized_props["specialized_value"] = specialized_value - - options.update(specialized_props) - return cls(proxy, **options) - elif ( - hasattr(proxy.node.target, "__name__") - and proxy.node.target.__name__ == "set_state" - and isinstance(proxy.node.target.__self__, torch._C.Generator) - or proxy.node.target == torch.random.set_rng_state - ): - from . import TorchVariable - - return TorchVariable(proxy.node.target) - elif istype(example_value, (int, bool, float)) and config.dynamic_shapes: - proxy.node.meta["example_value"] = example_value - return DynamicShapeVariable(proxy, example_value, **options) - elif istype(example_value, torch.Size) and config.dynamic_shapes: - proxy.node.meta["example_value"] = example_value - sizes = [] - for i, v in enumerate(example_value): - proxy_i = proxy[i] - proxy_i.node.meta["example_value"] = v - sizes.append(DynamicShapeVariable(proxy_i, v)) - return SizeVariable(sizes, proxy, **options) - elif istype(example_value, int) and proxy.node.target in ( - torch.seed, - operator.mod, - # some mac builds are missing torch.distributed.get_rank() - getattr(torch.distributed, "get_rank", _missing), - getattr(torch.distributed, "get_world_size", _missing), - ): - proxy.node.meta["example_value"] = example_value - return DynamicShapeVariable(proxy, example_value, **options) - elif istype(example_value, torch.Size) and all( - [isinstance(x, int) for x in example_value] - ): - sizes = [variables.ConstantVariable(x) for x in example_value] - return SizeVariable(sizes, **options) - elif isinstance(example_value, (tuple, list)): - unpacked = [] - for i, val in enumerate(example_value): - if val is None: - # nn.MultiheadAttention() can return None, see issue #175 - unpacked.append( - variables.ConstantVariable(None, **options), - ) - else: - unpacked.append( - cls.create( - tx, - proxy.tracer.create_proxy( - "call_function", operator.getitem, (proxy, i), {} - ), - example_value=val, - **options, - ) - ) - if istype(example_value, tuple): - return variables.TupleVariable(unpacked, **options) - elif istype(example_value, (list, immutable_list)): - return variables.ListVariable( - unpacked, mutable_local=MutableLocal(), **options - ) - else: - assert ( - example_value.__class__.__module__ == "torch.return_types" - or hasattr(example_value, "_fields") - ), "namedtuple?" - return variables.NamedTupleVariable( - unpacked, example_value.__class__, **options - ) - elif example_value is None or proxy.node.target is torch.manual_seed: - return variables.ConstantVariable(None, **options) - elif ( - isinstance(example_value, int) - and proxy.node.target is torch._utils._element_size - ): - proxy.node.meta["example_value"] = example_value - return variables.ConstantVariable(example_value, **options) - elif ( - isinstance(example_value, numbers.Number) - and ( - proxy.node.target == "item" - or proxy.node.target in {math.sqrt, math.pow} - ) - and config.capture_scalar_outputs - ): - if use_fake_tensors: - # item raw value should not be accessed - return FakeItemVariable.create( - tx=tx, - proxy=proxy, - example_value=torch.tensor(example_value), - **options, - ) - else: - return UnspecializedPythonVariable.create( - tx=tx, - proxy=proxy, - example_value=torch.tensor(example_value), - raw_value=None if use_fake_tensors else example_value, - need_unwrap=False, - **options, - ) - elif ( - proxy.node.target == torch._C._DisableFuncTorch - or proxy.node.target == torch.cuda._is_in_bad_fork - ): - from . import UserDefinedObjectVariable - - return UserDefinedObjectVariable(example_value) - elif isinstance(example_value, torch.SymInt): - proxy.node.meta["example_value"] = example_value - return cls(proxy, **options) - else: - raise AssertionError( - "torch.* op returned non-Tensor " - + f"{typestr(example_value)} {proxy.node.op} {proxy.node.target}" - ) + return get_real_value(self.proxy.node, self.proxy.tracer) def __init__( self, @@ -482,15 +183,26 @@ def call_method( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": from . import ConstantVariable, TupleVariable + from .builder import wrap_fx_proxy kwargs = dict(kwargs) - options = VariableTracker.propagate(self, args, kwargs.values()) if name == "stride" and self.stride is not None: constant_result = ConstantVariable(self.stride, **options) elif name == "size" and self.size is not None: sizes = [variables.ConstantVariable(x) for x in self.size] constant_result = SizeVariable(sizes, **options) + elif name == "size" and self.size is None and config.dynamic_shapes: + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self] + args, kwargs), + current_tx=tx, + ), + **options, + ) elif name == "numel" and self.size is not None: constant_result = ConstantVariable(product(self.size), **options) elif name in ("ndimension", "dim") and self.ndim is not None: @@ -531,11 +243,19 @@ def call_method( unimplemented(f"Tensor.{name}") elif name == "item": if config.capture_scalar_outputs: - return self.__class__.create( + use_fake_tensors = ( + fake_tensors_available and config.fake_tensor_propagation + ) + if use_fake_tensors: + example_value = get_fake_value(self.proxy.node, tx) + else: + example_value = get_real_value(self.proxy.node, tx.output).item() + return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", "item", (self.as_proxy(),), {}, current_tx=tx ), + example_value=example_value, **options, ) else: @@ -545,7 +265,7 @@ def call_method( assert not config.dynamic_shapes return ConstantVariable(self.size[0], **options) else: - return self.__class__.create( + return wrap_fx_proxy( tx, tx.output.create_proxy( "call_function", len, (self.as_proxy(),), {}, current_tx=tx @@ -584,7 +304,7 @@ def call_method( self.ndim = args[0].ndim self.is_contiguous = (memory_format,) - return self.__class__.create( + return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", @@ -604,8 +324,7 @@ def call_method( and not config.dynamic_shapes ): name = "new_empty" - - return self.__class__.create( + return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", @@ -617,13 +336,23 @@ def call_method( ) -class DynamicShapeVariable(TensorVariable): +class DynamicShapeVariable(VariableTracker): """ Represents a symbolic size, e.g., as returned by tensor.size(0) """ + @classmethod + def create(cls, tx, proxy, dyn_shape, **options): + if "example_value" in proxy.node.meta: + assert proxy.node.meta["example_value"] == dyn_shape + if dyn_shape is None: + dyn_shape = get_fake_value(proxy.node, tx) + proxy.node.meta["example_value"] = dyn_shape + return DynamicShapeVariable(proxy, dyn_shape, **options) + def __init__(self, proxy, dyn_shape, **kwargs): - super(DynamicShapeVariable, self).__init__(proxy, **kwargs) + super(DynamicShapeVariable, self).__init__(**kwargs) + self.proxy = proxy self.dyn_shape = dyn_shape def python_type(self): @@ -632,6 +361,36 @@ def python_type(self): def unpack_var_sequence(self, tx): super(DynamicShapeVariable, self).unpack_var_sequence(tx) + def as_proxy(self): + return self.proxy + + def evaluate_expr(self, output_graph): + if not isinstance(self.dyn_shape, torch.SymInt): + return self.dyn_shape + return output_graph.shape_env.evaluate_expr(self.dyn_shape.get_pyobj().expr) + + def call_method( + self, + tx, + name, + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + from .builder import wrap_fx_proxy + + options = VariableTracker.propagate(self, args, kwargs.values()) + + return wrap_fx_proxy( + tx, + tx.output.create_proxy( + "call_method", + name, + *proxy_args_kwargs([self] + list(args), kwargs), + current_tx=tx, + ), + **options, + ) + class TensorWithTFOverrideVariable(VariableTracker): """ diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index c55a64cff50c..0debfe9e9f3c 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -1,4 +1,6 @@ import logging + +import math import re import types from typing import Dict, List @@ -170,7 +172,15 @@ def can_constant_fold_through(self): def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": - from . import ConstantVariable, GradModeVariable, TensorVariable + from . import ( + ConstantVariable, + DynamicShapeVariable, + GradModeVariable, + TensorVariable, + ) + + # print("CALLING ON TORCH", self.value) + from .builder import wrap_fx_proxy constant_args = check_constant_args(args, kwargs) unspec_python_args = check_unspec_python_args(args, kwargs) @@ -302,7 +312,7 @@ def call_function( def get_state_from_generator(): return self.value() - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -338,7 +348,7 @@ def get_state_from_generator(): example_value = args[0].proxy.node.meta["example_value"] self.value.__module__ = self.__module__ - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -357,7 +367,7 @@ def get_state_from_generator(): ): # TODO(voz): This is rewritten as a call_method because # torch.numel(x) w/ sym shapes raises a RuntimeError and x.numel() does not - return TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_method", @@ -380,11 +390,21 @@ def get_state_from_generator(): if isinstance(x.value, numpy.generic): x.value = x.value.item() - tensor_variable = TensorVariable.create( + # TODO(voz): Replace w/ dynamic shape rewrite table. + # Ideally, we would be able to do this at ctor time, but alas we need a combination + # of value + args to determine this. + fn_ = self.value + if any([isinstance(x, DynamicShapeVariable) for x in args]): + if self.value == math.sqrt: + from torch.fx.experimental.symbolic_shapes import sym_sqrt + + fn_ = sym_sqrt + + tensor_variable = wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", - self.value, + fn_, *proxy_args_kwargs(args, kwargs), current_tx=tx, ), @@ -450,7 +470,9 @@ def _call_softmax(self, tx, args, kwargs, options): dim = args[0] if args else kwargs.get("dim", variables.ConstantVariable(None)) def fake_softmax(input): - return variables.TensorVariable.create( + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -502,7 +524,9 @@ def normalize_args( ) = normalize_args(*args, **kwargs) def fake_cross_entropy_loss(input, target): - return variables.TensorVariable.create( + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", @@ -577,6 +601,7 @@ def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": from . import ListVariable, TensorVariable, UserFunctionVariable + from .builder import wrap_fx_proxy assert kwargs is None or len(kwargs) == 0, "kwargs are not supported, yet" @@ -688,7 +713,7 @@ def register_as_subgraph(fn, name, args): p_args[2] = false_node # Store the invocation as a call - return variables.TensorVariable.create( + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 87e2793782be..8f9f2c4f461d 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -92,6 +92,7 @@ class cpp: "g++-10", "clang++", "g++", + "g++.par", ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 240c196a73b6..8a2e26ee9b94 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1552,11 +1552,23 @@ def loader(index): @dataclasses.dataclass class Layout(IRNode): - device: torch.device - dtype: torch.dtype - size: List[Expr] - stride: List[Expr] - offset: Expr = Integer(0) + def __init__( + self, + device: torch.device, + dtype: torch.dtype, + size: List[Expr], + stride: List[Expr], + offset: Expr = Integer(0), + ): + self.device = device + self.dtype = dtype + self.size = size + self._stride = stride + self.offset = offset + + @property + def stride(self): + return self._stride def __str__(self): offset = "" @@ -1772,6 +1784,15 @@ def __init__(self, target: IRNode): ) self.target = target + @Layout.stride.getter + def stride(self): + return self.real_layout().stride + + def real_layout(self): + if isinstance(self.target, MutationLayout): + return self.target.real_layout() + return self.target.data.layout + @classmethod def realize_into(cls, src, dst): dst.realize() @@ -2467,6 +2488,16 @@ def require_stride_order(cls, x, order): x.get_layout(), FixedLayout ) and x.get_layout().is_stride_ordered(order): return x + elif isinstance(x.get_layout(), MutationLayout): + if isinstance(x.get_layout().real_layout(), FlexibleLayout): + raise AssertionError( + "the MutationLayout's real layout shouldn't be FlexibleLayout" + ) + elif isinstance( + x.get_layout().real_layout(), FixedLayout + ) and x.get_layout().real_layout().is_stride_ordered(order): + return x + # TODO - Storage to InputBuffer if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): return x @@ -3099,7 +3130,7 @@ def create( sympy.Integer(V.graph.sizevars.guard_static_shape(s)) for s in weight.get_size() ] - _, _, *kernel_size = weight.get_size() + _, _, *kernel_size = weight_shape # choose runtime kernel config_conv = config.triton.convolution @@ -3324,50 +3355,28 @@ def _prepare_convolution_fusion_create( padding = tuple(padding_) dilation = tuple(dilation_) assert isinstance(groups, int) - + with FakeTensorMode(): + output, *_ = cls.process_kernel( + torch.ops.aten.convolution, + x, + weight, + bias, + stride, + padding, + dilation, + False, + [0, 0], + groups, + ) + + output_size = output.shape weight_shape = [ sympy.Integer(V.graph.sizevars.guard_static_shape(s)) for s in weight.get_size() ] - - out_channels, in_channels1, *kernel_size = weight_shape - in_channels1 = in_channels1 * groups - assert len(x.get_size()) == 2 + len(kernel_size) - batch, in_channels2, *input_size = x.get_size() - output_size = [batch] - V.graph.sizevars.guard_equals(in_channels1, in_channels2) - - output_size.append(out_channels) - assert ( - len(stride) - == len(padding) - == len(dilation) - == len(kernel_size) - == len(input_size) + _, _, *kernel_size = weight_shape + output_layout_str = ( + "torch.contiguous_format" if output.is_contiguous() else "torch.channels_last" ) - for i in range(len(stride)): - output_size.append( - IndexingDiv( - input_size[i] - + 2 * padding[i] - - dilation[i] * (kernel_size[i] - 1) - - 1 - + stride[i], - stride[i], - ) - ) - output_size[-1] = sympy.Integer( - V.graph.sizevars.guard_static_shape(output_size[-1]) - ) - - output_layout_str = "torch.contiguous_format" - # If x or weight have one channels_last(2d or 3d) format, it will call channels_last path, - # which align with aten.convolutuion path(cpu only support 2d case now). - # TODO: after cpu 3d convolution support channels_last path, the size check can be removed. - if len(x.get_size()) == 4 and ( - x.get_layout().is_channels_last_stride_ordered() - or weight.get_layout().is_channels_last_stride_ordered() - ): - output_layout_str = "torch.channels_last" if output_layout_str == "torch.channels_last": stride_order = [0] + list(reversed(range(1, len(kernel_size) + 1))) @@ -3409,6 +3418,8 @@ def codegen(self, wrapper): wrapper.writeline( f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" ) + if isinstance(self.layout, Layout): + self.codegen_size_asserts(wrapper) @classmethod def create( @@ -3513,6 +3524,77 @@ def apply_constraint(self): self.freeze_layout_with_stride_order(self.layout.preferred_stride_order) +class ConvolutionBinaryInplace(ExternKernelAlloc): + kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" + + def __init__( + self, + kernel_layout, + inputs_layout, + inputs, + constant_args=(), + kernel="torch.ops.mkldnn._convolution_pointwise_.binary", + ): + super().__init__(kernel_layout, inputs, constant_args) + self.kernel = kernel + self.inputs_layout = inputs_layout + + def codegen(self, wrapper): + wrapper.writeline( + f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" + ) + + def get_mutation_names(self): + assert isinstance(self.layout, MutationLayout) + return (self.layout.target.get_name(),) + + @classmethod + def create( + cls, + x: "TensorBox", + other: "TensorBox", + weight: "TensorBox", + bias: "TensorBox", + padding_: List[int], + stride_: List[int], + dilation_: List[int], + groups: int, + binary_attr: str, + binary_alpha: Optional[float], + unary_attr: Optional[str], + unary_scalars: Optional[List], + unary_algorithm: Optional[str], + ): + kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" + (inputs, constant_args, inputs_layout,) = _prepare_convolution_fusion_create( + cls, x, weight, bias, padding_, stride_, dilation_, groups + ) + other = cls.realize_input(other) + V.graph.realize_users_of(other.get_name()) + inputs.insert(1, other) + constant_args = constant_args + [ + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ] + return ConvolutionBinaryInplace( + kernel_layout=MutationLayout(inputs[1]), + inputs_layout=inputs_layout, + inputs=inputs, + constant_args=constant_args, + kernel=kernel, + ) + + def apply_constraint(self): + x = self.inputs[0] + # FixedLayout of input + x = self.require_stride_order(x, self.inputs_layout.preferred_stride_order) + self.inputs[0] = x + self.freeze_layout_with_stride_order(self.inputs_layout.preferred_stride_order) + + class LinearUnary(ExternKernelAlloc): kernel = "torch.ops.mkldnn._linear_pointwise" diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index dedd39cd91c4..9924396075f6 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -960,6 +960,40 @@ def convolution_binary( ) ) + @register_lowering(torch.ops.mkldnn._convolution_pointwise_.binary) + def convolution_binary_inplace( + x: TensorBox, + other: TensorBox, + weight: TensorBox, + bias: TensorBox, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ): + return TensorBox.create( + ir.ConvolutionBinaryInplace.create( + x, + other, + weight, + bias, + padding, + stride, + dilation, + groups, + binary_attr, + binary_alpha, + unary_attr, + unary_scalars, + unary_algorithm, + ) + ) + @register_lowering(torch.ops.mkldnn._linear_pointwise) def linear_unary( x: TensorBox, w: TensorBox, b: TensorBox, attr, scalars, algorithm diff --git a/torch/_inductor/overrides.py b/torch/_inductor/overrides.py index d89ee82674dd..3a95aa7ce880 100644 --- a/torch/_inductor/overrides.py +++ b/torch/_inductor/overrides.py @@ -157,6 +157,11 @@ def _update_module_params(self, conv, binary_op_name): self.unary_scalars = [] self.unary_algorithm = None + def _update_unary_params(self, unary): + self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[ + unary.__class__ + ](unary) + def _conv_forward(self, input, other, weight, bias): if self.padding_mode != "zeros": return torch.ops.mkldnn._convolution_pointwise( @@ -196,6 +201,79 @@ def forward(self, input, other): return self._conv_forward(input, other, self.weight, self.bias) +class ConvBinaryInplace2d(nn.Conv2d): + def __init__( + self, + conv: nn.Module, + binary_op_name: str, + ): + super(ConvBinaryInplace2d, self).__init__( + conv.in_channels, + conv.out_channels, + conv.kernel_size, + conv.stride, + conv.padding, + conv.dilation, + conv.groups, + conv.bias is not None, + conv.padding_mode, + conv.weight.device, + conv.weight.dtype, + ) + self._update_module_params(conv, binary_op_name) + + def _update_module_params(self, conv, binary_op_name): + self.__dict__ = copy.deepcopy(conv.__dict__) + self.binary_attr = binary_op_name + self.binary_alpha = None + self.unary_attr = None + self.unary_scalars = [] + self.unary_algorithm = None + + def _update_unary_params(self, unary): + self.unary_attr, self.unary_scalars, self.unary_algorithm = unary_modules_map[ + unary.__class__ + ](unary) + + def _conv_forward(self, input, other, weight, bias): + if self.padding_mode != "zeros": + return torch.ops.mkldnn._convolution_pointwise_( + F.pad( + input, self._reversed_padding_repeated_twice, mode=self.padding_mode + ), + other, + weight, + bias, + _pair(0), + self.stride, + self.dilation, + self.groups, + self.binary_attr, + self.binary_alpha, + self.unary_attr, + self.unary_scalars, + self.unary_algorithm, + ) + return torch.ops.mkldnn._convolution_pointwise_( + input, + other, + weight, + bias, + self.padding, + self.stride, + self.dilation, + self.groups, + self.binary_attr, + self.binary_alpha, + self.unary_attr, + self.unary_scalars, + self.unary_algorithm, + ) + + def forward(self, input, other): + return self._conv_forward(input, other, self.weight, self.bias) + + class LinearUnary(nn.Linear): def __init__( self, @@ -263,6 +341,21 @@ def fused_conv_binary_eval(conv: nn.Module, binary_op_name: str): ) +def fused_conv_binary_inplace_eval(conv: nn.Module, binary_op_name: str): + assert not (conv.training), "Fusion only for eval!" + return ConvBinaryInplace2d( + conv, + binary_op_name, + ) + + +def fused_binary_unary_eval(conv_binary: nn.Module, unary: nn.Module): + assert not (conv_binary.training), "Fusion only for eval!" + # reuse origin conv module, and just update its' unary attr. + conv_binary._update_unary_params(unary) + return conv_binary + + def is_bfloat16_module(m): weight_is_bf16 = m.weight.dtype == torch.bfloat16 bias_is_bf16 = m.bias is None or m.bias.dtype == torch.bfloat16 @@ -312,6 +405,25 @@ def check_node_is_binary(node): ) +def check_binary_op_kwargs_is_default(node): + # For binary op, we hope the kwargs values are the default value: + # torch.sub(add)(input, other, *, alpha=1, out=None). + if len(node.args) > 2: + return False + if len(node.kwargs) > 0: + if "out" in node.kwargs and node.kwargs["out"] is not None: + return False + if "alpha" in node.kwargs and node.kwargs["alpha"] != 1.0: + return False + return True + + +def check_node_is_add_inplace(node): + return (node.op == "call_function" and node.target in [operator.iadd]) or ( + node.op == "call_method" and node.target in ["add_"] + ) + + def fuse_fx(gm: torch.fx.GraphModule, example_inputs): # make sure the autograd is disabled. if torch.is_grad_enabled(): @@ -328,7 +440,11 @@ def fuse_fx(gm: torch.fx.GraphModule, example_inputs): # the binary inputs have same tensor info(device, dtype, and layout). ShapeProp(gm).propagate(*example_inputs) gm = fuse_unary(gm) + gm = fuse_binary_inplace(gm) gm = fuse_binary(gm) + # why re-run fuse_unary? we want to enable conv+binary+unary fusion, + # such as conv+add+relu for vision model. + gm = fuse_unary(gm) return gm @@ -383,7 +499,11 @@ def fuse_unary(gm: torch.fx.GraphModule): eval_mode = all(not n.training for n in [computation_node, unary_node]) if not eval_mode: continue - + # TODO: support padding str input("valid", "same"). + if type(computation_node) in [nn.Conv2d] and isinstance( + computation_node.padding, str + ): + continue # only fuse for linear when the dtype is bf16 if type(computation_node) in [nn.Linear] and not is_bfloat16_module( computation_node @@ -419,26 +539,31 @@ def replace_and_fuse_for_binary( node.replace_all_uses_with(node.args[index_node]) +def binary_inputs_meta_is_same(binary_node): + tensor0_meta = binary_node.args[0].meta.get("tensor_meta") + tensor1_meta = binary_node.args[1].meta.get("tensor_meta") + if not tensor0_meta or not tensor1_meta: + return False + if ( + tensor0_meta.shape != tensor1_meta.shape + or tensor0_meta.stride != tensor1_meta.stride + or tensor0_meta.dtype != tensor1_meta.dtype + ): + return False + + return True + + def fuse_binary(gm: torch.fx.GraphModule): modules = dict(gm.named_modules()) for node in gm.graph.nodes: - if check_node_is_binary(node) and ( - len(node.kwargs) != 2 or node.kwargs["alpha"] == 1.0 - ): + if check_node_is_binary(node) and check_binary_op_kwargs_is_default(node): for node_kind, fuse_func in computation_op_binary_op_fusion_map.items(): if not isinstance(node.args[0], torch.fx.Node) or not isinstance( node.args[1], torch.fx.Node ): continue - tensor0_meta = node.args[0].meta.get("tensor_meta") - tensor1_meta = node.args[1].meta.get("tensor_meta") - if not tensor0_meta or not tensor1_meta: - continue - if ( - tensor0_meta.shape != tensor1_meta.shape - or tensor0_meta.stride != tensor1_meta.stride - or tensor0_meta.dtype != tensor1_meta.dtype - ): + if not binary_inputs_meta_is_same(node): continue attr = binary_attr[node.target] index_list = supported_index_list[attr] @@ -449,6 +574,11 @@ def fuse_binary(gm: torch.fx.GraphModule): if len(node.args[index_node].users) > 1: continue computation_node = modules[node.args[index_node].target] + # TODO: support padding str input("valid", "same"). + if type(computation_node) in [nn.Conv2d] and isinstance( + computation_node.padding, str + ): + continue # only fuse for linear when the dtype is bf16 if type(computation_node) in [ nn.Linear @@ -473,6 +603,51 @@ def fuse_binary(gm: torch.fx.GraphModule): return gm +def fuse_binary_inplace(gm: torch.fx.GraphModule): + modules = dict(gm.named_modules()) + for node in gm.graph.nodes: + if check_node_is_add_inplace(node) and check_binary_op_kwargs_is_default(node): + for ( + node_kind, + fuse_func, + ) in computation_op_binary_op_fusion_inplace_map.items(): + if not isinstance(node.args[0], torch.fx.Node) or not isinstance( + node.args[1], torch.fx.Node + ): + continue + if not binary_inputs_meta_is_same(node): + continue + if check_node_kind(node.args[1], modules, node_kind): + if len(node.args[1].users) > 1: + continue + # make sure the output and input are not same tensor. + if node.args[1].args[0] == node.args[0]: + continue + computation_node = modules[node.args[1].target] + # TODO: support padding str input("valid", "same"). + if type(computation_node) in [nn.Conv2d] and isinstance( + computation_node.padding, str + ): + continue + replace_and_fuse_for_binary( + computation_node, + node, + fuse_func, + "add", + modules, + 1, # conv module index + 0, # binary op index + ) + # Make sure the fused node is post node of node's inputs nodes. + node.append(node.args[1]) + gm.graph.erase_node(node) + gm.graph.lint() + break + + gm.recompile() + return gm + + philox_rand_like = _prims._make_prim( schema="philox_rand_like(Tensor input, Tensor seed, int offset) -> Tensor", return_type=_prims.RETURN_TYPE.NEW, @@ -595,6 +770,8 @@ def rand_like(x, **kwargs): computation_op_unary_op_fusion_map = { nn.Conv2d: fused_conv_unary_eval, nn.Linear: fused_linear_unary_eval, + ConvBinary2d: fused_binary_unary_eval, + ConvBinaryInplace2d: fused_binary_unary_eval, } @@ -629,6 +806,10 @@ def rand_like(x, **kwargs): } +computation_op_binary_op_fusion_inplace_map = { + nn.Conv2d: fused_conv_binary_inplace_eval, +} + # For add: we support conv/linear + other and other + conv # For sub/add_/sub_, we only support conv/linear - other # or conv/linear +(-)= other diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 5d583de67d19..be7370e344f0 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1349,9 +1349,8 @@ def pool2d_shape_check( ) -@register_meta(aten.max_pool2d_with_indices.default) -def meta_max_pool2d_with_indices( - input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False +def max_pool2d_checks_and_compute_shape( + input, kernel_size, stride, padding, dilation, ceil_mode ): # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp def unpack(name, val): @@ -1376,6 +1375,9 @@ def unpack(name, val): padH, padW = unpack("padding", padding) dilationH, dilationW = unpack("dilation", dilation) + nInputPlane = input.size(-3) + inputHeight = input.size(-2) + inputWidth = input.size(-1) memory_format = utils.suggest_memory_format(input) if memory_format == torch.channels_last: @@ -1394,11 +1396,6 @@ def unpack(name, val): lambda: "Unsupport memory format. Supports only ChannelsLast, Contiguous", ) - nbatch = input.size(-4) if input.dim() == 4 else 1 - nInputPlane = input.size(-3) - inputHeight = input.size(-2) - inputWidth = input.size(-1) - outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) @@ -1420,6 +1417,49 @@ def unpack(name, val): memory_format, ) + return nInputPlane, outputHeight, outputWidth + + +@register_meta(aten.max_pool2d_with_indices_backward.default) +def meta_max_pool2d_with_indices_backward( + grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices +): + nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape( + self, kernel_size, stride, padding, dilation, ceil_mode + ) + + check( + self.dtype == grad_output.dtype, + lambda: "expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}", + ) + + nOutputPlane = nInputPlane + ndim = self.ndim + + def _check_dim_size(t): + check_dim_size(t, ndim, ndim - 3, nOutputPlane) + check_dim_size(t, ndim, ndim - 2, outputHeight) + check_dim_size(t, ndim, ndim - 1, outputWidth) + + _check_dim_size(grad_output) + _check_dim_size(indices) + + memory_format = utils.suggest_memory_format(self) + return torch.empty( + self.shape, dtype=self.dtype, device=self.device, memory_format=memory_format + ) + + +@register_meta(aten.max_pool2d_with_indices.default) +def meta_max_pool2d_with_indices( + input, kernel_size, stride=(), padding=(0,), dilation=(1,), ceil_mode=False +): + nInputPlane, outputHeight, outputWidth = max_pool2d_checks_and_compute_shape( + input, kernel_size, stride, padding, dilation, ceil_mode + ) + + nbatch = input.size(-4) if input.dim() == 4 else 1 + memory_format = utils.suggest_memory_format(input) if input.dim() == 3: size = [nInputPlane, outputHeight, outputWidth] else: diff --git a/torch/_prims/context.py b/torch/_prims/context.py index 203d73fd948e..b9f6e634bb49 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -68,7 +68,8 @@ def torch_to_refs_map(): # Support conversions for s in torch._refs._conversions.__all__: - r[getattr(torch.Tensor, s)] = torch._refs._conversions.__dict__.get(s) + tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s) + r[tensor_attr] = torch._refs._conversions.__dict__.get(s) return r diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 43b0c74192de..a1de9a438d77 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -122,7 +122,6 @@ "bitwise_right_shift", "bitwise_xor", "clamp_min", - # "complex", "copysign", "div", "eq", @@ -2750,19 +2749,39 @@ def flipud(a: TensorLikeType) -> TensorLikeType: # CompositeImplicitAutograd - don't register decomp -def narrow(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType: +def narrow( + a: TensorLikeType, dim: int, start: Union[int, TensorLikeType], length: int +) -> TensorLikeType: + # Supports Tensor overload that was added for XLA: + # https://github.com/pytorch/pytorch/issues/31558 + if isinstance(start, TensorLike): + check( + start.dim() == 0 and utils.is_integer_dtype(start.dtype), + lambda: "start must be an 0-dim integral Tensor.", + ) + start = start.item() # type: ignore[assignment] + check(a.dim() > 0, lambda: "narrow() cannot be applied to a 0-dim tensor.") + check(length >= 0, lambda: "narrow(): length must be non-negative.") dim = utils.canonicalize_dim(a.ndim, dim) + dim_length = a.size(dim) + # Start being the end is usually invalid since it's out of bounds. So it's + # not allowed by canonicalize_dim. But for narrow it's valid as long as + # the length is 0, which is handled by the check below. + if start != dim_length: + # Negative start means indexing from the end of dim. + # Note: a dimension isn't being canonicalized here, this reuses + # canonicalize_dim because the semantics are similar. + start = utils.canonicalize_dim(dim_length, start) # type: ignore[arg-type] + check( + start <= dim_length - length, # type: ignore[arg-type] + lambda: f"start ({start}) + length ({length}) exceeds dimension size ({dim_length}).", + ) return prims.slice_in_dim(a, start, start + length, axis=dim) -@register_decomposition(torch.ops.aten.narrow_copy) -@out_wrapper() -def narrow_copy(a: TensorLikeType, dim: int, start: int, length: int) -> TensorLikeType: - # TODO: This must return a sparse tensor if the input is sparse, but refs - # have no sparse support. See narrow_copy_sparse in core. - if a.is_sparse: - raise NotImplementedError("narrow_copy ref doesn't support sparse tensors") - return torch.clone(torch.narrow(a=a, dim=dim, start=start, length=length)) # type: ignore[call-overload] +# TODO: This must return a sparse tensor if the input is sparse, but refs have +# no sparse support. See narrow_copy_sparse in core. +narrow_copy = _make_copy_from_view(narrow) def _normalize( diff --git a/torch/_refs/_conversions.py b/torch/_refs/_conversions.py index 11657f7058bd..abcd5729818d 100644 --- a/torch/_refs/_conversions.py +++ b/torch/_refs/_conversions.py @@ -1,6 +1,12 @@ import torch +import torch._prims_common as utils -from torch._prims_common import TensorLikeType +# Utilities should come BEFORE this import +from torch._decomp import register_decomposition + +from torch._prims_common import check, TensorLikeType +from torch._prims_common.wrappers import out_wrapper +from torch._refs import _broadcast_shapes # Data conversion references. # @@ -10,6 +16,7 @@ # (like int). __all__ = [ + # dtypes "bfloat16", "bool", "byte", @@ -23,6 +30,8 @@ "int", "long", "short", + # misc + "complex", ] @@ -61,3 +70,37 @@ def fn( long = _make_conversion_method("long", torch.long) short = _make_conversion_method("short", torch.short) + + +@register_decomposition(torch.ops.aten.complex) +# Note: complex has type promotion tests disabled due to different semantics. +# exact_dtype is for compat with complex_check_dtype from core. +@out_wrapper(exact_dtype=True) +def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType: + allowed_dtypes = (torch.float32, torch.float64, torch.float16) + check( + real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes, + lambda: ( + f"Expected both inputs to be Half, Float or Double tensors but got " + f"{real.dtype} and {imag.dtype}" + ), + ) + check( + real.dtype == imag.dtype, + lambda: ( + f"Expected object of scalar type {real.dtype} but got " + f"scalar type {imag.dtype} for second argument" + ), + ) + result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type] + common_shape = _broadcast_shapes(real.shape, imag.shape) + result = real.new_empty( + common_shape, + dtype=result_dtype, + layout=real.layout, + device=real.device, + # pin_memory=real.is_pinned(), # NYI + ) + result.real = real + result.imag = imag + return result diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 14f5cd2de0a7..65f571f93ec0 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1081,7 +1081,9 @@ def __torch_function__(self, func, types, args=(), kwargs=None): # clone will get called in Parameter deepcopy if func == torch._C._TensorBase.clone: - return func(self.fake_mode.from_tensor(args[0]), **kwargs) + return func( + self.fake_mode.from_tensor(args[0], static_shapes=True), **kwargs + ) elif func == torch.Tensor.__deepcopy__: assert len(args) == 2 and len(kwargs) == 0 tensor, memo = args @@ -1089,7 +1091,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): if id(tensor) in memo: return memo[id(tensor)] - out = self.fake_mode.from_tensor(tensor) + out = self.fake_mode.from_tensor(tensor, static_shapes=True) memo[id(tensor)] = out return out else: diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 8c734a1f3774..726ae5137e6a 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -3436,18 +3436,7 @@ def callable(a, b) -> number r""" narrow(dimension, start, length) -> Tensor -See :func:`torch.narrow` - -Example:: - - >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - >>> x.narrow(0, 0, 2) - tensor([[ 1, 2, 3], - [ 4, 5, 6]]) - >>> x.narrow(1, 1, 2) - tensor([[ 2, 3], - [ 5, 6], - [ 8, 9]]) +See :func:`torch.narrow`. """, ) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 40375bae3e27..2ff2e9be315d 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -7980,8 +7980,10 @@ def merge_dicts(*dicts): Args: input (Tensor): the tensor to narrow dim (int): the dimension along which to narrow - start (Tensor or int): the starting dimension - length (int): the distance to the ending dimension + start (int or Tensor): index of the element to start the narrowed dimension + from. Can be negative, which means indexing from the end of `dim`. If + `Tensor`, it must be an 0-dim integral `Tensor` (bools not allowed) + length (int): length of the narrowed dimension, must be weakly positive Example:: @@ -7993,6 +7995,10 @@ def merge_dicts(*dicts): tensor([[ 2, 3], [ 5, 6], [ 8, 9]]) + >>> torch.narrow(x, -1, torch.tensor(-1), 1) + tensor([[3], + [6], + [9]]) """, ) @@ -8008,8 +8014,9 @@ def merge_dicts(*dicts): Args: input (Tensor): the tensor to narrow dim (int): the dimension along which to narrow - start (int): the starting offset - length (int): the distance to the ending dimension + start (int): index of the element to start the narrowed dimension from. Can + be negative, which means indexing from the end of `dim` + length (int): length of the narrowed dimension, must be weakly positive Keyword args: {out} @@ -8027,13 +8034,13 @@ def merge_dicts(*dicts): >>> s = torch.arange(16).reshape(2, 2, 2, 2).to_sparse(2) >>> torch.narrow_copy(s, 0, 0, 1) tensor(indices=tensor([[0, 0], - [0, 1]]), - values=tensor([[[0, 1], - [2, 3]], + [0, 1]]), + values=tensor([[[0, 1], + [2, 3]], - [[4, 5], - [6, 7]]]), - size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) + [[4, 5], + [6, 7]]]), + size=(1, 2, 2, 2), nnz=2, layout=torch.sparse_coo) .. seealso:: diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 61bb2cdc1b03..a5a989ec2148 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -183,7 +183,7 @@ def get_quantize_node_info( if hasattr(activation_post_process, "compute_dtype"): compute_dtype = activation_post_process.compute_dtype # type: ignore[attr-defined] quantize_op : Optional[Union[Callable, str]] = None - if dtype in [torch.quint8, torch.qint8] and \ + if dtype in [torch.quint8, torch.qint8, torch.qint32] and \ not hasattr(activation_post_process, 'compute_dtype'): node_type = "call_function" scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined] diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 920d0e7344b5..002b904d4072 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -305,10 +305,6 @@ void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor) // THPVariable_clear). // 2. We are decref-ing some other Python object. We don't do // PyObject resurrection on non-Tensors, so we just carry on as usual - if (is_tensor) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - !c10::impl::HermeticPyObjectTLS::get_state()); - } if (is_tensor && Py_REFCNT(pyobj) > 1) { // It's still alive! This can happen if a weak ref resurrected // the PyObject without flipping ownership. At this point it is diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index a2e0f05b6394..d438205e8947 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -144,7 +144,16 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { : grad_fn_; if (!is_leaf_ && !grad_fn) { - TORCH_INTERNAL_ASSERT(saved_for, "No grad_fn for non-leaf saved tensor"); + // This issue was introduced when we added logic to save the original + // because now we rely on data_.grad_fn(), but can be unreliable if the + // autograd_meta of that saved tensor is cleared with an in-place detach. + // As a simple fix, we choose to disallow that behavior here even though + // it makes behavior inconsistent depending on whether you are saving + // input or output. + TORCH_CHECK( + saved_for, + "Trying to use a saved tensor that has been detached in-place, i.e. with .detach_()." + "This is not supported, please use out-of-place `.detach()` instead"); grad_fn = std::move(saved_for); } diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index ea77bb337b4a..f825afca2a1d 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -40,6 +40,19 @@ std::tuple, c10::intrusive_ptr> allreduce_( std::move(tensor_vec), work); } +c10::intrusive_ptr allreduce_coalesced_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->allreduce_coalesced(tensor_vec, opts); +} + c10::intrusive_ptr reduce_( at::TensorList tensors, const c10::intrusive_ptr& process_group, @@ -75,6 +88,13 @@ allgather_( output_tensors, work); } +c10::intrusive_ptr _allgather_base_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + return process_group->_allgather_base(output_tensor, input_tensor); +} + std::tuple, c10::intrusive_ptr> reduce_scatter_( const std::vector& output_tensors, const std::vector>& input_tensors, @@ -177,9 +197,16 @@ TORCH_LIBRARY(c10d, m) { m.def( "allreduce_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, allreduce_)); + m.def( + "allreduce_coalesced_", + dispatch( + c10::DispatchKey::CompositeExplicitAutograd, allreduce_coalesced_)); m.def( "allgather_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, allgather_)); + m.def( + "_allgather_base_", + dispatch(c10::DispatchKey::CompositeExplicitAutograd, _allgather_base_)); m.def( "reduce_scatter_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_scatter_)); @@ -249,6 +276,25 @@ c10::intrusive_ptr allreduce( opts.timeout.count())); } +c10::intrusive_ptr allreduce_coalesced( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const AllreduceCoalescedOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::allreduce_coalesced_", "") + .typed( + at::TensorList, + const c10::intrusive_ptr<::c10d::ProcessGroup>&, + const c10::intrusive_ptr<::c10d::ReduceOp>&, + int64_t)>(); + + return op.call( + tensors, + process_group, + c10::make_intrusive(opts.reduceOp), + opts.timeout.count()); +} + c10::intrusive_ptr allgather( const c10::intrusive_ptr& process_group, const std::vector>& output_tensors, @@ -267,6 +313,21 @@ c10::intrusive_ptr allgather( output_tensors, input_tensors, process_group, opts.timeout.count())); } +c10::intrusive_ptr _allgather_base( + const c10::intrusive_ptr& process_group, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const AllgatherOptions& opts) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("c10d::_allgather_base_", "") + .typed( + at::Tensor&, + at::Tensor&, + const c10::intrusive_ptr<::c10d::ProcessGroup>&)>(); + + return op.call(output_tensor, input_tensor, process_group); +} + c10::intrusive_ptr reduce_scatter( const c10::intrusive_ptr& process_group, const std::vector& output_tensors, diff --git a/torch/csrc/distributed/c10d/Ops.hpp b/torch/csrc/distributed/c10d/Ops.hpp index adc64066a885..72f09e341d7d 100644 --- a/torch/csrc/distributed/c10d/Ops.hpp +++ b/torch/csrc/distributed/c10d/Ops.hpp @@ -21,12 +21,23 @@ TORCH_API c10::intrusive_ptr allreduce( at::TensorList tensors, const AllreduceOptions& opts = {}); +TORCH_API c10::intrusive_ptr allreduce_coalesced( + const c10::intrusive_ptr& process_group, + at::TensorList tensors, + const AllreduceCoalescedOptions& opts = {}); + TORCH_API c10::intrusive_ptr allgather( const c10::intrusive_ptr& process_group, const std::vector>& output_tensors, const std::vector& input_tensors, const AllgatherOptions& opts = {}); +TORCH_API c10::intrusive_ptr _allgather_base( + const c10::intrusive_ptr& process_group, + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const AllgatherOptions& opts = {}); + TORCH_API c10::intrusive_ptr reduce_scatter( const c10::intrusive_ptr& process_group, const std::vector& output_tensors, diff --git a/torch/csrc/distributed/c10d/OpsImpl.cpp b/torch/csrc/distributed/c10d/OpsImpl.cpp index 03ec6892857e..78e26c9656d8 100644 --- a/torch/csrc/distributed/c10d/OpsImpl.cpp +++ b/torch/csrc/distributed/c10d/OpsImpl.cpp @@ -149,6 +149,32 @@ std::tuple, c10::intrusive_ptr> allreduce_cuda_( std::move(tensor_vec), work); } +c10::intrusive_ptr allreduce_coalesced_cpu_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->allreduce_coalesced(tensor_vec, opts); +} + +c10::intrusive_ptr allreduce_coalesced_cuda_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{}; + opts.reduceOp = *reduce_op.get(); + opts.timeout = std::chrono::milliseconds(timeout); + + return process_group->allreduce_coalesced(tensor_vec, opts); +} + std::tuple>, c10::intrusive_ptr> allgather_cpu_( const std::vector>& output_tensors, @@ -185,6 +211,20 @@ allgather_cuda_( output_tensors, work); } +c10::intrusive_ptr _allgather_base_cpu_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + return process_group->_allgather_base(output_tensor, input_tensor); +} + +c10::intrusive_ptr _allgather_base_cuda_( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const c10::intrusive_ptr& process_group) { + return process_group->_allgather_base(output_tensor, input_tensor); +} + std::tuple, c10::intrusive_ptr> reduce_scatter_cpu_( const std::vector& output_tensors, @@ -367,6 +407,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("allreduce_", allreduce_cuda_); } +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("allreduce_coalesced_", allreduce_coalesced_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("allreduce_coalesced_", allreduce_coalesced_cuda_); +} + TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("allgather_", allgather_cpu_); } @@ -375,6 +423,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("allgather_", allgather_cuda_); } +TORCH_LIBRARY_IMPL(c10d, CPU, m) { + m.impl("_allgather_base_", _allgather_base_cpu_); +} + +TORCH_LIBRARY_IMPL(c10d, CUDA, m) { + m.impl("_allgather_base_", _allgather_base_cuda_); +} + TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("reduce_scatter_", reduce_scatter_cpu_); } diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 6515a3d9a87d..2424506eef0f 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -1134,10 +1134,10 @@ that adds a prefix to each key inserted to the store. .def( "allreduce_coalesced", - [](::c10d::ProcessGroup& self, - std::vector& xs, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + const std::vector& xs, ::c10d::AllreduceCoalescedOptions opts) { - return self.allreduce_coalesced(xs, opts); + return ::c10d::ops::allreduce_coalesced(self, xs, opts); }, py::arg("tensors"), py::arg("opts") = ::c10d::AllreduceCoalescedOptions(), @@ -1187,7 +1187,13 @@ that adds a prefix to each key inserted to the store. .def( "_allgather_base", - &::c10d::ProcessGroup::_allgather_base, + [](const c10::intrusive_ptr<::c10d::ProcessGroup>& self, + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const ::c10d::AllgatherOptions& opts) { + return ::c10d::ops::_allgather_base( + self, output_tensor, input_tensor, opts); + }, py::arg("output"), py::arg("input"), py::arg("opts") = ::c10d::AllgatherOptions(), diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index a72a8a2c1150..7ee48635cdff 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1148,38 +1148,65 @@ void initJITBindings(PyObject* module) { // NB: This isn't actually used for regular PyTorch symbolic tracing; // XLA is what needs this #define SYMNODE_UNARY(n) .def(#n, [](c10::SymNode a) { return a->n(); }) -#define SYMNODE_UNARY2(n2, n) .def(#n2, [](c10::SymNode a) { return a->n(); }) #define SYMNODE_BINARY(n) \ .def(#n, [](c10::SymNode a, c10::SymNode b) { return a->n(b); }) auto symnode_class = py::class_(m, "_SymNode") + // clang-format off // These DO NOT install magic methods; the SymInt/SymFloat wrapper in // Python is responsible for this SYMNODE_UNARY(clone) - // Named these for consistency with inner python class, but maybe - // should change the python side - SYMNODE_UNARY2(__bool__, bool_) SYMNODE_UNARY2(__int__, int_) - SYMNODE_UNARY2(__sym_int__, sym_int) SYMNODE_UNARY2( - __sym_float__, sym_float) SYMNODE_BINARY(add) SYMNODE_BINARY(sub) - SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) SYMNODE_BINARY(pow) - SYMNODE_BINARY(floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY( - eq) SYMNODE_BINARY(gt) SYMNODE_BINARY(lt) - SYMNODE_BINARY(le) SYMNODE_BINARY(ge) SYMNODE_BINARY(min) - SYMNODE_BINARY(max) SYMNODE_UNARY(ceil) - SYMNODE_UNARY(floor) SYMNODE_UNARY(neg) - // Intentionally don't set file line, as the - // Python backtrace matters more here - .def( - "guard_int", - [](c10::SymNode a) { - return a->guard_int(nullptr, 0); - }) - .def( - "__str__", - [](c10::SymNode a) { return a->str(); }) - .def("__repr__", [](c10::SymNode a) { - return a->str(); - }); + SYMNODE_UNARY(is_int) + SYMNODE_UNARY(is_float) + SYMNODE_UNARY(bool_) + SYMNODE_UNARY(int_) + SYMNODE_UNARY(sym_float) + SYMNODE_BINARY(add) + SYMNODE_BINARY(sub) + SYMNODE_BINARY(mul) + SYMNODE_BINARY(truediv) + SYMNODE_BINARY(pow) + SYMNODE_BINARY(floordiv) + SYMNODE_BINARY(mod) + SYMNODE_BINARY(eq) + SYMNODE_BINARY(gt) + SYMNODE_BINARY(lt) + SYMNODE_BINARY(le) + SYMNODE_BINARY(ge) + SYMNODE_BINARY(min) + SYMNODE_BINARY(max) + SYMNODE_UNARY(ceil) + SYMNODE_UNARY(floor) + SYMNODE_UNARY(neg) + // Intentionally don't set file line, as the + // Python backtrace matters more here + .def( + "guard_int", + [](c10::SymNode a) { + return a->guard_int(nullptr, 0); + }) + .def( + "guard_float", + [](c10::SymNode a) { + return a->guard_float(nullptr, 0); + }) + .def( + "wrap_int", + [](c10::SymNode a, int64_t b) { + return a->wrap_int(b); + }) + .def( + "wrap_float", + [](c10::SymNode a, double b) { + return a->wrap_float(b); + }) + .def( + "__str__", + [](c10::SymNode a) { return a->str(); }) + .def("__repr__", [](c10::SymNode a) { + return a->str(); + }); + // clang-format on // NOLINTNEXTLINE(bugprone-unused-raii) py::class_(m, "CompleteArgumentSpec") diff --git a/torch/csrc/utils/pybind.cpp b/torch/csrc/utils/pybind.cpp index 37e37a873774..4cd148fdfa91 100644 --- a/torch/csrc/utils/pybind.cpp +++ b/torch/csrc/utils/pybind.cpp @@ -25,11 +25,19 @@ py::handle type_caster::cast( return_value_policy /* policy */, handle /* parent */) { if (si.is_symbolic()) { - // TODO: generalize this to work with C++ backed class auto* py_node = dynamic_cast(si.toSymNodeImpl().get()); - TORCH_INTERNAL_ASSERT(py_node); - return torch::get_symint_class()(py_node->getPyObj()).release(); + if (py_node) { + // Return the Python directly (unwrap) + return torch::get_symint_class()(py_node->getPyObj()).release(); + } else { + // Wrap the C++ into Python + auto inner = py::cast(si.toSymNodeImpl()); + if (!inner) { + throw python_error(); + } + return torch::get_symint_class()(inner).release(); + } } else { return py::cast(si.as_int_unchecked()).release(); } diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index be402e4d5439..3a9fa79d37d6 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -164,10 +164,6 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return dispatch_common_(__FUNCTION__); } - c10::SymNode sym_int() override { - return dispatch_common_(__FUNCTION__); - } - c10::SymNode sym_float() override { return dispatch_common_(__FUNCTION__); } diff --git a/torch/distributed/_composable/fully_shard.py b/torch/distributed/_composable/fully_shard.py index 2d9e9329795b..174b2ca89a78 100644 --- a/torch/distributed/_composable/fully_shard.py +++ b/torch/distributed/_composable/fully_shard.py @@ -24,6 +24,7 @@ MixedPrecision, ShardingStrategy, ) +from torch.distributed.fsdp.wrap import _FSDPPolicy @contract @@ -32,7 +33,7 @@ def fully_shard( process_group: Optional[dist.ProcessGroup] = None, mixed_precision: Optional[MixedPrecision] = None, cpu_offload: Optional[CPUOffload] = None, - auto_wrap_policy: Optional[Callable] = None, + policy: Optional[_FSDPPolicy] = None, ignored_modules: Optional[Iterable[torch.nn.Module]] = None, device_id: Optional[Union[int, torch.device]] = None, param_init_fn: Optional[Callable[[nn.Module], None]] = None, @@ -41,6 +42,9 @@ def fully_shard( """ Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``. """ + # Enforce the new auto wrap policy + if policy is not None and not isinstance(policy, _FSDPPolicy): + raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}") state = fully_shard.state(module) state = _init_ignored_module_states(state, module, ignored_modules) state = _init_process_group_state(state, process_group) @@ -64,7 +68,7 @@ def fully_shard( state = _init_param_handles_from_module( state, module, - auto_wrap_policy, + policy, device_id, param_init_fn, sync_module_states, diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py index 324a3442dea9..b1bffdb25a0e 100644 --- a/torch/distributed/fsdp/__init__.py +++ b/torch/distributed/fsdp/__init__.py @@ -11,4 +11,3 @@ ShardingStrategy, StateDictType, ) -from .wrap import ParamExecOrderWrapPolicy diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 1265ee3578ed..7e128251fcc4 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -47,6 +47,7 @@ HandleConfig, HandleShardingStrategy, ) +from torch.distributed.fsdp.wrap import _FSDPPolicy from torch.distributed.utils import _sync_params_and_buffers from torch.utils.hooks import RemovableHandle @@ -262,7 +263,7 @@ def _init_param_handle_from_module( def _init_param_handles_from_module( state: _FSDPState, root_module: nn.Module, - auto_wrap_policy: Callable, + policy: _FSDPPolicy, device_id: Optional[Union[int, torch.device]], param_init_fn: Optional[Callable[[nn.Module], None]], sync_module_states: bool, @@ -273,7 +274,7 @@ def _init_param_handles_from_module( """ submodule_to_states = _get_submodule_to_states( root_module, - auto_wrap_policy, + policy, state._ignored_modules, state._ignored_params, ) diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 530a8480d552..70fb4156d537 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -22,9 +22,11 @@ import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file import torch.nn as nn from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.fsdp._common_utils import _get_param_to_fqns from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor -from torch.distributed.fsdp._runtime_utils import _clear_grads_if_needed +from torch.distributed.fsdp._runtime_utils import _clear_grads_if_needed, _lazy_init from torch.distributed.fsdp._shard_utils import _gather_state_dict +from torch.distributed.fsdp.api import ShardingStrategy from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle @@ -185,7 +187,7 @@ def _communicate_optim_state( # we take the target rank's value if ( fsdp_module.world_size == 1 - or fsdp_module.sharding_strategy == fsdp_file.ShardingStrategy.NO_SHARD + or fsdp_module.sharding_strategy == ShardingStrategy.NO_SHARD ): tensor_state[state_name] = value continue @@ -293,7 +295,7 @@ def _flatten_optim_state_dict( '"param_groups" to be a valid optimizer state dict' ) flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module(model) - param_to_fqns = fsdp_file._get_param_to_fqns(model) + param_to_fqns = _get_param_to_fqns(model) # Construct the "state" part flat_osd_state: Dict[_OptimStateKey, Any] = {} @@ -897,7 +899,7 @@ def _rekey_sharded_optim_state_dict( if using_optim_input else _get_param_to_param_id(optim) ) - param_to_fqns = fsdp_file._get_param_to_fqns(model) + param_to_fqns = _get_param_to_fqns(model) # All parameter keys in `param_to_flat_param_id` should be in # `param_to_fqns` -- strict inequality follows when not all parameters are # passed to the optimizer @@ -951,7 +953,7 @@ def _get_flat_param_to_fsdp_module(model: torch.nn.Module): flat_param_to_fsdp_module = {} for module in model.modules(): if isinstance(module, fsdp_file.FullyShardedDataParallel): - fsdp_file._lazy_init(module, module) + _lazy_init(module, module) for param in module.params: # may have none flat_param_to_fsdp_module[param] = module return flat_param_to_fsdp_module @@ -1165,9 +1167,7 @@ def _optim_state_dict( # Construct the local mapping between unflattened parameter names # (`_OptimStateKey`s) and parameter IDs and broadcast rank 0's mapping - param_to_fqns: Dict[torch.nn.Parameter, List[str]] = fsdp_file._get_param_to_fqns( - model - ) + param_to_fqns: Dict[torch.nn.Parameter, List[str]] = _get_param_to_fqns(model) flat_param_id_to_param: List[torch.nn.Parameter] = ( _get_param_id_to_param_from_optim_input(model, optim_input) if using_optim_input diff --git a/torch/distributed/fsdp/_wrap_utils.py b/torch/distributed/fsdp/_wrap_utils.py index 34d1c9c1ac24..cdda065df199 100644 --- a/torch/distributed/fsdp/_wrap_utils.py +++ b/torch/distributed/fsdp/_wrap_utils.py @@ -1,7 +1,7 @@ import collections import functools import warnings -from typing import Any, Callable, Deque, Dict, List, NamedTuple, Set, Tuple +from typing import Any, Deque, Dict, List, NamedTuple, Set, Tuple import torch import torch.nn as nn @@ -10,6 +10,7 @@ _override_batchnorm_mixed_precision, ) from torch.distributed.fsdp.wrap import ( + _FSDPPolicy, _or_policy, _recursive_wrap, _wrap_batchnorm_individually, @@ -45,6 +46,9 @@ def _auto_wrap( ``fsdp_kwargs`` contains all FSDP arguments except ``module``. """ auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"] + # Support new way to pass an auto wrap policy + if isinstance(auto_wrap_policy, _FSDPPolicy): + auto_wrap_policy = auto_wrap_policy.policy root_module = auto_wrap_kwargs["module"] assert auto_wrap_policy is not None # For auto wrapping, submodules should not already be wrapped with FSDP @@ -68,13 +72,13 @@ def _auto_wrap( "instances with mixed precision disabled since some batch norm " "kernels do not support low precision." ) - auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy + auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy _recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs) def _get_submodule_to_states( root_module: nn.Module, - auto_wrap_policy: Callable, + auto_wrap_policy: _FSDPPolicy, ignored_modules: Set[nn.Module], ignored_params: Set[nn.Parameter], ) -> Dict[nn.Module, SubmoduleState]: @@ -99,7 +103,7 @@ def _get_submodule_to_states( wrapper_cls = functools.partial(_record_module_wrapper_cls, wrapped_modules) _recursive_wrap( root_module, - auto_wrap_policy=auto_wrap_policy, + auto_wrap_policy=auto_wrap_policy.policy, wrapper_cls=wrapper_cls, ignored_modules=ignored_modules, ignored_params=ignored_params, @@ -158,8 +162,9 @@ def _record_module_wrapper_cls( **kwargs, ) -> nn.Module: """ - This defines a wrapper class to be passed to ``_recursive_wrap()`` that - records the wrapped module to the input ``wrapped_modules``. + This defines a pseudo-wrapper class to be passed to ``_recursive_wrap()`` + that records the wrapped module to the input ``wrapped_modules`` without + actually wrapping with a class. """ wrapped_modules.append(module) return module diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py index 0978f0875a28..b5892bca683a 100644 --- a/torch/distributed/fsdp/flat_param.py +++ b/torch/distributed/fsdp/flat_param.py @@ -838,7 +838,8 @@ def needs_unshard(self) -> bool: return False unsharded_flat_param = self._get_padded_unsharded_flat_param() already_unsharded = ( - unsharded_flat_param._typed_storage()._size() == unsharded_flat_param.numel() + unsharded_flat_param._typed_storage()._size() + == unsharded_flat_param.numel() ) return not already_unsharded @@ -1306,6 +1307,8 @@ def _use_unsharded_views(self, as_params: bool) -> None: assert tensor is not None # mypy param_var = tensor setattr(module, param_name, param_var) + if self._use_orig_params and self._training_state == HandleTrainingState.FORWARD: + module._parameters[param_name] = param_var # type: ignore[assignment] for i, ( param_name, module, @@ -1336,6 +1339,8 @@ def _use_unsharded_views(self, as_params: bool) -> None: module.register_parameter(param_name, prim_param) else: setattr(module, param_name, prim_param) + if self._use_orig_params and self._training_state == HandleTrainingState.FORWARD: + module._parameters[param_name] = prim_param # type: ignore[assignment] def _use_unsharded_grad_views(self) -> None: """ diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 510f90de2023..3e84315a4e11 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -96,14 +96,6 @@ ) from ._utils import p_assert from .flat_param import FlatParameter, FlatParamHandle -from .wrap import ParamExecOrderWrapPolicy - - -_TORCH_FX_AVAIL = True -if not hasattr(torch, "fx"): - _TORCH_FX_AVAIL = False -if _TORCH_FX_AVAIL: - from ._symbolic_trace import _init_execution_info, _patch_tracer, TracingConfig __all__ = [ @@ -207,37 +199,36 @@ class FullyShardedDataParallel(nn.Module): This configures CPU offloading. If this is set to ``None``, then no CPU offloading happens. See :class:`CPUOffload` for details. (Default: ``None``) - auto_wrap_policy (Optional[Callable[[nn.Module, bool, int], bool]]): - A callable specifying a policy to recursively wrap layers with FSDP. - Note that this policy currently will only apply to child modules of - the passed in module. The remainder modules are always wrapped in - the returned FSDP root instance. - ``size_based_auto_wrap_policy`` written in ``torch.distributed.fsdp.wrap`` is - an example of ``auto_wrap_policy`` callable, this policy wraps layers - with the number of parameters larger than 100M. ``transformer_auto_wrap_policy`` - written in ``torch.distributed.fsdp.wrap`` is an example of ``auto_wrap_policy`` - callable for transformer-like model architectures. Users can supply the customized - ``auto_wrap_policy`` callable that should accept following arguments: - ``module: nn.Module``, ``recurse: bool``, ``unwrapped_params: int``, and return - a ``bool`` specifying whether the passed in ``module``` should be wrapped - (if ``recurse=False``) or whether we should recurse down the subgraph of ``module`` - children (if ``recurse=True``). Extra customized arguments could be added to - the customized ``auto_wrap_policy`` callable as well. It is a good practice to - print out the sharded model and check whether the sharded model is what - the application wants and then adjust accordingly. + auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], _FSDPPolicy]]): + This is either ``None``, an ``_FSDPPolicy``, or a callable of + a fixed signature. If it is ``None``, then ``module`` is wrapped + with only a top-level FSDP instance without any nested wrapping. If + it is an ``_FSDPPolicy``, then the wrapping follows the given + policy. ``ModuleWrapPolicy`` in ``torch.distributed.fsdp.wrap.py`` + is an example. If it is a callable, then it should take in three + arguments ``module: nn.Module``, ``recurse: bool``, and + ``nonwrapped_numel: int`` and should return a ``bool`` specifying + whether the passed-in ``module`` should be wrapped if + ``recurse=False`` or if the traversal should continue down the + subtree if ``recurse=True``. Additional custom arguments may be + added to the callable. The ``size_based_auto_wrap_policy`` in + ``torch.distributed.fsdp.wrap.py`` gives an example callable that + wraps a module if the parameters in its subtree exceed 100M numel. + A good practice is to print the model after wrapping and adjust as + needed. Example:: >>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, - >>> unwrapped_params: int, - >>> # These are customizable for this policy function. + >>> nonwrapped_numel: int, + >>> # Additional custom arguments >>> min_num_params: int = int(1e8), >>> ) -> bool: - >>> return unwrapped_params >= min_num_params - >>> # Configure a custom min_num_params - >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=1e5) + >>> return nonwrapped_numel >= min_num_params + >>> # Configure a custom `min_num_params` + >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) backward_prefetch (Optional[BackwardPrefetch]): This configures explicit backward prefetching of all-gathers. See @@ -337,25 +328,6 @@ def __init__( limit_all_gathers: bool = False, use_orig_params: bool = False, ): - if isinstance(auto_wrap_policy, ParamExecOrderWrapPolicy): - self._init_param_exec_order_wrap_policy( - module=module, - process_group=process_group, - sharding_strategy=sharding_strategy, - cpu_offload=cpu_offload, - auto_wrap_policy=auto_wrap_policy, - backward_prefetch=backward_prefetch, - mixed_precision=mixed_precision, - ignored_modules=ignored_modules, - param_init_fn=param_init_fn, - device_id=device_id, - sync_module_states=sync_module_states, - forward_prefetch=forward_prefetch, - limit_all_gathers=limit_all_gathers, - use_orig_params=use_orig_params, - ) - return - torch._C._log_api_usage_once("torch.distributed.fsdp") super().__init__() @@ -1189,23 +1161,45 @@ def clip_grad_norm_( self._streams["unshard"], self._streams["pre_unshard"], ) - max_norm = float(max_norm) norm_type = float(norm_type) - # Compute the local gradient norm (only including this rank's shard - # of the gradients) - local_norm = _get_grad_norm(self.parameters(), norm_type).to( + # Perform local gradient norm computation, where sharded and + # non-sharded parameters must be handled separately + sharded_params = set() + nonsharded_params = set() # `NO_SHARD` or not FSDP-managed + for handle in FullyShardedDataParallel._fsdp_handles(self): + target_set = ( + sharded_params if handle.uses_sharded_strategy else nonsharded_params + ) + if handle._use_orig_params: + for param in handle.flat_param._params: + target_set.add(param) + else: + target_set.add(handle.flat_param) + for param in self.parameters(): + not_fsdp_managed = ( + param not in sharded_params and param not in nonsharded_params + ) + if not_fsdp_managed: + nonsharded_params.add(param) + local_sharded_norm = _get_grad_norm(sharded_params, norm_type).to( + self.compute_device + ) + local_nonsharded_norm = _get_grad_norm(nonsharded_params, norm_type).to( self.compute_device ) # Reconstruct the total gradient norm depending on the norm type if norm_type == math.inf: - total_norm = local_norm + total_norm = torch.maximum(local_sharded_norm, local_nonsharded_norm) dist.all_reduce( total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group ) else: - total_norm = local_norm**norm_type + total_norm = local_sharded_norm**norm_type dist.all_reduce(total_norm, group=self.process_group) + # All-reducing the local non-sharded norm would count it an extra + # world-size-many times + total_norm += local_nonsharded_norm**norm_type total_norm = total_norm ** (1.0 / norm_type) if self.cpu_offload.offload_params: total_norm = total_norm.cpu() @@ -1815,92 +1809,9 @@ def register_comm_hook(self, state: object, hook: callable): submodule._communication_hook_state = state submodule._communication_hook = hook - def _init_param_exec_order_wrap_policy(self, *args, **kwargs) -> None: - auto_wrap_policy = kwargs["auto_wrap_policy"] - module = kwargs["module"] - assert hasattr(auto_wrap_policy, "tracing_config") - if not _TORCH_FX_AVAIL: - assert ( - auto_wrap_policy.tracing_config is None - ), "tracing_config should be None when torch.fx is not enabled" - elif isinstance(auto_wrap_policy.tracing_config, TracingConfig): - tracer = auto_wrap_policy.tracing_config.tracer - execution_info = _init_execution_info(module) - - for m in module.modules(): - assert not isinstance( - m, FullyShardedDataParallel - ), "The input module of _patch_tracer should not contain FSDP modules" - - with _patch_tracer( - tracer=tracer, - root_module=module, - execution_info=execution_info, - ): - try: - tracer.trace(module, auto_wrap_policy.tracing_config.concrete_args) - except BaseException as e: - raise RuntimeError( - "tracer.trace failed inside _init_param_exec_order_wrap_policy" - f" with the error: {e}." - ) - else: - assert ( - auto_wrap_policy.tracing_config is None - ), "tracing_config should either be an instance of TracingConfig or be None" - # The initial FSDP wrapping is done with auto_wrap_policy.init_policy - kwargs["auto_wrap_policy"] = auto_wrap_policy.init_policy - self.__init__(*args, **kwargs) - self._param_exec_order_policy: bool = True - # self._param_exec_order_prep_stage is set to True before we get the execution order - self._param_exec_order_prep_stage: bool = True - # A list that stores the flatten parameters and its name based on the parameter execution order - self._fsdp_params_exec_order: List[FlatParameter] = [] - if _TORCH_FX_AVAIL and isinstance( - auto_wrap_policy.tracing_config, TracingConfig - ): - # Initialize a dict that maps each module to its parent FSDP wrap - module_to_fsdp: Dict[nn.Module, FullyShardedDataParallel] = dict() - for wrap in self.fsdp_modules(self): - module_to_fsdp[wrap.module] = wrap - # Set self._fsdp_params_exec_order based on execution_info.module_forward_order. - # TODO (linjianma): self._fsdp_params_exec_order will be set based on - # the parameter execution order rather than module_forward_order, - # once the non-recursive wrapping policy is fully implemented. - for m in execution_info.module_forward_order: - if m in module_to_fsdp: - for flat_param in module_to_fsdp[m].params: - self._fsdp_params_exec_order.append(flat_param) - self._param_exec_order_prep_stage = False - - for m in self.modules(): - if m is not self and isinstance(m, FullyShardedDataParallel): - # Assignment by reference, so each children FSDP wrap has access to - # the _fsdp_params_exec_order of the root module - m._fsdp_params_exec_order = self._fsdp_params_exec_order - m._param_exec_order_policy = self._param_exec_order_policy - m._param_exec_order_prep_stage = self._param_exec_order_prep_stage - - def _use_param_exec_order_policy(self) -> bool: - return ( - hasattr(self, "_param_exec_order_policy") and self._param_exec_order_policy - ) - - def _is_param_exec_order_prep_stage(self) -> bool: - is_prep_stage = ( - hasattr(self, "_param_exec_order_prep_stage") - and self._param_exec_order_prep_stage - ) - if not is_prep_stage: - for p in self.parameters(): - assert not hasattr( - p, "_params_exec_order_hook_handle" - ), "When not in execution order prep stage, all _params_exec_order_hook_handle should be removed." - return is_prep_stage - def _get_grad_norm( - params: List[nn.Parameter], + params: Iterable[nn.Parameter], norm_type: float, ) -> torch.Tensor: """ diff --git a/torch/distributed/fsdp/wrap.py b/torch/distributed/fsdp/wrap.py index c529bcde8c85..e20c07f18d13 100644 --- a/torch/distributed/fsdp/wrap.py +++ b/torch/distributed/fsdp/wrap.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. import contextlib -from dataclasses import dataclass +import functools +from abc import ABC, abstractmethod from typing import Any, Callable, cast, Dict, Generator, Optional, Set, Tuple, Type import torch.nn as nn @@ -17,22 +18,84 @@ "size_based_auto_wrap_policy", "enable_wrap", "wrap", - "ParamExecOrderWrapPolicy", + "ModuleWrapPolicy", ] def always_wrap_policy(*args, **kwargs) -> bool: """ - A simple wrapper policy that always returns ``True``, - i.e. when passed as the `auto_wrap_policy` into FSDP, - this will result in all submodules being wrapped as - distinct FSDP instances. + A simple recursive wrap policy that always returns ``True``. This means + that every submodule is wrapped by the wrapper class in + :func:`_recursive_wrap`. """ return True +class _FSDPPolicy(ABC): + """ + This defines an abstract base class that represents an FSDP policy for + constructing ``FlatParameter`` s. + """ + + # The motivation for this abstract base class is to hide the interface + # expected by `_recursive_wrap()` from users (i.e. the `recurse` argument). + def __init__(self): + ... + + @property + @abstractmethod + def policy(self) -> Callable: + ... + + +def _module_wrap_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + module_classes: Set[Type[nn.Module]], +) -> bool: + """ + This auto wrap policy wraps every module that is an instance of any type in + ``module_classes`` as its own FSDP instance. The root module given by + ``module`` is always wrapped as an FSDP instance regardless. Since the + wrapping proceeds bottom up, each FSDP instance manages the parameters in + its subtree excluding any already managed by a child FSDP instance. + + Args: + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + module_classes (Set[Type[nn.Module]]): Set of module classes that are + wrapped as FSDP instances. + + Returns: + ``True`` if ``recurse=True``, and whether ``module`` should be wrapped + if ``recurse=False``. + """ + if recurse: + return True # always recurse + return isinstance(module, tuple(module_classes)) + + +class ModuleWrapPolicy(_FSDPPolicy): + """This is a wrapper around :func:`_module_wrap_policy`.""" + + def __init__(self, module_classes: Set[Type[nn.Module]]): + self._policy: Callable = functools.partial( + _module_wrap_policy, + module_classes=module_classes, + ) + + @property + def policy(self): + return self._policy + + def lambda_auto_wrap_policy( - module: nn.Module, recurse: bool, unwrapped_params: int, lambda_fn: Callable + module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable ) -> bool: """ A convenient auto wrap policy to wrap submodules based on an arbitrary user @@ -44,70 +107,34 @@ def lambda_auto_wrap_policy( The first three parameters are required by :func:`_recursive_wrap`. Args: - module (nn.Module): - The module to be considered in this decision. - recurse (bool): - Indicate if this is called to make a decision on whether we - should recurse down a subgraph of the module structure. - If False, it means this function is called to make a decision - on whether we should wrap the said module. - unwrapped_params (int): - The number of parameters yet to be wrapped in this module. - - lambda_fn (Callable[nn.Module] -> bool): - If this returns ``True``, this module will be wrapped by - wrapper_cls individually. + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + + lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then + this module will be wrapped. """ if recurse: - # always recurse - return True - else: - # if not recursing, decide whether we should wrap for the leaf node or reminder - return lambda_fn(module) + return True # always recurse + return lambda_fn(module) def transformer_auto_wrap_policy( module: nn.Module, recurse: bool, - unwrapped_params: int, + nonwrapped_numel: int, transformer_layer_cls: Set[Type[nn.Module]], ) -> bool: """ - A convenient auto wrap policy for transformer models. If the submodule - is an instance of transformer_layer_cls, the submodule will be wrapped - as a FSDP unit. Otherwise, all the other remainder submodules are wrapped - by the outermost FSDP unit. Right now, FSDP requires submodules that share - weights to be wrapped in the same FSDP unit, this auto wrap policy can - conviniently wrap the shared embeddings into the same FSDP unit for transformer - models. In the near future, FSDP will support submodules that share weights - to be wrapped in the separated FSDP units. - - Return if a module should be wrapped during FSDP auto wrapping. - - The first three parameters are required by :func:`_recursive_wrap`. - - - Args: - module (nn.Module): - The module to be considered in this decision. - recurse (bool): - Indicate if this is called to make a decision on whether we - should recurse down a subgraph of the module structure. - If False, it means this function is called to make a decision - on whether we should wrap the said module. - unwrapped_params (int): - The number of parameters yet to be wrapped in this module. - - transformer_layer_cls (int): - Submodules with one of the `transformer_layer_cls` names - will be wrapped as separated FSDP units + See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the + same as ``module_classes``. Note that shared parameters must be wrapped in + the same FSDP instance, so this auto wrap policy can help wrap shared + embeddings into the same FSDP instance for transformer models. """ - if recurse: - # always recurse - return True - else: - # if not recursing, decide whether we should wrap for the leaf node or reminder - return isinstance(module, tuple(transformer_layer_cls)) + return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls) def _wrap_batchnorm_individually( @@ -117,7 +144,7 @@ def _wrap_batchnorm_individually( **kwargs, ) -> bool: """ - A policy that wraps ``BatchNorm`` instances in their own FSDP unit. + A policy that wraps ``BatchNorm`` instances in their own FSDP instance. """ if recurse: # always recurse @@ -131,52 +158,46 @@ def _wrap_batchnorm_individually( def _or_policy( module: nn.Module, recurse: bool, - unwrapped_params: int, + nonwrapped_numel: int, policies, ) -> bool: """ A policy that wraps ``module`` if any policy in the passed in iterable of ``policies`` returns ``True``. """ - return any(policy(module, recurse, unwrapped_params) for policy in policies) + return any(policy(module, recurse, nonwrapped_numel) for policy in policies) def size_based_auto_wrap_policy( module: nn.Module, recurse: bool, - unwrapped_params: int, - # These are customizable for this policy function. + nonwrapped_numel: int, + # Additional custom arguments min_num_params: int = int(1e8), force_leaf_modules: Optional[Set[Type[nn.Module]]] = None, exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None, ) -> bool: - """A size based auto_wrap_policy function for FSDP API. - - Return if a module should be wrapped during FSDP auto wrapping. - - The first three parameters are used by :func:`_recursive_wrap`. If - you write a custom version of this policy function, your version - needs to at least accept the first three parameters and free - to do whatever you want in the function. + """ + A size-based auto wrap policy. Args: - module (nn.Module): - The module to be considered in this decision. - recurse (bool): - Indicate if this is called to make a decision on whether we - should recurse down a subgraph of the module structure. - If False, it means this function is called to make a decision - on whether we should wrap the said module. - unwrapped_params (int): - The number of parameters yet to be wrapped in this module. - - min_num_params (int): - Customizable policy input. It controls the size threshold - on how big should a module be to be considered wrapped. - force_leaf_modules (Set[Type[nn.Module]]): set of module types to - keep as leaves, i.e., their children will never be wrapped. - exclude_wrap_modules (Set[Type[nn.Module]]): - Customizable set of module types to be excluded in wrapping. + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + + min_num_params (int): Customizable policy input that controls the size + threshold over which a module is ready to be wrapped. This is in + units of numel. + force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep + as leaves, i.e. their children will never be wrapped. + exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be + excluded in wrapping. + + Returns: + Whether ``module`` should be wrapped. """ force_leaf_modules = ( size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined] @@ -189,7 +210,10 @@ def size_based_auto_wrap_policy( else exclude_wrap_modules ) - is_large = unwrapped_params >= min_num_params + # Keep the argument `min_num_params` for BC for now, but it represents the + # minimum non-wrapped *numel* before triggering a wrapping + min_nonwrapped_numel = min_num_params + is_large = nonwrapped_numel >= min_nonwrapped_numel if recurse: # We should recurse if the module is big enough but not in force_leaf_modules list. return is_large and not isinstance(module, tuple(force_leaf_modules)) @@ -276,56 +300,6 @@ def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: return module -@dataclass -class ParamExecOrderWrapPolicy: - """ - This is the class used for the wrapping policy that wraps parameters and performs - the communication scheduling based on the parameter execution order in the forward pass - (also called non-recursive wrapping policy). - - The policy contains multiple wraps. Each wrap contains original parameters that will be executed together, - and the wrap transfers these parameters into one ``FlattenParameter``. In both forward and the backward passes, - the sharded parameters in each wrap will be gathered just before these parameters are used in the passes. - These parameters will then be reshaded once they have been used. - - TODO (linjianma): For now, the parameters contained in each wrap of ``ParamExecOrderWrapPolicy`` - are the parameters in each wrap of the ``init_policy`` (a recursive wrapping policy). - Later we will wrap parameters based on bucket size. - - Args: - init_policy (Callable): - The initial recursive wrapping policy used to guide the wrapping of - this policy. If tracing_config is none, in the first forward and - backward iteration, ``init_policy`` is used to record parameter - execution order. Otherwise, init_policy is only used in FSDP - constructor for module level wrapping. - - The default ``always_wrap_policy`` might not be the best choice for every model. For example, for - transformer based models, setting ``transformer_auto_wrap_policy`` as the ``init_policy`` will guarantee - wrapping each transformer layer into one FSDP unit, and can be easily combined with checkpointing - within each transformer layer. - - tracing_config (Optional[TracingConfig]): - The configuration used to perform symbolic tracing at FSDP - constructor to get the module and parameter execution order. The - type of ``tracing_config`` needs to be either ``None`` or - ``TracingConfig``. If set as ``None``, then symbolic tracing is not - enabled, and one forward as well as backward iteration are needed to - get the parameter execution order. - - ..warning :: Note that not all modules can be successfully traced when - ``tracing_config`` is not None and symbolic tracing is enabled. The two - cases below may be unable to trace: 1. when there is a data-dependent - branch, 2. when the forward pass contains operators that don't support - ``torch.fx.Proxy`` as the input type (e.g. ``arange``, ``zeros``, ``ones``, - ``full``, ``full_like``, ``eye``, ``empty``, ``tensor``). For those cases, - users can set ``tracing_config = None`` to disable symbolic tracing. - """ - - init_policy: Callable = always_wrap_policy - tracing_config: Any = None - - def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: assert wrapper_cls is not None if hasattr(module, "_wrap_overrides"): @@ -349,13 +323,13 @@ def _recursive_wrap( **kwargs: Any, ) -> Tuple[nn.Module, int]: """ - Automatically wrap child modules of *module* that meet the given - criteria with :func:`auto_wrap`. Does not rely on _ConfigAutoWrap. + Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns + ``True`` with ``wrapper_cls``. + Args: - module (nn.Module): - module to recursively wrap - auto_wrap_policy (Callable): - A callable specifying a policy to recursively wrap layers with FSDP. + module (nn.Module): Module to recursively wrap. + auto_wrap_policy (Callable): A callable representing a policy that + determines which modules to recursively wrap with ``wrapper_cls``. ignored_modules (Set[torch.nn.Module]): Modules to ignore when wrapping. ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when @@ -363,7 +337,7 @@ def _recursive_wrap( in ``ignored_modules``. Returns: (nn.Module, int): - Wrapped module and the number parameters wrapped recursively. + ``module`` after wrapping and the numel recursively wrapped. """ assert auto_wrap_policy is not None, "Must specify auto_wrap_policy." assert wrapper_cls is not None, "Must specify wrapper_cls" @@ -378,11 +352,13 @@ def _recursive_wrap( pass # We count all params, assuming none of them are already wrapped. - num_params = sum(p.numel() for p in module.parameters() if p not in ignored_params) + nonwrapped_numel = sum( + p.numel() for p in module.parameters() if p not in ignored_params + ) assert auto_wrap_policy is not None - if auto_wrap_policy(module=module, recurse=True, unwrapped_params=num_params): - total_wrapped_params = 0 + if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): + total_wrapped_numel = 0 # Iterate through the children, recursively wrap if necessary for name, child in module.named_children(): if child in ignored_modules: @@ -397,17 +373,17 @@ def _recursive_wrap( ) setattr(module, name, wrapped_child) # Keep track of how many parameters have been wrapped - total_wrapped_params += num_wrapped_params + total_wrapped_numel += num_wrapped_params # decide if we need to wrap the current module, # since the left over parameters exceed the number of params to wrap - remainder = num_params - total_wrapped_params + remainder = nonwrapped_numel - total_wrapped_numel if not only_wrap_children and auto_wrap_policy( - module=module, recurse=False, unwrapped_params=remainder + module=module, recurse=False, nonwrapped_numel=remainder ): # Leaf node or final wrapping of the remainder both happen here. - return _wrap(module, wrapper_cls, **kwargs), num_params + return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel else: - return module, total_wrapped_params + return module, total_wrapped_numel return module, 0 diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index d9b0a8fc2019..ae4427e2320e 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -126,6 +126,18 @@ def sym_int(a): return sym_floor(a) if a > 0 else sym_ceil(a) return int(a) +def to_node(self, num): + if isinstance(num, (SymInt, SymFloat)): + return num.node + elif isinstance(num, int): + return self.wrap_int(num) + elif isinstance(num, float): + return self.wrap_float(num) + else: + # NotImplemented is important so that Python tries the + # other magic method + return NotImplemented + # TODO: An incomplete list # 1. Set variables to be equal when we do equality # 2. Specialize on 0/1 when we do subtraction @@ -148,18 +160,6 @@ def expr(self): def _update_expr(self): self._expr = self.shape_env.replace(self._expr) - def to_node(self, num): - if isinstance(num, (SymInt, SymFloat)): - return num.node - elif isinstance(num, int): - return self.wrap_int(num) - elif isinstance(num, float): - return self.wrap_float(num) - else: - # NotImplemented is important so that Python tries the - # other magic method - return NotImplemented - def is_int(self): return self.pytype is int @@ -297,16 +297,15 @@ def _nyi(): always_bool_magic_methods = {"eq", "gt", "lt", "le", "ge"} def wrap_node(x): - if not isinstance(x, SymNode): - return x - if x.constant is not None: + # TODO: let C++ also take advantage of this + if isinstance(x, SymNode) and x.constant is not None: return x.constant - if x.pytype is int: + if x.is_int(): return SymInt(x) - elif x.pytype is float: + elif x.is_float(): return SymFloat(x) else: - raise AssertionError(f"unrecognized return type {x.pytype}") + raise AssertionError(f"unrecognized return type {x}") def _make_node_magic(method, func): func = lru_cache(256)(func) @@ -378,13 +377,13 @@ def unary_magic_impl(self): return wrap_node(getattr(self.node, method)()) def binary_magic_impl(self, other): - other_node = self.node.to_node(other) + other_node = to_node(self.node, other) if other_node is NotImplemented: return NotImplemented return wrap_node(getattr(self.node, method)(other_node)) def rbinary_magic_impl(self, other): - other_node = self.node.to_node(other) + other_node = to_node(self.node, other) if other_node is NotImplemented: return NotImplemented return wrap_node(getattr(other_node, method)(self.node)) @@ -457,7 +456,6 @@ def create_symbolic_sizes_strides(self, ex: torch.Tensor): We try our best to express stride in terms of the sizes, so as to not introduce new symbolic variables. """ - size = [self.create_symbol(i) for i in ex.size()] stride: List[Optional[sympy.Expr]] = [None] * len(size) for i, val in enumerate(ex.stride()): diff --git a/torch/onnx/_internal/diagnostics/_diagnostic.py b/torch/onnx/_internal/diagnostics/_diagnostic.py index 21e44f2b4467..efe5c0e34911 100644 --- a/torch/onnx/_internal/diagnostics/_diagnostic.py +++ b/torch/onnx/_internal/diagnostics/_diagnostic.py @@ -74,22 +74,6 @@ def record_cpp_call_stack(self, frames_to_skip) -> None: self.with_stack(stack) self.cpp_call_stack = stack - def with_model_source_location( - self: _ExportDiagnostic, - ) -> _ExportDiagnostic: - # TODO: Implement this. - # self.locations.append(...) - raise NotImplementedError() - return self - - def with_export_source_location( - self: _ExportDiagnostic, - ) -> _ExportDiagnostic: - # TODO: Implement this. - # self.locations.append(...) - raise NotImplementedError() - return self - class ExportDiagnosticEngine(infra.DiagnosticEngine): """PyTorch ONNX Export diagnostic engine. @@ -115,7 +99,6 @@ def __init__(self) -> None: name="torch.onnx", version=torch.__version__, diagnostic_type=ExportDiagnostic, - options=None, ) @property @@ -150,6 +133,7 @@ def create_export_diagnostic_context(): try: yield context finally: + context.pretty_print(context.options.log_verbose, context.options.log_level) context = engine.background_context diff --git a/torch/onnx/_internal/diagnostics/infra/_infra.py b/torch/onnx/_internal/diagnostics/infra/_infra.py index b8a4c5032f52..3414574cce73 100644 --- a/torch/onnx/_internal/diagnostics/infra/_infra.py +++ b/torch/onnx/_internal/diagnostics/infra/_infra.py @@ -17,10 +17,10 @@ class Level(enum.Enum): please use infra.Tag instead. """ - NONE = "none" - NOTE = "note" - WARNING = "warning" - ERROR = "error" + NONE = enum.auto() + NOTE = enum.auto() + WARNING = enum.auto() + ERROR = enum.auto() levels = Level @@ -107,6 +107,9 @@ def format_message(self, *args, **kwargs) -> str: """ return self.message_default_template.format(*args, **kwargs) + def pretty_print(self): + pass + @dataclasses.dataclass class Location: @@ -134,6 +137,25 @@ def sarif(self) -> sarif.Location: else None, ) + def pretty_print(self): + """Prints the location in a human-readable format.""" + location_strs = ["frame:"] + if self.snippet is not None: + location_strs.append(self.snippet) + if self.uri is not None: + line_strs = [self.uri] + line_strs.append(str(self.line)) if self.line is not None else "-1" + line_strs.append( + str(self.start_column) + ) if self.start_column is not None else "-1" + line_strs.append( + str(self.end_column) + ) if self.end_column is not None else "-1" + location_strs.append(":".join(line_strs)) + if self.message is not None: + location_strs.append(f"({self.message})") + print(" ".join(location_strs)) + @dataclasses.dataclass class StackFrame: @@ -143,6 +165,10 @@ def sarif(self) -> sarif.StackFrame: """Returns the SARIF representation of this stack frame.""" return sarif.StackFrame(location=self.location.sarif()) + def pretty_print(self): + """Prints the stack frame in a human-readable format.""" + self.location.pretty_print() + @dataclasses.dataclass class Stack: @@ -158,6 +184,12 @@ def sarif(self) -> sarif.Stack: else None, ) + def pretty_print(self): + """Prints the stack in a human-readable format.""" + formatter.pretty_print_title(f"Stack: {self.message}", fill_char="-") + for frame in self.frames: + frame.pretty_print() + # This is a workaround for mypy not supporting Self from typing_extensions. _Diagnostic = TypeVar("_Diagnostic", bound="Diagnostic") @@ -182,6 +214,9 @@ def sarif(self) -> sarif.Graph: properties=PatchedPropertyBag(name=self.name, description=self.description), ) + def pretty_print(self): + pass + @dataclasses.dataclass class Diagnostic: @@ -201,7 +236,7 @@ def sarif(self) -> sarif.Result: message = f"{message}\n{self.additional_message}" sarif_result = sarif.Result( message=sarif.Message(text=message), - level=self.level.value, + level=self.level.name.lower(), # type: ignore[arg-type] rule_id=self.rule.id, ) sarif_result.locations = [location.sarif() for location in self.locations] @@ -235,6 +270,31 @@ def with_additional_message(self: _Diagnostic, message: str) -> _Diagnostic: self.additional_message = f"{self.additional_message}\n{message}" return self + def pretty_print(self, verbose: bool = False, log_level: Level = Level.ERROR): + """Prints the diagnostics in a human-readable format. + + Args: + verbose: If True, prints all information. E.g. stack frames, graphs, etc. + Otherwise, only prints compact information. E.g., rule name and display message. + level: The minimum level of diagnostics to print. + """ + if self.level.value < log_level.value: + return + formatter.pretty_print_item_title(f"{self.level.name}: {self.rule.name}") + print(self.message) + + if not verbose: + print("\n") + return + + for location in self.locations: + location.pretty_print() + for stack in self.stacks: + stack.pretty_print() + for graph in self.graphs: + graph.pretty_print() + print() + @dataclasses.dataclass class RuleCollection: @@ -284,12 +344,15 @@ class DiagnosticOptions: Options for diagnostic context. """ + log_verbose: bool = dataclasses.field(default=False) + log_level: Level = dataclasses.field(default=Level.ERROR) + @dataclasses.dataclass class DiagnosticContext: name: str version: str - options: Optional[DiagnosticOptions] = None + options: DiagnosticOptions = dataclasses.field(default_factory=DiagnosticOptions) diagnostic_type: Type[Diagnostic] = dataclasses.field(default=Diagnostic) diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list) _invocation: Invocation = dataclasses.field(init=False) @@ -350,3 +413,38 @@ def diagnose( diagnostic = self.diagnostic_type(rule, level, message, **kwargs) self.add_diagnostic(diagnostic) return diagnostic + + def pretty_print( + self, verbose: bool = False, log_level: Level = Level.ERROR + ) -> None: + """Prints the diagnostics in a human-readable format. + + Args: + verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print. + level: The minimum level of diagnostics to print. + """ + formatter.pretty_print_title( + f"Diagnostic Run {self.name} version {self.version}" + ) + print(f"verbose: {verbose}, log level: {log_level}") + diagnostic_stats = {level: 0 for level in Level} + for diagnostic in self.diagnostics: + diagnostic_stats[diagnostic.level] += 1 + formatter.pretty_print_title( + " ".join(f"{diagnostic_stats[level]} {level.name}" for level in Level) + ) + + for diagnostic in self.diagnostics: + diagnostic.pretty_print(verbose, log_level) + + unprinted_diagnostic_stats = [ + (level, count) + for level, count in diagnostic_stats.items() + if count > 0 and level.value < log_level.value + ] + if unprinted_diagnostic_stats: + print( + f"{' '.join(f'{count} {level.name}' for level, count in unprinted_diagnostic_stats)} " + "were not printed due to the log level." + ) + print() diff --git a/torch/onnx/_internal/diagnostics/infra/engine.py b/torch/onnx/_internal/diagnostics/infra/engine.py index 2678268fbaf9..51a6057565bb 100644 --- a/torch/onnx/_internal/diagnostics/infra/engine.py +++ b/torch/onnx/_internal/diagnostics/infra/engine.py @@ -85,8 +85,23 @@ def create_diagnostic_context( Returns: A new diagnostic context. """ + if options is None: + options = infra.DiagnosticOptions() context = infra.DiagnosticContext( name, version, options, diagnostic_type=diagnostic_type ) self.contexts.append(context) return context + + def pretty_print( + self, verbose: bool = False, level: infra.Level = infra.Level.ERROR + ) -> None: + """Pretty prints all diagnostics in the diagnostic contexts. + + Args: + verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print. + level: The minimum level of diagnostics to print. + """ + formatter.pretty_print_title(f"{len(self.contexts)} Diagnostic Run") + for context in self.contexts: + context.pretty_print(verbose, level) diff --git a/torch/onnx/_internal/diagnostics/infra/formatter.py b/torch/onnx/_internal/diagnostics/infra/formatter.py index 2f35489f8d45..292a2b6a47a5 100644 --- a/torch/onnx/_internal/diagnostics/infra/formatter.py +++ b/torch/onnx/_internal/diagnostics/infra/formatter.py @@ -57,3 +57,21 @@ def sarif_to_json(attr_cls_obj: _SarifClass) -> str: dict = dataclasses.asdict(attr_cls_obj) dict = _convert_key(dict, _camel_case_to_snake_case) return json.dumps(dict, indent=4) + + +def pretty_print_title(title: str, width: int = 80, fill_char: str = "=") -> None: + """Pretty prints title in below format: + + ==================== title ==================== + """ + print(f" {title} ".center(width, fill_char)) + + +def pretty_print_item_title(title: str, fill_char: str = "=") -> None: + """Pretty prints title in below format: + + title + ===== + """ + print(title) + print(fill_char * len(title)) diff --git a/torch/onnx/_internal/diagnostics/infra/utils.py b/torch/onnx/_internal/diagnostics/infra/utils.py index c32de1c6b8ad..6a85df910463 100644 --- a/torch/onnx/_internal/diagnostics/infra/utils.py +++ b/torch/onnx/_internal/diagnostics/infra/utils.py @@ -6,7 +6,7 @@ def python_frame(frame: inspect.FrameInfo) -> _infra.StackFrame: """Returns a StackFrame for the given inspect.FrameInfo.""" snippet = ( - frame.code_context[frame.index] + frame.code_context[frame.index].strip() if frame.code_context is not None and frame.index is not None else None ) diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index 0dca22f48092..b4650adff569 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: distributed"] -import functools import itertools import sys from abc import ABC, abstractmethod @@ -21,11 +20,7 @@ ShardingStrategy, ) from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from torch.distributed.fsdp.wrap import ( - always_wrap_policy, - transformer_auto_wrap_policy, - wrap, -) +from torch.distributed.fsdp.wrap import always_wrap_policy, ModuleWrapPolicy, wrap from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel.distributed import DistributedDataParallel as DDP from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS @@ -285,8 +280,8 @@ def init( fsdp_init_mode (FSDPInitMode): If ``NO_FSDP``, then does not wrap any modules with FSDP. If ``RECURSIVE``, then wraps with top-level FSDP. By default, the top-level FSDP uses the - ``transformer_auto_wrap_policy()`` for encoder and decoder - layers, but a different auto wrap policy may be specified via + ``ModuleWrapPolicy`` for encoder and decoder layers, but a + different auto wrap policy may be specified via ``fsdp_kwargs``. cuda_init_mode (CUDAInitMode): Determines model movement to CUDA. fsdp_kwargs (Optional[Dict[str, Any]]): Optional keyword arguments @@ -302,14 +297,13 @@ def init( group, cuda_init_mode, add_bn, deterministic ) elif fsdp_init_mode == FSDPInitMode.RECURSIVE: - # Default to the `transformer_auto_wrap_policy()` + # Default to the `ModuleWrapPolicy` if "auto_wrap_policy" not in fsdp_kwargs: - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ + auto_wrap_policy = ModuleWrapPolicy( + { TransformerEncoderLayer, TransformerDecoderLayer, - }, + } ) else: auto_wrap_policy = fsdp_kwargs.pop("auto_wrap_policy") diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 5178ec978bd1..62c9b4750ae9 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2925,6 +2925,7 @@ def sample_inputs_max_pool(op_info, device, dtype, requires_grad, **kwargs): 'nn.functional.max_pool1d': _TestParamsMaxPool1d, 'nn.functional.max_pool2d': _TestParamsMaxPool2d, 'nn.functional.max_pool3d': _TestParamsMaxPool3d, + 'max_pool2d_with_indices_backward': _TestParamsMaxPool2d, } params_generator = params_generator_type_dict[op_info.name]() @@ -2932,6 +2933,15 @@ def sample_inputs_max_pool(op_info, device, dtype, requires_grad, **kwargs): arg = make_arg(shape).to(memory_format=memory_format).requires_grad_(requires_grad) yield SampleInput(arg, kwargs=kwargs) +def max_pool2d_backward(*args, kernel_size=(), stride=(), padding=(0,), dilation=(1,), ceil_mode=False, **kwargs): + out, indices = torch.nn.functional.max_pool2d_with_indices( + *args, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, ceil_mode=ceil_mode, return_indices=True) + grad_out = torch.ones_like(out) + if stride is None: + stride = kernel_size + out_b = torch.ops.aten.max_pool2d_with_indices_backward.default( + grad_out, *args, kernel_size, stride, padding, dilation, ceil_mode, indices) + return out_b def error_inputs_max_pool1d(op_info, device, **kwargs): # Toggle requires_grad because `max_pool1d` has different path @@ -4391,29 +4401,127 @@ def sample_repeat_tile(op_info, device, dtype, requires_grad, **kwargs): yield SampleInput(make_arg(shape), rep_dim) -def sample_inputs_narrow_copy(op_info, device, dtype, requires_grad, **kwargs): +def sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): shapes_and_args = ( - ((S, S, S), (1, 2, 2)), - ((S, S, S), (-1, 2, 2)), - ((S, S, S), (1, 0, 0)), - ((S, S, S), (-1, 0, 0)), - ((S, S, S), (2, 1, 2)), + ((S, S, S), 1, 2, 2), + ((S, S, S), -1, 2, 2), + ((S, S, S), 1, 0, 0), + ((S, S, S), -1, 0, 0), + ((S, S, S), 2, 1, 2), ) - for shape, args in shapes_and_args: + for shape, dim, start, length in shapes_and_args: tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) - yield SampleInput(tensor, args=args) + yield SampleInput(tensor, dim, start, length) + # narrow also accepts the start argument being a Tensor + if is_narrow: + yield SampleInput(tensor, dim, torch.tensor(start), length) +def reference_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, *, is_narrow, **kwargs): + yield from sample_inputs_narrow_narrow_copy(op_info, device, dtype, requires_grad, is_narrow=is_narrow, **kwargs) -def sample_inputs_narrow(op_info, device, dtype, requires_grad, **kwargs): - ''' - sample_inputs_narrow accepts the same inputs as narrow_copy, in addition - narrow also accepts `start` argument to be a Tensor. - ''' - for sample in sample_inputs_narrow_copy(op_info, device, dtype, requires_grad, **kwargs): - yield sample - yield SampleInput(sample.input, args=(sample.args[0], torch.tensor(sample.args[1]), sample.args[2])) + shapes_and_args = ( + # 1-dim + ((M,), 0, 0, 0), # 0 elems from the left + ((M,), -1, -1, 0), # 0 elems from the right + ((M,), 0, 5, 3), # 3 elems from the left + ((M,), 0, -5, 2), # 2 elems from the right + ((M,), -1, 0, M), # M elems from the left + ((M,), 0, -M, M), # M elems from the right + + # 2-dim + ((M, S), 1, 0, 0), # dim 1, 0 elems from the left + ((S, M), -2, -1, 0), # dim 0, 0 elems from the right + ((L, S), 1, 2, 3), # dim 1, 3 elems from the left + ((L, S), -1, 3, 2), # dim 1, 2 elems from the left + ((M, L), 0, 0, M), # dim 0, M elems from the left + ((M, L), -1, -L, L), # dim 1, L elems from the right + + # 3-dim + ((L, M, S), 2, 0, 0), # dim 2, 0 elems from the left + ((M, S, L), -1, -1, 0), # dim 2, 0 elems from the right + ((S, L, M), 2, 0, M), # dim 2, M elems from the left + ((L, S, M), -1, -M, M), # dim 2, M elems from the right + ((S, L, M), 1, 0, 0), # dim 1, 0 elems from the left + ((S, L, M), 0, 2, 1), # dim 0, 1 elem from the left + ((M, S, M), -1, -5, 4), # dim 2, 4 elems from the right + ) + + for shape, dim, start, length in shapes_and_args: + tensor = make_tensor(shape, dtype=dtype, device=device, low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(tensor, dim, start, length) + # narrow also accepts the start argument being a Tensor + if is_narrow: + yield SampleInput(tensor, dim, torch.tensor(start), length) + +def error_inputs_narrow_narrow_copy(op_info, device, *, is_narrow, is_ref): + make_arg = partial(make_tensor, device=device, dtype=torch.float32) + + # 0-dim + yield ErrorInput(SampleInput(make_arg(()), 0, 0, 1), + error_type=RuntimeError, + error_regex=r"narrow\(\) cannot be applied to a 0-dim tensor\.") + + # out of bounds dim + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), + error_type=RuntimeError, + error_regex=r"Expected dim < static_cast\(self_sizes.size\(\)\) to be true, but got false\.") + else: + yield ErrorInput(SampleInput(make_arg((M, S, L)), 3, 0, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got 3\)") + # out of bounds dim (negative) + yield ErrorInput(SampleInput(make_arg((L, S, M)), -4, 0, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-3, 2\], but got -4\)") + + # out of bounds start + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0), + error_type=RuntimeError, + error_regex=r"start \(11\) \+ length \(0\) exceeds dimension size \(10\)\.") + else: + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, M + 1, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got 11\)") + # out of bounds start (negative) + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, -M - 1, 0), + error_type=IndexError, + error_regex=r"Dimension out of range \(expected to be in range of \[-10, 9\], but got -11\)") + + # out of bounds length + yield ErrorInput(SampleInput(make_arg((S, L, M)), 2, 0, M + 1), + error_type=RuntimeError, + error_regex=r"start \(0\) \+ length \(11\) exceeds dimension size \(10\)\.") + # out of bounds length (negative) + if not is_narrow and not is_ref and torch.device(device).type == 'cpu': + # narrow_copy_dense_cpu_out + yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), + error_type=RuntimeError, + error_regex=r"start \(0\) \+ length \(-1\) exceeds dimension size \(10\)\.") + else: + yield ErrorInput(SampleInput(make_arg((M,)), 0, 0, -1), + error_type=RuntimeError, + error_regex=r"narrow\(\): length must be non-negative\.") + + # Test Tensor overload that was added for XLA. Start must be an 0-dim + # integral Tensor. narrow_copy doesn't have this overload. + # https://github.com/pytorch/pytorch/issues/31558 + if is_narrow: + # *1-dim* integral Tensor + yield ErrorInput(SampleInput(make_arg((L, M, S)), 1, make_arg(S, dtype=torch.int), 2), + error_type=RuntimeError, + error_regex=r"start must be an 0-dim integral Tensor\.") + + # 0-dim *bool* Tensor (bools are not allowed) + yield ErrorInput(SampleInput(make_arg((L, M, S)), -3, make_arg((), dtype=torch.bool), 3), + error_type=RuntimeError, + error_regex=r"start must be an 0-dim integral Tensor\.") def sample_trapezoid(op_info, device, dtype, requires_grad, **kwargs): @@ -5131,6 +5239,28 @@ def sample_inputs_view_as_real(op_info, device, dtype, requires_grad, **kwargs): sizes = ((S, S), ()) return (SampleInput(make_arg(size)) for size in sizes) +def error_inputs_complex(op_info, device, is_ref=False, **kwargs): + make_arg = partial(make_tensor, dtype=torch.float32, device=device) + + if is_ref: + error_float = "Expected both inputs to be Half, Float or Double tensors but got torch.float32 and torch.int32" + error_dtype = "Expected object of scalar type torch.float32 but got scalar type torch.float64 for second argument" + error_out = "Expected out tensor to have dtype torch.complex128 but got torch.complex64 instead" + else: + error_float = "Expected both inputs to be Half, Float or Double tensors but got Float and Int" + error_dtype = "Expected object of scalar type Float but got scalar type Double for second argument" + error_out = "Expected object of scalar type ComplexDouble but got scalar type ComplexFloat for argument 'out'" + + yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.int)), + error_type=RuntimeError, error_regex=error_float) + + yield ErrorInput(SampleInput(make_arg(M, S), make_arg(M, S, dtype=torch.float64)), + error_type=RuntimeError, error_regex=error_dtype) + + yield ErrorInput(SampleInput(make_arg(M, S, dtype=torch.float64), make_arg(M, S, dtype=torch.float64), + out=make_arg(M, S, dtype=torch.complex64)), + error_type=RuntimeError, error_regex=error_out) + def sample_inputs_prod(op_info, device, dtype, requires_grad, **kwargs): def make_arg(shape): # shrink values to be in the interval [-1, +1] for better precision in gradgradcheck @@ -8989,6 +9119,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_forward_ad=True, supports_fwgrad_bwgrad=True, supports_rhs_python_scalar=False, + error_inputs_func=error_inputs_complex, skips=( # Test doesn't account for complex's type promotion semantics DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), @@ -11469,6 +11600,31 @@ def reference_flatten(input, start_dim=0, end_dim=-1): dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), error_inputs_func=error_inputs_max_pool2d, sample_inputs_func=sample_inputs_max_pool), + OpInfo('max_pool2d_with_indices_backward', + op=max_pool2d_backward, + # We've defined a custom op, so there's no corresponding aten op + aten_name=None, + method_variant=None, + inplace_variant=None, + operator_variant=None, + inplace_operator_variant=None, + check_batched_gradgrad=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + check_batched_forward_grad=False, + assert_jit_shape_analysis=False, + dtypes=floating_types_and(torch.bfloat16), + dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16), + sample_inputs_func=sample_inputs_max_pool, + skips=( + # We've defined a custom op here, and we don't handle the case where we receive an out kwarg + DecorateInfo(unittest.skip("Skipped!"), "TestCommon", "test_out"), + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), + # FX failed to normalize op - add the op to the op_skip list. + DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'), + # object has no attribute max_pool2d_with_indices_backward (It's not available on torch -- so expected) + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit') + )), OpInfo('nn.functional.max_pool3d', aten_name='max_pool3d', # Runs very slowly on slow gradcheck - alternatively reduce input sizes @@ -12407,7 +12563,9 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, - sample_inputs_func=sample_inputs_narrow, + sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=True), + reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=True), + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=False), skips=( # Use of .item() DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'), @@ -12423,15 +12581,16 @@ def reference_flatten(input, start_dim=0, end_dim=-1): supports_fwgrad_bwgrad=False, supports_autograd=False, # https://github.com/pytorch/pytorch/issues/86931 - sample_inputs_func=sample_inputs_narrow_copy, + sample_inputs_func=partial(sample_inputs_narrow_narrow_copy, is_narrow=False), + reference_inputs_func=partial(reference_inputs_narrow_narrow_copy, is_narrow=False), + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=False), skips=( # https://github.com/pytorch/pytorch/issues/84577 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), - # Not implemented - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_outplace', device_type='cuda'), - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_outplace', device_type='cuda'), - DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_symbolic_meta', device_type='cuda'), + # Lazy tensor failures: mutating and aliasing ops should all have codegen'd kernels + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'), + DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'), )), UnaryUfuncInfo('neg', aliases=('negative', ), @@ -17797,6 +17956,17 @@ def reference_flatten(input, start_dim=0, end_dim=-1): DecorateInfo(unittest.skip('Overflow when downcasting signed type is undefined'), 'TestCommon', 'test_compare_cpu'), ) ), + ElementwiseBinaryPythonRefInfo( + "_refs._conversions.complex", + torch_opinfo_name="complex", + error_inputs_func=partial(error_inputs_complex, is_ref=True), + # prims.empty_strided.default does not support nvfuser + supports_nvfuser=False, + skips=( + # Test doesn't account for complex's type promotion semantics + DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), + ) + ), ElementwiseUnaryPythonRefInfo( "_refs._conversions.double", torch_opinfo_name="double", @@ -18061,22 +18231,20 @@ def reference_flatten(input, start_dim=0, end_dim=-1): "_refs.narrow", torch_opinfo_name="narrow", supports_nvfuser=False, - skips=( - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'), - ) - ), - PythonRefInfo( - "_refs.nn.functional.group_norm", - torch_opinfo_name="nn.functional.group_norm", - supports_nvfuser=False, - validate_view_consistency=False, + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=True, is_ref=True), ), PythonRefInfo( "_refs.narrow_copy", torch_opinfo_name="narrow_copy", supports_out=True, supports_nvfuser=False, + error_inputs_func=partial(error_inputs_narrow_narrow_copy, is_narrow=False, is_ref=True), + ), + PythonRefInfo( + "_refs.nn.functional.group_norm", + torch_opinfo_name="nn.functional.group_norm", + supports_nvfuser=False, + validate_view_consistency=False, ), PythonRefInfo( "_refs.native_layer_norm", diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py new file mode 100644 index 000000000000..84750a2de3ee --- /dev/null +++ b/torch/testing/_internal/inductor_utils.py @@ -0,0 +1,23 @@ +from subprocess import CalledProcessError + +from torch._inductor.codecache import CppCodeCache +from torch._inductor.utils import has_triton +from torch.testing._internal.common_utils import ( + IS_FBCODE, + TEST_WITH_ROCM, +) +import torch + +HAS_CPU = False +try: + CppCodeCache.load("") + HAS_CPU = not IS_FBCODE +except ( + CalledProcessError, + OSError, + torch._inductor.exc.InvalidCxxCompiler, + torch._inductor.exc.CppCompileError, +): + pass + +HAS_CUDA = has_triton() and not TEST_WITH_ROCM diff --git a/torch/utils/data/__init__.py b/torch/utils/data/__init__.py index 6fe6147ddc54..bc054a947069 100644 --- a/torch/utils/data/__init__.py +++ b/torch/utils/data/__init__.py @@ -39,8 +39,6 @@ runtime_validation, runtime_validation_disabled, ) -from torch.utils.data.dataloader_experimental import DataLoader2 -from torch.utils.data import communication __all__ = ['BatchSampler', 'ChainDataset', @@ -48,7 +46,6 @@ 'DFIterDataPipe', 'DataChunk', 'DataLoader', - 'DataLoader2', 'Dataset', 'DistributedSampler', 'IterDataPipe', @@ -63,8 +60,6 @@ 'WeightedRandomSampler', '_DatasetKind', 'argument_validation', - 'collate', - 'communication', 'default_collate', 'default_convert', 'functional_datapipe', diff --git a/torch/utils/data/communication/__init__.py b/torch/utils/data/communication/__init__.py deleted file mode 100644 index 1b9cae401189..000000000000 --- a/torch/utils/data/communication/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from . import eventloop -from . import iter -from . import map -from . import messages -from . import protocol -from . import queue diff --git a/torch/utils/data/communication/eventloop.py b/torch/utils/data/communication/eventloop.py deleted file mode 100644 index 9bf241d334df..000000000000 --- a/torch/utils/data/communication/eventloop.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -import threading -import pickle - -from torch.utils.data import IterDataPipe, communication, MapDataPipe - -try: - import dill - # XXX: By default, dill writes the Pickler dispatch table to inject its - # own logic there. This globally affects the behavior of the standard library - # pickler for any user who transitively depends on this module! - # Undo this extension to avoid altering the behavior of the pickler globally. - dill.extend(use_dill=False) - HAS_DILL = True -except ImportError: - HAS_DILL = False - -__all__ = [ - "DataPipeToQueuesLoop", - "SpawnProcessForDataPipeline", - "SpawnThreadForDataPipeline", -] - -def DataPipeToQueuesLoop(source_datapipe, req_queue, res_queue): - if isinstance(source_datapipe, IterDataPipe): - pipe_type = communication.iter - protocol_type = communication.protocol.IterDataPipeQueueProtocolServer - elif isinstance(source_datapipe, MapDataPipe): - pipe_type = communication.map # type: ignore[misc] - protocol_type = communication.protocol.MapDataPipeQueueProtocolServer # type: ignore[assignment] - else: - raise Exception('Only supports IterDataPipe or MapDataPipe, got', source_datapipe) - - torch.set_num_threads(1) - for _ in pipe_type.DataPipeBehindQueues(source_datapipe, protocol_type(req_queue, res_queue), - blocking_request_get=True): - pass - - -def SpawnProcessForDataPipeline(multiprocessing_ctx, datapipe): - req_queue = multiprocessing_ctx.Queue() - res_queue = multiprocessing_ctx.Queue() - process = multiprocessing_ctx.Process( - target=DataPipeToQueuesLoop, args=(datapipe, req_queue, res_queue)) - return process, req_queue, res_queue - - -def SpawnThreadForDataPipeline(datapipe): - r""" - Given a DataPipe, creates a copy of the DataPipe, starts a new Thread with DataPipeToQueuesLoop as target, - and return the process, req_queue, res_queue, thread_local_datapipe. - """ - req_queue = communication.queue.ThreadingQueue() - res_queue = communication.queue.ThreadingQueue() - - try: - new_datapipe = pickle.loads(pickle.dumps(datapipe)) - except Exception as pe: - if HAS_DILL: - try: - new_datapipe = dill.loads(dill.dumps(datapipe)) - except Exception as de: - raise Exception('Unable to dill DataPipe to make thread local copy', de) - - else: - raise Exception('Unable to pickle DataPipe to make thread local copy (consider installing `dill`)', pe) - - process = threading.Thread(target=DataPipeToQueuesLoop, args=( - new_datapipe, req_queue, res_queue), daemon=True) - return process, req_queue, res_queue, new_datapipe diff --git a/torch/utils/data/communication/iter.py b/torch/utils/data/communication/iter.py deleted file mode 100644 index 94f7cd2ec703..000000000000 --- a/torch/utils/data/communication/iter.py +++ /dev/null @@ -1,181 +0,0 @@ -import time -import types - -from torch.utils.data import IterDataPipe, communication - -DEFAULT_NON_BLOCKING_SLEEP = 0.001 - -__all__ = [ - "DataPipeBehindQueues", - "EnsureNonBlockingDataPipe", - "InvalidStateResetRequired", - "NonBlocking", - "NotAvailable", - "QueueWrapper", - "default_not_available_hook", -] - - -def default_not_available_hook(): - time.sleep(DEFAULT_NON_BLOCKING_SLEEP) - - -class NotAvailable(Exception): - pass - - -class InvalidStateResetRequired(Exception): - """ - Returned by DataPipe when it is expecting to get reset request, - for example RouterDataPipe expecting all workers to request reset' - """ - pass - - -class NonBlocking(IterDataPipe): - not_available_hook = default_not_available_hook - - def __iter__(self): - self.reset_iterator() - return self - - def __next__(self): - while True: - try: - return self.nonblocking_next() - except StopIteration: - raise StopIteration - except NotAvailable: - if NonBlocking.not_available_hook is not None: - NonBlocking.not_available_hook() - - def nonblocking_next(self): - raise NotImplementedError( - "nonblocking_next is not implemented for %s" % self.__class__) - - def reset_iterator(self): - raise NotImplementedError( - "reset_iterator is not implemented for %s" % self.__class__) - - @staticmethod - def register_not_available_hook(hook_function): - NonBlocking.not_available_hook = hook_function - - -def EnsureNonBlockingDataPipe(validated_datapipe): - if not isinstance(validated_datapipe, IterDataPipe): - raise Exception('Not Iterable DataPipe ' + - str(validated_datapipe.__class__)) - if isinstance(validated_datapipe, NonBlocking): - return validated_datapipe - if not hasattr(validated_datapipe, '_as_iterator'): - validated_datapipe._as_iterator = None # type: ignore[attr-defined] - if not hasattr(validated_datapipe, 'nonblocking_next'): - def nonblocking_next(self): - if self._as_iterator is None: - self._as_iterator = iter(self) - return next(self._as_iterator) - validated_datapipe.nonblocking_next = types.MethodType( # type: ignore[attr-defined] - nonblocking_next, validated_datapipe) - if not hasattr(validated_datapipe, 'reset_iterator'): - def reset_iterator(self): - self._as_iterator = None - validated_datapipe.reset_iterator = types.MethodType( # type: ignore[attr-defined] - reset_iterator, validated_datapipe) - return validated_datapipe - - -def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False): - """ - Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue - If raise_stop is true, raises exception when StopIteration received from the source_datapipe - """ - if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolServer): - raise Exception('Expecting IterDataPipeQueueProtocolServer, got', protocol) - source_datapipe = EnsureNonBlockingDataPipe(source_datapipe) - forever = True - while forever: - try: - # Non-blocking call is Extremely slow here for python.mp, need to figure out a good workaround - request = protocol.get_new_request(block=blocking_request_get) - except communication.protocol.EmptyQueue: - yield True - continue - - if isinstance(request, communication.messages.ResetIteratorRequest): - source_datapipe.reset_iterator() - protocol.response_reset_iterator() - - elif isinstance(request, communication.messages.TerminateRequest): - forever = False - protocol.response_terminate() - - elif isinstance(request, communication.messages.GetNextRequest): - while forever: - try: - value = source_datapipe.nonblocking_next() - except NotAvailable: - yield True - continue - except StopIteration: - protocol.response_stop_iteration() - if full_stop: - forever = False - else: - yield True - break - except InvalidStateResetRequired: - protocol.response_invalid_state() - if full_stop: - forever = False - else: - yield True - break - protocol.response_next(value) - yield True # Returns control - break - else: - raise Exception('Unrecognized type of request received', request) - - -class QueueWrapper(NonBlocking): - """ - Creates iter.DataPipe which reads data from the DataLoader.Queue - """ - - def __init__(self, protocol, response_wait_time=0.00001): - if not isinstance(protocol, communication.protocol.IterDataPipeQueueProtocolClient): - raise Exception('Got', protocol) - self.protocol = protocol - self.counter = 0 - self._stop_iteration = False - self._response_wait_time = response_wait_time - - def reset_iterator(self): - self._stop_iteration = False - self.counter = 0 - self.protocol.request_reset_iterator() - while True: - try: - self.protocol.get_response_reset_iterator() - break - except communication.protocol.EmptyQueue: - if NonBlocking.not_available_hook is not None: - NonBlocking.not_available_hook() - - def nonblocking_next(self): - if self._stop_iteration: - raise Exception( - '`next` or `nonblocking_next` called after receiving StopIteration') - if self.protocol.can_take_request(): - self.protocol.request_next() - try: - response = self.protocol.get_response_next(block=True, timeout=self._response_wait_time) - except communication.protocol.EmptyQueue: - raise NotAvailable - if isinstance(response, communication.messages.StopIterationResponse): - self._stop_iteration = True - raise StopIteration - if isinstance(response, communication.messages.InvalidStateResponse): - raise NotAvailable - return response.value diff --git a/torch/utils/data/communication/map.py b/torch/utils/data/communication/map.py deleted file mode 100644 index 8af63bf0c73e..000000000000 --- a/torch/utils/data/communication/map.py +++ /dev/null @@ -1,159 +0,0 @@ -import time -import types - -from torch.utils.data import communication, MapDataPipe - -DEFAULT_NON_BLOCKING_SLEEP = 0.001 - -__all__ = [ - "DataPipeBehindQueues", - "EnsureNonBlockingMapDataPipe", - "NonBlockingMap", - "NotAvailable", - "QueueWrapperForMap", - "default_not_available_hook", -] - - -def default_not_available_hook(): - time.sleep(DEFAULT_NON_BLOCKING_SLEEP) - - -class NotAvailable(Exception): - pass - - -class NonBlockingMap(MapDataPipe): - not_available_hook = default_not_available_hook - - def __getitem__(self, index): - while True: - try: - return self.nonblocking_getitem(index) - except NotAvailable: - if NonBlockingMap.not_available_hook is not None: - NonBlockingMap.not_available_hook() - - def __len__(self): - try: - return self.nonblocking_len() - except NotAvailable: - if NonBlockingMap.not_available_hook is not None: - NonBlockingMap.not_available_hook() - - def nonblocking_len(self): - raise NotImplementedError( - "nonblocking_len is not implemented for %s" % self.__class__) - - def nonblocking_getitem(self, index): - raise NotImplementedError( - "nonblocking_getitem is not implemented for %s" % self.__class__) - - @staticmethod - def register_not_available_hook(hook_function): - NonBlockingMap.not_available_hook = hook_function - - -def EnsureNonBlockingMapDataPipe(validated_datapipe): - if not isinstance(validated_datapipe, MapDataPipe): - raise Exception(f'Not Map DataPipe - got {validated_datapipe.__class__}') - if isinstance(validated_datapipe, NonBlockingMap): - return validated_datapipe - if not hasattr(validated_datapipe, 'nonblocking_len'): - def nonblocking_len(self): - return self.__len__() - validated_datapipe.nonblocking_len = types.MethodType( # type: ignore[attr-defined] - nonblocking_len, validated_datapipe) - if not hasattr(validated_datapipe, 'nonblocking_getitem'): - def nonblocking_getitem(self, index): - return self.__getitem__(index) - validated_datapipe.nonblocking_getitem = types.MethodType( # type: ignore[attr-defined] - nonblocking_getitem, validated_datapipe) - return validated_datapipe - - -def DataPipeBehindQueues(source_datapipe, protocol, full_stop=False, blocking_request_get=False): - """ - Indefinitely iterates over req_queue and passing values from source_datapipe to res_queue - If raise_stop is true, raises exception when StopIteration received from the source_datapipe - """ - if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolServer): - raise Exception('Expecting MapDataPipeQueueProtocolServer, got', protocol) - source_datapipe = EnsureNonBlockingMapDataPipe(source_datapipe) - forever = True - while forever: - try: - # Non-blocking call is Extremely slow here for python.mp, need to figure out a good workaround - request = protocol.get_new_request(block=blocking_request_get) - except communication.protocol.EmptyQueue: - yield True - continue - - if isinstance(request, communication.messages.TerminateRequest): - forever = False - protocol.response_terminate() - - elif isinstance(request, communication.messages.LenRequest): - size = source_datapipe.nonblocking_len() - protocol.response_len(size) - - elif isinstance(request, communication.messages.GetItemRequest): - while forever: - try: - value = source_datapipe.nonblocking_getitem(request.key) - except NotAvailable: - yield True - continue - except IndexError as e: - # Alternatively, we can just allow the underlying DataPipe to throw an exception? - protocol.response_index_out_of_bound() - if full_stop: - forever = False - else: - yield True - break - protocol.response_item(request.key, value) - yield True # Returns control - break - else: - raise Exception('Unrecognized type of request received', request) - - -class QueueWrapperForMap(NonBlockingMap): - """ - Creates map.DataPipe which reads data from the DataLoader.Queue - """ - def __init__(self, protocol, response_wait_time=0.00001): - if not isinstance(protocol, communication.protocol.MapDataPipeQueueProtocolClient): - raise Exception('Got', protocol) - self.protocol = protocol - self.counter = 0 - self._stop_iteration = False - self._response_wait_time = response_wait_time - - def nonblocking_getitem(self, index): - if self._stop_iteration: - raise Exception( - '`getitem` or `nonblocking_getitem` called after receiving StopIteration') - if self.protocol.can_take_request(): - self.protocol.request_item(index) - try: - response = self.protocol.get_response_item(block=True, timeout=self._response_wait_time) - except communication.protocol.EmptyQueue: - raise NotAvailable - if isinstance(response, communication.messages.StopIterationResponse): - self._stop_iteration = True - raise IndexError(f"Index {index} is out of bound.") - return response.key, response.value - - def nonblocking_len(self): - if self._stop_iteration: - raise Exception( - '`len` or `nonblocking_len` called after receiving StopIteration') - if self.protocol.can_take_request(): - self.protocol.request_len() - try: - response = self.protocol.get_response_len(block=True, timeout=self._response_wait_time) - except communication.protocol.EmptyQueue: - raise NotAvailable - return response.len diff --git a/torch/utils/data/communication/messages.py b/torch/utils/data/communication/messages.py deleted file mode 100644 index 449cf23cfc01..000000000000 --- a/torch/utils/data/communication/messages.py +++ /dev/null @@ -1,75 +0,0 @@ -class DataLoaderQueueMessage(object): - pass - - -class Request(DataLoaderQueueMessage): - pass - - -class Response(DataLoaderQueueMessage): - pass - - -class ResetIteratorRequest(Request): - pass - - -class ResetIteratorResponse(Response): - pass - - -class TerminateRequest(Request): - pass - - -class TerminateResponse(Response): - pass - - -class LenRequest(Request): - pass - - -class LenResponse(Response): - __slots__ = ('len') - - def __init__(self, len): - self.len = len - - -class GetItemRequest(Request): - __slots__ = ('key') - - def __init__(self, key): - self.key = key - - -class GetItemResponse(Response): - __slots__ = ('key', 'value') - - def __init__(self, key, value): - self.key = key - self.value = value - - -class GetNextRequest(Request): - pass - - -class GetNextResponse(Response): - __slots__ = ('value') - - def __init__(self, value): - self.value = value - - -class StopIterationResponse(Response): - pass - - -class InvalidStateResponse(Response): - """ - Returned by DataPipe when it is expecting to get reset request, - for example RouterDataPipe expecting all workers to request reset' - """ - pass diff --git a/torch/utils/data/communication/protocol.py b/torch/utils/data/communication/protocol.py deleted file mode 100644 index 5bf5fe1af062..000000000000 --- a/torch/utils/data/communication/protocol.py +++ /dev/null @@ -1,205 +0,0 @@ -from torch.utils.data import communication - - -class Protocol(object): - __slots__ = ('request_queue', 'response_queue') - - def __init__(self, request_queue, response_queue): - self.request_queue = request_queue - self.response_queue = response_queue - - -class ProtocolClient(Protocol): - """ - ProtocolClient takes charge of putting requests into req_queue and returning results from res_queue. - """ - _req_sent = None - - def __init__(self, request_queue, response_queue): - self.request_queue = request_queue - self.response_queue = response_queue - self._req_sent = None - - def can_take_request(self): - return self._req_sent is None - - def waiting_for_response(self): - return self._req_sent is not None - - def request_sent(self, request=True): - if not self.can_take_request(): - raise Exception('Protocol only supports one request in the Queue') - self._req_sent = request - - def request_served(self, result=None): - if not self.waiting_for_response(): - raise Exception( - 'Expected no peding requests, but something got served', result) - self._req_sent = None - - -class ProtocolServer(Protocol): - """ - ProtocolServer takes charge of getting requests from req_queue and fetching data from source datapipe. - """ - _req_received = None - - def __init__(self, request_queue, response_queue): - self.request_queue = request_queue - self.response_queue = response_queue - self._req_received = None - - def have_pending_request(self): - return self._req_received is not None - - def get_new_request(self, block=False): - if self.have_pending_request(): - raise Exception( - 'Trying to get next request, while having one unserved') - try: - response = self.request_queue.get(block=block) - except Exception as e: # TODO: Catch only timeout exceptions - raise EmptyQueue('queue is empty') - self._req_received = response - return response - # TODO: Validate supported requests - - def response_terminate(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - if not isinstance(self._req_received, communication.messages.TerminateRequest): - raise Exception( - "Replaying with terminate status to other type of message") - self.response_queue.put(communication.messages.TerminateResponse()) - self._req_received = None - - -class MapDataPipeQueueProtocolServer(ProtocolServer): - def response_item(self, key, value): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.GetItemResponse(key, value)) - self._req_received = None - - def response_len(self, size): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.LenResponse(size)) - self._req_received = None - - def response_index_out_of_bound(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.StopIterationResponse()) - self._req_received = None - -class MapDataPipeQueueProtocolClient(ProtocolClient): - def request_len(self): - if not self.can_take_request(): - raise Exception('Can not request len while we are still waiting response for previous request') - request = communication.messages.LenRequest() - self.request_queue.put(request) - self.request_sent(request) - - def request_item(self, index): - if not self.can_take_request(): - raise Exception('Can not request item while we are still waiting response for previous request') - request = communication.messages.GetItemRequest(index) - self.request_queue.put(request) - self.request_sent(request) - - def get_response_len(self, block=False, timeout=None): - if not self.waiting_for_response(): - raise Exception('Can not expect any response without submitted request') - try: - response = self.response_queue.get(block=block, timeout=timeout) - except TimeoutError: - raise EmptyQueue('queue is empty') - self.request_served(response) - if not isinstance(response, communication.messages.LenResponse): - raise Exception('Invalid response received') - return response - - def get_response_item(self, block=False, timeout=None): - if not self.waiting_for_response(): - raise Exception('Can not expect any response without submitted request') - try: - response = self.response_queue.get(block=block, timeout=timeout) - except TimeoutError: - raise EmptyQueue('queue is empty') - self.request_served(response) - # if not isinstance(response, communication.messages.GetItemResponse): - # raise Exception('Invalid response received') - return response - - -class EmptyQueue(Exception): - pass - - -class IterDataPipeQueueProtocolServer(ProtocolServer): - def response_reset_iterator(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - if not isinstance(self._req_received, communication.messages.ResetIteratorRequest): - raise Exception( - "Replaying with reset status to other type of message") - self.response_queue.put(communication.messages.ResetIteratorResponse()) - self._req_received = None - - def response_next(self, value): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.GetNextResponse(value)) - self._req_received = None - - def response_stop_iteration(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.StopIterationResponse()) - self._req_received = None - - def response_invalid_state(self): - if not self.have_pending_request(): - raise Exception("Attempting to reply with pending request") - self.response_queue.put(communication.messages.InvalidStateResponse()) - self._req_received = None - - -class IterDataPipeQueueProtocolClient(ProtocolClient): - def request_reset_iterator(self): - if not self.can_take_request(): - raise Exception('Can not reset while we are still waiting response for previous request') - request = communication.messages.ResetIteratorRequest() - self.request_queue.put(request) - self.request_sent(request) - - def request_next(self): - if not self.can_take_request(): - raise Exception('Can not request next item while we are still waiting response for previous request') - request = communication.messages.GetNextRequest() - self.request_queue.put(request) - self.request_sent(request) - - def get_response_reset_iterator(self, block=False): - try: - response = self.response_queue.get(block=block) - except Exception as e: # TODO: Catch only timeout exceptions - raise EmptyQueue('queue is empty') - self.request_served(response) - - if not isinstance(response, communication.messages.ResetIteratorResponse): - raise Exception('Invalid response received') - - def get_response_next(self, block=False, timeout=None): - if not self.waiting_for_response(): - raise Exception( - 'Can not expect any response without submitted request') - try: - response = self.response_queue.get(block=block, timeout=timeout) - except Exception as e: # TODO: Catch only timeout exceptions - raise EmptyQueue('queue is empty') - self.request_served(response) - - # TODO(VitalyFedyunin): Add possible response types validation here - return response diff --git a/torch/utils/data/communication/queue.py b/torch/utils/data/communication/queue.py deleted file mode 100644 index 85c33d4799cd..000000000000 --- a/torch/utils/data/communication/queue.py +++ /dev/null @@ -1,51 +0,0 @@ -import threading -import time - - -class LocalQueue(): - ops = 0 - stored = 0 - uid = 0 - empty = 0 - - def __init__(self, name='unnamed'): - self.items = [] - self.name = name - self.uid = LocalQueue.uid - LocalQueue.uid += 1 - - def put(self, item, block=True): - LocalQueue.ops += 1 - LocalQueue.stored += 1 - self.items.append(item) - - def get(self, block=True, timeout=0): - # TODO(VitalyFedyunin): Add support of block and timeout arguments - LocalQueue.ops += 1 - if not len(self.items): - LocalQueue.empty += 1 - raise Exception('LocalQueue is empty') - LocalQueue.stored -= 1 - return self.items.pop() - - -class ThreadingQueue(): - def __init__(self, name='unnamed'): - self.lock = threading.Lock() - self.items = [] - self.name = name - - def put(self, item, block=True): - with self.lock: - self.items.append(item) - - def get(self, block=True, timeout=0): - # TODO(VitalyFedyunin): Add support of block and timeout arguments - while True: - with self.lock: - if len(self.items) > 0: - return self.items.pop() - if not block: - raise Exception("Not available") - # TODO(VitalyFedyunin): Figure out what to do if nothing in the queue - time.sleep(0.000001) diff --git a/torch/utils/data/dataloader_experimental.py b/torch/utils/data/dataloader_experimental.py deleted file mode 100644 index 8a8d536b7985..000000000000 --- a/torch/utils/data/dataloader_experimental.py +++ /dev/null @@ -1,150 +0,0 @@ -import time - -from typing import Any, List - -import torch.utils.data.backward_compatibility - -import torch.utils.data.graph_settings -from torch.utils.data import DataLoader, IterDataPipe, communication -from torch.utils.data.datapipes.iter import IterableWrapper - -__all__ = [ - "DataLoader2", -] - - -class _ThreadingDataLoader2: - - def __init__(self, datapipe, num_workers=0, collate_fn=None): - self.threads = [] - self.datapipes = [] - self.collate_fn = collate_fn - for worker_id in range(num_workers): - (thread, req_queue, res_queue, thread_localdatapipe) = communication.eventloop.SpawnThreadForDataPipeline(datapipe) - torch.utils.data.graph_settings.apply_sharding(thread_localdatapipe, num_workers, worker_id) - thread.start() - self.threads.append((thread, req_queue, res_queue)) # These queues are independent - local_datapipe = communication.iter.QueueWrapper( - communication.protocol.IterDataPipeQueueProtocolClient(req_queue, res_queue)) - self.datapipes.append(local_datapipe) - - def __iter__(self): - not_available = False - forever = True - exclude_datapipes: List[Any] = [] - while len(exclude_datapipes) < len(self.datapipes): - for dp in self.datapipes: - if dp not in exclude_datapipes: - try: - value = dp.nonblocking_next() - yield value - except StopIteration: - exclude_datapipes.append(dp) - except communication.iter.NotAvailable: - not_available = True - if not_available: - time.sleep(0.001) - - def __del__(self): - self._cleanup_all_threads() - - def _cleanup_all_threads(self): - def clean_me(thread, req_queue, res_queue): - req_queue.put(communication.messages.TerminateRequest()) - _ = res_queue.get() - thread.join() - - for thread, req_queue, res_queue in self.threads: - clean_me(thread, req_queue, res_queue) - -class DataLoader2: - def __new__(cls, - dataset, - batch_size=1, - shuffle=None, - sampler=None, - batch_sampler=None, - num_workers=0, - collate_fn=None, - pin_memory=False, - drop_last=False, - timeout=0, - worker_init_fn=None, - *, - prefetch_factor=2, - persistent_workers=False, - batch_outside_worker=False, - parallelism_mode='mp'): - if isinstance(dataset, IterDataPipe): - data_loader: Any = None - if batch_sampler is not None: - raise Exception( - 'batch_sampler is not yet supported by DataPipes') - if sampler is not None: - raise Exception( - 'sampler is not yet supported by DataPipes') - datapipe = dataset - datapipe = torch.utils.data.graph_settings.apply_shuffle_settings(datapipe, shuffle=shuffle) # type: ignore[assignment] - if batch_outside_worker and pin_memory: - raise Exception( - 'pin_memory is not yet compatible with batch_outside_worker') - if not batch_outside_worker: - if batch_size is not None: - datapipe = datapipe.batch(batch_size, drop_last=drop_last) - if collate_fn is None: - collate_fn = torch.utils.data._utils.collate.default_collate - - # Note: It is safe to pass shuffle=True to the old DataLoader, as shuffle does nothing - # for Iterable, but required to set Pipes correctly. - data_loader = DataLoader(datapipe, - batch_size=None, # Replaced by .batch DataPipe - shuffle=shuffle, - sampler=None, - batch_sampler=None, - num_workers=num_workers, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=False, # Replaced by .batch DataPipe - timeout=timeout, - worker_init_fn=worker_init_fn, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers) - elif parallelism_mode == 'thread': - if collate_fn is not None and not batch_outside_worker: - datapipe = datapipe.map(collate_fn) - if pin_memory: - raise Exception( - 'pin_memory is not yet supported by DataPipes with Threading') - if worker_init_fn is not None: - raise Exception( - 'worker_init_fn is not yet supported by DataPipes with Threading') - data_loader = _ThreadingDataLoader2(datapipe, - num_workers=num_workers, - collate_fn=collate_fn) - else: - raise Exception('Unsupported parallelism mode', parallelism_mode) - if not batch_outside_worker: - return data_loader - else: - if collate_fn is None: - collate_fn = torch.utils.data._utils.collate.default_collate - datapipe = IterableWrapper(data_loader).batch( - batch_size, drop_last=drop_last).map(collate_fn) - return datapipe - else: - if parallelism_mode == 'thread': - raise Exception( - 'thread parallelism mode is not supported for old DataSets') - return DataLoader(dataset, - batch_size=batch_size, - shuffle=shuffle, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=num_workers, - collate_fn=collate_fn, - pin_memory=pin_memory, - drop_last=drop_last, - timeout=timeout, - worker_init_fn=worker_init_fn, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers)