Skip to content

C++ tensor multi-dim indexing: add index() and index_put_() overloads, simple indexing tests, merge with Python indexing path #32841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

yf225
Copy link
Contributor

@yf225 yf225 commented Jan 30, 2020

This PR adds the following items:

  • 1st item: ArrayRef<TensorIndex> and std::initializer_list<TensorIndex> overloads for Tensor::index and Tensor::index_put_, to be used specifically for multi-dim indexing purpose.

Design rationale:

  • C++ Tensor::index and Tensor::index_put_ are both existing tensor APIs, and they currently (before this PR) only accept a list of tensors (i.e. ArrayRef<Tensor>) as indices. If we change their signatures to also accept non-tensors as indices (i.e. ArrayRef<TensorIndex>, and TensorIndex is convertible from Tensor / Slice / None / Ellipsis), it would slow down the original code path (since now it has to go through more steps), which is undesirable.

    To get around this problem, the proposed solution is to keep the original ArrayRef<Tensor> overload, and add ArrayRef<TensorIndex> and std::initializer_list<TensorIndex> overloads to Tensor::index and Tensor::index_put_. This way, the original code path won’t be affected, and the tensor multi-dim indexing API is only used when the user explicitly pass an ArrayRef<TensorIndex> or a braced-init-list of TensorIndex-convertible types to Tensor::index and Tensor::index_put_ .

    Note that the above proposed solution would still affect perf for the user’s original Tensor::index or Tensor::index_put_ call sites that use a braced-init-list of tensors as input, e.g. tensor.index({...}) or tensor.index_put_({...}, value), since now such function calls would take the multi-dim indexing path instead of the original advanced indexing path. However, there are only two instances of this in our codebase (one in ATen cpp test, one in a C++ API nn init function), and they can be easily changed to explicitly use ArrayRef<Tensor> as input (I changed them in this PR). For external user’s code, since this is part of the C++ frontend which is still considered experimental, we will only talk about this change in the release note, and ask users to switch to using ArrayRef<Tensor> explicitly if they want to keep using the original advanced indexing code path.

  • 2nd item: Mechanisms for parsing ArrayRef<TensorIndex> indices and performing indexing operations (mirroring the functions in torch/csrc/autograd/python_variable_indexing.cpp).
  • 3rd item: Simple tests to demonstrate that the Tensor::index() and Tensor::index_put_() APIs work. I will add more tests after the first few PRs are reviewed.
  • 4th item: Merge Python/C++ indexing code paths, for code simplicity. I tested locally and found that there is no perf regression resulting from the merge. I will get more concrete numbers for common use cases when we settle on the overall design.

This PR supersedes #30425.

@kostmo
Copy link
Member

kostmo commented Jan 31, 2020

💊 CircleCI build failures summary and remediations

As of commit 8e71c6c:

  • 1/3 broken upstream at merge base 97da60d since Feb 25

    Please rebase on the viable/strict branch (expand for instructions)

    If your commit is newer than viable/strict, you can try basing on an older, stable commit:

    git fetch origin viable/strict
    git rebase --onto viable/strict $(git merge-base origin/master HEAD)
    

    If your commit is older than viable/strict:

    git fetch origin viable/strict
    git rebase viable/strict
    

    Check out the recency history of this "viable master" tracking branch.

  • 1/3 failures introduced in this PR

  • 1/3 recognized as flaky ❄️

    • Re-run these jobs?

Detailed failure analysis

One may explore the probable reasons each build failed interactively on the Dr. CI website.

🕵️ 1 new failure recognized by patterns

The following build failures do not appear to be due to upstream breakage:

See CircleCI build pytorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_test (1/1)

Step: "Test" (full log | pattern match details)

Feb 25 03:43:14 RuntimeError: test_quantization failed!
Feb 25 03:43:14 Ran 36 tests in 61.931s 
Feb 25 03:43:14  
Feb 25 03:43:14 FAILED (errors=1, skipped=1) 
Feb 25 03:43:14  
Feb 25 03:43:14 Generating XML reports... 
Feb 25 03:43:14 Traceback (most recent call last): 
Feb 25 03:43:14   File "test/run_test.py", line 486, in <module> 
Feb 25 03:43:14     main() 
Feb 25 03:43:14   File "test/run_test.py", line 479, in main 
Feb 25 03:43:14     raise RuntimeError(message) 
Feb 25 03:43:14 RuntimeError: test_quantization failed! 
Feb 25 03:43:15 + cleanup 
Feb 25 03:43:15 + retcode=1 
Feb 25 03:43:15 + set +x 
Feb 25 03:43:15 =================== sccache compilation log =================== 
Feb 25 03:43:15 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Feb 25 03:43:15 Compile requests                32 
Feb 25 03:43:15 Compile requests executed       11 
Feb 25 03:43:15 Cache hits                       1 
Feb 25 03:43:15 Cache misses                    10 
Feb 25 03:43:15 Cache timeouts                   0 

