diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 7a5104847ec3..1df54a668542 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -866,12 +866,23 @@ AtenXlaType::convolution_backward_overrideable( at::Tensor& AtenXlaType::copy_(at::Tensor& self, const at::Tensor& src, bool non_blocking) { - XLATensor self_tensor = bridge::GetXlaTensor(self); + c10::optional self_tensor = bridge::TryGetXlaTensor(self); c10::optional src_tensor = bridge::TryGetXlaTensor(src); - if (src_tensor) { - XLATensor::copy_(self_tensor, *src_tensor); + + if (!src_tensor) { + XLA_CHECK(self_tensor); + self_tensor->SetTensor(CopyTensor(src, self.scalar_type())); + } else if (!self_tensor) { + // TODO: Is self_tensor good enough? I don't think so... therefore + // the hack below: + std::vector tensors = {src}; + auto xla_tensors = bridge::XlaCreateTensorList(tensors); + // Hack in an overwrite of a const tensor. + at::Tensor t = CopyTensor(xla_tensors.front(), self.scalar_type()); + const_cast(self).unsafeGetTensorImpl()->shallow_copy_from( + t.getIntrusivePtr()); } else { - self_tensor.SetTensor(CopyTensor(src, self.scalar_type())); + XLATensor::copy_(*self_tensor, *src_tensor); } return self; }