Skip to content

Commit

Permalink
Update base for Update on "Make CI error on inductor fallback when de…
Browse files Browse the repository at this point in the history
…comp is available"


Fixes #99446 

Remove the warning, as that annoyed end-users who don't know what to do about it.

Instead, try to hold the line by preventing any decomp from being added without making
the corresponding change to inductor's fallbacks.

Note: we probably still need to better document how to update inductor's decomps,
for now it's pretty much "go ask the inductor team for advice"

cc soumith voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 desertfire

[ghstack-poisoned]
  • Loading branch information
wconstab committed Apr 20, 2023
2 parents cfc7e82 + 4721553 commit 6bb2149
Show file tree
Hide file tree
Showing 59 changed files with 2,721 additions and 542 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/common/install_onnx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pip_install \
transformers==4.25.1

# TODO: change this when onnx-script is on testPypi
pip_install "onnx-script@git+https://github.com/microsoft/onnx-script@1e8d764a9be04323d7171e4d5f511332790cb809"
pip_install "onnx-script@git+https://github.com/microsoft/onnxscript@1e8d764a9be04323d7171e4d5f511332790cb809"

# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/Generator.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#pragma once

#include <stdint.h>
#include <mutex>
#include <deque>
#include <atomic>
#include <typeinfo>
#include <utility>
#include <cstddef>
#include <cstdint>

#include <c10/util/Exception.h>
#include <c10/util/C++17.h>
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/functorch/BatchRulesDecompositions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,10 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(matrix_H);
OP_DECOMPOSE(matrix_power);
OP_DECOMPOSE2(max, other );
OP_DECOMPOSE(max_pool1d);
OP_DECOMPOSE(max_pool1d_with_indices);
OP_DECOMPOSE(max_pool2d);
OP_DECOMPOSE(max_pool3d);
OP_DECOMPOSE(meshgrid);
OP_DECOMPOSE2(meshgrid, indexing);
OP_DECOMPOSE(mH);
Expand Down
41 changes: 30 additions & 11 deletions aten/src/ATen/functorch/BatchRulesPooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,48 @@

namespace at { namespace functorch {

template <typename Func>
std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
max_pool2d_with_indices_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
IntArrayRef kernel_size, IntArrayRef stride,
IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
max_pool_with_indices_batch_rule_helper(
const Tensor& self, optional<int64_t> self_bdim,
IntArrayRef kernel_size, IntArrayRef stride,
IntArrayRef padding, IntArrayRef dilation, bool ceil_mode, int64_t n, Func pooling_fn) {

auto logical_rank = rankWithoutBatchDim(self, self_bdim);
TORCH_INTERNAL_ASSERT(logical_rank == 3 || logical_rank == 4);
// Tensor[B, C, H, W] -> just call max_pool2d
if (logical_rank == 3) {
TORCH_INTERNAL_ASSERT(logical_rank == n + 1 || logical_rank == n + 2);
// Tensor[B, logical_rank...] -> just call max_poolnd
if (logical_rank == n + 1) {
auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::max_pool2d_with_indices(
auto result = pooling_fn(
self_, kernel_size, stride, padding, dilation, ceil_mode);
return std::make_tuple(std::move(std::get<0>(result)), 0, std::move(std::get<1>(result)), 0);
}
// Tensor[B, N, C, H, W] -> Tensor[B * N, C, H, W]
// Tensor[B, N, logical_rank...] -> Tensor[B * N, logical_rank...]
auto bdim_size = self.size(*self_bdim);
auto self_ = reshape_dim_into(*self_bdim, 0, self);
auto result = at::max_pool2d_with_indices(
auto result = pooling_fn(
self_, kernel_size, stride, padding, dilation, ceil_mode);
return std::make_tuple(
reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0,
reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0);
}

std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
max_pool3d_with_indices_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
IntArrayRef kernel_size, IntArrayRef stride,
IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
return max_pool_with_indices_batch_rule_helper(self, self_bdim, kernel_size, stride, padding, dilation, ceil_mode, 3, at::max_pool3d_with_indices);
}

std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
max_pool2d_with_indices_batch_rule(
const Tensor& self, optional<int64_t> self_bdim,
IntArrayRef kernel_size, IntArrayRef stride,
IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
return max_pool_with_indices_batch_rule_helper(self, self_bdim, kernel_size, stride, padding, dilation, ceil_mode, 2, at::max_pool2d_with_indices);
}

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
EXISTING_BDIM(_adaptive_avg_pool2d);
EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool2d_backward);
Expand All @@ -48,9 +66,10 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
EXISTING_BDIM_ALL_BOXED(adaptive_max_pool3d);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, adaptive_max_pool2d_backward, 2);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(4, adaptive_max_pool3d_backward, 2);

VMAP_SUPPORT(max_pool2d_with_indices, max_pool2d_with_indices_batch_rule);
VMAP_SUPPORT(max_pool3d_with_indices, max_pool3d_with_indices_batch_rule);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(3, max_pool2d_with_indices_backward, 2);
ALL_TENSORS_HAVE_OPTIONAL_BDIM_BOXED_CONTIG1(4, max_pool3d_with_indices_backward, 2);
}

}}
9 changes: 5 additions & 4 deletions aten/src/ATen/functorch/BatchRulesReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ std::tuple<Tensor,optional<int64_t>> searchsorted_batch_rule(
const c10::optional<Tensor>& sorter,
c10::optional<int64_t> sorter_bdim) {
auto buckets_logical_rank = rankWithoutBatchDim(sorted_sequence, sorted_sequence_bdim);
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);

// Preprocess sorter and sorted_sequence.
// If they both exist, and only one has a bdim, then we need to make sure both do.
Expand Down Expand Up @@ -382,18 +383,18 @@ std::tuple<Tensor,optional<int64_t>> searchsorted_batch_rule(
// BD, B* -> BD, B flat(*)
if (buckets_bdim.has_value() && self_bdim.has_value()) {
auto self_ = moveBatchDimToFront(self, self_bdim);
self_ = self_.flatten(1);
self_ = self_logical_rank == 0 ? self_.unsqueeze(-1) : self_.flatten(1);
auto result = at::searchsorted(buckets, self_, out_int32, right, std::move(side), sorter_);
result = result.view(self_.sizes());
result = result.view(self_logical_rank == 0 ? IntArrayRef(self_.sizes().begin(), self_.sizes().end() - 1) : self_.sizes());
return std::make_tuple(std::move(result), 0);
}
// BD, * -> BD, flat(*) -> BD, B flat(*)
if (buckets_bdim.has_value() && !self_bdim.has_value()) {
auto bdim_size = buckets.size(*buckets_bdim);
auto self_ = ensure_has_bdim(self, false, bdim_size);
self_ = self_.flatten(1);
self_ = self_logical_rank == 0 ? self_.unsqueeze(-1) : self_.flatten(1);
auto result = at::searchsorted(buckets, self_, out_int32, right, std::move(side), sorter_);
result = result.view(self_.sizes());
result = result.view(self_logical_rank == 0 ? IntArrayRef(self_.sizes().begin(), self_.sizes().end() - 1) : self_.sizes());
return std::make_tuple(std::move(result), 0);
}
// D, B* -> no change
Expand Down
97 changes: 58 additions & 39 deletions aten/src/ATen/functorch/BatchRulesScatterOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,42 @@ std::tuple<Tensor,optional<int64_t>> masked_fill_scalar_batch_rule(
return std::make_tuple(result, 0);
}

std::tuple<Tensor,optional<int64_t>> index_fill_batch_rule_helper(
int64_t batch_size,
int64_t self_logical_rank,
int64_t index_logical_rank,
Tensor & self_,
int64_t dim,
Tensor & index_,
const Scalar & value
){
if (self_logical_rank != 0){
auto index_offset = at::arange(
batch_size,
at::TensorOptions().dtype(index_.scalar_type()).device(index_.device())
);
if (index_logical_rank == 0){
index_ = index_.unsqueeze(-1);
}
index_ = index_.add(index_offset.unsqueeze(-1), self_.size(dim + 1));
index_ = reshape_dim_into(0, 0, index_);
self_ = reshape_dim_into(0, dim, self_);
self_.index_fill_(dim, index_, value);
self_ = reshape_dim_outof(dim, batch_size, self_);
return std::make_tuple(self_, dim);
}

// If self_logical_rank == 0, the batch dim is certainly 0, and we must apply batched indices to each row.
if (index_logical_rank != 0){
index_ = reshape_dim_into(0, 0, index_);
}
self_.unsqueeze_(-1);
self_.index_fill_(dim + 1, index_, value);
self_.squeeze_(-1);

return std::make_tuple(self_, 0);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(
Tensor & self, optional<int64_t> self_bdim,
int64_t dim,
Expand Down Expand Up @@ -1051,7 +1087,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(

if (inplace) {
// Do for-loop for in-place because we cannot reshape
// `self_` having an incompatible stride without copying
// `self_` having an incompatible stride without copying.
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
Expand All @@ -1066,31 +1102,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule_impl(

self_ = self_bdim.has_value() ? self_ : self_.clone();

if (self_logical_rank != 0){
auto index_offset = at::arange(
batch_size,
at::TensorOptions().dtype(index_.scalar_type()).device(index_.device())
);
if (index_logical_rank == 0){
index_ = index_.unsqueeze(-1);
}
index_ = index_.add(index_offset.unsqueeze(-1), self_.size(dim + 1));
index_ = reshape_dim_into(0, 0, index_);
self_ = reshape_dim_into(0, dim, self_);
self_.index_fill_(dim, index_, value);
self_ = reshape_dim_outof(dim, batch_size, self_);
return std::make_tuple(self_, dim);
}

// If self_logical_rank == 0, the batch dim is certainly 0, and we must apply batched indices to each row.
if (index_logical_rank != 0){
index_ = reshape_dim_into(0, 0, index_);
}
self_.unsqueeze_(-1);
self_.index_fill_(dim + 1, index_, value);
self_.squeeze_(-1);

return std::make_tuple(self_, 0);
return index_fill_batch_rule_helper(batch_size, self_logical_rank, index_logical_rank, self_, dim, index_, value);
}

std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
Expand All @@ -1100,6 +1112,7 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
const Tensor & value, optional<int64_t> value_bdim,
const bool inplace) {
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
const auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
Tensor self_ = moveBatchDimToFront(self, self_bdim);
Tensor index_ = moveBatchDimToFront(index, index_bdim);
Tensor value_ = moveBatchDimToFront(value, value_bdim);
Expand All @@ -1123,22 +1136,28 @@ std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule_impl(
auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
value_ = ensure_has_bdim(value_, value_bdim.has_value(), batch_size);

self_ = self_bdim.has_value() ? self_ : self_.clone();

for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
const auto& value_slice = value_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value_slice
);
if (inplace || value_bdim.has_value()) {
// Do for-loop for in-place because we cannot reshape
// `self_` having an incompatible stride without copying.
// If value has a batch dim, we do for-loop as well because
// index_fill_ supports 1-element tensor only.
for (const auto i : c10::irange(0, batch_size)) {
const auto& self_slice = self_.select(0, i);
const auto& index_slice = index_.select(0, i);
self_slice.index_fill_(
dim,
index_slice,
value_bdim.has_value() ? value_.select(0, i) : value_
);
}
return std::make_tuple(self_, 0);
}

return std::make_tuple(self_, 0);
self_ = self_bdim.has_value() ? self_ : self_.clone();

// calling .item() on value is safe here because value is guaranteed to not be a batched tensor.
return index_fill_batch_rule_helper(batch_size, self_logical_rank, index_logical_rank, self_, dim, index_, value.item());
}

void index_fill__int_scalar_batch_rule(
Expand Down
11 changes: 11 additions & 0 deletions aten/src/ATen/functorch/BatchRulesUnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,17 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
UNARY_POINTWISE(special_i1);
UNARY_POINTWISE(special_i1e);
UNARY_POINTWISE(special_ndtri);
POINTWISE_BOXED(special_bessel_j0);
POINTWISE_BOXED(special_spherical_bessel_j0);
POINTWISE_BOXED(special_bessel_j1);
POINTWISE_BOXED(special_modified_bessel_i0);
POINTWISE_BOXED(special_modified_bessel_i1);
POINTWISE_BOXED(special_scaled_modified_bessel_k0);
POINTWISE_BOXED(special_modified_bessel_k0);
POINTWISE_BOXED(special_scaled_modified_bessel_k1);
POINTWISE_BOXED(special_modified_bessel_k1);
POINTWISE_BOXED(special_bessel_y0);
POINTWISE_BOXED(special_bessel_y1);

// Activation functions (from https://pytorch.org/docs/stable/nn.html#non-linear-activations-weighted-sum-nonlinearity)
UNARY_POINTWISE_ALL(elu);
Expand Down

0 comments on commit 6bb2149

Please sign in to comment.