Skip to content

Commit

Permalink
[vmap] implement batching rules for clamp, clamp_min and clamp_max (#…
Browse files Browse the repository at this point in the history
…48449)

Summary:
Fix #47754

- This PR implements batching rules for `clamp`, `clamp_min` and `clamp_max` operators.
- Testcases are added to `test/test_vmap.py`.

Pull Request resolved: #48449

Reviewed By: ejguan

Differential Revision: D25219360

Pulled By: zou3519

fbshipit-source-id: 0b7e1b00f5553b4578f15a6cc440640e506b4918
  • Loading branch information
RockingJavaBean authored and facebook-github-bot committed Nov 30, 2020
1 parent b514951 commit 8f8738c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
23 changes: 23 additions & 0 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,24 @@ std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int6
return result;
}

Tensor clamp_batching_rule(const Tensor& self, optional<Scalar> min, optional<Scalar> max) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::clamp(self_physical.tensor(), min, max);
return self_physical.newLogicalFromPhysical(result);
}

Tensor clamp_min_batching_rule(const Tensor& self, Scalar min) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::clamp_min(self_physical.tensor(), min);
return self_physical.newLogicalFromPhysical(result);
}

Tensor clamp_max_batching_rule(const Tensor& self, Scalar max) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto result = at::clamp_max(self_physical.tensor(), max);
return self_physical.newLogicalFromPhysical(result);
}

std::vector<Tensor> tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
Expand Down Expand Up @@ -984,6 +1002,11 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("view", view_batching_rule);
m.impl("view_as", native::view_as); // composite wrt autograd

// clamp operations
m.impl("clamp", clamp_batching_rule);
m.impl("clamp_min", clamp_min_batching_rule);
m.impl("clamp_max", clamp_max_batching_rule);

// unary pointwise, out-of-place, no additional arguments.
#define UNARY_POINTWISE(op) m.impl(#op, \
unwrap_and_call<Tensor (*)(const Tensor&), at::op>);
Expand Down
11 changes: 11 additions & 0 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,17 @@ def test_chunk(self):
test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)

def test_clamp(self):
clamp_cases = (
(lambda t: t.clamp(min=-0.5), TensorFactory.randn),
(lambda t: t.clamp(max=0.5), TensorFactory.randn),
(lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn),
(lambda t: t.clamp_min(min=-0.5), TensorFactory.randn),
(lambda t: t.clamp_max(max=0.5), TensorFactory.randn),
)
for op, getter in clamp_cases:
self._test_unary(op, getter, 'cpu')

def test_diagonal(self):
tensor = torch.randn(3, 5, 7, 11, 13)
test = self._vmap_view_test
Expand Down

0 comments on commit 8f8738c

Please sign in to comment.