Skip to content

[MPS] Register index.Tensor_out #82507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed

Conversation

kulinseth
Copy link
Collaborator

@kulinseth kulinseth commented Jul 29, 2022

  • Add more tests from test_indexing into test_mps
  • Cache the indexing library on the MPSDevice

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 29, 2022

🔗 Helpful links

✅ No Failures (5 Pending)

As of commit 132ba61 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 1, 2022
Copy link
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of nits, but two major things:

  • _mtl_indexing_library must be set to nil in MTLDevice construtor
  • Please extract dispatch logic in a standalone PR

@kulinseth
Copy link
Collaborator Author

@philipturner
Copy link

philipturner commented Aug 4, 2022

I have a question about the design of this PR. Statically typing incurs greatly more compile-time overhead with minimal improvements to runtime execution speed. It can increase the number of shader objects by a factor of 10 and flood the GPU instruction cache. Have you tried a dynamically typed approach that requires only one Metal shader, and benchmarked performance?

pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2022
Implement bitwise operators as metal kernels
Dynamically compile metal library for a triplet of input and output tensor types.
Use `dispatchThreads:threadsPerThreadgroup:` to dispatch work (relies on the fact that MPS device is at least `MTLGPUFamilyMac2`, which will be explicitly checked in #82507

Perf improvements: Add support for non-contiguous tensors and broadcasting

Test Plan:
Already tested in  `test_mps.py`, for example by `TestConsistencyCPU.test_output_match_bitwise_xor_cpu_uint8`
Pull Request resolved: #82307
Approved by: https://github.com/albanD
facebook-github-bot pushed a commit that referenced this pull request Aug 9, 2022
Summary:
Implement bitwise operators as metal kernels
Dynamically compile metal library for a triplet of input and output tensor types.
Use `dispatchThreads:threadsPerThreadgroup:` to dispatch work (relies on the fact that MPS device is at least `MTLGPUFamilyMac2`, which will be explicitly checked in #82507

Perf improvements: Add support for non-contiguous tensors and broadcasting

Pull Request resolved: #82307
Approved by: https://github.com/albanD

Test Plan:
contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/0377615b6855c7e306669a16e74bbfc01ab86c1c

Test plan from GitHub:
Already tested in  `test_mps.py`, for example by `TestConsistencyCPU.test_output_match_bitwise_xor_cpu_uint8`

Reviewed By: kit1980

Differential Revision: D38505765

Pulled By: malfet

fbshipit-source-id: f1265a4d43f0a0b52af622838c8873b32dffcbfb
@kulinseth
Copy link
Collaborator Author

kulinseth commented Aug 15, 2022

@malfet , please take a look at the PR.

@philipturner
Copy link

constant const on these buffer bindings is redundant. The Metal compiler already describes constant as const constant in error messages. The semantic meaning of the constant address space is to be const device, but explicitly state that it will likely fall into the uniform registers. Please remove the const in buffer bindings.

@malfet
Copy link
Contributor

malfet commented Aug 16, 2022

@kulinseth at the very least this PR needs a rebase(as dispatch code has already been landed in #82612 )and fix for the linter, leaving a few more comments right now

Comment on lines 15 to 24
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
MTLLanguageVersion languageVersion;

#if defined(__MAC_13_0) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_13_0
languageVersion = MTLLanguageVersion3_0;
#elif defined(__MAC_12_0) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_12_0
languageVersion = MTLLanguageVersion2_4;
#elif
#error "Metal is not available on the current platform."
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code as its currently written will fail to compile on MacOS 14.0, when it will be eventually released.
And hardcoding version would prevent as well prevent us from encountering a runtime errors, if newer MTL Language Standards would make some incompatible changes.
(And applied compiler optimizations are independent of MTLLanguageVersion, are they?)

Suggested change
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
MTLLanguageVersion languageVersion;
#if defined(__MAC_13_0) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_13_0
languageVersion = MTLLanguageVersion3_0;
#elif defined(__MAC_12_0) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_12_0
languageVersion = MTLLanguageVersion2_4;
#elif
#error "Metal is not available on the current platform."
#endif
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
// host_name attribute needs at least Metal 2.2
MTLLanguageVersion languageVersion = MTLLanguageVersion2_2;

Copy link
Collaborator

@DenisVieriu97 DenisVieriu97 Aug 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code as its currently written will fail to compile on MacOS 14.0, when it will be eventually released.

This shouldn't fail to compile on newer macOS. MacOS Ventura is 13.0 and according to the macro logic, it would fall into this macro logic:

#if defined(__MAC_13_0) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_13_0
  languageVersion = MTLLanguageVersion3_0;

For higher numbers it would fall into the same macro logic (this __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_13_0 is checking for the number to be >= 13, not strictly equal to 13)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And hardcoding version would prevent as well prevent us from encountering a runtime errors, if newer MTL Language Standards would make some incompatible changes.

These are backwards compatible, and generally we don't even need to update the Language version unless we do have to use a Metal feature from that. And if we do need to use some new language feature, we can bump the logic here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoding to 2_2 as per suggestion.

~MPSDevice();

private:
static MPSDevice* _device;
MTLDevice_t _mtl_device;
MTLLibrary_t _mtl_indexing_library;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it need to be a part of MPSDevice rather than implementation detail in Indexing.mm?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cached on the MPSDevice itself the very first time index.Tensor_out is called - all the subsequent calls will use directly the cached version of the library

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet are you worried about adding overhead during device initialization ?. I was thinking it will be better to load the kernels during the device init time, rather than taking the hit when we need the indexing operation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, just from code clarify point of view - mtl_indexing_library is implementation detail of Indexing operator and should not leak into MPSDevice. We can implement a RAII mechanism of registering libraries with MPSDevice, but individual implementations IMO do not belong here.

Comment on lines +25 to +36
case ScalarType::Float:
res = "float"; break;
case ScalarType::Half:
res = "half"; break;
case ScalarType::Long:
res = "long"; break;
case ScalarType::Int:
res = "int"; break;
case ScalarType::Short:
res = "short"; break;
case ScalarType::Char:
res = "char"; break;
case ScalarType::Byte:
res = "uchar"; break;
case ScalarType::Bool:
res = "bool"; break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need one kernel per type size, isn't it?

Suggested change
case ScalarType::Float:
res = "float"; break;
case ScalarType::Half:
res = "half"; break;
case ScalarType::Long:
res = "long"; break;
case ScalarType::Int:
res = "int"; break;
case ScalarType::Short:
res = "short"; break;
case ScalarType::Char:
res = "char"; break;
case ScalarType::Byte:
res = "uchar"; break;
case ScalarType::Bool:
res = "bool"; break;
case ScalarType::Long:
res = "64bit"; break;
case ScalarType::Float:
case ScalarType::Int:
res = "32bit"; break;
case ScalarType::Half:
case ScalarType::Short:
res = "16bit"; break;
case ScalarType::Char:
case ScalarType::Byte:
case ScalarType::Bool:
res = "8bit"; break;

Comment on lines +42 to +114
template
[[host_name("index_select_float")]]
kernel void index_select<float>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_half")]]
kernel void index_select<half>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_long")]]
kernel void index_select<long>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_int")]]
kernel void index_select<int>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_short")]]
kernel void index_select<short>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_char")]]
kernel void index_select<char>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_uchar")]]
kernel void index_select<uchar>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);

template
[[host_name("index_select_bool")]]
kernel void index_select<bool>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One doesn't need template per datatype, but rather template per elementsize (unless there is a memcpy like function on metal, than templates aren't needed at all)

Suggested change
template
[[host_name("index_select_float")]]
kernel void index_select<float>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_half")]]
kernel void index_select<half>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_long")]]
kernel void index_select<long>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_int")]]
kernel void index_select<int>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_short")]]
kernel void index_select<short>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_char")]]
kernel void index_select<char>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_uchar")]]
kernel void index_select<uchar>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_bool")]]
kernel void index_select<bool>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_32bit")]]
kernel void index_select<int>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_16bit")]]
kernel void index_select<short>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_64bit")]]
kernel void index_select<long>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_8bit")]]
kernel void index_select<char>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);

