From 96f0fa104ca5c13f75742c39c7861fff16172ab4 Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 1 Sep 2022 11:18:28 -0700 Subject: [PATCH] Fixing back invariant on offsets --- aten/src/ATen/NestedTensorImpl.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/aten/src/ATen/NestedTensorImpl.cpp b/aten/src/ATen/NestedTensorImpl.cpp index 4cfac41b6e1b..855a85f6ba0b 100644 --- a/aten/src/ATen/NestedTensorImpl.cpp +++ b/aten/src/ATen/NestedTensorImpl.cpp @@ -23,9 +23,8 @@ inline void validate_nested_tensor_metadata( TORCH_INTERNAL_ASSERT(nested_sizes.sizes() == nested_strides.sizes()); TORCH_INTERNAL_ASSERT( (size_dim == 0 && (int64_t)offsets.empty()) || - (size_dim == 2 && nested_sizes.size(0) + 1 == (int64_t)offsets.size())); + (size_dim == 2 && nested_sizes.size(0) == (int64_t)offsets.size())); } - } // namespace namespace at { namespace native { @@ -93,9 +92,8 @@ inline at::Tensor construct_nested_stride_tensor(const at::Tensor& sizes) { * * This function iterates over the implicit ntensor outer dimension * populating a vector with the num_elements in each implicit tensor. - * The first element is always 0 and the length of the returned vector - * is n_tensor + 1. - * num_elements in ntensor[i] = offsets[i+1] - offsets[i] + * The first element is always 0 and the length of the returned vector + * is n_tensor. * * @return A vector of offsets */ @@ -105,7 +103,7 @@ inline std::vector construct_offsets(const at::Tensor& sizes) { return std::vector(); } int64_t ntensors = sizes.size(0), orig_dim = sizes.size(1); - std::vector offsets(ntensors + 1); + std::vector offsets(ntensors); // nesting scalars has easy offsets if (orig_dim == 0) { std::iota(offsets.begin(), offsets.end(), 0); @@ -113,7 +111,7 @@ inline std::vector construct_offsets(const at::Tensor& sizes) { } const int64_t* sizes_ptr = sizes.data_ptr(); offsets[0] = 0; - for (const auto i : c10::irange(ntensors)) { + for (const auto i : c10::irange(ntensors - 1)) { const int64_t row_product = std::accumulate(sizes_ptr, sizes_ptr + orig_dim, 1, std::multiplies()); offsets[i + 1] = offsets[i] + row_product; sizes_ptr += orig_dim;