Skip to content

Commit ce3ac98

Browse files
Added RayJobSubmission Client wrapper for RCS not created by CodeFlare
1 parent b5a14d9 commit ce3ac98

File tree

2 files changed

+197
-0
lines changed

2 files changed

+197
-0
lines changed

src/codeflare_sdk/job/ray_jobs.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2022 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
The ray_jobs sub-module contains methods needed to submit jobs and connect to Ray Clusters that were not created by CodeFlare.
17+
The SDK acts as a wrapper for the Ray Job Submission Client.
18+
"""
19+
from ray.job_submission import JobSubmissionClient
20+
from typing import Iterator, Optional, Dict, Any, Union
21+
22+
23+
class RayJobClient:
24+
"""
25+
An object for that acts as the Ray Job Submission Client.
26+
"""
27+
28+
def __init__(
29+
self,
30+
address: Optional[str] = None,
31+
create_cluster_if_needed: bool = False,
32+
cookies: Optional[Dict[str, Any]] = None,
33+
metadata: Optional[Dict[str, Any]] = None,
34+
headers: Optional[Dict[str, Any]] = None,
35+
verify: Optional[Union[str, bool]] = True,
36+
):
37+
self.rayJobClient = JobSubmissionClient(
38+
address=address,
39+
create_cluster_if_needed=create_cluster_if_needed,
40+
cookies=cookies,
41+
metadata=metadata,
42+
headers=headers,
43+
verify=verify,
44+
)
45+
46+
def submit_job(
47+
self,
48+
entrypoint: str,
49+
job_id: Optional[str] = None,
50+
runtime_env: Optional[Dict[str, Any]] = None,
51+
metadata: Optional[Dict[str, str]] = None,
52+
submission_id: Optional[str] = None,
53+
entrypoint_num_cpus: Optional[Union[int, float]] = None,
54+
entrypoint_num_gpus: Optional[Union[int, float]] = None,
55+
entrypoint_resources: Optional[Dict[str, float]] = None,
56+
) -> str:
57+
"""
58+
Method for submitting jobs to a Ray Cluster and returning the job id with entrypoint being a mandatory field.
59+
"""
60+
return self.rayJobClient.submit_job(
61+
entrypoint=entrypoint,
62+
job_id=job_id,
63+
runtime_env=runtime_env,
64+
metadata=metadata,
65+
submission_id=submission_id,
66+
entrypoint_num_cpus=entrypoint_num_cpus,
67+
entrypoint_num_gpus=entrypoint_num_gpus,
68+
entrypoint_resources=entrypoint_resources,
69+
)
70+
71+
def delete_job(self, job_id: str) -> bool:
72+
"""
73+
Method for deleting jobs with the job id being a mandatory field.
74+
"""
75+
return self.rayJobClient.delete_job(job_id=job_id)
76+
77+
def get_address(self) -> str:
78+
"""
79+
Method for getting the address from the RayJobClient
80+
"""
81+
return self.rayJobClient.get_address()
82+
83+
def get_job_info(self, job_id: str):
84+
"""
85+
Method for getting the job info with the job id being a mandatory field.
86+
"""
87+
return self.rayJobClient.get_job_info(job_id=job_id)
88+
89+
def get_job_logs(self, job_id: str) -> str:
90+
"""
91+
Method for getting the job info with the job id being a mandatory field.
92+
"""
93+
return self.rayJobClient.get_job_logs(job_id=job_id)
94+
95+
def get_job_status(self, job_id: str) -> str:
96+
"""
97+
Method for getting the job's status with the job id being a mandatory field.
98+
"""
99+
return self.rayJobClient.get_job_status(job_id=job_id)
100+
101+
def list_jobs(self):
102+
"""
103+
Method for getting a list of current jobs in the Ray Cluster.
104+
"""
105+
return self.rayJobClient.list_jobs()
106+
107+
def stop_job(self, job_id: str) -> bool:
108+
"""
109+
Method for stopping a job with the job id being a mandatory field.
110+
"""
111+
return self.rayJobClient.stop_job(job_id=job_id)
112+
113+
def tail_job_logs(self, job_id: str) -> Iterator[str]:
114+
"""
115+
Method for getting an iterator that follows the logs of a job with the job id being a mandatory field.
116+
"""
117+
return self.rayJobClient.tail_job_logs(job_id=job_id)

tests/unit_test.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
from unittest.mock import MagicMock
104104
from pytest_mock import MockerFixture
105105
from ray.job_submission import JobSubmissionClient
106+
from codeflare_sdk.job.ray_jobs import RayJobClient
106107

107108
# For mocking openshift client results
108109
fake_res = openshift.Result("fake")
@@ -2846,6 +2847,85 @@ def test_gen_app_wrapper_with_oauth(mocker: MockerFixture):
28462847
)
28472848

28482849

2850+
"""
2851+
Ray Jobs tests
2852+
"""
2853+
# rjc == RayJobClient
2854+
@pytest.fixture
2855+
def ray_job_client(mocker):
2856+
# Creating a fixture to instantiate RayJobClient with a mocked JobSubmissionClient
2857+
mocker.patch.object(JobSubmissionClient, "__init__", return_value=None)
2858+
return RayJobClient(
2859+
"https://ray-dashboard-unit-test-cluster-ns.apps.cluster.awsroute.org"
2860+
)
2861+
2862+
2863+
def test_rjc_submit_job(ray_job_client, mocker):
2864+
mocked_submit_job = mocker.patch.object(
2865+
JobSubmissionClient, "submit_job", return_value="mocked_submission_id"
2866+
)
2867+
submission_id = ray_job_client.submit_job(entrypoint={"pip": ["numpy"]})
2868+
2869+
mocked_submit_job.assert_called_once_with(
2870+
entrypoint={"pip": ["numpy"]},
2871+
job_id=None,
2872+
runtime_env=None,
2873+
metadata=None,
2874+
submission_id=None,
2875+
entrypoint_num_cpus=None,
2876+
entrypoint_num_gpus=None,
2877+
entrypoint_resources=None,
2878+
)
2879+
2880+
assert submission_id == "mocked_submission_id"
2881+
2882+
2883+
def test_rjc_delete_job(ray_job_client, mocker):
2884+
mocked_delete_job = mocker.patch.object(
2885+
JobSubmissionClient, "delete_job", return_value=True
2886+
)
2887+
result = ray_job_client.delete_job(job_id="mocked_job_id")
2888+
2889+
mocked_delete_job.assert_called_once_with(job_id="mocked_job_id")
2890+
assert result is True
2891+
2892+
2893+
def test_rjc_address(ray_job_client, mocker):
2894+
mocked_rjc_address = mocker.patch.object(
2895+
JobSubmissionClient,
2896+
"get_address",
2897+
return_value="https://ray-dashboard-unit-test-cluster-ns.apps.cluster.awsroute.org",
2898+
)
2899+
address = ray_job_client.get_address()
2900+
2901+
mocked_rjc_address.assert_called_once()
2902+
assert (
2903+
address
2904+
== "https://ray-dashboard-unit-test-cluster-ns.apps.cluster.awsroute.org"
2905+
)
2906+
2907+
2908+
def test_rjc_get_job_logs(ray_job_client, mocker):
2909+
mocked_rjc_get_job_logs = mocker.patch.object(
2910+
JobSubmissionClient, "get_job_logs", return_value="Logs"
2911+
)
2912+
logs = ray_job_client.get_job_logs(job_id="mocked_job_id")
2913+
2914+
mocked_rjc_get_job_logs.assert_called_once_with(job_id="mocked_job_id")
2915+
assert logs == "Logs"
2916+
2917+
2918+
def test_rjc_get_job_info(ray_job_client, mocker):
2919+
job_details_example = "JobDetails(type=<JobType.SUBMISSION: 'SUBMISSION'>, job_id=None, submission_id='mocked_submission_id', driver_info=None, status=<JobStatus.PENDING: 'PENDING'>, entrypoint='python test.py', message='Job has not started yet. It may be waiting for the runtime environment to be set up.', error_type=None, start_time=1701271760641, end_time=None, metadata={}, runtime_env={'working_dir': 'gcs://_ray_pkg_67de6f0e60d43b19.zip', 'pip': {'packages': ['numpy'], 'pip_check': False}, '_ray_commit': 'b4bba4717f5ba04ee25580fe8f88eed63ef0c5dc'}, driver_agent_http_address=None, driver_node_id=None)"
2920+
mocked_rjc_get_job_info = mocker.patch.object(
2921+
JobSubmissionClient, "get_job_info", return_value=job_details_example
2922+
)
2923+
job_details = ray_job_client.get_job_info(job_id="mocked_job_id")
2924+
2925+
mocked_rjc_get_job_info.assert_called_once_with(job_id="mocked_job_id")
2926+
assert job_details == job_details_example
2927+
2928+
28492929
# Make sure to always keep this function last
28502930
def test_cleanup():
28512931
os.remove(f"{aw_dir}unit-test-cluster.yaml")

0 commit comments

Comments
 (0)