diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index 1d3611c7cba9..836d39e8f734 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -18,6 +18,8 @@ wait_for_pid_to_exit, run_string_as_driver, ) +from ray.job_submission import JobSubmissionClient, JobStatus +from ray._raylet import GcsClient import psutil @@ -975,6 +977,77 @@ def test_redis_logs(external_redis): ) +@pytest.mark.parametrize( + "ray_start_cluster_head_with_external_redis", + [ + generate_system_config_map( + gcs_failover_worker_reconnect_timeout=20, + gcs_rpc_server_reconnect_timeout_s=2, + ) + ], + indirect=True, +) +def test_job_finished_after_head_node_restart( + ray_start_cluster_head_with_external_redis, +): + cluster = ray_start_cluster_head_with_external_redis + head_node = cluster.head_node + + # submit job + client = JobSubmissionClient(head_node.address) + submission_id = client.submit_job( + entrypoint="python -c 'import ray; ray.init(); print(ray.cluster_resources()); \ + import time; time.sleep(1000)'" + ) + + def get_job_info(submission_id): + gcs_client = GcsClient(cluster.address) + all_job_info = gcs_client.get_all_job_info() + + return list( + filter( + lambda job_info: "job_submission_id" in job_info.config.metadata + and job_info.config.metadata["job_submission_id"] == submission_id, + list(all_job_info.values()), + ) + ) + + def _check_job_running(submission_id: str) -> bool: + job_infos = get_job_info(submission_id) + if len(job_infos) == 0: + return False + job_info = job_infos[0].job_info + return job_info.status == JobStatus.RUNNING + + # wait until job info is written in redis + wait_for_condition(_check_job_running, submission_id=submission_id, timeout=10) + + # kill head node + ray.shutdown() + gcs_server_process = head_node.all_processes["gcs_server"][0].process + gcs_server_pid = gcs_server_process.pid + + cluster.remove_node(head_node) + + # Wait to prevent the gcs server process becoming zombie. + gcs_server_process.wait() + wait_for_pid_to_exit(gcs_server_pid, 1000) + + # restart head node + cluster.add_node() + ray.init(cluster.address) + + # verify if job is finished, which marked is_dead + def _check_job_is_dead(submission_id: str) -> bool: + job_infos = get_job_info(submission_id) + if len(job_infos) == 0: + return False + job_info = job_infos[0] + return job_info.is_dead + + wait_for_condition(_check_job_is_dead, submission_id=submission_id, timeout=10) + + if __name__ == "__main__": import pytest diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index c0d32c5b5b4a..43550bf41e45 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -204,11 +204,11 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, job_data_key_to_indices[job_data_key].push_back(i); } + WorkerID worker_id = WorkerID::FromBinary(data.second.driver_address().worker_id()); + // If job is not dead, get is_running_tasks from the core worker for the driver. if (data.second.is_dead()) { reply->mutable_job_info_list(i)->set_is_running_tasks(false); - WorkerID worker_id = - WorkerID::FromBinary(data.second.driver_address().worker_id()); core_worker_clients_.Disconnect(worker_id); (*num_processed_jobs)++; ; @@ -218,11 +218,13 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, auto client = core_worker_clients_.GetOrConnect(data.second.driver_address()); std::unique_ptr request( new rpc::NumPendingTasksRequest()); + RAY_LOG(DEBUG) << "Send NumPendingTasksRequest to worker " << worker_id; client->NumPendingTasks( std::move(request), - [reply, i, num_processed_jobs, try_send_reply]( + [worker_id, reply, i, num_processed_jobs, try_send_reply]( const Status &status, const rpc::NumPendingTasksReply &num_pending_tasks_reply) { + RAY_LOG(DEBUG) << "Received NumPendingTasksReply from worker " << worker_id; if (!status.ok()) { RAY_LOG(WARNING) << "Failed to get is_running_tasks from core worker: " << status.ToString(); @@ -297,5 +299,28 @@ std::shared_ptr GcsJobManager::GetJobConfig(const JobID &job_id) return it->second; } +void GcsJobManager::OnNodeDead(const NodeID &node_id) { + RAY_LOG(INFO) << "Node " << node_id + << " failed, mark all jobs from this node as finished"; + + auto on_done = [this, node_id](const absl::flat_hash_map &result) { + // If job is not dead and from driver in current node, then mark it as finished + for (auto &data : result) { + if (!data.second.is_dead() && + NodeID::FromBinary(data.second.driver_address().raylet_id()) == node_id) { + RAY_LOG(DEBUG) << "Marking job: " << data.first << " as finished"; + MarkJobAsFinished(data.second, [data](Status status) { + if (!status.ok()) { + RAY_LOG(WARNING) << "Failed to mark job as finished. Status: " << status; + } + }); + } + } + }; + + // make all jobs in current node to finished + RAY_CHECK_OK(gcs_table_storage_->JobTable().GetAll(on_done)); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.h b/src/ray/gcs/gcs_server/gcs_job_manager.h index fbc9ba3c96b7..75c29a6789fd 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.h +++ b/src/ray/gcs/gcs_server/gcs_job_manager.h @@ -79,6 +79,12 @@ class GcsJobManager : public rpc::JobInfoHandler { std::shared_ptr GetJobConfig(const JobID &job_id) const; + /// Handle a node death. This will marks all jobs associated with the + /// specified node id as finished. + /// + /// \param node_id The specified node id. + void OnNodeDead(const NodeID &node_id); + private: std::shared_ptr gcs_table_storage_; std::shared_ptr gcs_publisher_; diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 0d393a78901e..48fe3b4b4399 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -740,6 +740,7 @@ void GcsServer::InstallEventListeners() { gcs_resource_manager_->OnNodeDead(node_id); gcs_placement_group_manager_->OnNodeDead(node_id); gcs_actor_manager_->OnNodeDead(node_id, node_ip_address); + gcs_job_manager_->OnNodeDead(node_id); raylet_client_pool_->Disconnect(node_id); gcs_healthcheck_manager_->RemoveNode(node_id); pubsub_handler_->RemoveSubscriberFrom(node_id.Binary()); diff --git a/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc index d331506db8a7..81e5c152ce9c 100644 --- a/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_job_manager_test.cc @@ -553,6 +553,87 @@ TEST_F(GcsJobManagerTest, TestPreserveDriverInfo) { ASSERT_EQ(data.driver_pid(), 8264); } +TEST_F(GcsJobManagerTest, TestNodeFailure) { + gcs::GcsJobManager gcs_job_manager(gcs_table_storage_, + gcs_publisher_, + runtime_env_manager_, + *function_manager_, + *fake_kv_, + client_factory_); + + auto job_id1 = JobID::FromInt(1); + auto job_id2 = JobID::FromInt(2); + gcs::GcsInitData gcs_init_data(gcs_table_storage_); + gcs_job_manager.Initialize(/*init_data=*/gcs_init_data); + + rpc::AddJobReply empty_reply; + std::promise promise1; + std::promise promise2; + + auto add_job_request1 = Mocker::GenAddJobRequest(job_id1, "namespace_1"); + gcs_job_manager.HandleAddJob( + *add_job_request1, + &empty_reply, + [&promise1](Status, std::function, std::function) { + promise1.set_value(true); + }); + promise1.get_future().get(); + + auto add_job_request2 = Mocker::GenAddJobRequest(job_id2, "namespace_2"); + gcs_job_manager.HandleAddJob( + *add_job_request2, + &empty_reply, + [&promise2](Status, std::function, std::function) { + promise2.set_value(true); + }); + promise2.get_future().get(); + + rpc::GetAllJobInfoRequest all_job_info_request; + rpc::GetAllJobInfoReply all_job_info_reply; + std::promise all_job_info_promise; + + // Check if all job are not dead + gcs_job_manager.HandleGetAllJobInfo( + all_job_info_request, + &all_job_info_reply, + [&all_job_info_promise](Status, std::function, std::function) { + all_job_info_promise.set_value(true); + }); + all_job_info_promise.get_future().get(); + for (auto job_info : all_job_info_reply.job_info_list()) { + ASSERT_TRUE(!job_info.is_dead()); + } + + // Remove node and then check that the job is dead. + auto address = all_job_info_reply.job_info_list().Get(0).driver_address(); + auto node_id = NodeID::FromBinary(address.raylet_id()); + gcs_job_manager.OnNodeDead(node_id); + + // Test get all jobs and check if killed node jobs marked as finished + auto condition = [&gcs_job_manager, node_id]() -> bool { + rpc::GetAllJobInfoRequest all_job_info_request2; + rpc::GetAllJobInfoReply all_job_info_reply2; + std::promise all_job_info_promise2; + gcs_job_manager.HandleGetAllJobInfo( + all_job_info_request2, + &all_job_info_reply2, + [&all_job_info_promise2](Status, std::function, std::function) { + all_job_info_promise2.set_value(true); + }); + all_job_info_promise2.get_future().get(); + + bool job_condition = true; + // job1 from the current node should dead, while job2 is still alive + for (auto job_info : all_job_info_reply2.job_info_list()) { + auto job_node_id = NodeID::FromBinary(job_info.driver_address().raylet_id()); + job_condition = job_condition && (job_info.is_dead() == (job_node_id == node_id)); + } + return job_condition; + }; + + EXPECT_TRUE(WaitForCondition(condition, 2000)); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS();