-
Notifications
You must be signed in to change notification settings - Fork 24.9k
[MPS] Register index.Tensor_out #82507
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (5 Pending)As of commit 132ba61 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lots of nits, but two major things:
_mtl_indexing_library
must be set tonil
in MTLDevice construtor- Please extract dispatch logic in a standalone PR
|
I have a question about the design of this PR. Statically typing incurs greatly more compile-time overhead with minimal improvements to runtime execution speed. It can increase the number of shader objects by a factor of 10 and flood the GPU instruction cache. Have you tried a dynamically typed approach that requires only one Metal shader, and benchmarked performance? |
Implement bitwise operators as metal kernels Dynamically compile metal library for a triplet of input and output tensor types. Use `dispatchThreads:threadsPerThreadgroup:` to dispatch work (relies on the fact that MPS device is at least `MTLGPUFamilyMac2`, which will be explicitly checked in #82507 Perf improvements: Add support for non-contiguous tensors and broadcasting Test Plan: Already tested in `test_mps.py`, for example by `TestConsistencyCPU.test_output_match_bitwise_xor_cpu_uint8` Pull Request resolved: #82307 Approved by: https://github.com/albanD
Summary: Implement bitwise operators as metal kernels Dynamically compile metal library for a triplet of input and output tensor types. Use `dispatchThreads:threadsPerThreadgroup:` to dispatch work (relies on the fact that MPS device is at least `MTLGPUFamilyMac2`, which will be explicitly checked in #82507 Perf improvements: Add support for non-contiguous tensors and broadcasting Pull Request resolved: #82307 Approved by: https://github.com/albanD Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/0377615b6855c7e306669a16e74bbfc01ab86c1c Test plan from GitHub: Already tested in `test_mps.py`, for example by `TestConsistencyCPU.test_output_match_bitwise_xor_cpu_uint8` Reviewed By: kit1980 Differential Revision: D38505765 Pulled By: malfet fbshipit-source-id: f1265a4d43f0a0b52af622838c8873b32dffcbfb
@malfet , please take a look at the PR. |
|
@kulinseth at the very least this PR needs a rebase(as dispatch code has already been landed in #82612 )and fix for the linter, leaving a few more comments right now |
aten/src/ATen/mps/MPSDevice.mm
Outdated
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants) | ||
MTLLanguageVersion languageVersion; | ||
|
||
#if defined(__MAC_13_0) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_13_0 | ||
languageVersion = MTLLanguageVersion3_0; | ||
#elif defined(__MAC_12_0) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_12_0 | ||
languageVersion = MTLLanguageVersion2_4; | ||
#elif | ||
#error "Metal is not available on the current platform." | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code as its currently written will fail to compile on MacOS 14.0, when it will be eventually released.
And hardcoding version would prevent as well prevent us from encountering a runtime errors, if newer MTL Language Standards would make some incompatible changes.
(And applied compiler optimizations are independent of MTLLanguageVersion, are they?)
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants) | |
MTLLanguageVersion languageVersion; | |
#if defined(__MAC_13_0) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_13_0 | |
languageVersion = MTLLanguageVersion3_0; | |
#elif defined(__MAC_12_0) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_12_0 | |
languageVersion = MTLLanguageVersion2_4; | |
#elif | |
#error "Metal is not available on the current platform." | |
#endif | |
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants) | |
// host_name attribute needs at least Metal 2.2 | |
MTLLanguageVersion languageVersion = MTLLanguageVersion2_2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code as its currently written will fail to compile on MacOS 14.0, when it will be eventually released.
This shouldn't fail to compile on newer macOS. MacOS Ventura is 13.0 and according to the macro logic, it would fall into this macro logic:
#if defined(__MAC_13_0) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_13_0
languageVersion = MTLLanguageVersion3_0;
For higher numbers it would fall into the same macro logic (this __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_13_0
is checking for the number to be >= 13, not strictly equal to 13)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And hardcoding version would prevent as well prevent us from encountering a runtime errors, if newer MTL Language Standards would make some incompatible changes.
These are backwards compatible, and generally we don't even need to update the Language version unless we do have to use a Metal feature from that. And if we do need to use some new language feature, we can bump the logic here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoding to 2_2 as per suggestion.
~MPSDevice(); | ||
|
||
private: | ||
static MPSDevice* _device; | ||
MTLDevice_t _mtl_device; | ||
MTLLibrary_t _mtl_indexing_library; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does it need to be a part of MPSDevice rather than implementation detail in Indexing.mm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is cached on the MPSDevice itself the very first time index.Tensor_out is called - all the subsequent calls will use directly the cached version of the library
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@malfet are you worried about adding overhead during device initialization ?. I was thinking it will be better to load the kernels during the device init time, rather than taking the hit when we need the indexing operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, just from code clarify point of view - mtl_indexing_library is implementation detail of Indexing operator and should not leak into MPSDevice. We can implement a RAII mechanism of registering libraries with MPSDevice, but individual implementations IMO do not belong here.
case ScalarType::Float: | ||
res = "float"; break; | ||
case ScalarType::Half: | ||
res = "half"; break; | ||
case ScalarType::Long: | ||
res = "long"; break; | ||
case ScalarType::Int: | ||
res = "int"; break; | ||
case ScalarType::Short: | ||
res = "short"; break; | ||
case ScalarType::Char: | ||
res = "char"; break; | ||
case ScalarType::Byte: | ||
res = "uchar"; break; | ||
case ScalarType::Bool: | ||
res = "bool"; break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need one kernel per type size, isn't it?
case ScalarType::Float: | |
res = "float"; break; | |
case ScalarType::Half: | |
res = "half"; break; | |
case ScalarType::Long: | |
res = "long"; break; | |
case ScalarType::Int: | |
res = "int"; break; | |
case ScalarType::Short: | |
res = "short"; break; | |
case ScalarType::Char: | |
res = "char"; break; | |
case ScalarType::Byte: | |
res = "uchar"; break; | |
case ScalarType::Bool: | |
res = "bool"; break; | |
case ScalarType::Long: | |
res = "64bit"; break; | |
case ScalarType::Float: | |
case ScalarType::Int: | |
res = "32bit"; break; | |
case ScalarType::Half: | |
case ScalarType::Short: | |
res = "16bit"; break; | |
case ScalarType::Char: | |
case ScalarType::Byte: | |
case ScalarType::Bool: | |
res = "8bit"; break; |
template | ||
[[host_name("index_select_float")]] | ||
kernel void index_select<float>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
template | ||
[[host_name("index_select_half")]] | ||
kernel void index_select<half>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
template | ||
[[host_name("index_select_long")]] | ||
kernel void index_select<long>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
template | ||
[[host_name("index_select_int")]] | ||
kernel void index_select<int>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
template | ||
[[host_name("index_select_short")]] | ||
kernel void index_select<short>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
template | ||
[[host_name("index_select_char")]] | ||
kernel void index_select<char>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
template | ||
[[host_name("index_select_uchar")]] | ||
kernel void index_select<uchar>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
|
||
template | ||
[[host_name("index_select_bool")]] | ||
kernel void index_select<bool>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One doesn't need template per datatype, but rather template per elementsize (unless there is a memcpy like function on metal, than templates aren't needed at all)
template | |
[[host_name("index_select_float")]] | |
kernel void index_select<float>(constant const IndexAB & indexAB [[buffer(0)]], | |
constant const void * indexSizes [[buffer(1)]], | |
constant const void * indexStrides [[buffer(2)]], | |
constant const uint3 * offsets [[buffer(3)]], | |
constant const void * inputData [[buffer(4)]], | |
device void * outputData [[buffer(5)]], | |
uint thread_index [[thread_position_in_grid]]); | |
template | |
[[host_name("index_select_half")]] | |
kernel void index_select<half>(constant const IndexAB & indexAB [[buffer(0)]], | |
constant const void * indexSizes [[buffer(1)]], | |
constant const void * indexStrides [[buffer(2)]], | |
constant const uint3 * offsets [[buffer(3)]], | |
constant const void * inputData [[buffer(4)]], | |
device void * outputData [[buffer(5)]], | |
uint thread_index [[thread_position_in_grid]]); | |
template | |
[[host_name("index_select_long")]] | |
kernel void index_select<long>(constant const IndexAB & indexAB [[buffer(0)]], | |
constant const void * indexSizes [[buffer(1)]], | |
constant const void * indexStrides [[buffer(2)]], | |
constant const uint3 * offsets [[buffer(3)]], | |
constant const void * inputData [[buffer(4)]], | |
device void * outputData [[buffer(5)]], | |
uint thread_index [[thread_position_in_grid]]); | |
template | |
[[host_name("index_select_int")]] | |
kernel void index_select<int>(constant const IndexAB & indexAB [[buffer(0)]], | |
constant const void * indexSizes [[buffer(1)]], | |
constant const void * indexStrides [[buffer(2)]], | |
constant const uint3 * offsets [[buffer(3)]], | |
constant const void * inputData [[buffer(4)]], | |
device void * outputData [[buffer(5)]], | |
uint thread_index [[thread_position_in_grid]]); | |
template | |
[[host_name("index_select_short")]] | |
kernel void index_select<short>(constant const IndexAB & indexAB [[buffer(0)]], | |
constant const void * indexSizes [[buffer(1)]], | |
constant const void * indexStrides [[buffer(2)]], | |
constant const uint3 * offsets [[buffer(3)]], | |
constant const void * inputData [[buffer(4)]], | |
device void * outputData [[buffer(5)]], | |
uint thread_index [[thread_position_in_grid]]); | |
template | |
[[host_name("index_select_char")]] | |
kernel void index_select<char>(constant const IndexAB & indexAB [[buffer(0)]], | |
constant const void * indexSizes [[buffer(1)]], | |
constant const void * indexStrides [[buffer(2)]], | |
constant const uint3 * offsets [[buffer(3)]], | |
constant const void * inputData [[buffer(4)]], | |
device void * outputData [[buffer(5)]], | |
uint thread_index [[thread_position_in_grid]]); | |
template | |
[[host_name("index_select_uchar")]] | |
kernel void index_select<uchar>(constant const IndexAB & indexAB [[buffer(0)]], | |
constant const void * indexSizes [[buffer(1)]], | |
constant const void * indexStrides [[buffer(2)]], | |
constant const uint3 * offsets [[buffer(3)]], | |
constant const void * inputData [[buffer(4)]], | |
device void * outputData [[buffer(5)]], | |
uint thread_index [[thread_position_in_grid]]); | |
template | |
[[host_name("index_select_bool")]] | |
kernel void index_select<bool>(constant const IndexAB & indexAB [[buffer(0)]], | |
constant const void * indexSizes [[buffer(1)]], | |
constant const void * indexStrides [[buffer(2)]], | |
constant const uint3 * offsets [[buffer(3)]], | |
constant const void * inputData [[buffer(4)]], | |
device void * outputData [[buffer(5)]], | |
uint thread_index [[thread_position_in_grid]]); | |
template | |
[[host_name("index_select_32bit")]] | |
kernel void index_select<int>(constant const IndexAB & indexAB [[buffer(0)]], | |
constant const void * indexSizes [[buffer(1)]], | |
constant const void * indexStrides [[buffer(2)]], | |
constant const uint3 * offsets [[buffer(3)]], | |
constant const void * inputData [[buffer(4)]], | |
device void * outputData [[buffer(5)]], | |
uint thread_index [[thread_position_in_grid]]); | |
template | |
[[host_name("index_select_16bit")]] | |
kernel void index_select<short>(constant const IndexAB & indexAB [[buffer(0)]], | |
constant const void * indexSizes [[buffer(1)]], | |
constant const void * indexStrides [[buffer(2)]], | |
constant const uint3 * offsets [[buffer(3)]], | |
constant const void * inputData [[buffer(4)]], | |
device void * outputData [[buffer(5)]], | |
uint thread_index [[thread_position_in_grid]]); | |
template | |
[[host_name("index_select_64bit")]] | |
kernel void index_select<long>(constant const IndexAB & indexAB [[buffer(0)]], | |
constant const void * indexSizes [[buffer(1)]], | |
constant const void * indexStrides [[buffer(2)]], | |
constant const uint3 * offsets [[buffer(3)]], | |
constant const void * inputData [[buffer(4)]], | |
device void * outputData [[buffer(5)]], | |
uint thread_index [[thread_position_in_grid]]); | |
template | |
[[host_name("index_select_8bit")]] | |
kernel void index_select<char>(constant const IndexAB & indexAB [[buffer(0)]], | |
constant const void * indexSizes [[buffer(1)]], | |
constant const void * indexStrides [[buffer(2)]], | |
constant const uint3 * offsets [[buffer(3)]], | |
constant const void * inputData [[buffer(4)]], | |
device void * outputData [[buffer(5)]], | |
uint thread_index [[thread_position_in_grid]]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 64/32/8 bit templates make sense for indexing, I'll update that. Regarding the use of templates, we'd still need them - there is no memcpy in metal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@malfet , can we follow this up in next PR, with index_put
. We didn't want to club these two as it will become a huge PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, can you please file a followup issue?
Rebase and lint issues are fixed. |
@malfet , unrelated to this change. But
that we are unable to find the lint issues locally. |
~MPSDevice(); | ||
|
||
private: | ||
static MPSDevice* _device; | ||
MTLDevice_t _mtl_device; | ||
MTLLibrary_t _mtl_indexing_library; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, just from code clarify point of view - mtl_indexing_library is implementation detail of Indexing operator and should not leak into MPSDevice. We can implement a RAII mechanism of registering libraries with MPSDevice, but individual implementations IMO do not belong here.
template | ||
[[host_name("index_select_float")]] | ||
kernel void index_select<float>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
template | ||
[[host_name("index_select_half")]] | ||
kernel void index_select<half>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
template | ||
[[host_name("index_select_long")]] | ||
kernel void index_select<long>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
template | ||
[[host_name("index_select_int")]] | ||
kernel void index_select<int>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
template | ||
[[host_name("index_select_short")]] | ||
kernel void index_select<short>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
template | ||
[[host_name("index_select_char")]] | ||
kernel void index_select<char>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
template | ||
[[host_name("index_select_uchar")]] | ||
kernel void index_select<uchar>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); | ||
|
||
template | ||
[[host_name("index_select_bool")]] | ||
kernel void index_select<bool>(constant const IndexAB & indexAB [[buffer(0)]], | ||
constant const void * indexSizes [[buffer(1)]], | ||
constant const void * indexStrides [[buffer(2)]], | ||
constant const uint3 * offsets [[buffer(3)]], | ||
constant const void * inputData [[buffer(4)]], | ||
device void * outputData [[buffer(5)]], | ||
uint thread_index [[thread_position_in_grid]]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, can you please file a followup issue?
std::string res = ""; | ||
switch (scalar_type) { | ||
case ScalarType::Float: | ||
res = "float"; break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit
res = "float"; break; | |
return "float"; |
namespace native { | ||
namespace mps { | ||
|
||
std::string getMetalScalarType(ScalarType scalar_type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit (or return const std::string&
) <- i.e. please return string literal type to avoid unnecessary copy
std::string getMetalScalarType(ScalarType scalar_type) { | |
const char* getMetalScalarType(ScalarType scalar_type) { |
AT_ASSERT(num_indices == iter.ntensors() - 2); | ||
const Tensor& inputTensor = iter.tensor(1); | ||
|
||
TORCH_CHECK(c10::isIntegralType(inputTensor.scalar_type(), /*includesBool=*/true), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would return false for Float
and half
, which I believe is not an intended behavior.(Also, is it covered by tests right now? If not, then imo it should be)
TORCH_CHECK(c10::isIntegralType(inputTensor.scalar_type(), /*includesBool=*/true), | |
TORCH_CHECK(inputTensor.scalar_type() == ScalarType::Float || | |
inputTensor.scalar_type() == ScalarType::Half || | |
c10::isIntegralType(inputTensor.scalar_type(), /*includesBool=*/true), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, float and half are covered in the test cases (latest commit should contain this change):
TORCH_CHECK(c10::isIntegralType(inputTensor.scalar_type(), /*includesBool=*/true) ||
inputTensor.scalar_type() == ScalarType::Float ||
inputTensor.scalar_type() == ScalarType::Half,
This error seems unrelated to this PR. |
@pytorchbot rebase |
@pytorchbot successfully started a rebase job. Check the current status here |
* Add more tests from test_indexing into test_mps * Cache the indexing library on the MPSDevice
Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
Successfully rebased |
16106d9
to
132ba61
Compare
@pytorchbot merge |
@pytorchbot successfully started a merge job. Check the current status here. |
Hey @kulinseth. |
Summary: * Add more tests from test_indexing into test_mps * Cache the indexing library on the MPSDevice Pull Request resolved: #82507 Approved by: https://github.com/malfet Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ce7177f88a8c76351087bd06520681e60591ff50 Reviewed By: atalman Differential Revision: D38830978 Pulled By: atalman fbshipit-source-id: 69eb9e0a5779cf4d0b0d0c492e1ba210d9ae59bf
Uh oh!
There was an error while loading. Please reload this page.