diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 17aad41a03d0..e739f9ca6ecb 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -59,16 +59,23 @@ c10::intrusive_ptr XLATensorImpl::shallow_copy_and_detach( const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const { auto impl = c10::make_intrusive(tensor_); - impl->is_wrapped_number_ = is_wrapped_number_; - impl->reserved_ = reserved_; - impl->set_version_counter(version_counter); - impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); + copy_tensor_data( + /*src_impl=*/this, + /*dest_impl=*/impl.get(), + /*version_counter=*/version_counter, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); return impl; } void XLATensorImpl::shallow_copy_from( const c10::intrusive_ptr& impl) { XLATensorImpl* xla_impl = dynamic_cast(impl.get()); + copy_tensor_data( + /*src_impl=*/xla_impl, + /*dest_impl=*/this, + /*version_counter=*/version_counter(), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); + tensor_ = XLATensor::clone(xla_impl->tensor_); generation_ = 0; }