-
Notifications
You must be signed in to change notification settings - Fork 25k
C++ tensor multi-dim indexing: add index() and index_put_() overloads, simple indexing tests, merge with Python indexing path #32841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
💊 CircleCI build failures summary and remediationsAs of commit 8e71c6c:
Detailed failure analysisOne may explore the probable reasons each build failed interactively on the Dr. CI website. 🕵️ 1 new failure recognized by patternsThe following build failures do not appear to be due to upstream breakage:
|
1bfe1e5
to
0d50419
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claim: C++ indexing API shouldn't handle tracing, this should be done entirely in the Python frontend. I'm open to argue about this, but I'm going to tentatively make this claim.
8184c2a
to
3f8cc18
Compare
const at::Device& self_device, | ||
const IntArrayRef& self_sizes) { | ||
// TODO: implement negative step | ||
TORCH_CHECK_VALUE(step > 0, "step must be greater than zero"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It kind of looks like this error test might be redundant with whatever error test self.slice
does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(No action needed on this PR, I assume this was faithfully translated from the original).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the idea is to give a better error message than what the underlying implementation would have given.
// over the shape of the `self` tensor, and we still want to record | ||
// the slice. | ||
int64_t length = (self_device == at::kCPU || self_device == at::kCUDA) ? self_sizes[dim] : self.size(dim); | ||
if (!ensure_view && start == 0 && stop == length && step == 1 && !is_tracing) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function has an irritatingly large number of parameters, so I looked into what was goin on with ensure_view
. What I think is going on with ensure_view
, is that it's a sort of trick to reduce the amount of code we have to write at down stream sites. The underlying motivation for ensure_view
is to let you avoid writing code like this:
Variable sliced = applySlicing(self_, holder.get(), variableIndices);
if (variableIndices.empty()) {
if (sliced.is_same(self_)) {
// ensure we return a shallow copy for things like x[...]
sliced = at::alias(sliced);
}
return wrap(sliced);
}
for getters, as getters are guaranteed to always return a new view, no matter what (for setters, you can avoid returning a new view.) So the shorter version of this code looks like:
else if (PySlice_Check(index)) {
return wrap(applySlice(self_, 0, index, true));
}
Here, we passed in true
to say, "Hey, you always have to return a view here", and so we'll end up with a slice view here. BTW, we can't (easily?) do this in the applySlicing case, because the actual call to applySlice is done inside some sort of loop, where we only need to make sure an alias gets created at the very end (e.g., it is wasteful to create fresh aliases for each index in x[:,:,:]
, we only need to do it at the end."
This original intent of the code got substantially garbled after your refactoring to pass in is_tracing
as an argument, as we now have two boolean arguments which serve exactly the same purpose, but no clear semantic distinction between them.
I'm not exactly sure what to do here (perhaps the booleans should be combined?), but it's worth thinking about. I started looking into this because I saw is_tracing_and_1d_slice_or_Nd
which is rubbing me the wrong way.
} | ||
|
||
// This mirrors `THPVariable_getitem` in torch/csrc/autograd/python_variable_indexing.cpp | ||
static inline Tensor get_item(const Tensor& self, const ArrayRef<TensorIndex>& indices, bool is_tracing_and_1d_slice_or_Nd = false) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I audited every call site here and I didn't see any use of the 1d_slice_or_Nd
sense of is_tracing_and_1d_slice_or_Nd
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I think that here, the tracing argument is totally unnecessary. Here's why; look at the use site:
/*ensure_view=*/true,
/*is_tracing=*/is_tracing_and_1d_slice_or_Nd,
When ensure_view
is true, we NEVER directly return self. So the setting of is_tracing_and_1d_slice_or_Nd
doesn't matter at all; you're always going to make a fresh object.
There is one more use site for multi-dimensional indexing, so I'm wrong here.
index.slice().stop(), | ||
index.slice().step(), | ||
/*ensure_view=*/false, | ||
/*is_tracing=*/is_tracing, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the other hand, this is_tracing
matters! (Because it is setters that attempt to apply the optimization)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/* ensure_view */is_tracing
index.slice().stop(), | ||
index.slice().step(), | ||
/*ensure_view=*/true, | ||
/*is_tracing=*/is_tracing_and_1d_slice_or_Nd, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/*ensure_view=*/ true
index.slice().stop(), | ||
index.slice().step(), | ||
/*ensure_view=*/false, | ||
/*is_tracing=*/is_tracing_and_1d_slice_or_Nd, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/*ensure_view=*/ is_tracing
} | ||
return THPVariable_Wrap( | ||
at::indexing::get_item( | ||
self_, {at::indexing::TensorIndex({start, stop, step})}, /*is_tracing_and_1d_slice_or_Nd=*/is_tracing)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Why is this the only one that gets is_tracing
" (Ans: because that's what it's called) (Q: But why?!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yf225 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
…, simple indexing tests, merge with Python indexing path (#32841) Summary: This PR adds the following items: - **1st item**: `ArrayRef<TensorIndex>` and `std::initializer_list<TensorIndex>` overloads for `Tensor::index` and `Tensor::index_put_`, to be used specifically for multi-dim indexing purpose. Design rationale: * C++ `Tensor::index` and `Tensor::index_put_` are both existing tensor APIs, and they currently (before this PR) only accept a list of tensors (i.e. `ArrayRef<Tensor>`) as indices. If we change their signatures to also accept non-tensors as indices (i.e. `ArrayRef<TensorIndex>`, and `TensorIndex` is convertible from `Tensor` / `Slice` / `None` / `Ellipsis`), it would slow down the original code path (since now it has to go through more steps), which is undesirable. To get around this problem, the proposed solution is to keep the original `ArrayRef<Tensor>` overload, and add `ArrayRef<TensorIndex>` and `std::initializer_list<TensorIndex>` overloads to `Tensor::index` and `Tensor::index_put_`. This way, the original code path won’t be affected, and the tensor multi-dim indexing API is only used when the user explicitly pass an `ArrayRef<TensorIndex>` or a braced-init-list of `TensorIndex`-convertible types to `Tensor::index` and `Tensor::index_put_` . Note that the above proposed solution would still affect perf for the user’s original `Tensor::index` or `Tensor::index_put_` call sites that use a braced-init-list of tensors as input, e.g. `tensor.index({...})` or `tensor.index_put_({...}, value)`, since now such function calls would take the multi-dim indexing path instead of the original advanced indexing path. However, there are only two instances of this in our codebase (one in ATen cpp test, one in a C++ API nn init function), and they can be easily changed to explicitly use `ArrayRef<Tensor>` as input (I changed them in this PR). For external user’s code, since this is part of the C++ frontend which is still considered experimental, we will only talk about this change in the release note, and ask users to switch to using `ArrayRef<Tensor>` explicitly if they want to keep using the original advanced indexing code path. - **2nd item**: Mechanisms for parsing `ArrayRef<TensorIndex>` indices and performing indexing operations (mirroring the functions in `torch/csrc/autograd/python_variable_indexing.cpp`). - **3rd item**: Simple tests to demonstrate that the `Tensor::index()` and `Tensor::index_put_()` APIs work. I will add more tests after the first few PRs are reviewed. - **4th item**: Merge Python/C++ indexing code paths, for code simplicity. I tested locally and found that there is no perf regression resulting from the merge. I will get more concrete numbers for common use cases when we settle on the overall design. This PR supersedes #30425. Pull Request resolved: #32841 Differential Revision: D19919692 Pulled By: yf225 fbshipit-source-id: 7467e64f97fc0e407624809dd183c95ea16b1482
…, simple indexing tests, merge with Python indexing path (pytorch#32841) Summary: This PR adds the following items: - **1st item**: `ArrayRef<TensorIndex>` and `std::initializer_list<TensorIndex>` overloads for `Tensor::index` and `Tensor::index_put_`, to be used specifically for multi-dim indexing purpose. Design rationale: * C++ `Tensor::index` and `Tensor::index_put_` are both existing tensor APIs, and they currently (before this PR) only accept a list of tensors (i.e. `ArrayRef<Tensor>`) as indices. If we change their signatures to also accept non-tensors as indices (i.e. `ArrayRef<TensorIndex>`, and `TensorIndex` is convertible from `Tensor` / `Slice` / `None` / `Ellipsis`), it would slow down the original code path (since now it has to go through more steps), which is undesirable. To get around this problem, the proposed solution is to keep the original `ArrayRef<Tensor>` overload, and add `ArrayRef<TensorIndex>` and `std::initializer_list<TensorIndex>` overloads to `Tensor::index` and `Tensor::index_put_`. This way, the original code path won’t be affected, and the tensor multi-dim indexing API is only used when the user explicitly pass an `ArrayRef<TensorIndex>` or a braced-init-list of `TensorIndex`-convertible types to `Tensor::index` and `Tensor::index_put_` . Note that the above proposed solution would still affect perf for the user’s original `Tensor::index` or `Tensor::index_put_` call sites that use a braced-init-list of tensors as input, e.g. `tensor.index({...})` or `tensor.index_put_({...}, value)`, since now such function calls would take the multi-dim indexing path instead of the original advanced indexing path. However, there are only two instances of this in our codebase (one in ATen cpp test, one in a C++ API nn init function), and they can be easily changed to explicitly use `ArrayRef<Tensor>` as input (I changed them in this PR). For external user’s code, since this is part of the C++ frontend which is still considered experimental, we will only talk about this change in the release note, and ask users to switch to using `ArrayRef<Tensor>` explicitly if they want to keep using the original advanced indexing code path. - **2nd item**: Mechanisms for parsing `ArrayRef<TensorIndex>` indices and performing indexing operations (mirroring the functions in `torch/csrc/autograd/python_variable_indexing.cpp`). - **3rd item**: Simple tests to demonstrate that the `Tensor::index()` and `Tensor::index_put_()` APIs work. I will add more tests after the first few PRs are reviewed. - **4th item**: Merge Python/C++ indexing code paths, for code simplicity. I tested locally and found that there is no perf regression resulting from the merge. I will get more concrete numbers for common use cases when we settle on the overall design. This PR supersedes pytorch#30425. Pull Request resolved: pytorch#32841 Differential Revision: D19919692 Pulled By: yf225 fbshipit-source-id: 7467e64f97fc0e407624809dd183c95ea16b1482
This PR adds the following items:
ArrayRef<TensorIndex>
andstd::initializer_list<TensorIndex>
overloads forTensor::index
andTensor::index_put_
, to be used specifically for multi-dim indexing purpose.Design rationale:
C++
Tensor::index
andTensor::index_put_
are both existing tensor APIs, and they currently (before this PR) only accept a list of tensors (i.e.ArrayRef<Tensor>
) as indices. If we change their signatures to also accept non-tensors as indices (i.e.ArrayRef<TensorIndex>
, andTensorIndex
is convertible fromTensor
/Slice
/None
/Ellipsis
), it would slow down the original code path (since now it has to go through more steps), which is undesirable.To get around this problem, the proposed solution is to keep the original
ArrayRef<Tensor>
overload, and addArrayRef<TensorIndex>
andstd::initializer_list<TensorIndex>
overloads toTensor::index
andTensor::index_put_
. This way, the original code path won’t be affected, and the tensor multi-dim indexing API is only used when the user explicitly pass anArrayRef<TensorIndex>
or a braced-init-list ofTensorIndex
-convertible types toTensor::index
andTensor::index_put_
.Note that the above proposed solution would still affect perf for the user’s original
Tensor::index
orTensor::index_put_
call sites that use a braced-init-list of tensors as input, e.g.tensor.index({...})
ortensor.index_put_({...}, value)
, since now such function calls would take the multi-dim indexing path instead of the original advanced indexing path. However, there are only two instances of this in our codebase (one in ATen cpp test, one in a C++ API nn init function), and they can be easily changed to explicitly useArrayRef<Tensor>
as input (I changed them in this PR). For external user’s code, since this is part of the C++ frontend which is still considered experimental, we will only talk about this change in the release note, and ask users to switch to usingArrayRef<Tensor>
explicitly if they want to keep using the original advanced indexing code path.ArrayRef<TensorIndex>
indices and performing indexing operations (mirroring the functions intorch/csrc/autograd/python_variable_indexing.cpp
).Tensor::index()
andTensor::index_put_()
APIs work. I will add more tests after the first few PRs are reviewed.This PR supersedes #30425.