Skip to content

Commit

Permalink
Revert "Push rocm to slow path (#46216)" (#46728)
Browse files Browse the repository at this point in the history
Summary:
This reverts commit bc1ce58.

Fixes #{issue number}

Pull Request resolved: #46728

Reviewed By: cpuhrsch

Differential Revision: D24482783

Pulled By: izdeby

fbshipit-source-id: 619b710a8e790b9878e7317f672b4947e7b88145
  • Loading branch information
Iurii Zdebskyi authored and facebook-github-bot committed Oct 22, 2020
1 parent 9ccf85b commit c57c560
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 25 deletions.
21 changes: 0 additions & 21 deletions aten/src/ATen/native/ForeachUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ void check_foreach_api_restrictions(TensorList tensors, ArrayRef<double> scalars
// - Resulting tensor must have the same dtype as the input one
bool can_use_fast_route(TensorList tensors, Scalar scalar) {
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");

#ifdef __HIP_PLATFORM_HCC__
return false;
#else
auto expected_device = tensors[0].device();

for (auto t : tensors) {
Expand Down Expand Up @@ -98,13 +94,9 @@ bool can_use_fast_route(TensorList tensors, Scalar scalar) {
}

return true;
#endif
}

bool can_use_fast_route(TensorList tensors) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
auto expected_device = tensors[0].device();

Expand All @@ -123,13 +115,9 @@ bool can_use_fast_route(TensorList tensors) {
}

return true;
#endif
}

bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
auto expected_device = tensors1[0].device();

for (int64_t i = 0; i < tensors1.size(); i++) {
Expand Down Expand Up @@ -161,13 +149,9 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2) {
}

return true;
#endif
}

bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3) {
#ifdef __HIP_PLATFORM_HCC__
return false;
#else
auto expected_device = tensors1[0].device();

for (int64_t i = 0; i < tensors1.size(); i++) {
Expand Down Expand Up @@ -208,7 +192,6 @@ bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList ten
}

return true;
#endif
}

bool can_use_fast_route(TensorList tensors1, TensorList tensors2, TensorList tensors3, ArrayRef<double> scalars) {
Expand All @@ -221,11 +204,7 @@ bool can_use_fast_route(TensorList tensors, ArrayRef<double> scalars) {
TORCH_CHECK(scalars.size() > 0, "Scalars list must have at least one value.");
TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");

#ifdef __HIP_PLATFORM_HCC__
return false;
#else
return can_use_fast_route(tensors);
#endif
}

}
Expand Down
5 changes: 1 addition & 4 deletions test/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def test_exp(self, device, dtype):
#
# Pointwise ops
#
@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes(include_bfloat16=False, include_bool=False, include_complex=False))
def test_addcmul(self, device, dtype):
if self.device_type == 'cpu':
Expand Down Expand Up @@ -261,7 +262,6 @@ def test_max_min_inf_nan(self, device, dtype):
#
# Ops with scalar
#
@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes())
def test_int_scalar(self, device, dtype):
for N in N_values:
Expand Down Expand Up @@ -303,7 +303,6 @@ def test_int_scalar(self, device, dtype):
# We need to update codegen to correctly handle function overloads with float[] and int[].
# As optimizers work with float tensors, the result will always be torch.float32 for now.
# Current schema is using 'float[]' as scalar list type.
@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes())
def test_int_scalarlist(self, device, dtype):
for N in N_values:
Expand Down Expand Up @@ -452,7 +451,6 @@ def test_float_scalarlist(self, device, dtype):
else:
self.assertEqual(tensors, expected)

@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes())
def test_complex_scalar(self, device, dtype):
for N in N_values:
Expand Down Expand Up @@ -516,7 +514,6 @@ def test_complex_scalarlist(self, device, dtype):
with self.assertRaisesRegex(TypeError, "argument 'scalars' must be tuple of floats"):
foreach_bin_op_(tensors, scalars)

@skipCUDAIfRocm
@dtypes(*torch.testing.get_all_dtypes())
def test_bool_scalar(self, device, dtype):
for N in N_values:
Expand Down

0 comments on commit c57c560

Please sign in to comment.