diff --git a/aten/src/ATen/functorch/BatchRulesFactory.cpp b/aten/src/ATen/functorch/BatchRulesFactory.cpp index bde6842342b2d..09430ce5f2483 100644 --- a/aten/src/ATen/functorch/BatchRulesFactory.cpp +++ b/aten/src/ATen/functorch/BatchRulesFactory.cpp @@ -103,6 +103,121 @@ static std::tuple> _new_zeros_with_same_feature_meta_ba return std::make_tuple(result, 0); } +static std::tuple> linspace_logspace_batch_rule_helper( + const at::Tensor& start, optional start_bdim, + const at::Tensor& end, optional end_bdim, + int64_t steps, + c10::optional base, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) +{ + auto batch_size = get_bdim_size2(start, start_bdim, end, end_bdim); + auto start_ = ensure_has_bdim(start, start_bdim.has_value(), batch_size); + auto end_ = ensure_has_bdim(end, end_bdim.has_value(), batch_size); + start_ = moveBatchDimToFront(start_, start_bdim); + end_ = moveBatchDimToFront(end_, end_bdim); + + auto tensor_options = at::TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory); + + Tensor result; + if (steps == 0){ + result = at::full({batch_size, 0}, 0, tensor_options); + } else if (steps == 1){ + result = start_.new_empty({batch_size}, tensor_options).copy_(start_).unsqueeze(1); + } else { + result = (start_ + at::arange(0, steps, tensor_options).unsqueeze_(1) * (end_ - start_) / (steps - 1)).transpose(0, 1); + } + + if (base){ + result = at::pow(*base, result); + } + + if (dtype && result.scalar_type() != *dtype){ + result = result.to(*dtype); + } + + return std::make_tuple(result, 0); +} + +static std::tuple> linspace_Tensor_Tensor_batch_rule( + const at::Tensor& start, optional start_bdim, + const at::Tensor& end, optional end_bdim, + int64_t steps, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory){ + return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, c10::nullopt, dtype, layout, device, pin_memory); +} + +static std::tuple> linspace_Tensor_Scalar_batch_rule( + const at::Tensor& start, optional start_bdim, + const at::Scalar& end, + int64_t steps, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory){ + + auto end_t = at::native::wrapped_scalar_tensor(end, start.device()); + return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, c10::nullopt, steps, c10::nullopt, dtype, layout, device, pin_memory); +} + +static std::tuple> linspace_Scalar_Tensor_batch_rule( + const at::Scalar& start, + const at::Tensor& end, optional end_bdim, + int64_t steps, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory){ + + auto start_t = at::native::wrapped_scalar_tensor(start, end.device()); + return linspace_logspace_batch_rule_helper(start_t, c10::nullopt, end, end_bdim, steps, c10::nullopt, dtype, layout, device, pin_memory); +} + +static std::tuple> logspace_Tensor_Tensor_batch_rule( + const at::Tensor& start, optional start_bdim, + const at::Tensor& end, optional end_bdim, + int64_t steps, + double base, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory){ + return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, c10::make_optional(base), dtype, layout, device, pin_memory); +} + +static std::tuple> logspace_Tensor_Scalar_batch_rule( + const at::Tensor& start, optional start_bdim, + const at::Scalar& end, + int64_t steps, + double base, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory){ + + auto end_t = at::native::wrapped_scalar_tensor(end, start.device()); + return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, c10::nullopt, steps, c10::make_optional(base), dtype, layout, device, pin_memory); +} + +static std::tuple> logspace_Scalar_Tensor_batch_rule( + const at::Scalar& start, + const at::Tensor& end, optional end_bdim, + int64_t steps, + double base, + c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory){ + + auto start_t = at::native::wrapped_scalar_tensor(start, end.device()); + return linspace_logspace_batch_rule_helper(start_t, c10::nullopt, end, end_bdim, steps, c10::make_optional(base), dtype, layout, device, pin_memory); +} + static bool _has_same_storage_numel_batch_rule(const Tensor& a, const Tensor& b) { return true; } @@ -119,6 +234,12 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(new_zeros, NEW_BLAH_BATCH_RULE_SYMINT(ATEN_FN(new_zeros))); VMAP_SUPPORT(new_ones, NEW_BLAH_BATCH_RULE_SYMINT(ATEN_FN(new_ones))); VMAP_SUPPORT(new_full, NEW_BLAH_BATCH_RULE_SYMINT(ATEN_FN(new_full))); + VMAP_SUPPORT2(linspace, Tensor_Tensor, linspace_Tensor_Tensor_batch_rule); + VMAP_SUPPORT2(linspace, Tensor_Scalar, linspace_Tensor_Scalar_batch_rule); + VMAP_SUPPORT2(linspace, Scalar_Tensor, linspace_Scalar_Tensor_batch_rule); + VMAP_SUPPORT2(logspace, Tensor_Tensor, logspace_Tensor_Tensor_batch_rule); + VMAP_SUPPORT2(logspace, Tensor_Scalar, logspace_Tensor_Scalar_batch_rule); + VMAP_SUPPORT2(logspace, Scalar_Tensor, logspace_Scalar_Tensor_batch_rule); VMAP_SUPPORT(_new_zeros_with_same_feature_meta, _new_zeros_with_same_feature_meta_batch_rule); // Not sure how to add the ones with irregular args to the mix cleanly (i.e. randint takes an extra int parameter) } diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index f055dc49979e2..d19308b79295e 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3739,8 +3739,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('as_strided_scatter', ''), xfail('equal', ''), xfail('linalg.lu', ''), - xfail('linspace', 'tensor_overload'), - xfail('logspace', 'tensor_overload'), skip('linalg.ldl_solve', ''), skip('_softmax_backward_data'), # https://github.com/pytorch/pytorch/issues/96560