Copy link
Collaborator

@DenisVieriu97 DenisVieriu97 Aug 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 64/32/8 bit templates make sense for indexing, I'll update that. Regarding the use of templates, we'd still need them - there is no memcpy in metal

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@malfet , can we follow this up in next PR, with index_put. We didn't want to club these two as it will become a huge PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, can you please file a followup issue?

@kulinseth
Copy link
Collaborator Author

@kulinseth at the very least this PR needs a rebase(as dispatch code has already been landed in #82612 )and fix for the linter, leaving a few more comments right now

Rebase and lint issues are fixed.

@kulinseth
Copy link
Collaborator Author

@malfet , unrelated to this change. But lintrunner -m master on MacOS has become un-usable. There are so many warnings:

         76  |  if (str == "torch.Tensor") {

  Error (CLANGTIDY) [cppcoreguidelines-init-variables,-warnings-as-errors]
    variable 'ret' is not initialized

        106  |}
        107  |
        108  |std::vector<std::pair<Backend, ScalarType>> all_declared_types() {
    >>> 109  |  std::vector<std::pair<Backend, ScalarType>> ret;
        110  |
        111  |  // NOTE: Do not add more types here. This list controls the creation
        112  |  // of legacy tensor types e.g. torch.cuda.FloatTensor which are

  Error (CLANGTIDY) [cppcoreguidelines-init-variables,-warnings-as-errors]
    variable 'backends' is not initialized

        111  |  // NOTE: Do not add more types here. This list controls the creation
        112  |  // of legacy tensor types e.g. torch.cuda.FloatTensor which are
        113  |  // maintained for backwards-compatibility only.
    >>> 114  |  std::vector<Backend> backends = {
        115  |      Backend::CPU, Backend::CUDA, Backend::SparseCPU, Backend::SparseCUDA};
        116  |  std::vector<ScalarType> scalar_types = {
        117  |      ScalarType::Byte,

  Error (CLANGTIDY) [cppcoreguidelines-init-variables,-warnings-as-errors]
    variable 'scalar_types' is not initialized

        113  |  // maintained for backwards-compatibility only.
        114  |  std::vector<Backend> backends = {
        115  |      Backend::CPU, Backend::CUDA, Backend::SparseCPU, Backend::SparseCUDA};
    >>> 116  |  std::vector<ScalarType> scalar_types = {
        117  |      ScalarType::Byte,
        118  |      ScalarType::Char,
        119  |      ScalarType::Double,



>>> Lint for torch/csrc/utils/throughput_benchmark.cpp:

  Error (CLANGTIDY) [bugprone-branch-clone,-warnings-as-errors]
    if with identical then and else branches

         16  |
         17  |void ThroughputBenchmark::addInput(py::args args, py::kwargs kwargs) {
         18  |  CHECK(script_module_.initialized() ^ module_.initialized());
    >>>  19  |  if (script_module_.initialized()) {
         20  |    script_module_.addInput(std::move(args), std::move(kwargs));
         21  |  } else {
         22  |    CHECK(module_.initialized());

  Error (CLANGTIDY) [cppcoreguidelines-pro-type-member-init,-warnings-as-errors]
    constructor does not initialize these fields: module_

         39  |  }
         40  |}
         41  |
    >>>  42  |ThroughputBenchmark::ThroughputBenchmark(jit::Module script_module)
         43  |    : script_module_(script_module) {}
         44  |
         45  |ThroughputBenchmark::ThroughputBenchmark(py::object module)

  Error (CLANGTIDY) [cppcoreguidelines-pro-type-member-init,-warnings-as-errors]
    constructor does not initialize these fields: module_

         42  |ThroughputBenchmark::ThroughputBenchmark(jit::Module script_module)
         43  |    : script_module_(script_module) {}
         44  |
    >>>  45  |ThroughputBenchmark::ThroughputBenchmark(py::object module)
         46  |    : module_(std::move(module)) {}
         47  |
         48  |BenchmarkExecutionStats ThroughputBenchmark::benchmark(



>>> Lint for torch/utils/jit/__init__.py:

  Warning (FLAKE8) W391
    blank line at end of file
    See https://www.flake8rules.com/rules/W391.html

    >>> 1  |



>>> Lint for torch/utils/data/dataloader.py:

  Error (MYPY) [attr-defined]
    Module has no attribute "sched_getaffinity"

         541  |        cpuset_checked = False
         542  |        if hasattr(os, 'sched_getaffinity'):
         543  |            try:
    >>>  544  |                max_num_worker_suggest = len(os.sched_getaffinity(0))
         545  |                cpuset_checked = True
         546  |            except Exception:
         547  |                pass

that we are unable to find the lint issues locally.

@kulinseth kulinseth added ciflow/trunk Trigger trunk jobs on your pull request ciflow/mps Run MPS tests (subset of trunk) labels Aug 16, 2022
~MPSDevice();

private:
static MPSDevice* _device;
MTLDevice_t _mtl_device;
MTLLibrary_t _mtl_indexing_library;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, just from code clarify point of view - mtl_indexing_library is implementation detail of Indexing operator and should not leak into MPSDevice. We can implement a RAII mechanism of registering libraries with MPSDevice, but individual implementations IMO do not belong here.

Comment on lines +42 to +114
template
[[host_name("index_select_float")]]
kernel void index_select<float>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_half")]]
kernel void index_select<half>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_long")]]
kernel void index_select<long>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_int")]]
kernel void index_select<int>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_short")]]
kernel void index_select<short>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_char")]]
kernel void index_select<char>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
template
[[host_name("index_select_uchar")]]
kernel void index_select<uchar>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);

template
[[host_name("index_select_bool")]]
kernel void index_select<bool>(constant const IndexAB & indexAB [[buffer(0)]],
constant const void * indexSizes [[buffer(1)]],
constant const void * indexStrides [[buffer(2)]],
constant const uint3 * offsets [[buffer(3)]],
constant const void * inputData [[buffer(4)]],
device void * outputData [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, can you please file a followup issue?

std::string res = "";
switch (scalar_type) {
case ScalarType::Float:
res = "float"; break;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit

Suggested change
res = "float"; break;
return "float";

namespace native {
namespace mps {

std::string getMetalScalarType(ScalarType scalar_type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (or return const std::string&) <- i.e. please return string literal type to avoid unnecessary copy

Suggested change
std::string getMetalScalarType(ScalarType scalar_type) {
const char* getMetalScalarType(ScalarType scalar_type) {

AT_ASSERT(num_indices == iter.ntensors() - 2);
const Tensor& inputTensor = iter.tensor(1);

TORCH_CHECK(c10::isIntegralType(inputTensor.scalar_type(), /*includesBool=*/true),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would return false for Float and half, which I believe is not an intended behavior.(Also, is it covered by tests right now? If not, then imo it should be)

Suggested change
TORCH_CHECK(c10::isIntegralType(inputTensor.scalar_type(), /*includesBool=*/true),
TORCH_CHECK(inputTensor.scalar_type() == ScalarType::Float ||
inputTensor.scalar_type() == ScalarType::Half ||
c10::isIntegralType(inputTensor.scalar_type(), /*includesBool=*/true),

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, float and half are covered in the test cases (latest commit should contain this change):

   TORCH_CHECK(c10::isIntegralType(inputTensor.scalar_type(), /*includesBool=*/true) ||
                inputTensor.scalar_type() == ScalarType::Float ||
                inputTensor.scalar_type() == ScalarType::Half,

@kulinseth
Copy link
Collaborator Author

======================================================================
[9989](https://github.com/pytorch/pytorch/runs/7888712813?check_suite_focus=true#step:10:9990)
ERROR [0.002s]: test_single_output (__main__.TestAOTAutograd)
[9990](https://github.com/pytorch/pytorch/runs/7888712813?check_suite_focus=true#step:10:9991)
----------------------------------------------------------------------
[9991](https://github.com/pytorch/pytorch/runs/7888712813?check_suite_focus=true#step:10:9992)
Traceback (most recent call last):
[9992](https://github.com/pytorch/pytorch/runs/7888712813?check_suite_focus=true#step:10:9993)
  File "/var/lib/jenkins/workspace/functorch/test/test_pythonkey.py", line 223, in test_single_output
[9993](https://github.com/pytorch/pytorch/runs/7888712813?check_suite_focus=true#step:10:9994)
    self.verify_aot_autograd(f, inp)
[9994](https://github.com/pytorch/pytorch/runs/7888712813?check_suite_focus=true#step:10:9995)
  File "/var/lib/jenkins/workspace/functorch/test/test_pythonkey.py", line 216, in verify_aot_autograd
[9995](https://github.com/pytorch/pytorch/runs/7888712813?check_suite_focus=true#step:10:9996)
    self.assertEqual(ref_out, test_out)
[9996](https://github.com/pytorch/pytorch/runs/7888712813?check_suite_focus=true#step:10:9997)
  File "/opt/conda/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2391, in assertEqual
[9997](https://github.com/pytorch/pytorch/runs/7888712813?check_suite_focus=true#step:10:9998)
    y = torch.as_tensor(y, dtype=x.dtype, device=x.device)
[9998](https://github.com/pytorch/pytorch/runs/7888712813?check_suite_focus=true#step:10:9999)
ValueError: only one element tensors can be converted to Python scalars

This error seems unrelated to this PR.

@kulinseth
Copy link
Collaborator Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased indexing onto refs/remotes/origin/master, please pull locally before adding more changes (for example, via git checkout indexing && git pull --rebase)

@kulinseth
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a merge job. Check the current status here.
The merge job was triggered without a flag. This means that your change will be merged once all checks on your PR have passed (ETA: 0-4 Hours). If this is not the intended behavior, feel free to use some of the other merge options in the wiki.
Please reach out to the PyTorch DevX Team with feedback or questions!

@github-actions
Copy link
Contributor

Hey @kulinseth.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

facebook-github-bot pushed a commit that referenced this pull request Aug 19, 2022
Summary:
* Add more tests from test_indexing into test_mps
* Cache the indexing library on the MPSDevice

Pull Request resolved: #82507
Approved by: https://github.com/malfet

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/ce7177f88a8c76351087bd06520681e60591ff50

Reviewed By: atalman

Differential Revision: D38830978

Pulled By: atalman

fbshipit-source-id: 69eb9e0a5779cf4d0b0d0c492e1ba210d9ae59bf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants