diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 1e0ac6b99e81..26407708475d 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -2718,7 +2718,7 @@ def put( @PublicAPI @client_mode_hook def wait( - ray_waitables: Union["ObjectRef[R]", "ObjectRefGenerator[R]"], + ray_waitables: List[Union["ObjectRef[R]", "ObjectRefGenerator[R]"]], *, num_returns: int = 1, timeout: Optional[float] = None, diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index e8ed289569ff..ae452b49e745 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -4832,6 +4832,11 @@ cdef class CoreWorker: return (CCoreWorkerProcess.GetCoreWorker().GetWorkerContext() .CurrentActorMaxConcurrency()) + def get_current_root_detached_actor_id(self) -> ActorID: + # This is only used in test + return ActorID(CCoreWorkerProcess.GetCoreWorker().GetWorkerContext() + .GetRootDetachedActorID().Binary()) + def get_queued_future(self, task_id: Optional[TaskID]) -> ConcurrentFuture: """Get a asyncio.Future that's queued in the event loop.""" with self._task_id_to_future_lock: diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 7a034d2cda04..1006299c2104 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -85,6 +85,7 @@ cdef extern from "ray/core_worker/context.h" nogil: c_bool CurrentActorIsAsync() const c_string &GetCurrentSerializedRuntimeEnv() int CurrentActorMaxConcurrency() + const CActorID &GetRootDetachedActorID() cdef extern from "ray/core_worker/generator_waiter.h" nogil: cdef cppclass CGeneratorBackpressureWaiter "ray::core::GeneratorBackpressureWaiter": # noqa diff --git a/python/ray/tests/test_job.py b/python/ray/tests/test_job.py index e748036a09c7..021d36da2aa6 100644 --- a/python/ray/tests/test_job.py +++ b/python/ray/tests/test_job.py @@ -17,6 +17,7 @@ run_string_as_driver_nonblocking, wait_for_condition, format_web_url, + wait_for_pid_to_exit, ) from ray.job_config import JobConfig from ray.job_submission import JobSubmissionClient @@ -296,6 +297,99 @@ def verify(): # TODO(sang): Client entrypoint not supported yet. +def test_task_spec_root_detached_actor_id(shutdown_only): + """Test to make sure root detached actor id is set correctly + for task spec of submitted task or actor. + """ + + ray.init() + + @ray.remote + def get_task_root_detached_actor_id(): + core_worker = ray._private.worker.global_worker.core_worker + return core_worker.get_current_root_detached_actor_id().hex() + + @ray.remote + class Actor: + def get_root_detached_actor_id(self): + core_worker = ray._private.worker.global_worker.core_worker + return core_worker.get_current_root_detached_actor_id().hex() + + @ray.remote(lifetime="detached") + class DetachedActor: + def check(self): + core_worker = ray._private.worker.global_worker.core_worker + assert ( + ray.get_runtime_context().get_actor_id() + == core_worker.get_current_root_detached_actor_id().hex() + ) + assert ray.get_runtime_context().get_actor_id() == ray.get( + get_task_root_detached_actor_id.remote() + ) + actor = Actor.remote() + assert ray.get_runtime_context().get_actor_id() == ray.get( + actor.get_root_detached_actor_id.remote() + ) + + assert ( + ray.get(get_task_root_detached_actor_id.remote()) + == ray._raylet.ActorID.nil().hex() + ) + actor = Actor.remote() + assert ( + ray.get(actor.get_root_detached_actor_id.remote()) + == ray._raylet.ActorID.nil().hex() + ) + detached_actor = DetachedActor.remote() + ray.get(detached_actor.check.remote()) + + +def test_no_process_leak_after_job_finishes(ray_start_cluster): + """Test to make sure when a job finishes, + all the worker processes belonging to it exit. + """ + cluster = ray_start_cluster + cluster.add_node(num_cpus=8) + ray.init(address=cluster.address) + + @ray.remote(num_cpus=0) + class PidActor: + def __init__(self): + self.pids = set() + self.pids.add(os.getpid()) + + def add_pid(self, pid): + self.pids.add(pid) + + def get_pids(self): + return self.pids + + @ray.remote + def child(pid_actor): + # child worker process should be forcibly killed + # when the job finishes. + ray.get(pid_actor.add_pid.remote(os.getpid())) + time.sleep(1000000) + + @ray.remote + def parent(pid_actor): + ray.get(pid_actor.add_pid.remote(os.getpid())) + child.remote(pid_actor) + + pid_actor = PidActor.remote() + ray.get(parent.remote(pid_actor)) + + wait_for_condition(lambda: len(ray.get(pid_actor.get_pids.remote())) == 3) + + pids = ray.get(pid_actor.get_pids.remote()) + + ray.shutdown() + # Job finishes at this point + + for pid in pids: + wait_for_pid_to_exit(pid) + + if __name__ == "__main__": # Make subprocess happy in bazel. diff --git a/python/ray/workflow/tests/test_workflow_queuing.py b/python/ray/workflow/tests/test_workflow_queuing.py index 297dd0ff3aaf..05c2ec8eb59d 100644 --- a/python/ray/workflow/tests/test_workflow_queuing.py +++ b/python/ray/workflow/tests/test_workflow_queuing.py @@ -1,6 +1,8 @@ +import os import pytest import ray from ray import workflow +from ray._private.test_utils import wait_for_condition from ray.tests.conftest import * # noqa @@ -146,6 +148,8 @@ def test_workflow_queuing_resume_all(shutdown_only, tmp_path): @ray.remote def long_running(x): + file_path = str(tmp_path / f".long_running_{x}") + open(file_path, "w") with filelock.FileLock(lock_path): return x @@ -156,6 +160,16 @@ def long_running(x): workflow.run_async(wfs[i], workflow_id=f"workflow_{i}") for i in range(4) ] + # Make sure workflow_0 and workflow_1 are running user code + # Otherwise it might run workflow code that contains + # ray.get() when ray.shutdown() + # is called and that can cause ray.get() to throw exception + # since raylet is stopped + # before worker process (this is a bug we should fix) + # and transition the workflow to FAILED status. + wait_for_condition(lambda: os.path.isfile(str(tmp_path / ".long_running_0"))) + wait_for_condition(lambda: os.path.isfile(str(tmp_path / ".long_running_1"))) + assert sorted(x[0] for x in workflow.list_all({workflow.RUNNING})) == [ "workflow_0", "workflow_1", diff --git a/src/mock/ray/raylet/worker.h b/src/mock/ray/raylet/worker.h index 43878348626e..aaf7de98506e 100644 --- a/src/mock/ray/raylet/worker.h +++ b/src/mock/ray/raylet/worker.h @@ -80,6 +80,7 @@ class MockWorkerInterface : public WorkerInterface { MOCK_METHOD(bool, IsRegistered, (), (override)); MOCK_METHOD(rpc::CoreWorkerClientInterface *, rpc_client, (), (override)); MOCK_METHOD(bool, SetJobId, (const JobID &job_id), (override)); + MOCK_METHOD(const ActorID &, GetRootDetachedActorId, (), (override)); }; } // namespace raylet diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 926a0bfb36bd..5f7ea3412eed 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -862,9 +862,6 @@ RAY_CONFIG(int64_t, RAY_CONFIG(bool, worker_core_dump_exclude_plasma_store, true) RAY_CONFIG(bool, raylet_core_dump_exclude_plasma_store, true) -/// Whether to kill idle workers of a terminated job. -RAY_CONFIG(bool, kill_idle_workers_of_terminated_job, true) - // Instruct the Python default worker to preload the specified imports. // This is specified as a comma-separated list. // If left empty, no such attempt will be made. diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 218e565429b8..29cf326649a8 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -159,6 +159,13 @@ TaskID TaskSpecification::ParentTaskId() const { return TaskID::FromBinary(message_->parent_task_id()); } +ActorID TaskSpecification::RootDetachedActorId() const { + if (message_->root_detached_actor_id().empty() /* e.g., empty proto default */) { + return ActorID::Nil(); + } + return ActorID::FromBinary(message_->root_detached_actor_id()); +} + TaskID TaskSpecification::SubmitterTaskId() const { if (message_->submitter_task_id().empty() /* e.g., empty proto default */) { return TaskID::Nil(); @@ -198,7 +205,8 @@ int TaskSpecification::GetRuntimeEnvHash() const { WorkerCacheKey env = {SerializedRuntimeEnv(), GetRequiredResources().GetResourceMap(), IsActorCreationTask(), - GetRequiredResources().Get(scheduling::ResourceID::GPU()) > 0}; + GetRequiredResources().Get(scheduling::ResourceID::GPU()) > 0, + !(RootDetachedActorId().IsNil())}; return env.IntHash(); } @@ -594,13 +602,15 @@ WorkerCacheKey::WorkerCacheKey( const std::string serialized_runtime_env, const absl::flat_hash_map &required_resources, bool is_actor, - bool is_gpu) + bool is_gpu, + bool is_root_detached_actor) : serialized_runtime_env(serialized_runtime_env), required_resources(RayConfig::instance().worker_resource_limits_enabled() ? required_resources : absl::flat_hash_map{}), is_actor(is_actor && RayConfig::instance().isolate_workers_across_task_types()), is_gpu(is_gpu && RayConfig::instance().isolate_workers_across_resource_types()), + is_root_detached_actor(is_root_detached_actor), hash_(CalculateHash()) {} std::size_t WorkerCacheKey::CalculateHash() const { @@ -617,6 +627,7 @@ std::size_t WorkerCacheKey::CalculateHash() const { boost::hash_combine(hash, serialized_runtime_env); boost::hash_combine(hash, is_actor); boost::hash_combine(hash, is_gpu); + boost::hash_combine(hash, is_root_detached_actor); std::vector> resource_vars(required_resources.begin(), required_resources.end()); @@ -637,7 +648,7 @@ bool WorkerCacheKey::operator==(const WorkerCacheKey &k) const { bool WorkerCacheKey::EnvIsEmpty() const { return IsRuntimeEnvEmpty(serialized_runtime_env) && required_resources.empty() && - !is_gpu; + !is_gpu && !is_root_detached_actor; } std::size_t WorkerCacheKey::Hash() const { return hash_; } diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index cf6a0bc44fd6..b3ae0c11a5c6 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -287,6 +287,8 @@ class TaskSpecification : public MessageWrapper { TaskID ParentTaskId() const; + ActorID RootDetachedActorId() const; + TaskID SubmitterTaskId() const; size_t ParentCounter() const; @@ -527,12 +529,17 @@ class WorkerCacheKey { /// worker. \param required_resources The required resouce. /// worker. \param is_actor Whether the worker will be an actor. This is set when /// task type isolation between workers is enabled. - /// worker. \param iis_gpu Whether the worker will be using GPUs. This is set when + /// worker. \param is_gpu Whether the worker will be using GPUs. This is set when /// resource type isolation between workers is enabled. + /// worker. \param is_root_detached_actor Whether the worker will be running + /// tasks or actors whose root ancestor is a detached actor. This is set + /// to prevent worker reuse between tasks whose root is the driver process + /// and tasks whose root is a detached actor. WorkerCacheKey(const std::string serialized_runtime_env, const absl::flat_hash_map &required_resources, bool is_actor, - bool is_gpu); + bool is_gpu, + bool is_root_detached_actor); bool operator==(const WorkerCacheKey &k) const; @@ -564,6 +571,9 @@ class WorkerCacheKey { const bool is_actor; /// Whether the worker is to use a GPU. const bool is_gpu; + /// Whether the worker is to run tasks or actors + /// whose root is a detached actor. + const bool is_root_detached_actor; /// The hash of the worker's environment. This is set to 0 /// for unspecified or empty environments. const std::size_t hash_ = 0; diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index 30460d6300ee..488c52069aa4 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -172,12 +172,16 @@ class TaskSpecBuilder { int max_retries, bool retry_exceptions, const std::string &serialized_retry_exception_allowlist, - const rpc::SchedulingStrategy &scheduling_strategy) { + const rpc::SchedulingStrategy &scheduling_strategy, + const ActorID root_detached_actor_id) { message_->set_max_retries(max_retries); message_->set_retry_exceptions(retry_exceptions); message_->set_serialized_retry_exception_allowlist( serialized_retry_exception_allowlist); message_->mutable_scheduling_strategy()->CopyFrom(scheduling_strategy); + if (!root_detached_actor_id.IsNil()) { + message_->set_root_detached_actor_id(root_detached_actor_id.Binary()); + } return *this; } @@ -230,7 +234,8 @@ class TaskSpecBuilder { bool is_asyncio = false, const std::vector &concurrency_groups = {}, const std::string &extension_data = "", - bool execute_out_of_order = false) { + bool execute_out_of_order = false, + ActorID root_detached_actor_id = ActorID::Nil()) { message_->set_type(TaskType::ACTOR_CREATION_TASK); auto actor_creation_spec = message_->mutable_actor_creation_task_spec(); actor_creation_spec->set_actor_id(actor_id.Binary()); @@ -258,6 +263,9 @@ class TaskSpecBuilder { } actor_creation_spec->set_execute_out_of_order(execute_out_of_order); message_->mutable_scheduling_strategy()->CopyFrom(scheduling_strategy); + if (!root_detached_actor_id.IsNil()) { + message_->set_root_detached_actor_id(root_detached_actor_id.Binary()); + } return *this; } diff --git a/src/ray/common/test/task_spec_test.cc b/src/ray/common/test/task_spec_test.cc index 1a0bf3fbcf47..72cf8a1f6031 100644 --- a/src/ray/common/test/task_spec_test.cc +++ b/src/ray/common/test/task_spec_test.cc @@ -15,6 +15,7 @@ #include "ray/common/task/task_spec.h" #include "gtest/gtest.h" +#include "ray/common/task/task_util.h" namespace ray { TEST(TaskSpecTest, TestSchedulingClassDescriptor) { @@ -146,6 +147,77 @@ TEST(TaskSpecTest, TestTaskSpecification) { ASSERT_TRUE(task_spec.GetNodeAffinitySchedulingStrategyNodeId() == node_id); } +TEST(TaskSpecTest, TestRootDetachedActorId) { + ActorID actor_id = + ActorID::Of(JobID::FromInt(1), TaskID::FromRandom(JobID::FromInt(1)), 0); + TaskSpecification task_spec; + ASSERT_TRUE(task_spec.RootDetachedActorId().IsNil()); + task_spec.GetMutableMessage().set_root_detached_actor_id(actor_id.Binary()); + ASSERT_EQ(task_spec.RootDetachedActorId(), actor_id); +} + +TEST(TaskSpecTest, TestTaskSpecBuilderRootDetachedActorId) { + TaskSpecBuilder task_spec_builder; + task_spec_builder.SetNormalTaskSpec( + 0, false, "", rpc::SchedulingStrategy(), ActorID::Nil()); + ASSERT_TRUE(task_spec_builder.Build().RootDetachedActorId().IsNil()); + ActorID actor_id = + ActorID::Of(JobID::FromInt(1), TaskID::FromRandom(JobID::FromInt(1)), 0); + task_spec_builder.SetNormalTaskSpec(0, false, "", rpc::SchedulingStrategy(), actor_id); + ASSERT_EQ(task_spec_builder.Build().RootDetachedActorId(), actor_id); + + TaskSpecBuilder actor_spec_builder; + actor_spec_builder.SetActorCreationTaskSpec(actor_id, + /*serialized_actor_handle=*/"", + rpc::SchedulingStrategy(), + /*max_restarts=*/0, + /*max_task_retries=*/0, + /*dynamic_worker_options=*/{}, + /*max_concurrency=*/1, + /*is_detached=*/false, + /*name=*/"", + /*ray_namespace=*/"", + /*is_asyncio=*/false, + /*concurrency_groups=*/{}, + /*extension_data=*/"", + /*execute_out_of_order=*/false, + /*root_detached_actor_id=*/ActorID::Nil()); + ASSERT_TRUE(actor_spec_builder.Build().RootDetachedActorId().IsNil()); + actor_spec_builder.SetActorCreationTaskSpec(actor_id, + /*serialized_actor_handle=*/"", + rpc::SchedulingStrategy(), + /*max_restarts=*/0, + /*max_task_retries=*/0, + /*dynamic_worker_options=*/{}, + /*max_concurrency=*/1, + /*is_detached=*/true, + /*name=*/"", + /*ray_namespace=*/"", + /*is_asyncio=*/false, + /*concurrency_groups=*/{}, + /*extension_data=*/"", + /*execute_out_of_order=*/false, + /*root_detached_actor_id=*/actor_id); + ASSERT_EQ(actor_spec_builder.Build().RootDetachedActorId(), actor_id); +} + +TEST(TaskSpecTest, TestWorkerCacheKey) { + // Test TaskSpec calculates the correct WorkerCacheKey hash. + std::string serialized_runtime_env_A = "mock_env_A"; + rpc::RuntimeEnvInfo runtime_env_info_A; + runtime_env_info_A.set_serialized_runtime_env(serialized_runtime_env_A); + TaskSpecification task_spec; + task_spec.GetMutableMessage().mutable_runtime_env_info()->CopyFrom(runtime_env_info_A); + const WorkerCacheKey key_A = {serialized_runtime_env_A, {}, false, false, false}; + ASSERT_EQ(task_spec.GetRuntimeEnvHash(), key_A.IntHash()); + ActorID actor_id = + ActorID::Of(JobID::FromInt(1), TaskID::FromRandom(JobID::FromInt(1)), 0); + task_spec.GetMutableMessage().set_root_detached_actor_id(actor_id.Binary()); + ASSERT_NE(task_spec.GetRuntimeEnvHash(), key_A.IntHash()); + const WorkerCacheKey key_B = {serialized_runtime_env_A, {}, false, false, true}; + ASSERT_EQ(task_spec.GetRuntimeEnvHash(), key_B.IntHash()); +} + TEST(TaskSpecTest, TestNodeLabelSchedulingStrategy) { rpc::SchedulingStrategy scheduling_strategy_1; auto expr_1 = scheduling_strategy_1.mutable_node_label_scheduling_strategy() diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 826c94cb2dc8..f463abf338a2 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -163,6 +163,7 @@ WorkerContext::WorkerContext(WorkerType worker_type, current_actor_placement_group_id_(PlacementGroupID::Nil()), placement_group_capture_child_tasks_(false), main_thread_id_(boost::this_thread::get_id()), + root_detached_actor_id_(ActorID::Nil()), mutex_() { // For worker main thread which initializes the WorkerContext, // set task_id according to whether current worker is a driver. @@ -290,6 +291,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { RAY_CHECK(current_job_id_ == task_spec.JobId()); if (task_spec.IsNormalTask()) { current_task_is_direct_call_ = true; + root_detached_actor_id_ = task_spec.RootDetachedActorId(); } else if (task_spec.IsActorCreationTask()) { if (!current_actor_id_.IsNil()) { RAY_CHECK(current_actor_id_ == task_spec.ActorCreationId()); @@ -301,6 +303,7 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { is_detached_actor_ = task_spec.IsDetachedActor(); current_actor_placement_group_id_ = task_spec.PlacementGroupBundleId().first; placement_group_capture_child_tasks_ = task_spec.PlacementGroupCaptureChildTasks(); + root_detached_actor_id_ = task_spec.RootDetachedActorId(); } else if (task_spec.IsActorTask()) { RAY_CHECK(current_actor_id_ == task_spec.ActorId()); } else { @@ -330,6 +333,11 @@ const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; } +const ActorID &WorkerContext::GetRootDetachedActorID() const { + absl::ReaderMutexLock lock(&mutex_); + return root_detached_actor_id_; +} + bool WorkerContext::CurrentThreadIsMain() const { return boost::this_thread::get_id() == main_thread_id_; } diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 26a926080c47..007ee4401d39 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -95,6 +95,8 @@ class WorkerContext { const ActorID &GetCurrentActorID() const ABSL_LOCKS_EXCLUDED(mutex_); + const ActorID &GetRootDetachedActorID() const ABSL_LOCKS_EXCLUDED(mutex_); + /// Returns whether the current thread is the main worker thread. bool CurrentThreadIsMain() const; @@ -156,6 +158,9 @@ class WorkerContext { /// for concurrent actor, or the main thread's task id for other cases. /// Used merely for observability purposes to track task hierarchy. TaskID main_thread_or_actor_creation_task_id_ ABSL_GUARDED_BY(mutex_); + /// If the current task or actor is originated from a detached actor, + /// this contains that actor's id otherwise it's nil. + ActorID root_detached_actor_id_ ABSL_GUARDED_BY(mutex_); // To protect access to mutable members; mutable absl::Mutex mutex_; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 8e218901b229..795400be4878 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -2175,10 +2175,15 @@ std::vector CoreWorker::SubmitTask( /*generator_backpressure_num_objects*/ task_options.generator_backpressure_num_objects, /*enable_task_event*/ task_options.enable_task_events); + ActorID root_detached_actor_id; + if (!worker_context_.GetRootDetachedActorID().IsNil()) { + root_detached_actor_id = worker_context_.GetRootDetachedActorID(); + } builder.SetNormalTaskSpec(max_retries, retry_exceptions, serialized_retry_exception_allowlist, - scheduling_strategy); + scheduling_strategy, + root_detached_actor_id); TaskSpecification task_spec = builder.Build(); RAY_LOG(DEBUG) << "Submitting normal task " << task_spec.DebugString(); std::vector returned_refs; @@ -2283,6 +2288,12 @@ Status CoreWorker::CreateActor(const RayFunction &function, actor_creation_options.enable_task_events); std::string serialized_actor_handle; actor_handle->Serialize(&serialized_actor_handle); + ActorID root_detached_actor_id; + if (is_detached) { + root_detached_actor_id = actor_id; + } else if (!worker_context_.GetRootDetachedActorID().IsNil()) { + root_detached_actor_id = worker_context_.GetRootDetachedActorID(); + } builder.SetActorCreationTaskSpec(actor_id, serialized_actor_handle, actor_creation_options.scheduling_strategy, @@ -2296,7 +2307,8 @@ Status CoreWorker::CreateActor(const RayFunction &function, actor_creation_options.is_asyncio, actor_creation_options.concurrency_groups, extension_data, - actor_creation_options.execute_out_of_order); + actor_creation_options.execute_out_of_order, + root_detached_actor_id); // Add the actor handle before we submit the actor creation task, since the // actor handle must be in scope by the time the GCS sends the // WaitForActorOutOfScopeRequest. @@ -4300,26 +4312,31 @@ void CoreWorker::HandleDeleteSpilledObjects(rpc::DeleteSpilledObjectsRequest req void CoreWorker::HandleExit(rpc::ExitRequest request, rpc::ExitReply *reply, rpc::SendReplyCallback send_reply_callback) { - bool own_objects = reference_counter_->OwnObjects(); - int64_t pins_in_flight = local_raylet_client_->GetPinsInFlight(); + const bool own_objects = reference_counter_->OwnObjects(); + const size_t num_pending_tasks = task_manager_->NumPendingTasks(); + const int64_t pins_in_flight = local_raylet_client_->GetPinsInFlight(); // We consider the worker to be idle if it doesn't own any objects and it doesn't have - // any object pinning RPCs in flight. - bool is_idle = !own_objects && pins_in_flight == 0; + // any object pinning RPCs in flight and it doesn't have pending tasks. + bool is_idle = !own_objects && (pins_in_flight == 0) && (num_pending_tasks == 0); bool force_exit = request.force_exit(); RAY_LOG(DEBUG) << "Exiting: is_idle: " << is_idle << " force_exit: " << force_exit; if (!is_idle && force_exit) { RAY_LOG(INFO) << "Force exiting worker that owns object. This may cause other " "workers that depends on the object to lose it. " << "Own objects: " << own_objects - << " # Pins in flight: " << pins_in_flight; + << " # Pins in flight: " << pins_in_flight + << " # pending tasks: " << num_pending_tasks; } bool will_exit = is_idle || force_exit; reply->set_success(will_exit); send_reply_callback( Status::OK(), - [this, will_exit]() { + [this, will_exit, force_exit]() { // If the worker is idle, we exit. - if (will_exit) { + if (force_exit) { + ForceExit(rpc::WorkerExitType::INTENDED_SYSTEM_EXIT, + "Worker force exits because its job has finished"); + } else if (will_exit) { Exit(rpc::WorkerExitType::INTENDED_SYSTEM_EXIT, "Worker exits because it was idle (it doesn't have objects it owns while " "no task or actor has been scheduled) for a long time."); diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 81c7af97dcfc..05c148b3fb2b 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -509,6 +509,10 @@ message TaskSpec { int64 generator_backpressure_num_objects = 38; // Boolean if task events enabled, i.e tasks events would be reported. bool enable_task_events = 39; + // If this task is originated from a detached actor, + // this field contains the detached actor id. + // Otherwise it's empty and is originated from a driver. + bytes root_detached_actor_id = 40; } message TaskInfoEntry { diff --git a/src/ray/raylet/local_task_manager.cc b/src/ray/raylet/local_task_manager.cc index 1a663e655618..a3df5a229cca 100644 --- a/src/ray/raylet/local_task_manager.cc +++ b/src/ray/raylet/local_task_manager.cc @@ -522,6 +522,11 @@ bool LocalTaskManager::PoppedWorkerHandler( task_id, rpc::RequestWorkerLeaseReply::SCHEDULING_CANCELLED_RUNTIME_ENV_SETUP_FAILED, /*scheduling_failure_message*/ runtime_env_setup_error_message); + } else if (status == PopWorkerStatus::JobFinished) { + // The task job finished. + // Just remove the task from dispatch queue. + RAY_LOG(DEBUG) << "Call back to a job finished task, task id = " << task_id; + erase_from_dispatch_queue_fn(work, scheduling_class); } else { // In other cases, set the work status `WAITING` to make this task // could be re-dispatched. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index e86ea9ea6414..e354d9c828ce 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -616,6 +616,31 @@ void NodeManager::HandleJobStarted(const JobID &job_id, const JobTableData &job_ void NodeManager::HandleJobFinished(const JobID &job_id, const JobTableData &job_data) { RAY_LOG(DEBUG) << "HandleJobFinished " << job_id; RAY_CHECK(job_data.is_dead()); + // Force kill all the worker processes belonging to the finished job + // so that no worker processes is leaked. + for (const auto &pair : leased_workers_) { + auto &worker = pair.second; + RAY_CHECK(!worker->GetAssignedJobId().IsNil()); + if (worker->GetRootDetachedActorId().IsNil() && + (worker->GetAssignedJobId() == job_id)) { + // Don't kill worker processes belonging to the detached actor + // since those are expected to outlive the job. + RAY_LOG(INFO) << "The leased worker " << worker->WorkerId() + << " is killed because the job " << job_id << " finished."; + rpc::ExitRequest request; + request.set_force_exit(true); + worker->rpc_client()->Exit( + request, [this, worker](const ray::Status &status, const rpc::ExitReply &r) { + if (!status.ok()) { + RAY_LOG(WARNING) << "Failed to send exit request to worker " + << worker->WorkerId() << ": " << status.ToString() + << ". Killing it using SIGKILL instead."; + // Just kill-9 as a last resort. + KillWorker(worker, /* force */ true); + } + }); + } + } worker_pool_.HandleJobFinished(job_id); } @@ -1053,8 +1078,6 @@ void NodeManager::HandleUnexpectedWorkerFailure(const rpc::WorkerDeltaData &data RAY_LOG(DEBUG) << "Lease " << worker->WorkerId() << " owned by " << owner_worker_id; RAY_CHECK(!owner_worker_id.IsNil() && !owner_node_id.IsNil()); if (!worker->IsDetachedActor()) { - // TODO (Alex): Cancel all pending child tasks of the tasks whose owners have failed - // because the owner could've submitted lease requests before failing. if (!worker_id.IsNil()) { // If the failed worker was a leased worker's owner, then kill the leased worker. if (owner_worker_id == worker_id) { diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index 149c95677c37..c4bad4e83048 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -185,7 +185,7 @@ RayTask CreateTask( } } - spec_builder.SetNormalTaskSpec(0, false, "", scheduling_strategy); + spec_builder.SetNormalTaskSpec(0, false, "", scheduling_strategy, ActorID::Nil()); return RayTask(spec_builder.Build()); } @@ -544,7 +544,7 @@ TEST_F(ClusterTaskManagerTest, DispatchQueueNonBlockingTest) { pool_.TriggerCallbacks(); // Push a worker that can only run task A. - const WorkerCacheKey env_A = {serialized_runtime_env_A, {}, false, false}; + const WorkerCacheKey env_A = {serialized_runtime_env_A, {}, false, false, false}; const int runtime_env_hash_A = env_A.IntHash(); std::shared_ptr worker_A = std::make_shared(WorkerID::FromRandom(), 1234, runtime_env_hash_A); @@ -1130,6 +1130,21 @@ TEST_F(ClusterTaskManagerTest, NotOKPopWorkerTest) { ASSERT_TRUE(reply.canceled()); ASSERT_EQ(reply.scheduling_failure_message(), runtime_env_error_msg); + // Test that local task manager handles PopWorkerStatus::JobFinished correctly. + callback_called = false; + reply.Clear(); + RayTask task3 = CreateTask({{ray::kCPU_ResourceLabel, 1}}); + task_manager_.QueueAndScheduleTask(task3, false, false, &reply, callback); + ASSERT_EQ(NumTasksToDispatchWithStatus(internal::WorkStatus::WAITING_FOR_WORKER), 1); + ASSERT_EQ(NumTasksToDispatchWithStatus(internal::WorkStatus::WAITING), 0); + ASSERT_EQ(NumRunningTasks(), 1); + pool_.TriggerCallbacksWithNotOKStatus(PopWorkerStatus::JobFinished); + // The task should be removed from the dispatch queue. + ASSERT_FALSE(callback_called); + ASSERT_EQ(NumTasksToDispatchWithStatus(internal::WorkStatus::WAITING_FOR_WORKER), 0); + ASSERT_EQ(NumTasksToDispatchWithStatus(internal::WorkStatus::WAITING), 0); + ASSERT_EQ(NumRunningTasks(), 0); + AssertNoLeaks(); } diff --git a/src/ray/raylet/test/util.h b/src/ray/raylet/test/util.h index d1ad7f5198fa..e9b4351af891 100644 --- a/src/ray/raylet/test/util.h +++ b/src/ray/raylet/test/util.h @@ -40,6 +40,7 @@ class MockWorker : public WorkerInterface { void SetAssignedTask(const RayTask &assigned_task) override { task_ = assigned_task; task_assign_time_ = absl::Now(); + root_detached_actor_id_ = assigned_task.GetTaskSpecification().RootDetachedActorId(); }; absl::Time GetAssignedTaskTime() const override { return task_assign_time_; }; @@ -155,6 +156,10 @@ class MockWorker : public WorkerInterface { void SetJobId(const JobID &job_id) override { job_id_ = job_id; } + const ActorID &GetRootDetachedActorId() const override { + return root_detached_actor_id_; + } + protected: void SetStartupToken(StartupToken startup_token) override { RAY_CHECK(false) << "Method unused"; @@ -175,6 +180,7 @@ class MockWorker : public WorkerInterface { int runtime_env_hash_; TaskID task_id_; JobID job_id_; + ActorID root_detached_actor_id_; }; } // namespace raylet diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index 13ad92cc17d7..f45878568c34 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -112,6 +112,8 @@ class WorkerInterface { virtual void SetJobId(const JobID &job_id) = 0; + virtual const ActorID &GetRootDetachedActorId() const = 0; + protected: virtual void SetStartupToken(StartupToken startup_token) = 0; @@ -120,6 +122,7 @@ class WorkerInterface { FRIEND_TEST(WorkerPoolDriverRegisteredTest, TestWorkerCappingLaterNWorkersNotOwningObjects); FRIEND_TEST(WorkerPoolDriverRegisteredTest, TestJobFinishedForceKillIdleWorker); + FRIEND_TEST(WorkerPoolDriverRegisteredTest, TestJobFinishedForPopWorker); FRIEND_TEST(WorkerPoolDriverRegisteredTest, WorkerFromAliveJobDoesNotBlockWorkerFromDeadJobFromGettingKilled); FRIEND_TEST(WorkerPoolDriverRegisteredTest, TestWorkerCappingWithExitDelay); @@ -205,6 +208,8 @@ class Worker : public WorkerInterface { lifetime_allocated_instances_ = allocated_instances; }; + const ActorID &GetRootDetachedActorId() const { return root_detached_actor_id_; } + std::shared_ptr GetLifetimeAllocatedInstances() { return lifetime_allocated_instances_; }; @@ -216,6 +221,7 @@ class Worker : public WorkerInterface { void SetAssignedTask(const RayTask &assigned_task) { assigned_task_ = assigned_task; task_assign_time_ = absl::Now(); + root_detached_actor_id_ = assigned_task.GetTaskSpecification().RootDetachedActorId(); } absl::Time GetAssignedTaskTime() const { return task_assign_time_; }; @@ -271,6 +277,8 @@ class Worker : public WorkerInterface { const int runtime_env_hash_; /// The worker's actor ID. If this is nil, then the worker is not an actor. ActorID actor_id_; + /// Root detached actor ID for the worker's last assigned task. + ActorID root_detached_actor_id_; /// The worker's placement group bundle. It is used to detect if the worker is /// associated with a placement group bundle. BundleID bundle_id_; diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index d2ab8539a800..74e682b7a52e 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -183,7 +183,8 @@ void WorkerPool::SetRuntimeEnvAgentClient( runtime_env_agent_client_ = runtime_env_agent_client; } -void WorkerPool::PopWorkerCallbackAsync(const PopWorkerCallback &callback, +void WorkerPool::PopWorkerCallbackAsync(const TaskSpecification &task_spec, + const PopWorkerCallback &callback, std::shared_ptr worker, PopWorkerStatus status) { // This method shouldn't be invoked when runtime env creation has failed because @@ -192,17 +193,34 @@ void WorkerPool::PopWorkerCallbackAsync(const PopWorkerCallback &callback, RAY_CHECK(status != PopWorkerStatus::RuntimeEnvCreationFailed); // Call back this function asynchronously to make sure executed in different stack. io_service_->post( - [this, callback, worker, status]() { - PopWorkerCallbackInternal(callback, worker, status); + [this, task_spec, callback, worker, status]() { + PopWorkerCallbackInternal(task_spec, callback, worker, status); }, "WorkerPool.PopWorkerCallback"); } -void WorkerPool::PopWorkerCallbackInternal(const PopWorkerCallback &callback, +void WorkerPool::PopWorkerCallbackInternal(const TaskSpecification &task_spec, + const PopWorkerCallback &callback, std::shared_ptr worker, PopWorkerStatus status) { RAY_CHECK(callback); - auto used = callback(worker, status, /*runtime_env_setup_error_message*/ ""); + auto used = false; + if (worker && finished_jobs_.contains(task_spec.JobId()) && + task_spec.RootDetachedActorId().IsNil()) { + // When a job finishes, node manager will kill leased workers one time + // and worker pool will kill idle workers periodically. + // The current worker is already removed from the idle workers + // but hasn't been added to the leased workers since the callback is not called yet. + // We shouldn't add this worker to the leased workers since killing leased workers + // for this finished job may already happen and won't happen again (this is one time) + // so it will cause a process leak. + // Instead we fail the PopWorker and add the worker back to the idle workers so it can + // be killed later. + RAY_CHECK(status == PopWorkerStatus::OK); + callback(nullptr, PopWorkerStatus::JobFinished, ""); + } else { + used = callback(worker, status, /*runtime_env_setup_error_message*/ ""); + } if (worker && !used) { // The invalid worker not used, restore it to worker pool. PushWorker(worker); @@ -338,20 +356,18 @@ WorkerPool::BuildProcessCommandArgs(const Language &language, worker_command_args.push_back("--worker-launch-time-ms=" + std::to_string(current_sys_time_ms())); worker_command_args.push_back("--node-id=" + node_id_.Hex()); + // TODO(jjyao) This should be renamed to worker cache key hash + worker_command_args.push_back("--runtime-env-hash=" + + std::to_string(runtime_env_hash)); } else if (language == Language::CPP) { worker_command_args.push_back("--startup_token=" + std::to_string(worker_startup_token_counter_)); + worker_command_args.push_back("--ray_runtime_env_hash=" + + std::to_string(runtime_env_hash)); } if (serialized_runtime_env_context != "{}" && !serialized_runtime_env_context.empty()) { worker_command_args.push_back("--language=" + Language_Name(language)); - if (language == Language::CPP) { - worker_command_args.push_back("--ray_runtime_env_hash=" + - std::to_string(runtime_env_hash)); - } else { - worker_command_args.push_back("--runtime-env-hash=" + - std::to_string(runtime_env_hash)); - } worker_command_args.push_back("--serialized-runtime-env-context=" + serialized_runtime_env_context); } else if (language == Language::PYTHON && worker_command_args.size() >= 2 && @@ -561,14 +577,12 @@ void WorkerPool::MonitorStartingWorkerProcess(const Process &proc, process_failed_pending_registration_++; bool found; bool used; - TaskID task_id; InvokePopWorkerCallbackForProcess(state.starting_workers_to_tasks, proc_startup_token, nullptr, status, &found, - &used, - &task_id); + &used); DeleteRuntimeEnvIfPossible(it->second.runtime_env_info.serialized_runtime_env()); RemoveWorkerProcess(state, proc_startup_token); if (IsIOWorkerType(worker_type)) { @@ -972,21 +986,29 @@ void WorkerPool::InvokePopWorkerCallbackForProcess( const std::shared_ptr &worker, const PopWorkerStatus &status, bool *found, - bool *worker_used, - TaskID *task_id) { + bool *worker_used) { *found = false; *worker_used = false; auto it = starting_workers_to_tasks.find(startup_token); if (it != starting_workers_to_tasks.end()) { *found = true; - *task_id = it->second.task_id; const auto &callback = it->second.callback; RAY_CHECK(callback); // This method shouldn't be invoked when runtime env creation has failed because // when runtime env is failed to be created, they are all // invoking the callback immediately. RAY_CHECK(status != PopWorkerStatus::RuntimeEnvCreationFailed); - *worker_used = callback(worker, status, /*runtime_env_setup_error_message*/ ""); + if (worker && finished_jobs_.contains(it->second.task_spec.JobId()) && + it->second.task_spec.RootDetachedActorId().IsNil()) { + // If the job has finished, we should fail the PopWorker callback + // and add the worker back to the idle workers so it can be killed later. + // This doesn't apply to detached actor and its descendants + // since they can outlive the job. + RAY_CHECK(status == PopWorkerStatus::OK); + callback(nullptr, PopWorkerStatus::JobFinished, ""); + } else { + *worker_used = callback(worker, status, /*runtime_env_setup_error_message*/ ""); + } starting_workers_to_tasks.erase(it); } } @@ -998,14 +1020,12 @@ void WorkerPool::PushWorker(const std::shared_ptr &worker) { auto &state = GetStateForLanguage(worker->GetLanguage()); bool found; bool used; - TaskID task_id; InvokePopWorkerCallbackForProcess(state.starting_workers_to_tasks, worker->GetStartupToken(), worker, PopWorkerStatus::OK, &found, - &used, - &task_id); + &used); RAY_LOG(DEBUG) << "PushWorker " << worker->WorkerId() << " used: " << used; if (!used) { // Put the worker to the idle pool. @@ -1102,8 +1122,7 @@ void WorkerPool::KillIdleWorker(std::shared_ptr idle_worker, RAY_CHECK(rpc_client); rpc::ExitRequest request; const auto &job_id = idle_worker->GetAssignedJobId(); - if (finished_jobs_.contains(job_id) && - RayConfig::instance().kill_idle_workers_of_terminated_job()) { + if (finished_jobs_.contains(job_id) && idle_worker->GetRootDetachedActorId().IsNil()) { RAY_LOG(INFO) << "Force exiting worker whose job has exited " << idle_worker->WorkerId(); request.set_force_exit(true); @@ -1170,7 +1189,7 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, if (status == PopWorkerStatus::OK) { RAY_CHECK(proc.IsValid()); WarnAboutSize(); - auto task_info = TaskWaitingForWorkerInfo{task_spec.TaskId(), callback}; + auto task_info = TaskWaitingForWorkerInfo{task_spec, callback}; state.starting_workers_to_tasks[startup_token] = std::move(task_info); } else if (status == PopWorkerStatus::TooManyStartingWorkerProcesses) { // TODO(jjyao) As an optimization, we don't need to delete the runtime env @@ -1180,7 +1199,7 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, PopWorkerRequest{task_spec, callback, allocated_instances_serialized_json}); } else { DeleteRuntimeEnvIfPossible(task_spec.SerializedRuntimeEnv()); - PopWorkerCallbackAsync(callback, nullptr, status); + PopWorkerCallbackAsync(task_spec, callback, nullptr, status); } }; @@ -1297,7 +1316,7 @@ void WorkerPool::PopWorker(const TaskSpecification &task_spec, RAY_LOG(DEBUG) << "Re-using worker " << worker->WorkerId() << " for task " << task_spec.DebugString(); stats::NumWorkersStartedFromCache.Record(1); - PopWorkerCallbackAsync(callback, worker); + PopWorkerCallbackAsync(task_spec, callback, worker); } } @@ -1341,7 +1360,8 @@ void WorkerPool::PrestartDefaultCpuWorkers(ray::Language language, int64_t num_n static const WorkerCacheKey kDefaultCpuWorkerCacheKey{/*serialized_runtime_env*/ "", {{"CPU", 1}}, /*is_actor*/ false, - /*is_gpu*/ false}; + /*is_gpu*/ false, + /*is_root_detached_actor*/ false}; RAY_LOG(DEBUG) << "PrestartDefaultCpuWorkers " << num_needed; for (int i = 0; i < num_needed; i++) { PopWorkerStatus status; diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index b8e03198dfd2..dcd705addcd4 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -59,6 +59,9 @@ enum PopWorkerStatus { // Any fails of runtime env creation. // A nullptr worker will be returned with callback. RuntimeEnvCreationFailed = 4, + // The task's job has finished. + // A nullptr worker will be returned with callback. + JobFinished = 5, }; /// \param[in] worker The started worker instance. Nullptr if worker is not started. @@ -442,7 +445,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { virtual void WarnAboutSize(); /// Make this synchronized function for unit test. - void PopWorkerCallbackInternal(const PopWorkerCallback &callback, + void PopWorkerCallbackInternal(const TaskSpecification &task_spec, + const PopWorkerCallback &callback, std::shared_ptr worker, PopWorkerStatus status); @@ -486,8 +490,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { }; struct TaskWaitingForWorkerInfo { - /// The id of task. - TaskID task_id; + /// The spec of task. + TaskSpecification task_spec; /// The callback function which should be called when worker registered. PopWorkerCallback callback; }; @@ -608,7 +612,8 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// Call the `PopWorkerCallback` function asynchronously to make sure executed in /// different stack. - virtual void PopWorkerCallbackAsync(const PopWorkerCallback &callback, + virtual void PopWorkerCallbackAsync(const TaskSpecification &task_spec, + const PopWorkerCallback &callback, std::shared_ptr worker, PopWorkerStatus status = PopWorkerStatus::OK); @@ -623,15 +628,13 @@ class WorkerPool : public WorkerPoolInterface, public IOWorkerPoolInterface { /// \param found Whether the related task found or not. /// \param worker_used Whether the worker is used by the task, only valid when found is /// true. - /// \param task_id The related task id. void InvokePopWorkerCallbackForProcess( absl::flat_hash_map &workers_to_tasks, StartupToken startup_token, const std::shared_ptr &worker, const PopWorkerStatus &status, bool *found /* output */, - bool *worker_used /* output */, - TaskID *task_id /* output */); + bool *worker_used /* output */); /// We manage all runtime env resources locally by the two methods: /// `GetOrCreateRuntimeEnv` and `DeleteRuntimeEnvIfPossible`. diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index fdfeeeff2d9f..ab194e595093 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -160,10 +160,11 @@ class WorkerPoolMock : public WorkerPool { using WorkerPool::PopWorkerCallbackInternal; // Mock `PopWorkerCallbackAsync` to synchronized function. - void PopWorkerCallbackAsync(const PopWorkerCallback &callback, + void PopWorkerCallbackAsync(const TaskSpecification &task_spec, + const PopWorkerCallback &callback, std::shared_ptr worker, PopWorkerStatus status = PopWorkerStatus::OK) override { - PopWorkerCallbackInternal(callback, worker, status); + PopWorkerCallbackInternal(task_spec, callback, worker, status); } Process StartProcess(const std::vector &worker_command_args, @@ -1377,6 +1378,86 @@ TEST_F(WorkerPoolDriverRegisteredTest, TestWorkerCappingWithExitDelay) { ASSERT_EQ(worker_pool_->GetIdleWorkerSize(), workers.size()); } +TEST_F(WorkerPoolDriverRegisteredTest, TestJobFinishedForPopWorker) { + // Test to make sure that if job finishes, + // PopWorker should fail with PopWorkerStatus::JobFinished + + auto job_id = JOB_ID; + + /// Add worker to the pool. + PopWorkerStatus status; + auto [proc, token] = worker_pool_->StartWorkerProcess( + Language::PYTHON, rpc::WorkerType::WORKER, job_id, &status); + auto worker = worker_pool_->CreateWorker(Process(), Language::PYTHON, job_id); + worker->SetStartupToken(worker_pool_->GetStartupToken(proc)); + RAY_CHECK_OK(worker_pool_->RegisterWorker( + worker, proc.GetId(), worker_pool_->GetStartupToken(proc), [](Status, int) {})); + worker_pool_->OnWorkerStarted(worker); + worker_pool_->PushWorker(worker); + ASSERT_EQ(worker_pool_->GetIdleWorkerSize(), 1); + + auto mock_rpc_client_it = mock_worker_rpc_clients_.find(worker->WorkerId()); + auto mock_rpc_client = mock_rpc_client_it->second; + + // Finish the job. + worker_pool_->HandleJobFinished(job_id); + + auto task_spec = ExampleTaskSpec(/*actor_id=*/ActorID::Nil(), Language::PYTHON, job_id); + PopWorkerStatus pop_worker_status; + // This PopWorker should fail since the job finished. + worker = worker_pool_->PopWorkerSync(task_spec, false, &pop_worker_status); + ASSERT_EQ(pop_worker_status, PopWorkerStatus::JobFinished); + ASSERT_FALSE(worker); + ASSERT_EQ(worker_pool_->GetIdleWorkerSize(), 1); + + worker_pool_->TryKillingIdleWorkers(); + ASSERT_EQ(mock_rpc_client->exit_count, 1); + ASSERT_EQ(mock_rpc_client->last_exit_forced, true); + mock_rpc_client->ExitReplySucceed(); + + job_id = JOB_ID2; + rpc::JobConfig job_config; + RegisterDriver(Language::PYTHON, job_id, job_config); + task_spec = ExampleTaskSpec(/*actor_id=*/ActorID::Nil(), Language::PYTHON, job_id); + pop_worker_status = PopWorkerStatus::OK; + // This will start a new worker. + worker_pool_->PopWorker( + task_spec, + [&](const std::shared_ptr worker, + PopWorkerStatus status, + const std::string &runtime_env_setup_error_message) -> bool { + pop_worker_status = status; + return false; + }); + auto process = worker_pool_->LastStartedWorkerProcess(); + RAY_CHECK(process.IsValid()); + ASSERT_EQ(1, worker_pool_->NumWorkersStarting()); + + worker = worker_pool_->CreateWorker(Process()); + worker->SetStartupToken(worker_pool_->GetStartupToken(process)); + RAY_CHECK_OK(worker_pool_->RegisterWorker( + worker, process.GetId(), worker_pool_->GetStartupToken(process), [](Status, int) { + })); + // Call `OnWorkerStarted` to emulate worker port announcement. + worker_pool_->OnWorkerStarted(worker); + + mock_rpc_client_it = mock_worker_rpc_clients_.find(worker->WorkerId()); + mock_rpc_client = mock_rpc_client_it->second; + + // Finish the job. + worker_pool_->HandleJobFinished(job_id); + + // This will trigger the PopWorker callback. + worker_pool_->PushWorker(worker); + ASSERT_EQ(pop_worker_status, PopWorkerStatus::JobFinished); + ASSERT_EQ(worker_pool_->GetIdleWorkerSize(), 1); + + worker_pool_->TryKillingIdleWorkers(); + ASSERT_EQ(mock_rpc_client->exit_count, 1); + ASSERT_EQ(mock_rpc_client->last_exit_forced, true); + mock_rpc_client->ExitReplySucceed(); +} + TEST_F(WorkerPoolDriverRegisteredTest, TestJobFinishedForceKillIdleWorker) { auto job_id = JOB_ID;