diff --git a/functorch/csrc/BatchedTensorImpl.cpp b/functorch/csrc/BatchedTensorImpl.cpp index 3b1ef0bb5..487df2900 100644 --- a/functorch/csrc/BatchedTensorImpl.cpp +++ b/functorch/csrc/BatchedTensorImpl.cpp @@ -28,7 +28,7 @@ BatchedTensorImpl::BatchedTensorImpl(Tensor value, int64_t bdim, int64_t level) TORCH_INTERNAL_ASSERT(false); TORCH_INTERNAL_ASSERT(value_.defined()); set_storage_access_should_throw(); - set_has_contiguity_policy(HasContiguityPolicy::CustomBehavior); + set_sizes_strides_policy(SizesStridesPolicy::CustomStrides); checkInvariants(); const auto public_dims = value_.dim() - 1; @@ -57,7 +57,7 @@ BatchedTensorImpl::BatchedTensorImpl(DispatchKeySet key_set, Tensor value, int64 { TORCH_INTERNAL_ASSERT(value_.defined()); set_storage_access_should_throw(); - set_has_contiguity_policy(HasContiguityPolicy::CustomBehavior); + set_sizes_strides_policy(SizesStridesPolicy::CustomStrides); checkInvariants(); refreshTensorMetadata(); } @@ -119,6 +119,13 @@ void BatchedTensorImpl::checkInvariants() const { } // The following are publically exposed as methods of Tensor + +IntArrayRef BatchedTensorImpl::strides_custom() const { + return strides_default(); +} + +// TODO: implement proper contiguity on batched tensor, then put +// sizes_strides_policy back to Default bool BatchedTensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const { TORCH_CHECK(memory_format == MemoryFormat::Contiguous, "NYI: querying is_contiguous inside of vmap for memory_format ", diff --git a/functorch/csrc/BatchedTensorImpl.h b/functorch/csrc/BatchedTensorImpl.h index c5c81e56f..45999e0db 100644 --- a/functorch/csrc/BatchedTensorImpl.h +++ b/functorch/csrc/BatchedTensorImpl.h @@ -66,6 +66,8 @@ struct BatchedTensorImpl : public c10::TensorImpl { // bt.actualDim(3) -> Error int64_t actualDim(int64_t dim, bool wrap_dim = true) const; + // We have to override this because we opted into CustomStrides + IntArrayRef strides_custom() const override; // Override a bunch of methods inherited from TensorImpl to return error messages. bool is_contiguous_custom(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const override; void set_size(int64_t dim, int64_t new_size) override;