Skip to content

Commit

Permalink
[core] Fix performance regression in single_client_tasks_and_get_batch (
Browse files Browse the repository at this point in the history
#39362)

The single_client_tasks_and_get_batch benchmark saw a ~0.5-1k tasks/s average regression (2k tasks/s on a local machine) due to #38323, which changed some tracing logic to unconditionally change the signature of every remote function to accomodate tracing during _inject_tracing_into_function.

Make the signature change conditional again, but move it to the execution portion of RemoteFunction rather than the definition. Also make sure the injection only happens once even when the remote function is executed multiple times.
  • Loading branch information
vitsai committed Sep 8, 2023
1 parent 0f5b6f5 commit b6edccf
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
28 changes: 24 additions & 4 deletions python/ray/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import uuid
from functools import wraps
from threading import Lock

import ray._private.signature
from ray import Language, cross_language
Expand Down Expand Up @@ -116,14 +117,14 @@ def __init__(
self._default_options["runtime_env"] = self._runtime_env

self._language = language
self._function = _inject_tracing_into_function(function)
self._function = function
self._function_signature = None
# Guards trace injection to enforce exactly once semantics
self._inject_lock = Lock()
self._function_name = function.__module__ + "." + function.__name__
self._function_descriptor = function_descriptor
self._is_cross_language = language != Language.PYTHON
self._decorator = getattr(function, "__ray_invocation_decorator__", None)
self._function_signature = ray._private.signature.extract_signature(
self._function
)
self._last_export_session_and_job = None
self._uuid = uuid.uuid4()

Expand All @@ -141,6 +142,16 @@ def __call__(self, *args, **kwargs):
f"try '{self._function_name}.remote()'."
)

# Lock is not picklable
def __getstate__(self):
attrs = self.__dict__.copy()
del attrs["_inject_lock"]
return attrs

def __setstate__(self, state):
self.__dict__.update(state)
self.__dict__["_inject_lock"] = Lock()

def options(self, **task_options):
"""Configures and overrides the task invocation parameters.
Expand Down Expand Up @@ -254,6 +265,15 @@ def _remote(self, args=None, kwargs=None, **task_options):
worker = ray._private.worker.global_worker
worker.check_connected()

# We cannot do this when the function is first defined, because we need
# ray.init() to have been called when this executes
with self._inject_lock:
if self._function_signature is None:
self._function = _inject_tracing_into_function(self._function)
self._function_signature = ray._private.signature.extract_signature(
self._function
)

# If this function was not exported in this session and job, we need to
# export this function again, because the current GCS doesn't have it.
if (
Expand Down
1 change: 0 additions & 1 deletion python/ray/tune/tests/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def mocked_run(*args, **kwargs):
default_kwargs.pop("run_or_experiment")
default_kwargs.pop("_remote")
default_kwargs.pop("progress_reporter")
default_kwargs.pop("_ray_trace_ctx") # automatically added for remote

self.assertDictEqual(kwargs, default_kwargs)

Expand Down
9 changes: 3 additions & 6 deletions python/ray/util/tracing/tracing_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,9 @@ def _inject_tracing_into_function(function):
future execution of that function will include tracing.
Use the provided trace context from kwargs.
"""
# Add _ray_trace_ctx to function signature
if not _is_tracing_enabled():
return function

setattr(
function,
"__signature__",
Expand All @@ -340,11 +342,6 @@ def _inject_tracing_into_function(function):
),
)

# Skip wrapping if tracing is disabled (still add _ray_trace_ctx however to make
# sure _ray_trace_ctx could be passed)
if not _is_tracing_enabled():
return function

@wraps(function)
def _function_with_tracing(
*args: Any,
Expand Down

0 comments on commit b6edccf

Please sign in to comment.