Skip to content

Commit

Permalink
Push rocm to slow path (#46216)
Browse files Browse the repository at this point in the history
Summary:
Push rocm to slow path

Pull Request resolved: #46216

Reviewed By: bwasti

Differential Revision: D24263731

Pulled By: izdeby

fbshipit-source-id: 98ede2478b8f075ceed44a9e4f2aa292f523b8e2
  • Loading branch information
Iurii Zdebskyi authored and facebook-github-bot committed Oct 22, 2020
1 parent 3526b60 commit bc1ce58
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
21 changes: 21 additions & 0 deletions aten/src/ATen/native/ForeachUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ void check_foreach_api_restrictions(TensorList tensors, ArrayRef<double> scalars
// - Resulting tensor must have the same dtype as the input one
bool can_use_fast_route(TensorList tensors, Scalar scalar) {
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");

#ifdef __HIP_PLATFORM_HCC__
return false;
#else
auto expected_device = tensors[0].device();

for (auto t : tensors) {
Expand Down Expand Up @@ -94,9 +98,13 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) {
}

return true;
#endif
}

bool can_use_fast_route(TensorList tensors) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
auto expected_device = tensors[0].device();

Expand All @@ -115,9 +123,13 @@ bool can_use_fast_route(TensorList tensors) {
}

return true;
#endif
}

bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
auto expected_device = tensors1[0].device();

for (int64_t i = 0; i < tensors1.size(); i++) {
Expand Down Expand Up @@ -149,9 +161,13 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
}

return true;
#endif
}

bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
auto expected_device = tensors1[0].device();

for (int64_t i = 0; i < tensors1.size(); i++) {
Expand Down Expand Up @@ -192,6 +208,7 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList ten
}

return true;
#endif
}

bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double> scalars) {
Expand All @@ -204,7 +221,11 @@ bool can_use_fast_route(TensorList tensors, ArrayRef<double> scalars) {
TORCH_CHECK(scalars.size() > 0, "Scalars list must have at least one value.");
TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");

#ifdef __HIP_PLATFORM_HCC__
return false;
#else
return can_use_fast_route(tensors);
#endif
}

}
Expand Down
5 changes: 4 additions & 1 deletion test/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ def test_exp(self, device, dtype):
#
# Pointwise ops
#
@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False))
def test_addcmul(self, device, dtype):
if self.device_type == 'cpu':
Expand Down Expand Up @@ -262,6 +261,7 @@ def test_max_min_inf_nan(self, device, dtype):
#
# Ops with scalar
#
@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes())
def test_int_scalar(self, device, dtype):
for N in N_values:
Expand Down Expand Up @@ -303,6 +303,7 @@ def test_int_scalar(self, device, dtype):
# We need to update codegen to correctly handle function overloads with float[] and int[].
# As optimizers work with float tensors, the result will always be torch.float32 for now.
# Current schema is using 'float[]' as scalar list type.
@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes())
def test_int_scalarlist(self, device, dtype):
for N in N_values:
Expand Down Expand Up @@ -451,6 +452,7 @@ def test_float_scalarlist(self, device, dtype):
else:
self.assertEqual(tensors, expected)

@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes())
def test_complex_scalar(self, device, dtype):
for N in N_values:
Expand Down Expand Up @@ -514,6 +516,7 @@ def test_complex_scalarlist(self, device, dtype):
with self.assertRaisesRegex(TypeError, "argument 'scalars' must be tuple of floats"):
foreach_bin_op_(tensors, scalars)

@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes())
def test_bool_scalar(self, device, dtype):
for N in N_values:
Expand Down

0 comments on commit bc1ce58

Please sign in to comment.