Skip to content

Commit

Permalink
addressing code-review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
deven-amd committed Dec 13, 2019
1 parent e762347 commit 5d1ccc1
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def testLargeBatchSparseMatrixAddGrad(self):
return

if test.is_built_with_rocm():
# sparse-matrix-add op is not yet supported on the ROCm platform
self.skipTest("sparse-matrix-add op not supported on ROCm")

sparsify = lambda m: m * (m > 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@ def testSparseMatrixAdd(self):
return

if test.is_built_with_rocm():
# sparse-matrix-add op is not yet supported on the ROCm platform
self.skipTest("sparse-matrix-add op not supported on ROCm")

a_indices = np.array([[0, 0], [2, 3]])
Expand Down Expand Up @@ -474,7 +473,6 @@ def testLargeBatchSparseMatrixAdd(self):
return

if test.is_built_with_rocm():
# sparse-matrix-add op is not yet supported on the ROCm platform
self.skipTest("sparse-matrix-add op not supported on ROCm")

sparsify = lambda m: m * (m > 0)
Expand Down Expand Up @@ -520,7 +518,6 @@ def testSparseMatrixMatMul(self):
@test_util.run_in_graph_and_eager_modes
def testSparseMatrixMatMulConjugateOutput(self):
if test.is_built_with_rocm():
# complex types are not yet supported on the ROCm platform
self.skipTest("complex type not supported on ROCm")

for shapes in [[(5, 6), (6, 1)], [(5, 6), (6, 2)]]:
Expand Down Expand Up @@ -552,6 +549,8 @@ def testLargeBatchSparseMatrixMatMul(self):

if test.is_built_with_rocm():
# TODO(rocm): fix this
# This test is currently failing on the ROCm platform
# Ren-enable it once the fix is available
self.skipTest("hipSPARSE all failure on the ROCm platform")

sparsify = lambda m: m * (m > 0)
Expand Down Expand Up @@ -612,6 +611,8 @@ def testLargeBatchSparseMatrixMatMulTransposed(self):

if test.is_built_with_rocm():
# TODO(rocm): fix this
# This test is currently failing on the ROCm platform
# Ren-enable it once the fix is available
self.skipTest("hipSPARSE all failure on the ROCm platform")

sparsify = lambda m: m * (m > 0)
Expand Down
4 changes: 2 additions & 2 deletions third_party/gpus/rocm_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _rocm_include_path(repository_ctx, rocm_config):

return inc_dirs

def enable_rocm(repository_ctx):
def _enable_rocm(repository_ctx):
if "TF_NEED_ROCM" in repository_ctx.os.environ:
enable_rocm = repository_ctx.os.environ["TF_NEED_ROCM"].strip()
if enable_rocm == "1":
Expand Down Expand Up @@ -895,7 +895,7 @@ def _create_remote_rocm_repository(repository_ctx, remote_config_repo):

def _rocm_autoconf_impl(repository_ctx):
"""Implementation of the rocm_autoconf repository rule."""
if not enable_rocm(repository_ctx):
if not _enable_rocm(repository_ctx):
_create_dummy_repository(repository_ctx)
elif _TF_ROCM_CONFIG_REPO in repository_ctx.os.environ:
_create_remote_rocm_repository(
Expand Down

0 comments on commit 5d1ccc1

Please sign in to comment.