Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[vmap] symintify alias and squeeze #107577

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions aten/src/ATen/functorch/BatchRulesViews.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ std::tuple<Tensor, optional<int64_t>> squeeze_batch_rule(const Tensor& self, opt
// Manually calculate the output shape by eliding all dimensions of
// size 1 keeping track of where the batch index started and where it
// ended up moving to. We also ensure we do not drop the batch index.
auto shape = self.sizes();
DimVector squeezed_sizes;
auto shape = self.sym_sizes();
SymDimVector squeezed_sizes;
bool before_batch_idx = true;
int64_t new_batch_idx = 0;
int64_t original_idx = 0;
Expand All @@ -219,7 +219,7 @@ std::tuple<Tensor, optional<int64_t>> squeeze_batch_rule(const Tensor& self, opt
++original_idx;
}

auto result = self.view(squeezed_sizes);
auto result = self.view_symint(squeezed_sizes);
return std::make_tuple(std::move(result), c10::optional<int64_t>(new_batch_idx));
}

Expand Down Expand Up @@ -453,7 +453,7 @@ std::tuple<Tensor, optional<int64_t>> expand_batch_rule(
"must be greater or equal to the number of dimensions in the tensor (", static_cast<uint64_t>(self_dim - 1), ")");

auto self_ = moveBatchDimToFront(self, self_bdim);
auto self_sizes = self_.sizes();
auto self_sizes = self_.sym_sizes();
auto batch_size = self_sizes[0];

c10::SmallVector<c10::SymInt> size_(size.size() + 1);
Expand Down
25 changes: 24 additions & 1 deletion aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1602,6 +1602,29 @@ Tensor alias_with_sizes_and_strides(
return self_;
}

// specialization for symbolic shapes and strides.
// SymIntArrayRef/ArrayRef<c10::SymInt> and SmallVector<c10::SymInt>/SymDimVector
template <template <typename...> typename Container>
Tensor alias_with_sizes_and_strides(
const Tensor& self,
const Container<c10::SymInt>& sizes,
const Container<c10::SymInt>& strides) {
//caller should make sure that sizes and strides are valid for self
//(storage is sufficient, strides are non-negative, strides and sizes array size is the same)
Tensor self_;
if (self.is_quantized()) {
self_ = at::detail::make_tensor<QTensorImpl>(
c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype(), get_qtensorimpl(self)->quantizer());
self_.unsafeGetTensorImpl()->set_sizes_and_strides(sizes, strides, self.sym_storage_offset());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why we can't change the ArrayRef<int64_t> and SmallVector<int64_t> template above to use this line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just forgot about that! Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, that won't work because there are two overloads for set_sizes_and_strides

  1. void set_sizes_and_strides(SymIntArrayRef sizes, SymIntArrayRef strides, optional<c10::SymInt> storage_offset)
  2. void set_sizes_and_strides(IntArrayRef new_size, IntArrayRef new_stride, optional<int64_t> storage_offset)

So, all inputs have to be symbolic or concrete, we can't pass self.sym_storage_offset() to second overload.

} else {
self_ = at::detail::make_tensor<TensorImpl>(
c10::TensorImpl::VIEW, Storage(self.storage()), self.key_set(), self.dtype());
self_.unsafeGetTensorImpl()->set_sizes_and_strides(sizes, strides, self.sym_storage_offset());
}
namedinference::propagate_names(self_, self);
return self_;
}

Tensor reshape_symint(const Tensor& self, c10::SymIntArrayRef proposed_shape) {
if (self.is_sparse()) {
AT_ERROR("reshape is not implemented for sparse tensors");
Expand Down Expand Up @@ -3687,7 +3710,7 @@ Tensor view(const Tensor& self,
}

Tensor alias(const Tensor& self) {
return alias_with_sizes_and_strides(self, self.sizes(), self.strides());
return alias_with_sizes_and_strides(self, self.sym_sizes(), self.sym_strides());
}

Tensor detach(const Tensor& self) {
Expand Down