Skip to content

Commit

Permalink
Add torch._foreach_maximum(TensorList, TensorList) & torch._foreach_m…
Browse files Browse the repository at this point in the history
…inimum(TensorList, TensorList) APIs (#45692)

Summary:
- Adding torch._foreach_maximum(TensorList, TensorList) API
- Adding torch._foreach_minimum(TensorList, TensorList) API
- Updated Adam/AdamW optimizers

Tested via unit tests

Pull Request resolved: #45692

Reviewed By: anjali411

Differential Revision: D24142464

Pulled By: izdeby

fbshipit-source-id: 6a4fc343a1613cb1e26c8398450ac9cea0a2eb51
  • Loading branch information
Iurii Zdebskyi authored and facebook-github-bot committed Oct 13, 2020
1 parent 5741de8 commit 1a57b39
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 6 deletions.
16 changes: 16 additions & 0 deletions aten/src/ATen/native/ForeachOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,20 @@ FOREACH_UNARY_OP(exp);
FOREACH_POINTWISE_OP(addcdiv);
FOREACH_POINTWISE_OP(addcmul);

#define FOREACH_MAXIMUM_MINIMUM_OP(NAME) \
std::vector<Tensor> foreach_tensor_##NAME##_slow(TensorList tensors1, TensorList tensors2) { \
check_foreach_api_restrictions(tensors1, tensors2); \
\
std::vector<Tensor> result; \
result.reserve(tensors1.size()); \
for (int i = 0; i < tensors1.size(); i++) { \
result.emplace_back(at::NAME(tensors1[i], tensors2[i])); \
} \
\
return result; \
} \

FOREACH_MAXIMUM_MINIMUM_OP(maximum)
FOREACH_MAXIMUM_MINIMUM_OP(minimum)

}} // namespace at::native
68 changes: 68 additions & 0 deletions aten/src/ATen/native/cuda/ForeachFunctors.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,74 @@ struct PointwiseOpFunctor {
}
};

template<typename T>
struct BinaryOpListFunctor {
using opmath_t = typename get_opmath_t<T>::opmath_t;
template<typename Op> __device__ __forceinline__ void operator() (
int chunk_size,
TensorListMetadata<3>& tl,
Op op) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

T* x = (T*)tl.addresses[0][tensor_loc];
x += chunk_idx * chunk_size;

T* y = (T*)tl.addresses[1][tensor_loc];
y += chunk_idx * chunk_size;

T* out = (T*)tl.addresses[2][tensor_loc];
out += chunk_idx * chunk_size;

n -= chunk_idx * chunk_size;

T r_x[kILP];
T r_y[kILP];

// to make things simple, we put aligned case in a different code path
if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(out)) {
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
// load
load_store(r_x, x, 0 , i_start);
load_store(r_y, y, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_x[ii] = static_cast<T>(op(static_cast<opmath_t>(r_x[ii]),
static_cast<opmath_t>(r_y[ii])));
}
// store
load_store(out, r_x, i_start , 0);
}
}
else {
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_x[ii] = 0;
r_y[ii] = 0;
int i = i_start + threadIdx.x + ii * blockDim.x;
if(i < n && i < chunk_size) {
r_x[ii] = x[i];
r_y[ii] = y[i];
}
}
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
r_x[ii] = static_cast<T>(op(static_cast<opmath_t>(r_x[ii]),
static_cast<opmath_t>(r_y[ii])));
}
#pragma unroll
for(int ii = 0; ii < kILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if(i < n && i < chunk_size)
out[i] = r_x[ii];
}
}
}
}
};

} // namespace

}} // namespace at::native
46 changes: 43 additions & 3 deletions aten/src/ATen/native/cuda/ForeachPointwiseOp.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <ATen/Dispatch.h>
#include <ATen/native/ForeachUtils.h>
#include <ATen/native/cuda/ForeachFunctors.cuh>
#include <ATen/NumericUtils.h>

namespace at { namespace native {

Expand Down Expand Up @@ -45,7 +46,7 @@ void foreach_pointwise_op_(TensorList input, TensorList tensors1, TensorList ten
});
}

#define FOREACH_UNARY_OP(NAME, OP) \
#define FOREACH_POINTWISE_OP(NAME, OP) \
std::vector<Tensor> foreach_tensor_##NAME##_cuda(TensorList input, TensorList tensors1, TensorList tensors2, Scalar scalar) { \
TORCH_CHECK(input.size() > 0, "Tensor list must have at least one tensor."); \
TORCH_CHECK(input.size() == tensors1.size(), "Tensor lists must be of the same length."); \
Expand Down Expand Up @@ -74,7 +75,46 @@ void foreach_tensor_##NAME##_cuda_(TensorList input, TensorList tensors1, Tensor
foreach_pointwise_op_<OP>(input, tensors1, tensors2, scalar); \
}

FOREACH_UNARY_OP(addcmul, std::multiplies);
FOREACH_UNARY_OP(addcdiv, std::divides);
FOREACH_POINTWISE_OP(addcmul, std::multiplies);
FOREACH_POINTWISE_OP(addcdiv, std::divides);

#define FOREACH_MAXIMUM_MINIMUM_OP(NAME, OP) \
std::vector<Tensor> foreach_tensor_##NAME##_cuda(TensorList tensors1, TensorList tensors2) { \
TORCH_CHECK(tensors1.size() > 0, "Tensor list must have at least one tensor."); \
TORCH_CHECK(tensors1.size() == tensors2.size(), "Tensor lists must be of the same length."); \
\
if (!can_use_fast_route(tensors1, tensors2)) { \
return at::native::foreach_tensor_##NAME##_slow(tensors1, tensors2); \
} \
\
std::vector<std::vector<at::Tensor>> tensor_lists; \
std::vector<at::Tensor> vec_res; \
vec_res.reserve(tensors1.size()); \
for (const auto& t: tensors1) { \
vec_res.emplace_back(at::native::empty_like(t)); \
} \
\
tensor_lists.emplace_back(tensors1.vec()); \
tensor_lists.emplace_back(tensors2.vec()); \
tensor_lists.emplace_back(std::move(vec_res)); \
\
AT_DISPATCH_ALL_TYPES_AND(kHalf, tensors1[0].scalar_type(), "foreach_maximum_minimum_op_cuda", [&]() { \
using opmath_t = get_opmath_t<scalar_t>::opmath_t; \
auto op = [] GPU_LAMBDA (opmath_t a, opmath_t b) -> opmath_t { \
opmath_t c = a OP b ? a : b; \
if (_isnan(a)) { \
c = a; \
} \
return c;}; \
multi_tensor_apply<3>(tensor_lists, \
BinaryOpListFunctor<scalar_t>(), \
op); \
}); \
\
return tensor_lists[2]; \
} \

FOREACH_MAXIMUM_MINIMUM_OP(maximum, >)
FOREACH_MAXIMUM_MINIMUM_OP(minimum, <)

}} // namespace at::native
16 changes: 16 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6623,6 +6623,22 @@
CPU: foreach_tensor_addcmul_slow
CUDA: foreach_tensor_addcmul_cuda

- func: _foreach_maximum.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[]
use_c10_dispatcher: full
device_guard: False
variants: function
dispatch:
CPU: foreach_tensor_maximum_slow
CUDA: foreach_tensor_maximum_cuda

- func: _foreach_minimum.List(Tensor[] tensors1, Tensor[] tensors2) -> Tensor[]
use_c10_dispatcher: full
device_guard: False
variants: function
dispatch:
CPU: foreach_tensor_minimum_slow
CUDA: foreach_tensor_minimum_cuda

- func: _mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor, Tensor)
use_c10_dispatcher: full
dispatch:
Expand Down
72 changes: 72 additions & 0 deletions test/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_SLOW
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, skipCUDAIfRocm
from torch._six import inf, nan

N_values = [20] if not TEST_WITH_SLOW else [30, 300]

Expand Down Expand Up @@ -171,6 +172,77 @@ def test_addcdiv(self, device, dtype):
return
self._test_pointwise_op(device, dtype, torch._foreach_addcdiv, torch._foreach_addcdiv_, torch.addcdiv)

@dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False))
def test_min_max(self, device, dtype):
for N in N_values:
tensors1 = self._get_test_data(device, dtype, N)
tensors2 = self._get_test_data(device, dtype, N)

# Mimics cuda kernel dtype flow. With fp16/bf16 input, runs in fp32 and casts output back to fp16/bf16.
control_dtype = torch.float32 if (self.device_type == 'cuda' and
(dtype is torch.float16 or dtype is torch.bfloat16)) else dtype

expected_max = [torch.max(tensors1[i].to(dtype=control_dtype),
tensors2[i].to(dtype=control_dtype)).to(dtype=dtype) for i in range(N)]

expected_min = [torch.min(tensors1[i].to(dtype=control_dtype),
tensors2[i].to(dtype=control_dtype)).to(dtype=dtype) for i in range(N)]

res_max = torch._foreach_maximum(tensors1, tensors2)
self.assertEqual(res_max, expected_max)

res_min = torch._foreach_minimum(tensors1, tensors2)
self.assertEqual(res_min, expected_min)


@dtypes(*(torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=False)))
def test_max_min_float_inf_nan(self, device, dtype):
a = [
torch.tensor([float('inf')], device=device, dtype=dtype),
torch.tensor([-float('inf')], device=device, dtype=dtype),
torch.tensor([float('nan')], device=device, dtype=dtype),
torch.tensor([float('nan')], device=device, dtype=dtype)
]

b = [
torch.tensor([-float('inf')], device=device, dtype=dtype),
torch.tensor([float('inf')], device=device, dtype=dtype),
torch.tensor([float('inf')], device=device, dtype=dtype),
torch.tensor([float('nan')], device=device, dtype=dtype)
]

expected = [torch.max(a1, b1) for a1, b1 in zip(a, b)]
res = torch._foreach_maximum(a, b)
self.assertEqual(expected, res)

expected = [torch.min(a1, b1) for a1, b1 in zip(a, b)]
res = torch._foreach_minimum(a, b)
self.assertEqual(expected, res)

@dtypes(*(torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=False)))
def test_max_min_inf_nan(self, device, dtype):
a = [
torch.tensor([inf], device=device, dtype=dtype),
torch.tensor([-inf], device=device, dtype=dtype),
torch.tensor([nan], device=device, dtype=dtype),
torch.tensor([nan], device=device, dtype=dtype)
]

b = [
torch.tensor([-inf], device=device, dtype=dtype),
torch.tensor([inf], device=device, dtype=dtype),
torch.tensor([inf], device=device, dtype=dtype),
torch.tensor([nan], device=device, dtype=dtype)
]

expected_max = [torch.max(a1, b1) for a1, b1 in zip(a, b)]
res_max = torch._foreach_maximum(a, b)
self.assertEqual(expected_max, res_max)

expected_min = [torch.min(a1, b1) for a1, b1 in zip(a, b)]
res_min = torch._foreach_minimum(a, b)
self.assertEqual(expected_min, res_min)

#
# Ops with scalar
#
Expand Down
2 changes: 1 addition & 1 deletion test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def test_multi_tensor_optimizers(self):
((optim.Adadelta, optim._multi_tensor.Adadelta), dict(weight_decay=1)),
]

kIterations = 1001
kIterations = 11
device = 'cuda'

for optimizers, params in optimizer_pairs_with_flags:
Expand Down
3 changes: 2 additions & 1 deletion torch/optim/_multi_tensor/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def step(self, closure=None):

if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
[torch.max(a, b, out=a) for a, b in zip(max_exp_avg_sq, exp_avg_sq)]
max_exp_avg_sq = torch._foreach_maximum(max_exp_avg_sq, exp_avg_sq)

# Use the max. for normalizing running avg. of gradient
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sq)
bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
Expand Down
3 changes: 2 additions & 1 deletion torch/optim/_multi_tensor/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def step(self, closure=None):

if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
[torch.max(a, b, out=a) for a, b in zip(max_exp_avg_sq, exp_avg_sq)]
max_exp_avg_sq = torch._foreach_maximum(max_exp_avg_sq, exp_avg_sq)

# Use the max. for normalizing running avg. of gradient
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sq)
bias_correction_sqrt = [math.sqrt(bc) for bc in bias_correction2]
Expand Down

0 comments on commit 1a57b39

Please sign in to comment.