Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

min/max support for SymInt/Floats, finish as_strided/scatter/squeeze() backward symint support #86643

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <c10/cuda/CUDAMathCompat.h>

#include <ATen/native/NonSymbolicBC.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>

Expand Down Expand Up @@ -368,8 +367,8 @@ __global__ void transform_bias_rescale_qkv_add_padding_kernel(
}

Tensor collapse_dims_1_and_2(const Tensor& sizes) {
auto sizes_dim1 = at::native::narrow(sizes, 1, 0, 1);
auto sizes_dim2 = at::native::narrow(sizes, 1, 1, 1);
auto sizes_dim1 = at::native::narrow_symint(sizes, 1, 0, 1);
auto sizes_dim2 = at::native::narrow_symint(sizes, 1, 1, 1);

return (sizes_dim1 * sizes_dim2).contiguous();
}
Expand Down Expand Up @@ -451,7 +450,7 @@ __host__ std::tuple<Tensor, Tensor, Tensor> transform_bias_rescale_qkv_cuda(
auto sizes = collapse_dims_1_and_2(nt_qkv->get_nested_size_tensor());
auto offsets =
NestedTensor_batch_offsets_from_size_tensor(sizes, sizes.numel());
at::native::narrow(offsets, 0, sizes.numel() + 1, sizes.numel())
at::native::narrow_symint(offsets, 0, sizes.numel() + 1, sizes.numel())
.copy_(sizes.reshape({-1}));
auto metadata = offsets.to(at::Device(kCUDA), at::kInt, true, true);
const auto offsets_ptr = metadata.data_ptr<int>();
Expand Down
15 changes: 15 additions & 0 deletions c10/core/SymInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,21 @@ bool SymInt::operator>=(SymInt sci) const {
return res[0]->ge(res[1])->bool_();
}

SymInt SymInt::min(SymInt sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return std::min(data_, sci.data_);
}
auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->min(res[1]));
}
SymInt SymInt::max(SymInt sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return std::max(data_, sci.data_);
}
auto res = normalize_symints(*this, sci);
return SymInt::toSymInt(res[0]->max(res[1]));
}