❄️ 1 failure recognized as flaky

The following build failures have been detected as flaky and may not be your fault:

See CircleCI build pytorch_xla_linux_xenial_py3_6_clang7_build (1/1)

Step: "Build" (full log | pattern match details) ❄️

Feb 25 02:02:03 gpg: no valid OpenPGP data found.
Feb 25 02:02:01 Hit:6 http://archive.ubuntu.com/ubuntu xenial-backports InRelease 
Feb 25 02:02:02 Reading package lists... 
Feb 25 02:02:02  
Feb 25 02:02:02 ## Confirming "xenial" is supported... 
Feb 25 02:02:02  
Feb 25 02:02:02 + curl -sLf -o /dev/null 'https://deb.nodesource.com/node_6.x/dists/xenial/Release' 
Feb 25 02:02:02  
Feb 25 02:02:02 ## Adding the NodeSource signing key to your keyring... 
Feb 25 02:02:02  
Feb 25 02:02:02 + curl -s https://deb.nodesource.com/gpgkey/nodesource.gpg.key | apt-key add - 
Feb 25 02:02:03 gpg: no valid OpenPGP data found. 
Feb 25 02:02:03 Error executing command, exiting 
Feb 25 02:02:03 + cleanup 
Feb 25 02:02:03 + retcode=1 
Feb 25 02:02:03 + set +x 
Feb 25 02:02:03 =================== sccache compilation log =================== 
Feb 25 02:02:03 =========== If your build fails, please take a look at the log above for possible reasons =========== 
Feb 25 02:02:03 Compile requests               2567 
Feb 25 02:02:03 Compile requests executed      2306 
Feb 25 02:02:03 Cache hits                     2244 
Feb 25 02:02:03 Cache misses                     50 

🚧 1 upstream failure recognized by patterns:

These builds matched patterns, but were probably caused by upstream breakages:


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 554 times.

@yf225 yf225 force-pushed the cpp_tensor_default_dtype_and_advanced_indexing_new_exp1 branch 10 times, most recently from 1bfe1e5 to 0d50419 Compare January 31, 2020 19:08
@yf225 yf225 requested review from ezyang and removed request for ebetica, apaszke and goldsborough January 31, 2020 19:45
@yf225 yf225 changed the title [WIP] Cpp tensor default dtype and advanced indexing new exp1 C++ tensor multi-dim indexing: add index() and index_put_() overloads, simple indexing tests, merge with Python indexing path Jan 31, 2020
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claim: C++ indexing API shouldn't handle tracing, this should be done entirely in the Python frontend. I'm open to argue about this, but I'm going to tentatively make this claim.

