-
Notifications
You must be signed in to change notification settings - Fork 560
Use some new codegen functionality #2915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f4ef69a
6a77a1c
bbdaf92
615e7ae
8000596
c9f35fd
cbfd606
b950ca5
7d4c9aa
fe763fb
0a20d9a
1db52f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
#57510 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -359,6 +359,34 @@ at::Tensor AtenXlaType::_copy_from(const at::Tensor& self, | |
return dst; | ||
} | ||
|
||
at::Tensor AtenXlaType::_copy_from_and_resize(const at::Tensor& self, | ||
const at::Tensor& dst) { | ||
XLA_FN_COUNTER("xla::"); | ||
auto dst_tensor = bridge::TryGetXlaTensor(dst); | ||
auto self_tensor = bridge::TryGetXlaTensor(self); | ||
if (!self_tensor) { | ||
XLA_CHECK(dst_tensor); | ||
dst_tensor->UpdateFromTensorOut(self); | ||
} else if (!dst_tensor) { | ||
at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true); | ||
at::Tensor typed_tensor = | ||
CopyTensor(tensor, dst.scalar_type(), /*copy=*/false); | ||
dst.resize_as_(typed_tensor).copy_(typed_tensor); | ||
} else { | ||
// at this point we know dst is an XLA tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you help me understand why we know dst is an XLA tensor here? I guess I am trying to understand the use case of this function. What does checking There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrote this function to be similar to
The main difference is just that in the case that
In the This piece of logic gets hit inside of the generated
|
||
XLATensorImpl* dest_impl = | ||
dynamic_cast<XLATensorImpl*>(dst.unsafeGetTensorImpl()); | ||
dest_impl->tensor().UpdateFromTensorOut(*self_tensor); | ||
dest_impl->force_refresh_sizes(); | ||
} | ||
return dst; | ||
} | ||
|
||
std::vector<at::Tensor> AtenXlaType::_to_cpu(at::TensorList tensors) { | ||
XLA_FN_COUNTER("xla::"); | ||
return bridge::XlaCreateTensorList(tensors); | ||
} | ||
|
||
at::Tensor& AtenXlaType::_index_put_impl_( | ||
at::Tensor& self, const c10::List<c10::optional<at::Tensor>>& indices, | ||
const at::Tensor& values, bool accumulate, bool /* unsafe */) { | ||
|
@@ -489,18 +517,6 @@ at::Tensor AtenXlaType::add(const at::Tensor& self, const at::Scalar& other, | |
}); | ||
} | ||
|
||
at::Tensor& AtenXlaType::add_(at::Tensor& self, const at::Tensor& other, | ||
const at::Scalar& alpha) { | ||
XLA_FN_COUNTER("xla::"); | ||
at::native::alpha_check(at::result_type(self, other), alpha); | ||
CheckBinaryOpTypePromotion(self, self, other); | ||
XLATensor self_tensor = bridge::GetXlaTensor(self); | ||
XLATensor::add_(self_tensor, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually, I should have removed the code for |
||
bridge::GetOrCreateXlaTensor(other, self_tensor.GetDevice()), | ||
alpha); | ||
return self; | ||
} | ||
|
||
at::Tensor& AtenXlaType::add_(at::Tensor& self, const at::Scalar& other, | ||
const at::Scalar& alpha) { | ||
XLA_FN_COUNTER("xla::"); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq, in here you chose to call
UpdateFromTensorOut
which also handle the view check and defualtsync-update
to false. Is there a reaon to do this instead ofdst_tensor->UpdateFromTensor(self, /*sync=*/sync_update);
inAtenXlaType::_copy_from
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that was actually my original motivation for writing this op. When I call
dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update);
instead (like in _copy_from), xla complains ifdst
andself
have different sizes. The next thing that I tried was to explicitly resize the dest tensor before the call, but then I hit #2881. I got the idea to callUpdateFromTensorOut
from thebridge::
API that the codegen used previously, which is here. The only reason I didn't pass insync_update
was becauseUpdateFromTensorOut
doesn't accept that arg :) but the original codegen also didn't use that argument, so I figured that this would be more in line with existing functionality.Side note: implicitly resizing output tensors is actually currently allowed for in-tree kernels, but it's deprecated. So right now we want to allow that case, but eventually it'll probably go away. That's out of the scope of this function though, since that can be fixed through the codegen (we can just call _copy_from instead of _copy_from_and_resize).
Let me know if that all clears things up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation!