Skip to content
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ torch_xla/csrc/version.cpp
torch_xla/csrc/aten_xla_type.h
torch_xla/csrc/aten_xla_type_default.h
torch_xla/csrc/aten_xla_type_default.cpp
torch_xla/csrc/RegisterXLA.cpp
torch_xla/csrc/RegisterAutogradXLA.cpp

# BEGIN NOT-CLEAN-FILES (setup.py handles this marker. Do not change.)
#
Expand Down
1 change: 1 addition & 0 deletions torch_patches/.torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#57510
40 changes: 28 additions & 12 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,34 @@ at::Tensor AtenXlaType::_copy_from(const at::Tensor& self,
return dst;
}

at::Tensor AtenXlaType::_copy_from_and_resize(const at::Tensor& self,
const at::Tensor& dst) {
XLA_FN_COUNTER("xla::");
auto dst_tensor = bridge::TryGetXlaTensor(dst);
auto self_tensor = bridge::TryGetXlaTensor(self);
if (!self_tensor) {
XLA_CHECK(dst_tensor);
dst_tensor->UpdateFromTensorOut(self);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq, in here you chose to call UpdateFromTensorOut which also handle the view check and defualt sync-update to false. Is there a reaon to do this instead of dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update); in AtenXlaType::_copy_from?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that was actually my original motivation for writing this op. When I call dst_tensor->UpdateFromTensor(self, /*sync=*/sync_update); instead (like in _copy_from), xla complains if dst and self have different sizes. The next thing that I tried was to explicitly resize the dest tensor before the call, but then I hit #2881. I got the idea to call UpdateFromTensorOut from the bridge:: API that the codegen used previously, which is here. The only reason I didn't pass in sync_update was because UpdateFromTensorOut doesn't accept that arg :) but the original codegen also didn't use that argument, so I figured that this would be more in line with existing functionality.

Side note: implicitly resizing output tensors is actually currently allowed for in-tree kernels, but it's deprecated. So right now we want to allow that case, but eventually it'll probably go away. That's out of the scope of this function though, since that can be fixed through the codegen (we can just call _copy_from instead of _copy_from_and_resize).

Let me know if that all clears things up.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation!

} else if (!dst_tensor) {
at::Tensor tensor = self_tensor->ToTensor(/*detached=*/true);
at::Tensor typed_tensor =
CopyTensor(tensor, dst.scalar_type(), /*copy=*/false);
dst.resize_as_(typed_tensor).copy_(typed_tensor);
} else {
// at this point we know dst is an XLA tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help me understand why we know dst is an XLA tensor here? I guess I am trying to understand the use case of this function. What does checking self_tensor means? Are we checking if they are empty tensor or CPU tensor?

Copy link
Collaborator Author

@bdhirsh bdhirsh May 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this function to be similar to AtenXlaType::_copy_from, in that it handles all 3 cases:

  • self is on CPU and dst is on XLA
  • self is on XLA and dst is on CPU
  • they're both on XLA

The main difference is just that in the case that dst is on XLA, it handles resizing properly. I got the resizing logic mostly by looking at the function that the codegen used to call.

Could you help me understand why we know dst is an XLA tensor here?

In the } else { section here, we know that both bridge::TryGetXlaTensor(self) and bridge::TryGetXlaTensor(dst) returned valid XLATensor objects, which I took to mean that they both live on XLA. So this piece of the code handles the cases where both self and dst live on XLA, and I took the corresponding logic from here to figure out how to copy and resize the dst XLA tensor.

This piece of logic gets hit inside of the generated out wrappers. I pasted an example codegen'd out wrapper for add below:

at::Tensor & wrapper_out_add_out(const at::Tensor & self, const at::Tensor & other, co
  XLA_FN_TRACK(3);
  TF_VLOG(3) << "XLA wrapper_out_add_out :" << " self=" << self.toString() << " other=

  auto wrapper_out_add_out_tmp = wrapper_Tensor_add(self, other, alpha);
  at::_copy_from_and_resize(wrapper_out_add_out_tmp, out);
  return out;
}

XLATensorImpl* dest_impl =
dynamic_cast<XLATensorImpl*>(dst.unsafeGetTensorImpl());
dest_impl->tensor().UpdateFromTensorOut(*self_tensor);
dest_impl->force_refresh_sizes();
}
return dst;
}

std::vector<at::Tensor> AtenXlaType::_to_cpu(at::TensorList tensors) {
XLA_FN_COUNTER("xla::");
return bridge::XlaCreateTensorList(tensors);
}

at::Tensor& AtenXlaType::_index_put_impl_(
at::Tensor& self, const c10::List<c10::optional<at::Tensor>>& indices,
const at::Tensor& values, bool accumulate, bool /* unsafe */) {
Expand Down Expand Up @@ -489,18 +517,6 @@ at::Tensor AtenXlaType::add(const at::Tensor& self, const at::Scalar& other,
});
}

at::Tensor& AtenXlaType::add_(at::Tensor& self, const at::Tensor& other,
const at::Scalar& alpha) {
XLA_FN_COUNTER("xla::");
at::native::alpha_check(at::result_type(self, other), alpha);
CheckBinaryOpTypePromotion(self, self, other);
XLATensor self_tensor = bridge::GetXlaTensor(self);
XLATensor::add_(self_tensor,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I should have removed the code for XLATensor::add_ in this PR too. I'm planning a followup PR where I remove most other inplace ops, so I'll remove it there too.

bridge::GetOrCreateXlaTensor(other, self_tensor.GetDevice()),
alpha);
return self;
}

at::Tensor& AtenXlaType::add_(at::Tensor& self, const at::Scalar& other,
const at::Scalar& alpha) {
XLA_FN_COUNTER("xla::");
Expand Down
3 changes: 2 additions & 1 deletion xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ supported:
- acos
- acos_
- add.Tensor
- add_.Tensor
- add.Scalar
- add_.Scalar
- all.dim
Expand Down Expand Up @@ -52,6 +51,8 @@ supported:
- convolution_overrideable
- convolution_backward_overrideable
- _copy_from
- _copy_from_and_resize
- _to_cpu
- cos
- cos_
- cosh
Expand Down