void SymInt::operator*=(SymInt sci) {
*this = *this * sci;
}
Expand Down
3 changes: 3 additions & 0 deletions c10/core/SymInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ class C10_API SymInt {
void operator*=(SymInt sci);
void operator+=(SymInt sci);

SymInt min(SymInt sci) const;
SymInt max(SymInt sci) const;

SymInt operator*(int64_t sci) const;
bool operator<(int64_t sci) const;
bool operator==(int64_t sci) const;
Expand Down
6 changes: 6 additions & 0 deletions c10/core/SymIntNodeImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class C10_API SymIntNodeImpl : public c10::intrusive_ptr_target {
virtual SymIntNode ceil() {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode min(const SymIntNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode max(const SymIntNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymIntNode clone() {
TORCH_CHECK(false, "NYI");
};
Expand Down
8 changes: 4 additions & 4 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1493,19 +1493,19 @@
result: auto_element_wise

- name: squeeze(Tensor(a) self) -> Tensor(a)
self: unsqueeze_to(grad, self.sizes())
self: unsqueeze_to(grad, self.sym_sizes())
result: auto_linear

- name: squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)
self: unsqueeze_to(grad, dim, self.sizes())
self: unsqueeze_to(grad, dim, self.sym_sizes())
result: auto_linear

- name: squeeze_(Tensor(a!) self) -> Tensor(a!)
self: unsqueeze_to(grad, self.sizes())
self: unsqueeze_to(grad, self.sym_sizes())
result: auto_linear

- name: squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)
self: unsqueeze_to(grad, dim, self.sizes())
self: unsqueeze_to(grad, dim, self.sym_sizes())
result: auto_linear

- name: std.correction(Tensor self, int[1]? dim, *, int? correction, bool keepdim=False) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,7 @@ def wrap(e, device=None):
def functions_with_cpp_meta_impl_that_support_symint(self):
return [
aten.empty_strided.default,
aten.as_strided_scatter.default,
aten.as_strided.default,
aten.zeros.default,
aten.detach.default,
Expand Down
35 changes: 22 additions & 13 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,23 +848,26 @@ Tensor unbind_backward(const variable_list& grads, int64_t dim) {
return at::stack(grads_tensors, dim);
}

Tensor unsqueeze_to(const Tensor& self, IntArrayRef sizes) {
Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) {
auto result = self;

int64_t nDims = sizes.size();
int64_t nDims = sym_sizes.size();
for (const auto dim : c10::irange(nDims)) {
if (sizes[dim] == 1) {
if (sym_sizes[dim] == 1) {
result = result.unsqueeze(dim);
}
}
return result;
}

Tensor unsqueeze_to(const Tensor& self, int64_t dim, IntArrayRef sizes) {
dim = at::maybe_wrap_dim(dim, sizes.size());
Tensor unsqueeze_to(
const Tensor& self,
int64_t dim,
c10::SymIntArrayRef sym_sizes) {
dim = at::maybe_wrap_dim(dim, sym_sizes.size());
// in NumPy it's not an error to unsqueeze a scalar, but we still need to
// avoided unsqueezing in the backward.
if (sizes.size() > 0 && sizes[dim] == 1) {
if (sym_sizes.size() > 0 && sym_sizes[dim] == 1) {
return self.unsqueeze(dim);
}
return self;
Expand Down Expand Up @@ -2836,21 +2839,27 @@ Tensor as_strided_backward(

// Step (1): create underlying tensor as "storage"
auto shared_offset =
std::min(input_geometry.sym_storage_offset(), sym_storage_offset);
// TODO: symint-ify. Do we need a min() and max() for SymInts?
input_geometry.sym_storage_offset().min(sym_storage_offset);
auto inp_effective_offset =
input_geometry.sym_storage_offset() - shared_offset;
auto out_effective_offset = sym_storage_offset - shared_offset;
auto base_size = std::max(
_min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset),
_min_storage_size(out_sizes_, out_strides_, out_effective_offset));
auto storage = grad.new_empty_symint(c10::SymIntArrayRef(base_size));
storage.zero_();
auto base_size1 =
_min_storage_size(inp_sizes_, inp_strides_, inp_effective_offset);
auto base_size2 =
_min_storage_size(out_sizes_, out_strides_, out_effective_offset);
auto base_size = base_size1.max(base_size2);
auto storage = grad.new_zeros_symint(c10::SymIntArrayRef(base_size));

// prepare indices tensor if we will do index_add_ later
c10::optional<at::Tensor> flatten_full_indices;
if (inp_maybe_overlap || out_maybe_overlap) {
flatten_full_indices =
at::arange(0, base_size, grad.options().dtype(at::kLong));
// TODO: should we symint-ify arange? Need SymScalar.
at::arange(
0,
base_size.guard_int(__FILE__, __LINE__),
grad.options().dtype(at::kLong));
}

// Step (2): use output geometry to scatter gradients into storage
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ at::Tensor logcumsumexp_backward(
at::Tensor result,
int64_t dim);
at::Tensor unbind_backward(const variable_list& grads, int64_t dim);
at::Tensor unsqueeze_to(const at::Tensor& self, at::IntArrayRef sizes);
at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes);
at::Tensor unsqueeze_to(
const at::Tensor& self,
int64_t dim,
at::IntArrayRef sizes);
c10::SymIntArrayRef sym_sizes);
std::vector<at::Tensor> cat_tensors_backward(
const at::Tensor& grad,
const std::vector<std::vector<c10::SymInt>>& sizes,
Expand Down
19 changes: 19 additions & 0 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,13 @@ class PythonSymIntNodeImpl : public c10::SymIntNodeImpl {
return dispatch_common_(__FUNCTION__, other);
}

virtual SymIntNode min(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}
virtual SymIntNode max(const SymIntNode& other) override {
return dispatch_common_(__FUNCTION__, other);
}

virtual SymIntNode ceil() override {
return dispatch_common_(__FUNCTION__);
}
Expand Down Expand Up @@ -1481,6 +1488,18 @@ void initJITBindings(PyObject* module) {
.def(
"__ceil__",
[](c10::SymIntNode a) -> c10::SymIntNode { return a->ceil(); })
.def(
"__min__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->min(snb);
})
.def(
"__max__",
[](c10::SymIntNode a, py::object b) -> c10::SymIntNode {
auto snb = toSymIntNode(a, b);
return a->max(snb);
})
.def("__bool__", [](c10::SymIntNode a) { return a->bool_(); })
.def("__int__", [](c10::SymIntNode a) { return a->int_(); })
// Intentionally don't set file line, as the Python backtrace matters
Expand Down
20 changes: 14 additions & 6 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def eval(cls, a):
'mul': lambda a, b: a * b,
'mod': lambda a, b: a % b,
'truediv': lambda a, b: a / b,
'floordiv': lambda a, b: FloorDiv(a, b)
'floordiv': lambda a, b: FloorDiv(a, b),
}

magic_methods = {
Expand All @@ -215,7 +215,9 @@ def eval(cls, a):
'lt': lambda a, b: sympy.Lt(a, b),
'le': lambda a, b: sympy.Le(a, b),
'ge': lambda a, b: sympy.Ge(a, b),
'ceil': lambda a: Ceil(a)
'ceil': lambda a: Ceil(a),
'min': lambda a, b: sympy.Min(a, b),
'max': lambda a, b: sympy.Max(a, b),
}

unary_magic_methods = {
Expand All @@ -228,14 +230,19 @@ def _make_magic(method, func, py_type):
func = lru_cache(256)(func)

def magic_impl(self, other):
if method in ["min", "max"]:
# op = getattr(builtins, method)
return self
else:
op = getattr(operator, method)
if SYM_FUNCTION_MODE:
return _handle_sym_dispatch(getattr(operator, method), (self, other), {})
return _handle_sym_dispatch(op, (self, other), {})
if isinstance(other, py_type):
other = other.expr
other_expr = other.expr
# TODO: consider constant prop here
expr = self.shape_env.replace(self.expr)
other = self.shape_env.replace(other)
out = func(expr, other)
other_expr = self.shape_env.replace(other_expr)
out = func(expr, other_expr)
out = sympy.expand(out)
if method in ["truediv"]:
return PySymFloat(out, self.shape_env)
Expand All @@ -246,6 +253,7 @@ def magic_impl(self, other):

def unary_magic_impl(self):
if SYM_FUNCTION_MODE:
# TODO: Should this if/else be moved outside of SYM_FUNCTION_MODE ?
if method in ["ceil", "floor"]:
op = getattr(math, method)
else:
Expand Down