diff --git a/python/pyspark/resource/tests/test_connect_resources.py b/python/pyspark/resource/tests/test_connect_resources.py index 40c68029a1535..1529a33cb0ad0 100644 --- a/python/pyspark/resource/tests/test_connect_resources.py +++ b/python/pyspark/resource/tests/test_connect_resources.py @@ -18,10 +18,13 @@ from pyspark.resource import ResourceProfileBuilder, TaskResourceRequests, ExecutorResourceRequests from pyspark.sql import SparkSession -from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message +from pyspark.testing.connectutils import ( + should_test_connect, + connect_requirement_message, +) -@unittest.skipIf(not have_pandas, pandas_requirement_message) +@unittest.skipIf(not should_test_connect, connect_requirement_message) class ResourceProfileTests(unittest.TestCase): def test_profile_before_sc_for_connect(self): rpb = ResourceProfileBuilder() diff --git a/python/pyspark/resource/tests/test_resources.py b/python/pyspark/resource/tests/test_resources.py index 6f61d5af2d926..e29a77ed36dda 100644 --- a/python/pyspark/resource/tests/test_resources.py +++ b/python/pyspark/resource/tests/test_resources.py @@ -15,10 +15,16 @@ # limitations under the License. # import unittest +from typing import cast from pyspark.resource import ExecutorResourceRequests, ResourceProfileBuilder, TaskResourceRequests from pyspark.sql import SparkSession -from pyspark.testing.sqlutils import have_pandas, pandas_requirement_message +from pyspark.testing.sqlutils import ( + have_pandas, + have_pyarrow, + pandas_requirement_message, + pyarrow_requirement_message, +) class ResourceProfileTests(unittest.TestCase): @@ -72,7 +78,10 @@ def assert_request_contents(exec_reqs, task_reqs): assert_request_contents(rp3.executorResources, rp3.taskResources) sc.stop() - @unittest.skipIf(not have_pandas, pandas_requirement_message) + @unittest.skipIf( + not have_pandas or not have_pyarrow, + cast(str, pandas_requirement_message or pyarrow_requirement_message), + ) def test_profile_before_sc_for_sql(self): rpb = ResourceProfileBuilder() treqs = TaskResourceRequests().cpus(2)