Skip to content

Commit

Permalink
torch.xlogy: Use wrapped_scalar_tensor / gpu_with_scalars to speed up…
Browse files Browse the repository at this point in the history
… GPU kernel. (#49926)

Summary:
Pull Request resolved: #49926

While investigating #49758, I changed the xlogy kernel to use the recommended wrapped_scaler_tensor pattern instead of moving the scalar to the GPU as a tensor.
While this doesn't avoid a synchronization (there is no synchronization in the move, as its done via fill), this does significantly speed up the GPU kernel (almost ~50%, benchmark in PR comments).

From looking at the nvprof output, it looks like this code path avoids broadcasting.  Aside: this seems unnecessary, as there is nothing special from the point-of-view of broadcasting whether the Tensor
is ()-sized or marked as a wrapped_scalar.  Still, this is a useful change to make as we avoid extra kernel launches and dispatches to create and fill the tensor.

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D25724215

Pulled By: gchanan

fbshipit-source-id: 4adcd5d8b3297502672ffeafc77e8af80592f460
  • Loading branch information
gchanan authored and facebook-github-bot committed Jan 4, 2021
1 parent 483670f commit 74dcb6d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
10 changes: 5 additions & 5 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1109,11 +1109,11 @@ Tensor& xlogy_out(Tensor& result, const Tensor& self, const Tensor& other) {
}

Tensor& xlogy_out(Tensor& result, Scalar self, const Tensor& other) {
return at::xlogy_out(result, c10::scalar_to_tensor(self, other.device()), other);
return at::xlogy_out(result, wrapped_scalar_tensor(self), other);
}

Tensor& xlogy_out(Tensor& result, const Tensor& self, Scalar other) {
return at::xlogy_out(result, self, c10::scalar_to_tensor(other, self.device()));
return at::xlogy_out(result, self, wrapped_scalar_tensor(other));
}

Tensor xlogy(const Tensor& x, const Tensor& y) {
Expand All @@ -1124,19 +1124,19 @@ Tensor xlogy(const Tensor& x, const Tensor& y) {
}

Tensor xlogy(Scalar x, const Tensor& y) {
return at::xlogy(c10::scalar_to_tensor(x, y.device()), y);
return at::xlogy(wrapped_scalar_tensor(x), y);
}

Tensor xlogy(const Tensor& x, Scalar y) {
return at::xlogy(x, c10::scalar_to_tensor(y, x.device()));
return at::xlogy(x, wrapped_scalar_tensor(y));
}

Tensor& xlogy_(Tensor& x, const Tensor& y) {
return at::xlogy_out(x, x, y);
}

Tensor& xlogy_(Tensor& x, Scalar y) {
return at::xlogy_out(x, x, c10::scalar_to_tensor(y, x.device()));
return at::xlogy_out(x, x, wrapped_scalar_tensor(y));
}

} // namespace native
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void mse_kernel_cuda(TensorIterator& iter) {

void xlogy_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "xlogy_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t {
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t x, scalar_t y) -> scalar_t {
if (at::_isnan(y)){
return NAN;
}
Expand Down
17 changes: 14 additions & 3 deletions test/test_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,13 +1082,13 @@ def test_maximum_minimum_cross_device(self, device):
ops = (torch.maximum, torch.minimum)

for torch_op in ops:
with self.assertRaisesRegex(RuntimeError,
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch_op(a, b)

with self.assertRaisesRegex(RuntimeError,
with self.assertRaisesRegex(RuntimeError,
"Expected all tensors to be on the same device"):
torch_op(b, a)
torch_op(b, a)

# test cuda tensor and cpu scalar
ops = ((torch.maximum, np.maximum), (torch.minimum, np.minimum))
Expand Down Expand Up @@ -2560,6 +2560,17 @@ def inplace_variant_helper(x, y):
self.compare_with_numpy(torch_fn, reference_fn, t, exact_dtype=False)
out_variant_helper(torch.xlogy, 0, t)

def test_xlogy_scalar_type_promotion(self, device):
# Test that python numbers don't participate in type promotion at the same
# priority level as 0-dim tensors
t = torch.randn((), dtype=torch.float32, device=device)

self.assertEqual(t.dtype, torch.xlogy(t, 5).dtype)
self.assertEqual(t.dtype, torch.xlogy(t, 5.).dtype)

self.assertEqual(t.dtype, torch.xlogy(5, t).dtype)
self.assertEqual(t.dtype, torch.xlogy(5., t).dtype)

@skipIf(not TEST_SCIPY, "Scipy required for the test.")
def test_xlogy_bfloat16(self, device):
def _compare_helper(x, y):
Expand Down

0 comments on commit 74dcb6d

Please sign in to comment.