@yf225 yf225 force-pushed the cpp_tensor_default_dtype_and_advanced_indexing_new_exp1 branch 2 times, most recently from 8184c2a to 3f8cc18 Compare January 31, 2020 22:22
@yf225 yf225 requested a review from ezyang February 24, 2020 19:40
const at::Device& self_device,
const IntArrayRef& self_sizes) {
// TODO: implement negative step
TORCH_CHECK_VALUE(step > 0, "step must be greater than zero");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It kind of looks like this error test might be redundant with whatever error test self.slice does.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(No action needed on this PR, I assume this was faithfully translated from the original).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the idea is to give a better error message than what the underlying implementation would have given.

// over the shape of the `self` tensor, and we still want to record
// the slice.
int64_t length = (self_device == at::kCPU || self_device == at::kCUDA) ? self_sizes[dim] : self.size(dim);
if (!ensure_view && start == 0 && stop == length && step == 1 && !is_tracing) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function has an irritatingly large number of parameters, so I looked into what was goin on with ensure_view. What I think is going on with ensure_view, is that it's a sort of trick to reduce the amount of code we have to write at down stream sites. The underlying motivation for ensure_view is to let you avoid writing code like this:

  Variable sliced = applySlicing(self_, holder.get(), variableIndices);
  if (variableIndices.empty()) {
    if (sliced.is_same(self_)) {
      // ensure we return a shallow copy for things like x[...]
      sliced = at::alias(sliced);
    }
    return wrap(sliced);
  }

for getters, as getters are guaranteed to always return a new view, no matter what (for setters, you can avoid returning a new view.) So the shorter version of this code looks like:

 else if (PySlice_Check(index)) {
    return wrap(applySlice(self_, 0, index, true));
  }

Here, we passed in true to say, "Hey, you always have to return a view here", and so we'll end up with a slice view here. BTW, we can't (easily?) do this in the applySlicing case, because the actual call to applySlice is done inside some sort of loop, where we only need to make sure an alias gets created at the very end (e.g., it is wasteful to create fresh aliases for each index in x[:,:,:], we only need to do it at the end."

This original intent of the code got substantially garbled after your refactoring to pass in is_tracing as an argument, as we now have two boolean arguments which serve exactly the same purpose, but no clear semantic distinction between them.

I'm not exactly sure what to do here (perhaps the booleans should be combined?), but it's worth thinking about. I started looking into this because I saw is_tracing_and_1d_slice_or_Nd which is rubbing me the wrong way.

}

// This mirrors `THPVariable_getitem` in torch/csrc/autograd/python_variable_indexing.cpp
static inline Tensor get_item(const Tensor& self, const ArrayRef<TensorIndex>& indices, bool is_tracing_and_1d_slice_or_Nd = false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I audited every call site here and I didn't see any use of the 1d_slice_or_Nd sense of is_tracing_and_1d_slice_or_Nd

Copy link
Contributor

@ezyang ezyang Feb 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think that here, the tracing argument is totally unnecessary. Here's why; look at the use site:

        /*ensure_view=*/true,
        /*is_tracing=*/is_tracing_and_1d_slice_or_Nd,

When ensure_view is true, we NEVER directly return self. So the setting of is_tracing_and_1d_slice_or_Nd doesn't matter at all; you're always going to make a fresh object.

There is one more use site for multi-dimensional indexing, so I'm wrong here.

index.slice().stop(),
index.slice().step(),
/*ensure_view=*/false,
/*is_tracing=*/is_tracing,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other hand, this is_tracing matters! (Because it is setters that attempt to apply the optimization)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/* ensure_view */is_tracing

index.slice().stop(),
index.slice().step(),
/*ensure_view=*/true,
/*is_tracing=*/is_tracing_and_1d_slice_or_Nd,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/*ensure_view=*/ true

index.slice().stop(),
index.slice().step(),
/*ensure_view=*/false,
/*is_tracing=*/is_tracing_and_1d_slice_or_Nd,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/*ensure_view=*/ is_tracing

}
return THPVariable_Wrap(
at::indexing::get_item(
self_, {at::indexing::TensorIndex({start, stop, step})}, /*is_tracing_and_1d_slice_or_Nd=*/is_tracing));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Why is this the only one that gets is_tracing" (Ans: because that's what it's called) (Q: But why?!)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@yf225 yf225 added the module: cpp Related to C++ API label Feb 26, 2020
hczhu pushed a commit that referenced this pull request Feb 28, 2020
…, simple indexing tests, merge with Python indexing path (#32841)

Summary:
This PR adds the following items:
- **1st item**: `ArrayRef<TensorIndex>` and `std::initializer_list<TensorIndex>` overloads for `Tensor::index` and `Tensor::index_put_`, to be used specifically for multi-dim indexing purpose.

Design rationale:
* C++ `Tensor::index` and `Tensor::index_put_` are both existing tensor APIs, and they currently (before this PR) only accept a list of tensors (i.e. `ArrayRef<Tensor>`) as indices. If we change their signatures to also accept non-tensors as indices (i.e. `ArrayRef<TensorIndex>`, and `TensorIndex` is convertible from `Tensor` / `Slice` / `None` / `Ellipsis`), it would slow down the original code path (since now it has to go through more steps), which is undesirable.

    To get around this problem, the proposed solution is to keep the original `ArrayRef<Tensor>` overload, and add `ArrayRef<TensorIndex>` and `std::initializer_list<TensorIndex>` overloads to `Tensor::index` and `Tensor::index_put_`. This way, the original code path won’t be affected, and the tensor multi-dim indexing API is only used when the user explicitly pass an `ArrayRef<TensorIndex>` or a braced-init-list of `TensorIndex`-convertible types to `Tensor::index` and `Tensor::index_put_` .

    Note that the above proposed solution would still affect perf for the user’s original `Tensor::index` or `Tensor::index_put_` call sites that use a braced-init-list of tensors as input, e.g. `tensor.index({...})` or `tensor.index_put_({...}, value)`, since now such function calls would take the multi-dim indexing path instead of the original advanced indexing path. However, there are only two instances of this in our codebase (one in ATen cpp test, one in a C++ API nn init function), and they can be easily changed to explicitly use `ArrayRef<Tensor>` as input (I changed them in this PR). For external user’s code, since this is part of the C++ frontend which is still considered experimental, we will only talk about this change in the release note, and ask users to switch to using `ArrayRef<Tensor>` explicitly if they want to keep using the original advanced indexing code path.

- **2nd item**: Mechanisms for parsing `ArrayRef<TensorIndex>` indices and performing indexing operations (mirroring the functions in `torch/csrc/autograd/python_variable_indexing.cpp`).
- **3rd item**: Simple tests to demonstrate that the `Tensor::index()` and `Tensor::index_put_()` APIs work. I will add more tests after the first few PRs are reviewed.
- **4th item**: Merge Python/C++ indexing code paths, for code simplicity. I tested locally and found that there is no perf regression resulting from the merge. I will get more concrete numbers for common use cases when we settle on the overall design.

This PR supersedes #30425.
Pull Request resolved: #32841

Differential Revision: D19919692

Pulled By: yf225

fbshipit-source-id: 7467e64f97fc0e407624809dd183c95ea16b1482
ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
…, simple indexing tests, merge with Python indexing path (pytorch#32841)

Summary:
This PR adds the following items:
- **1st item**: `ArrayRef<TensorIndex>` and `std::initializer_list<TensorIndex>` overloads for `Tensor::index` and `Tensor::index_put_`, to be used specifically for multi-dim indexing purpose.

Design rationale:
* C++ `Tensor::index` and `Tensor::index_put_` are both existing tensor APIs, and they currently (before this PR) only accept a list of tensors (i.e. `ArrayRef<Tensor>`) as indices. If we change their signatures to also accept non-tensors as indices (i.e. `ArrayRef<TensorIndex>`, and `TensorIndex` is convertible from `Tensor` / `Slice` / `None` / `Ellipsis`), it would slow down the original code path (since now it has to go through more steps), which is undesirable.

    To get around this problem, the proposed solution is to keep the original `ArrayRef<Tensor>` overload, and add `ArrayRef<TensorIndex>` and `std::initializer_list<TensorIndex>` overloads to `Tensor::index` and `Tensor::index_put_`. This way, the original code path won’t be affected, and the tensor multi-dim indexing API is only used when the user explicitly pass an `ArrayRef<TensorIndex>` or a braced-init-list of `TensorIndex`-convertible types to `Tensor::index` and `Tensor::index_put_` .

    Note that the above proposed solution would still affect perf for the user’s original `Tensor::index` or `Tensor::index_put_` call sites that use a braced-init-list of tensors as input, e.g. `tensor.index({...})` or `tensor.index_put_({...}, value)`, since now such function calls would take the multi-dim indexing path instead of the original advanced indexing path. However, there are only two instances of this in our codebase (one in ATen cpp test, one in a C++ API nn init function), and they can be easily changed to explicitly use `ArrayRef<Tensor>` as input (I changed them in this PR). For external user’s code, since this is part of the C++ frontend which is still considered experimental, we will only talk about this change in the release note, and ask users to switch to using `ArrayRef<Tensor>` explicitly if they want to keep using the original advanced indexing code path.

- **2nd item**: Mechanisms for parsing `ArrayRef<TensorIndex>` indices and performing indexing operations (mirroring the functions in `torch/csrc/autograd/python_variable_indexing.cpp`).
- **3rd item**: Simple tests to demonstrate that the `Tensor::index()` and `Tensor::index_put_()` APIs work. I will add more tests after the first few PRs are reviewed.
- **4th item**: Merge Python/C++ indexing code paths, for code simplicity. I tested locally and found that there is no perf regression resulting from the merge. I will get more concrete numbers for common use cases when we settle on the overall design.

This PR supersedes pytorch#30425.
Pull Request resolved: pytorch#32841

Differential Revision: D19919692

Pulled By: yf225

fbshipit-source-id: 7467e64f97fc0e407624809dd183c95ea16b1482
@facebook-github-bot facebook-github-bot deleted the cpp_tensor_default_dtype_and_advanced_indexing_new_exp1 branch July 13, 2020 17:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cpp Related to C++ API
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants