Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Jobs] Fix race condition on submitting multiple jobs with the same id #33259

Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions dashboard/modules/job/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,26 @@ def __init__(self, gcs_aio_client: GcsAioClient):
self._gcs_aio_client = gcs_aio_client
assert _internal_kv_initialized()

async def put_info(self, job_id: str, job_info: JobInfo):
await self._gcs_aio_client.internal_kv_put(
async def put_info(
self, job_id: str, job_info: JobInfo, overwrite: bool = True
) -> bool:
"""Put job info to the internal kv store.

Args:
job_id: The job id.
job_info: The job info.
overwrite: Whether to overwrite the existing job info.

Returns:
True if a new key is added.
"""
added_num = await self._gcs_aio_client.internal_kv_put(
self.JOB_DATA_KEY.format(job_id=job_id).encode(),
json.dumps(job_info.to_json()).encode(),
True,
overwrite,
namespace=ray_constants.KV_NAMESPACE_JOB,
)
return added_num == 1

async def get_info(self, job_id: str, timeout: int = 30) -> Optional[JobInfo]:
serialized_info = await self._gcs_aio_client.internal_kv_get(
Expand Down
11 changes: 8 additions & 3 deletions dashboard/modules/job/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,8 +802,6 @@ async def submit_job(
entrypoint_num_gpus = 0
if submission_id is None:
submission_id = generate_job_id()
elif await self._job_info_client.get_status(submission_id) is not None:
raise RuntimeError(f"Job {submission_id} already exists.")

logger.info(f"Starting job with submission_id: {submission_id}")
job_info = JobInfo(
Expand All @@ -816,7 +814,14 @@ async def submit_job(
entrypoint_num_gpus=entrypoint_num_gpus,
entrypoint_resources=entrypoint_resources,
)
await self._job_info_client.put_info(submission_id, job_info)
new_key_added = await self._job_info_client.put_info(
submission_id, job_info, overwrite=False
)
if not new_key_added:
raise ValueError(
f"Job with submission_id {submission_id} already exists. "
"Please use a different submission_id."
)

# Wait for the actor to start up asynchronously so this call always
# returns immediately and we can catch errors with the actor starting
Expand Down
33 changes: 33 additions & 0 deletions dashboard/modules/job/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,39 @@ async def test_pass_job_id(job_manager):
)


@pytest.mark.asyncio
async def test_simultaneous_submit_job(job_manager):
"""Test that we can submit multiple jobs at once."""
job_ids = await asyncio.gather(
job_manager.submit_job(entrypoint="echo hello"),
job_manager.submit_job(entrypoint="echo hello"),
job_manager.submit_job(entrypoint="echo hello"),
)

for job_id in job_ids:
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id=job_id
)


@pytest.mark.asyncio
async def test_simultaneous_with_same_id(job_manager):
"""Test that we can submit multiple jobs at once with the same id.

The second job should raise a friendly error.
"""
with pytest.raises(ValueError) as excinfo:
await asyncio.gather(
job_manager.submit_job(entrypoint="echo hello", submission_id="1"),
job_manager.submit_job(entrypoint="echo hello", submission_id="1"),
)
assert "Job with submission_id 1 already exists" in str(excinfo.value)
# Check that the (first) job can still succeed.
await async_wait_for_condition_async_predicate(
check_job_succeeded, job_manager=job_manager, job_id="1"
)


@pytest.mark.asyncio
class TestShellScriptExecution:
async def test_submit_basic_echo(self, job_manager):
Expand Down
14 changes: 14 additions & 0 deletions python/ray/_private/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,20 @@ async def internal_kv_put(
namespace: Optional[bytes],
timeout: Optional[float] = None,
) -> int:
"""Put a key-value pair into the GCS.

Args:
key: The key to put.
value: The value to put.
overwrite: Whether to overwrite the value if the key already exists.
namespace: The namespace to put the key-value pair into.
timeout: The timeout in seconds.

Returns:
The number of keys added. If overwrite is True, this will be 1 if the
key was added and 0 if the key was updated. If overwrite is False,
this will be 1 if the key was added and 0 if the key already exists.
"""
logger.debug(f"internal_kv_put {key!r} {value!r} {overwrite} {namespace!r}")
req = gcs_service_pb2.InternalKVPutRequest(
namespace=namespace,
Expand Down