Skip to content

Commit

Permalink
[core][state] make get_task() accept ObjectRef (#43507)
Browse files Browse the repository at this point in the history
Signed-off-by: max-509 <123456vershinin@gmail.com>
Signed-off-by: hongchaodeng <hongchaodeng1@gmail.com>
Co-authored-by: max-509 <123456vershinin@gmail.com>
  • Loading branch information
hongchaodeng and max-509 committed Mar 7, 2024
1 parent 192b728 commit e9a661d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
17 changes: 14 additions & 3 deletions python/ray/tests/test_state_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2340,9 +2340,14 @@ def g(dep):
def impossible():
pass

out = [f.options(name=f"f_{i}").remote() for i in range(2)] # noqa
g_out = g.remote(f.remote()) # noqa
im = impossible.remote() # noqa
f_refs = [f.options(name=f"f_{i}").remote() for i in range(2)] # noqa
g_ref = g.remote(f.remote()) # noqa
im_ref = impossible.remote() # noqa

def verify_task_from_objectref(task, job_id, tasks):
assert task["job_id"] == job_id
assert task["actor_id"] is None
assert any(task["task_id"] == t["task_id"] for t in tasks)

def verify():
tasks = list_tasks()
Expand All @@ -2352,6 +2357,12 @@ def verify():
for task in tasks:
assert task["actor_id"] is None

# Test get_task by objectRef
for ref in f_refs:
verify_task_from_objectref(get_task(ref), job_id, tasks)
verify_task_from_objectref(get_task(g_ref), job_id, tasks)
verify_task_from_objectref(get_task(im_ref), job_id, tasks)

waiting_for_execution = len(
list(
filter(
Expand Down
12 changes: 9 additions & 3 deletions python/ray/util/state/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import requests

import ray
from ray.dashboard.modules.dashboard_sdk import SubmissionClient
from ray.dashboard.utils import (
get_address_for_submission_client,
Expand Down Expand Up @@ -710,15 +711,15 @@ def get_worker(

@DeveloperAPI
def get_task(
id: str,
id: Union[str, "ray.ObjectRef"],
address: Optional[str] = None,
timeout: int = DEFAULT_RPC_TIMEOUT,
_explain: bool = False,
) -> Optional[TaskState]:
"""Get task attempts of a task by id.
Args:
id: Id of the task
id: String id of the task or ObjectRef that corresponds to task
address: Ray bootstrap address, could be `auto`, `localhost:6379`.
If None, it will be resolved automatically from an initialized ray.
timeout: Max timeout value for the state APIs requests made.
Expand All @@ -734,9 +735,14 @@ def get_task(
Exceptions: :class:`RayStateApiException <ray.util.state.exception.RayStateApiException>` if the CLI
failed to query the data.
""" # noqa: E501
str_id: str
if isinstance(id, str):
str_id = id
else:
str_id = id.task_id().hex()
return StateApiClient(address=address).get(
StateResource.TASKS,
id,
str_id,
GetApiOptions(timeout=timeout),
_explain=_explain,
)
Expand Down

0 comments on commit e9a661d

Please sign in to comment.