diff --git a/test/cpp/test_aten_xla_tensor.cpp b/test/cpp/test_aten_xla_tensor.cpp index ac408caf13fe..0f1bda03a101 100644 --- a/test/cpp/test_aten_xla_tensor.cpp +++ b/test/cpp/test_aten_xla_tensor.cpp @@ -910,6 +910,16 @@ TEST_F(AtenXlaTensorTest, TestDim) { }); } +TEST_F(AtenXlaTensorTest, TestContiguous) { + at::Tensor input = GetTestTensor({2, 3}); + at::Tensor output = at::native::contiguous(input); + ForEachDevice([&](const Device& device) { + at::Tensor xla_input = bridge::CreateXlaTensor(input, device); + at::Tensor xla_output = at::native::contiguous(xla_input); + AllClose(output, xla_output); + }); +} + TEST_F(AtenXlaTensorTest, TestAvgPool2DBackward) { int kernel_size = 2; for (int stride = 1; stride <= 2; ++stride) { diff --git a/torch_xla/csrc/tensor_impl.cpp b/torch_xla/csrc/tensor_impl.cpp index 4f63cd9580e6..2ef6e8d420e1 100644 --- a/torch_xla/csrc/tensor_impl.cpp +++ b/torch_xla/csrc/tensor_impl.cpp @@ -74,6 +74,12 @@ c10::intrusive_ptr XLATensorImpl::shallow_copy_and_detach() return impl; } +bool XLATensorImpl::is_contiguous() const { + // Only check that the storage is already contiguous. + XLA_CHECK(is_contiguous_) << "Non-contiguous storage for XLA tensor"; + return true; +} + void XLATensorImpl::SetupSizeProperties() { // Fill up the basic dimension data members which the base class // implementation uses in its APIs. diff --git a/torch_xla/csrc/tensor_impl.h b/torch_xla/csrc/tensor_impl.h index 4e2dc1395277..119f83e20cba 100644 --- a/torch_xla/csrc/tensor_impl.h +++ b/torch_xla/csrc/tensor_impl.h @@ -20,6 +20,8 @@ class XLATensorImpl : public c10::TensorImpl { c10::intrusive_ptr shallow_copy_and_detach() const override; + bool is_contiguous() const override; + private: void SetupSizeProperties();