Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VMAP] Add linspace and logspace batch rules #105451

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
c9a19ae
[VMAP] Add linspace and logspace batch rules
qqaatw Jul 18, 2023
f982f18
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Jul 18, 2023
f13a481
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Jul 18, 2023
a83c2bf
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Jul 18, 2023
1be0875
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Aug 3, 2023
d2862f9
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Aug 3, 2023
62d705e
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Aug 11, 2023
4b16600
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Aug 17, 2023
c36eaf3
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Aug 17, 2023
a9f747e
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Aug 17, 2023
d458871
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Aug 21, 2023
47a9298
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Aug 21, 2023
c2d17c8
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Aug 25, 2023
5008184
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Aug 30, 2023
eae99d5
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Aug 30, 2023
fba8d7a
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Sep 7, 2023
bc48db1
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Sep 7, 2023
ce97130
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Sep 7, 2023
55cddba
Update on "[VMAP] Add linspace and logspace batch rules"
qqaatw Sep 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
121 changes: 121 additions & 0 deletions aten/src/ATen/functorch/BatchRulesFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,121 @@ static std::tuple<Tensor,optional<int64_t>> _new_zeros_with_same_feature_meta_ba
return std::make_tuple(result, 0);
}

static std::tuple<Tensor,optional<int64_t>> linspace_logspace_batch_rule_helper(
const at::Tensor& start, optional<int64_t> start_bdim,
const at::Tensor& end, optional<int64_t> end_bdim,
int64_t steps,
c10::optional<double> base,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if base is nullopt then we are talking linspace?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory)
{
auto batch_size = get_bdim_size2(start, start_bdim, end, end_bdim);
auto start_ = ensure_has_bdim(start, start_bdim.has_value(), batch_size);
auto end_ = ensure_has_bdim(end, end_bdim.has_value(), batch_size);
start_ = moveBatchDimToFront(start_, start_bdim);
end_ = moveBatchDimToFront(end_, end_bdim);

auto tensor_options = at::TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);

Tensor result;
if (steps == 0){
result = at::full({batch_size, 0}, 0, tensor_options);
} else if (steps == 1){
result = start_.new_empty({batch_size}, tensor_options).copy_(start_).unsqueeze(1);
} else {
result = (start_ + at::arange(0, steps, tensor_options).unsqueeze_(1) * (end_ - start_) / (steps - 1)).transpose(0, 1);
}

if (base){
result = at::pow(*base, result);
}

if (dtype && result.scalar_type() != *dtype){
result = result.to(*dtype);
}

return std::make_tuple(result, 0);
}

static std::tuple<Tensor,optional<int64_t>> linspace_Tensor_Tensor_batch_rule(
const at::Tensor& start, optional<int64_t> start_bdim,
const at::Tensor& end, optional<int64_t> end_bdim,
int64_t steps,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory){
return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, c10::nullopt, dtype, layout, device, pin_memory);
}

static std::tuple<Tensor,optional<int64_t>> linspace_Tensor_Scalar_batch_rule(
const at::Tensor& start, optional<int64_t> start_bdim,
const at::Scalar& end,
int64_t steps,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory){

auto end_t = at::native::wrapped_scalar_tensor(end, start.device());
return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, c10::nullopt, steps, c10::nullopt, dtype, layout, device, pin_memory);
}

static std::tuple<Tensor,optional<int64_t>> linspace_Scalar_Tensor_batch_rule(
const at::Scalar& start,
const at::Tensor& end, optional<int64_t> end_bdim,
int64_t steps,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory){

auto start_t = at::native::wrapped_scalar_tensor(start, end.device());
return linspace_logspace_batch_rule_helper(start_t, c10::nullopt, end, end_bdim, steps, c10::nullopt, dtype, layout, device, pin_memory);
}

static std::tuple<Tensor,optional<int64_t>> logspace_Tensor_Tensor_batch_rule(
const at::Tensor& start, optional<int64_t> start_bdim,
const at::Tensor& end, optional<int64_t> end_bdim,
int64_t steps,
double base,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory){
return linspace_logspace_batch_rule_helper(start, start_bdim, end, end_bdim, steps, c10::make_optional(base), dtype, layout, device, pin_memory);
}

static std::tuple<Tensor,optional<int64_t>> logspace_Tensor_Scalar_batch_rule(
const at::Tensor& start, optional<int64_t> start_bdim,
const at::Scalar& end,
int64_t steps,
double base,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory){

auto end_t = at::native::wrapped_scalar_tensor(end, start.device());
return linspace_logspace_batch_rule_helper(start, start_bdim, end_t, c10::nullopt, steps, c10::make_optional(base), dtype, layout, device, pin_memory);
}

static std::tuple<Tensor,optional<int64_t>> logspace_Scalar_Tensor_batch_rule(
const at::Scalar& start,
const at::Tensor& end, optional<int64_t> end_bdim,
int64_t steps,
double base,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory){

auto start_t = at::native::wrapped_scalar_tensor(start, end.device());
return linspace_logspace_batch_rule_helper(start_t, c10::nullopt, end, end_bdim, steps, c10::make_optional(base), dtype, layout, device, pin_memory);
}

static bool _has_same_storage_numel_batch_rule(const Tensor& a, const Tensor& b) {
return true;
}
Expand All @@ -119,6 +234,12 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
VMAP_SUPPORT(new_zeros, NEW_BLAH_BATCH_RULE_SYMINT(ATEN_FN(new_zeros)));
VMAP_SUPPORT(new_ones, NEW_BLAH_BATCH_RULE_SYMINT(ATEN_FN(new_ones)));
VMAP_SUPPORT(new_full, NEW_BLAH_BATCH_RULE_SYMINT(ATEN_FN(new_full)));
VMAP_SUPPORT2(linspace, Tensor_Tensor, linspace_Tensor_Tensor_batch_rule);
VMAP_SUPPORT2(linspace, Tensor_Scalar, linspace_Tensor_Scalar_batch_rule);
VMAP_SUPPORT2(linspace, Scalar_Tensor, linspace_Scalar_Tensor_batch_rule);
VMAP_SUPPORT2(logspace, Tensor_Tensor, logspace_Tensor_Tensor_batch_rule);
VMAP_SUPPORT2(logspace, Tensor_Scalar, logspace_Tensor_Scalar_batch_rule);
VMAP_SUPPORT2(logspace, Scalar_Tensor, logspace_Scalar_Tensor_batch_rule);
VMAP_SUPPORT(_new_zeros_with_same_feature_meta, _new_zeros_with_same_feature_meta_batch_rule);
// Not sure how to add the ones with irregular args to the mix cleanly (i.e. randint takes an extra int parameter)
}
Expand Down
2 changes: 0 additions & 2 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3739,8 +3739,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('as_strided_scatter', ''),
xfail('equal', ''),
xfail('linalg.lu', ''),
xfail('linspace', 'tensor_overload'),
xfail('logspace', 'tensor_overload'),
skip('linalg.ldl_solve', ''),
skip('_softmax_backward_data'),
# https://github.com/pytorch/pytorch/issues/96560
Expand Down