Skip to content

Commit

Permalink
[tracing] Fix issue where actor/task is defined before ray.init is …
Browse files Browse the repository at this point in the history
…called (ray-project#38323)

Fixes an issue where the `_ray_trace_ctx` kwarg isn't injected to the function signature if `ray.init` is called w/ a tracing hook _after_ defining the function (see issue for repro).

The issue was we were checking `_is_tracing_enabled` at function definition time and selectively injecting the kwarg, but this variable isn't set until `ray.init` is called. I modified it to always inject the kwarg (matching the existing behavior for actor methods).

I've updated the tests to not explicitly call `ray.init` before defining the task.

Signed-off-by: Victor <vctr.y.m@example.com>
  • Loading branch information
edoakes authored and Victor committed Oct 11, 2023
1 parent 58b73ef commit 35c2953
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 45 deletions.
5 changes: 3 additions & 2 deletions python/ray/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ray import ActorClassID, Language, cross_language
from ray._private import ray_option_utils
from ray._private.async_compat import is_async_func
from ray._private.auto_init_hook import auto_init_ray
from ray._private.auto_init_hook import wrap_auto_init
from ray._private.client_mode_hook import (
client_mode_convert_actor,
client_mode_hook,
Expand Down Expand Up @@ -163,6 +163,7 @@ def remote(self, *args, **kwargs):

return FuncWrapper()

@wrap_auto_init
@_tracing_actor_method_invocation
def _remote(
self, args=None, kwargs=None, name="", num_returns=None, concurrency_group=None
Expand Down Expand Up @@ -661,6 +662,7 @@ class or functions.

return ActorOptionWrapper()

@wrap_auto_init
@_tracing_actor_creation
def _remote(self, args=None, kwargs=None, **actor_options):
"""Create an actor.
Expand Down Expand Up @@ -764,7 +766,6 @@ def _remote(self, args=None, kwargs=None, **actor_options):
if actor_options.get("max_concurrency") is None:
actor_options["max_concurrency"] = 1000 if is_asyncio else 1

auto_init_ray()
if client_mode_should_convert():
return client_mode_convert_actor(self, args, kwargs, **actor_options)

Expand Down
4 changes: 2 additions & 2 deletions python/ray/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ray._private.signature
from ray import Language, cross_language
from ray._private import ray_option_utils
from ray._private.auto_init_hook import auto_init_ray
from ray._private.auto_init_hook import wrap_auto_init
from ray._private.client_mode_hook import (
client_mode_convert_function,
client_mode_should_convert,
Expand Down Expand Up @@ -241,13 +241,13 @@ class or functions.

return FuncWrapper()

@wrap_auto_init
@_tracing_task_invocation
def _remote(self, args=None, kwargs=None, **task_options):
"""Submit the remote function for execution."""
# We pop the "max_calls" coming from "@ray.remote" here. We no longer need
# it in "_remote()".
task_options.pop("max_calls", None)
auto_init_ray()
if client_mode_should_convert():
return client_mode_convert_function(self, args, kwargs, **task_options)

Expand Down
26 changes: 11 additions & 15 deletions python/ray/tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def ray_start_cli_tracing(scope="function"):
check_call_ray(
["start", "--head", "--tracing-startup-hook", setup_tracing_path],
)
ray.init(address="auto")
yield
ray.shutdown()
check_call_ray(["stop", "--force"])
Expand Down Expand Up @@ -101,7 +100,7 @@ def f(value):
ray.get(obj_ref)

span_list = get_span_list()
assert len(span_list) == 2
assert len(span_list) == 2, span_list

# The spans could show up in a different order, so just check that
# all spans are as expected
Expand All @@ -112,7 +111,7 @@ def f(value):
}


def sync_actor_helper(connect_to_cluster: bool = False):
def sync_actor_helper():
"""Run a Ray sync actor and check the spans produced."""

@ray.remote
Expand All @@ -124,9 +123,6 @@ def increment(self):
self.value += 1
return self.value

if connect_to_cluster:
ray.init(address="auto")

# Create an actor from this class.
counter = Counter.remote()
obj_ref = counter.increment.remote()
Expand All @@ -138,12 +134,12 @@ def increment(self):
# The spans could show up in a different order, so just check that
# all spans are as expected
span_names = get_span_dict(span_list)
return span_names == {
assert span_names == {
"sync_actor_helper.<locals>.Counter.__init__ ray.remote": 1,
"sync_actor_helper.<locals>.Counter.increment ray.remote": 1,
"Counter.__init__ ray.remote_worker": 1,
"Counter.increment ray.remote_worker": 1,
}
}, span_names


def async_actor_helper():
Expand All @@ -165,12 +161,12 @@ async def run_concurrent(self):
# The spans could show up in a different order, so just check that
# all spans are as expected
span_names = get_span_dict(span_list)
return span_names == {
assert span_names == {
"async_actor_helper.<locals>.AsyncActor.__init__ ray.remote": 1,
"async_actor_helper.<locals>.AsyncActor.run_concurrent ray.remote": 4,
"AsyncActor.__init__ ray.remote_worker": 1,
"AsyncActor.run_concurrent ray.remote_worker": 4,
}
}, span_names


def test_tracing_task_init_workflow(cleanup_dirs, ray_start_init_tracing):
Expand All @@ -182,23 +178,23 @@ def test_tracing_task_start_workflow(cleanup_dirs, ray_start_cli_tracing):


def test_tracing_sync_actor_init_workflow(cleanup_dirs, ray_start_init_tracing):
assert sync_actor_helper()
sync_actor_helper()


def test_tracing_sync_actor_start_workflow(cleanup_dirs, ray_start_cli_tracing):
assert sync_actor_helper()
sync_actor_helper()


def test_tracing_async_actor_init_workflow(cleanup_dirs, ray_start_init_tracing):
assert async_actor_helper()
async_actor_helper()


def test_tracing_async_actor_start_workflow(cleanup_dirs, ray_start_cli_tracing):
assert async_actor_helper()
async_actor_helper()


def test_tracing_predefined_actor(cleanup_dirs, ray_start_cli_predefined_actor_tracing):
assert sync_actor_helper(connect_to_cluster=True)
sync_actor_helper()


def test_wrapping(ray_start_init_tracing):
Expand Down
53 changes: 27 additions & 26 deletions python/ray/util/tracing/tracing_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
cast,
)

import ray
import ray._private.worker
from ray._private.inspect_util import (
is_class_method,
Expand Down Expand Up @@ -84,7 +85,6 @@ def _try_import(self, module):
)


_nameable = Union[str, Callable[..., Any]]
_global_is_tracing_enabled = False
_opentelemetry = None

Expand Down Expand Up @@ -192,14 +192,14 @@ def _use_context(
_opentelemetry.context.detach(token)


def _function_hydrate_span_args(func: Callable[..., Any]):
def _function_hydrate_span_args(function_name: str):
"""Get the Attributes of the function that will be reported as attributes
in the trace."""
runtime_context = get_runtime_context()

span_args = {
"ray.remote": "function",
"ray.function": func,
"ray.function": function_name,
"ray.pid": str(os.getpid()),
"ray.job_id": runtime_context.get_job_id(),
"ray.node_id": runtime_context.get_node_id(),
Expand All @@ -220,21 +220,18 @@ def _function_hydrate_span_args(func: Callable[..., Any]):

def _function_span_producer_name(func: Callable[..., Any]) -> str:
"""Returns the function span name that has span kind of producer."""
args = _function_hydrate_span_args(func)
name = args["ray.function"]

return f"{name} ray.remote"
return f"{func} ray.remote"


def _function_span_consumer_name(func: Callable[..., Any]) -> str:
"""Returns the function span name that has span kind of consumer."""
args = _function_hydrate_span_args(func)
name = args["ray.function"]

return f"{name} ray.remote_worker"
return f"{func} ray.remote_worker"


def _actor_hydrate_span_args(class_: _nameable, method: _nameable):
def _actor_hydrate_span_args(
class_: Union[str, Callable[..., Any]],
method: Union[str, Callable[..., Any]],
):
"""Get the Attributes of the actor that will be reported as attributes
in the trace."""
if callable(class_):
Expand All @@ -243,7 +240,6 @@ def _actor_hydrate_span_args(class_: _nameable, method: _nameable):
method = method.__name__

runtime_context = get_runtime_context()

span_args = {
"ray.remote": "actor",
"ray.actor_class": class_,
Expand All @@ -268,22 +264,30 @@ def _actor_hydrate_span_args(class_: _nameable, method: _nameable):
return span_args


def _actor_span_producer_name(class_: _nameable, method: _nameable) -> str:
def _actor_span_producer_name(
class_: Union[str, Callable[..., Any]],
method: Union[str, Callable[..., Any]],
) -> str:
"""Returns the actor span name that has span kind of producer."""
args = _actor_hydrate_span_args(class_, method)
assert args is not None
name = args["ray.function"]
if not isinstance(class_, str):
class_ = class_.__name__
if not isinstance(method, str):
method = method.__name__

return f"{name} ray.remote"
return f"{class_}.{method} ray.remote"


def _actor_span_consumer_name(class_: _nameable, method: _nameable) -> str:
def _actor_span_consumer_name(
class_: Union[str, Callable[..., Any]],
method: Union[str, Callable[..., Any]],
) -> str:
"""Returns the actor span name that has span kind of consumer."""
args = _actor_hydrate_span_args(class_, method)
assert args is not None
name = args["ray.function"]
if not isinstance(class_, str):
class_ = class_.__name__
if not isinstance(method, str):
method = method.__name__

return f"{name} ray.remote_worker"
return f"{class_}.{method} ray.remote_worker"


def _tracing_task_invocation(method):
Expand Down Expand Up @@ -325,9 +329,6 @@ def _inject_tracing_into_function(function):
Use the provided trace context from kwargs.
"""
# Add _ray_trace_ctx to function signature
if not _is_tracing_enabled():
return function

setattr(
function,
"__signature__",
Expand Down

0 comments on commit 35c2953

Please sign in to comment.