From c9a19aeef438f00d998d0cb9dc493e00c3e3708b Mon Sep 17 00:00:00 2001 From: "Li-Huai (Allan) Lin" Date: Tue, 18 Jul 2023 16:25:37 +0800 Subject: [PATCH] [VMAP] Add linspace and logspace batch rules [ghstack-poisoned] --- aten/src/ATen/functorch/BatchRulesFactory.cpp | 58 +++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 4 +- test/functorch/test_vmap.py | 2 - test/test_decomp.py | 10 ++-- 4 files changed, 65 insertions(+), 9 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesFactory.cpp b/aten/src/ATen/functorch/BatchRulesFactory.cpp index bde6842342b2d..c5159691f14e5 100644 --- a/aten/src/ATen/functorch/BatchRulesFactory.cpp +++ b/aten/src/ATen/functorch/BatchRulesFactory.cpp @@ -103,6 +103,62 @@ static std::tuple> _new_zeros_with_same_feature_meta_ba return std::make_tuple(result, 0); } +static std::tuple> linspace_logspace_batch_rule_helper( + const at::Tensor& start, optional start_bdim, + const at::Tensor& end, optional end_bdim, + int64_t steps, + c10::optional base, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) +{ + auto batch_size = get_bdim_size2(start, start_bdim, end, end_bdim); + auto start_ = ensure_has_bdim(start, start_bdim.has_value(), batch_size); + auto end_ = ensure_has_bdim(end, end_bdim.has_value(), batch_size); + start_ = moveBatchDimToFront(start_, start_bdim); + end_ = moveBatchDimToFront(end_, end_bdim); + + auto tensor_options = at::TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); + + Tensor result; + if (steps == 0){ + result = at::full({batch_size, 0}, 0, tensor_options); + } else if (steps == 1){ + result = at::empty({start_.size(0), 1}, tensor_options).copy_(start_.unsqueeze_(1)); + } else { + result = (start_ + at::arange(0, steps, tensor_options).unsqueeze_(1) * (end_ - start_) / (steps - 1)).transpose(0, 1); + } + + if (base){ + result = at::pow(*base, result); + } + return std::make_tuple(result, 0); +} + +static std::tuple> linspace_batch_rule( + const at::Tensor& start, optional start_bdim, + const at::Tensor& end, optional end_bdim, + int64_t steps, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory){ + return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, c10::nullopt, dtype, layout, device, pin_memory); +} + +static std::tuple> logspace_batch_rule( + const at::Tensor& start, optional start_bdim, + const at::Tensor& end, optional end_bdim, + int64_t steps, + double base, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory){ + return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, c10::make_optional(base), dtype, layout, device, pin_memory); +} + static bool _has_same_storage_numel_batch_rule(const Tensor& a, const Tensor& b) { return true; } @@ -119,6 +175,8 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(new_zeros, NEW_BLAH_BATCH_RULE_SYMINT(ATEN_FN(new_zeros))); VMAP_SUPPORT(new_ones, NEW_BLAH_BATCH_RULE_SYMINT(ATEN_FN(new_ones))); VMAP_SUPPORT(new_full, NEW_BLAH_BATCH_RULE_SYMINT(ATEN_FN(new_full))); + VMAP_SUPPORT2(linspace, Tensor, linspace_batch_rule); + VMAP_SUPPORT2(logspace, Tensor, logspace_batch_rule); VMAP_SUPPORT(_new_zeros_with_same_feature_meta, _new_zeros_with_same_feature_meta_batch_rule); // Not sure how to add the ones with irregular args to the mix cleanly (i.e. randint takes an extra int parameter) } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 701f6a97cb04a..def2bfb279b4c 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3263,7 +3263,7 @@ - func: ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) tags: pointwise -- func: linspace.Scalar(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor +- func: linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: CompositeExplicitAutograd: linspace @@ -3272,7 +3272,7 @@ dispatch: CompositeExplicitAutograd: linspace -- func: linspace.Scalar_out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) +- func: linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) dispatch: CPU, Meta: linspace_out CUDA: linspace_cuda_out diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index bd9aaf105421e..8a504d4dca5c2 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3736,8 +3736,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('as_strided_scatter', ''), xfail('equal', ''), xfail('linalg.lu', ''), - xfail('linspace', ''), - xfail('logspace', ''), skip('linalg.ldl_solve', ''), skip('_softmax_backward_data'), # UBSAN: runtime error: shift exponent -1 is negative diff --git a/test/test_decomp.py b/test/test_decomp.py index 2dc0f807c056f..21530532d4285 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -219,11 +219,11 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs): # there's an off-by-one error. See # https://github.com/pytorch/pytorch/issues/81996 # https://github.com/pytorch/pytorch/issues/82230 - (torch.int8, torch.ops.aten.linspace.Scalar) : (0, 1), - (torch.uint8, torch.ops.aten.linspace.Scalar) : (0, 1), - (torch.int16, torch.ops.aten.linspace.Scalar) : (0, 1), - (torch.int32, torch.ops.aten.linspace.Scalar) : (0, 1), - (torch.int64, torch.ops.aten.linspace.Scalar) : (0, 1), + (torch.int8, torch.ops.aten.linspace.default) : (0, 1), + (torch.uint8, torch.ops.aten.linspace.default) : (0, 1), + (torch.int16, torch.ops.aten.linspace.default) : (0, 1), + (torch.int32, torch.ops.aten.linspace.default) : (0, 1), + (torch.int64, torch.ops.aten.linspace.default) : (0, 1), (torch.int8, torch.ops.aten.linspace.Tensor) : (0, 1), (torch.uint8, torch.ops.aten.linspace.Tensor) : (0, 1), (torch.int16, torch.ops.aten.linspace.Tensor) : (0, 1),