diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index 36643ce286f19..c7f100aade5ac 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -41,7 +41,7 @@ ctypedef pair[c_vector[CObjectID], c_vector[CObjectID]] WaitResultPair cdef extern from "ray/raylet/raylet_client.h" nogil: - cdef cppclass CRayletClient "RayletClient": + cdef cppclass CRayletClient "ray::raylet::RayletClient": CRayletClient(const c_string &raylet_socket, const CWorkerID &worker_id, c_bool is_worker, const CJobID &job_id, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 7b43ee81c7963..9e1461d4ea16b 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -128,7 +128,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, auto grpc_client = rpc::NodeManagerWorkerClient::make( node_ip_address, node_manager_port, *client_call_manager_); ClientID local_raylet_id; - local_raylet_client_ = std::shared_ptr(new RayletClient( + local_raylet_client_ = std::shared_ptr(new raylet::RayletClient( std::move(grpc_client), raylet_socket, WorkerID::FromBinary(worker_context_.GetWorkerID().Binary()), (worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(), @@ -210,8 +210,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, [this](const rpc::Address &address) { auto grpc_client = rpc::NodeManagerWorkerClient::make( address.ip_address(), address.port(), *client_call_manager_); - return std::shared_ptr( - new RayletClient(std::move(grpc_client))); + return std::shared_ptr( + new raylet::RayletClient(std::move(grpc_client))); }, memory_store_, task_manager_, local_raylet_id, RayConfig::instance().worker_lease_timeout_milliseconds())); diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 16c024fe1e466..dcd4aa76e9307 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -89,7 +89,7 @@ class CoreWorker { WorkerContext &GetWorkerContext() { return worker_context_; } - RayletClient &GetRayletClient() { return *local_raylet_client_; } + raylet::RayletClient &GetRayletClient() { return *local_raylet_client_; } const TaskID &GetCurrentTaskId() const { return worker_context_.GetCurrentTaskID(); } @@ -525,7 +525,7 @@ class CoreWorker { // shared_ptr for direct calls because we can lease multiple workers through // one client, and we need to keep the connection alive until we return all // of the workers. - std::shared_ptr local_raylet_client_; + std::shared_ptr local_raylet_client_; // Thread that runs a boost::asio service to process IO events. std::thread io_thread_; diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index dcb92064afbc1..b634ee4e98615 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -109,7 +109,7 @@ std::shared_ptr GetRequest::Get(const ObjectID &object_id) const { CoreWorkerMemoryStore::CoreWorkerMemoryStore( std::function store_in_plasma, std::shared_ptr counter, - std::shared_ptr raylet_client) + std::shared_ptr raylet_client) : store_in_plasma_(store_in_plasma), ref_counter_(counter), raylet_client_(raylet_client) {} diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index d847d5bfac2db..09f08297f8266 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -29,7 +29,7 @@ class CoreWorkerMemoryStore { CoreWorkerMemoryStore( std::function store_in_plasma = nullptr, std::shared_ptr counter = nullptr, - std::shared_ptr raylet_client = nullptr); + std::shared_ptr raylet_client = nullptr); ~CoreWorkerMemoryStore(){}; /// Put an object with specified ID into object store. @@ -124,7 +124,7 @@ class CoreWorkerMemoryStore { std::shared_ptr ref_counter_ = nullptr; // If set, this will be used to notify worker blocked / unblocked on get calls. - std::shared_ptr raylet_client_ = nullptr; + std::shared_ptr raylet_client_ = nullptr; /// Protects the data structures below. absl::Mutex mu_; diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.cc b/src/ray/core_worker/store_provider/plasma_store_provider.cc index ef3a1dcf16d4d..11b9f22df8fbe 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.cc +++ b/src/ray/core_worker/store_provider/plasma_store_provider.cc @@ -7,7 +7,8 @@ namespace ray { CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider( - const std::string &store_socket, const std::shared_ptr raylet_client, + const std::string &store_socket, + const std::shared_ptr raylet_client, std::function check_signals) : raylet_client_(raylet_client) { check_signals_ = check_signals; @@ -128,7 +129,7 @@ Status CoreWorkerPlasmaStoreProvider::FetchAndGetFromPlasmaStore( return Status::OK(); } -Status UnblockIfNeeded(const std::shared_ptr &client, +Status UnblockIfNeeded(const std::shared_ptr &client, const WorkerContext &ctx) { if (ctx.CurrentTaskIsDirectCall()) { if (ctx.ShouldReleaseResourcesOnBlockingCalls()) { diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index c6cac42128f5b..1b545ac7ddb78 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -20,7 +20,7 @@ namespace ray { class CoreWorkerPlasmaStoreProvider { public: CoreWorkerPlasmaStoreProvider(const std::string &store_socket, - const std::shared_ptr raylet_client, + const std::shared_ptr raylet_client, std::function check_signals); ~CoreWorkerPlasmaStoreProvider(); @@ -83,7 +83,7 @@ class CoreWorkerPlasmaStoreProvider { static void WarnIfAttemptedTooManyTimes(int num_attempts, const absl::flat_hash_set &remaining); - const std::shared_ptr raylet_client_; + const std::shared_ptr raylet_client_; plasma::PlasmaClient store_client_; std::mutex store_client_mutex_; std::function check_signals_; diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index 07c89c6e4ee84..6a380dc7586df 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -151,7 +151,7 @@ CoreWorkerDirectTaskReceiver::CoreWorkerDirectTaskReceiver( exit_handler_(exit_handler), task_main_io_service_(main_io_service) {} -void CoreWorkerDirectTaskReceiver::Init(RayletClient &raylet_client) { +void CoreWorkerDirectTaskReceiver::Init(raylet::RayletClient &raylet_client) { waiter_.reset(new DependencyWaiterImpl(raylet_client)); } diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 134c5691ca873..7e6825496fd67 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -167,7 +167,8 @@ class DependencyWaiter { class DependencyWaiterImpl : public DependencyWaiter { public: - DependencyWaiterImpl(RayletClient &raylet_client) : raylet_client_(raylet_client) {} + DependencyWaiterImpl(raylet::RayletClient &raylet_client) + : raylet_client_(raylet_client) {} void Wait(const std::vector &dependencies, std::function on_dependencies_available) override { @@ -187,7 +188,7 @@ class DependencyWaiterImpl : public DependencyWaiter { private: int64_t next_request_id_ = 0; std::unordered_map> requests_; - RayletClient &raylet_client_; + raylet::RayletClient &raylet_client_; }; /// Wraps a thread-pool to block posts until the pool has free slots. This is used @@ -436,7 +437,7 @@ class CoreWorkerDirectTaskReceiver { } /// Initialize this receiver. This must be called prior to use. - void Init(RayletClient &client); + void Init(raylet::RayletClient &client); /// Handle a `PushTask` request. /// diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc index 3c9982260c151..941531621346d 100644 --- a/src/ray/core_worker/transport/raylet_transport.cc +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -6,7 +6,7 @@ namespace ray { CoreWorkerRayletTaskReceiver::CoreWorkerRayletTaskReceiver( - const WorkerID &worker_id, std::shared_ptr &raylet_client, + const WorkerID &worker_id, std::shared_ptr &raylet_client, const TaskHandler &task_handler, const std::function &exit_handler) : worker_id_(worker_id), raylet_client_(raylet_client), diff --git a/src/ray/core_worker/transport/raylet_transport.h b/src/ray/core_worker/transport/raylet_transport.h index 0ced28c6e5cef..3a392e573fccf 100644 --- a/src/ray/core_worker/transport/raylet_transport.h +++ b/src/ray/core_worker/transport/raylet_transport.h @@ -17,7 +17,7 @@ class CoreWorkerRayletTaskReceiver { std::vector> *return_objects)>; CoreWorkerRayletTaskReceiver(const WorkerID &worker_id, - std::shared_ptr &raylet_client, + std::shared_ptr &raylet_client, const TaskHandler &task_handler, const std::function &exit_handler); @@ -37,7 +37,7 @@ class CoreWorkerRayletTaskReceiver { WorkerID worker_id_; /// Reference to the core worker's raylet client. This is a pointer ref so that it /// can be initialized by core worker after this class is constructed. - std::shared_ptr &raylet_client_; + std::shared_ptr &raylet_client_; /// The callback function to process a task. TaskHandler task_handler_; /// The callback function to exit the worker. diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 80d8888efd297..dd145de5cf68e 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -92,8 +92,10 @@ int write_bytes(Socket &conn, uint8_t *cursor, size_t length) { return 0; } -RayletConnection::RayletConnection(const std::string &raylet_socket, int num_retries, - int64_t timeout) { +namespace ray { + +raylet::RayletConnection::RayletConnection(const std::string &raylet_socket, + int num_retries, int64_t timeout) { // Pick the default values if the user did not specify. if (num_retries < 0) { num_retries = RayConfig::instance().num_connect_attempts(); @@ -122,9 +124,9 @@ RayletConnection::RayletConnection(const std::string &raylet_socket, int num_ret } } -ray::Status RayletConnection::Disconnect() { +Status raylet::RayletConnection::Disconnect() { flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateDisconnectClient(fbb); + auto message = protocol::CreateDisconnectClient(fbb); fbb.Finish(message); auto status = WriteMessage(MessageType::IntentionalDisconnectClient, &fbb); // Don't be too strict for disconnection errors. @@ -133,11 +135,11 @@ ray::Status RayletConnection::Disconnect() { RAY_LOG(ERROR) << status.ToString() << " [RayletClient] Failed to disconnect from raylet."; } - return ray::Status::OK(); + return Status::OK(); } -ray::Status RayletConnection::ReadMessage(MessageType type, - std::unique_ptr &message) { +Status raylet::RayletConnection::ReadMessage(MessageType type, + std::unique_ptr &message) { int64_t cookie; int64_t type_field; int64_t length; @@ -159,26 +161,26 @@ ray::Status RayletConnection::ReadMessage(MessageType type, length = 0; } if (type_field == static_cast(MessageType::DisconnectClient)) { - return ray::Status::IOError("[RayletClient] Raylet connection closed."); + return Status::IOError("[RayletClient] Raylet connection closed."); } if (type_field != static_cast(type)) { - return ray::Status::TypeError( + return Status::TypeError( std::string("[RayletClient] Raylet connection corrupted. ") + "Expected message type: " + std::to_string(static_cast(type)) + "; got message type: " + std::to_string(type_field) + ". Check logs or dmesg for previous errors."); } - return ray::Status::OK(); + return Status::OK(); } -ray::Status RayletConnection::WriteMessage(MessageType type, - flatbuffers::FlatBufferBuilder *fbb) { +Status raylet::RayletConnection::WriteMessage(MessageType type, + flatbuffers::FlatBufferBuilder *fbb) { std::unique_lock guard(write_mutex_); int64_t cookie = RayConfig::instance().ray_cookie(); int64_t length = fbb ? fbb->GetSize() : 0; uint8_t *bytes = fbb ? fbb->GetBufferPointer() : nullptr; int64_t type_field = static_cast(type); - auto io_error = ray::Status::IOError("[RayletClient] Connection closed unexpectedly."); + auto io_error = Status::IOError("[RayletClient] Connection closed unexpectedly."); int closed; closed = write_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie)); if (closed) return io_error; @@ -188,10 +190,10 @@ ray::Status RayletConnection::WriteMessage(MessageType type, if (closed) return io_error; closed = write_bytes(conn_, bytes, length * sizeof(char)); if (closed) return io_error; - return ray::Status::OK(); + return Status::OK(); } -ray::Status RayletConnection::AtomicRequestReply( +Status raylet::RayletConnection::AtomicRequestReply( MessageType request_type, MessageType reply_type, std::unique_ptr &reply_message, flatbuffers::FlatBufferBuilder *fbb) { std::unique_lock guard(mutex_); @@ -200,19 +202,21 @@ ray::Status RayletConnection::AtomicRequestReply( return ReadMessage(reply_type, reply_message); } -RayletClient::RayletClient(std::shared_ptr grpc_client) +raylet::RayletClient::RayletClient( + std::shared_ptr grpc_client) : grpc_client_(std::move(grpc_client)) {} -RayletClient::RayletClient(std::shared_ptr grpc_client, - const std::string &raylet_socket, const WorkerID &worker_id, - bool is_worker, const JobID &job_id, const Language &language, - ClientID *raylet_id, int port) +raylet::RayletClient::RayletClient( + std::shared_ptr grpc_client, + const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker, + const JobID &job_id, const Language &language, ClientID *raylet_id, int port) : grpc_client_(std::move(grpc_client)), worker_id_(worker_id), job_id_(job_id) { // For C++14, we could use std::make_unique - conn_ = std::unique_ptr(new RayletConnection(raylet_socket, -1, -1)); + conn_ = std::unique_ptr( + new raylet::RayletConnection(raylet_socket, -1, -1)); flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateRegisterClientRequest( + auto message = protocol::CreateRegisterClientRequest( fbb, is_worker, to_flatbuf(fbb, worker_id), getpid(), to_flatbuf(fbb, job_id), language, port); fbb.Finish(message); @@ -222,12 +226,11 @@ RayletClient::RayletClient(std::shared_ptr gr auto status = conn_->AtomicRequestReply(MessageType::RegisterClientRequest, MessageType::RegisterClientReply, reply, &fbb); RAY_CHECK_OK_PREPEND(status, "[RayletClient] Unable to register worker with raylet."); - auto reply_message = - flatbuffers::GetRoot(reply.get()); + auto reply_message = flatbuffers::GetRoot(reply.get()); *raylet_id = ClientID::FromBinary(reply_message->raylet_id()->str()); } -ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) { +Status raylet::RayletClient::SubmitTask(const TaskSpecification &task_spec) { for (size_t i = 0; i < task_spec.NumArgs(); i++) { if (task_spec.ArgByRef(i)) { for (size_t j = 0; j < task_spec.ArgIdCount(i); j++) { @@ -237,58 +240,57 @@ ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) { } } flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateSubmitTaskRequest( - fbb, fbb.CreateString(task_spec.Serialize())); + auto message = + protocol::CreateSubmitTaskRequest(fbb, fbb.CreateString(task_spec.Serialize())); fbb.Finish(message); return conn_->WriteMessage(MessageType::SubmitTask, &fbb); } -ray::Status RayletClient::TaskDone() { +Status raylet::RayletClient::TaskDone() { return conn_->WriteMessage(MessageType::TaskDone); } -ray::Status RayletClient::FetchOrReconstruct(const std::vector &object_ids, - bool fetch_only, bool mark_worker_blocked, - const TaskID ¤t_task_id) { +Status raylet::RayletClient::FetchOrReconstruct(const std::vector &object_ids, + bool fetch_only, bool mark_worker_blocked, + const TaskID ¤t_task_id) { flatbuffers::FlatBufferBuilder fbb; auto object_ids_message = to_flatbuf(fbb, object_ids); - auto message = ray::protocol::CreateFetchOrReconstruct( - fbb, object_ids_message, fetch_only, mark_worker_blocked, - to_flatbuf(fbb, current_task_id)); + auto message = protocol::CreateFetchOrReconstruct(fbb, object_ids_message, fetch_only, + mark_worker_blocked, + to_flatbuf(fbb, current_task_id)); fbb.Finish(message); auto status = conn_->WriteMessage(MessageType::FetchOrReconstruct, &fbb); return status; } -ray::Status RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) { +Status raylet::RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) { flatbuffers::FlatBufferBuilder fbb; - auto message = - ray::protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id)); + auto message = protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id)); fbb.Finish(message); return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb); } -ray::Status RayletClient::NotifyDirectCallTaskBlocked() { +Status raylet::RayletClient::NotifyDirectCallTaskBlocked() { flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateNotifyDirectCallTaskBlocked(fbb); + auto message = protocol::CreateNotifyDirectCallTaskBlocked(fbb); fbb.Finish(message); return conn_->WriteMessage(MessageType::NotifyDirectCallTaskBlocked, &fbb); } -ray::Status RayletClient::NotifyDirectCallTaskUnblocked() { +Status raylet::RayletClient::NotifyDirectCallTaskUnblocked() { flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateNotifyDirectCallTaskUnblocked(fbb); + auto message = protocol::CreateNotifyDirectCallTaskUnblocked(fbb); fbb.Finish(message); return conn_->WriteMessage(MessageType::NotifyDirectCallTaskUnblocked, &fbb); } -ray::Status RayletClient::Wait(const std::vector &object_ids, int num_returns, - int64_t timeout_milliseconds, bool wait_local, - bool mark_worker_blocked, const TaskID ¤t_task_id, - WaitResultPair *result) { +Status raylet::RayletClient::Wait(const std::vector &object_ids, + int num_returns, int64_t timeout_milliseconds, + bool wait_local, bool mark_worker_blocked, + const TaskID ¤t_task_id, WaitResultPair *result) { // Write request. flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateWaitRequest( + auto message = protocol::CreateWaitRequest( fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, wait_local, mark_worker_blocked, to_flatbuf(fbb, current_task_id)); fbb.Finish(message); @@ -297,7 +299,7 @@ ray::Status RayletClient::Wait(const std::vector &object_ids, int num_ MessageType::WaitReply, reply, &fbb); if (!status.ok()) return status; // Parse the flatbuffer object. - auto reply_message = flatbuffers::GetRoot(reply.get()); + auto reply_message = flatbuffers::GetRoot(reply.get()); auto found = reply_message->found(); for (size_t i = 0; i < found->size(); i++) { ObjectID object_id = ObjectID::FromBinary(found->Get(i)->str()); @@ -308,22 +310,23 @@ ray::Status RayletClient::Wait(const std::vector &object_ids, int num_ ObjectID object_id = ObjectID::FromBinary(remaining->Get(i)->str()); result->second.push_back(object_id); } - return ray::Status::OK(); + return Status::OK(); } -ray::Status RayletClient::WaitForDirectActorCallArgs( +Status raylet::RayletClient::WaitForDirectActorCallArgs( const std::vector &object_ids, int64_t tag) { flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateWaitForDirectActorCallArgsRequest( + auto message = protocol::CreateWaitForDirectActorCallArgsRequest( fbb, to_flatbuf(fbb, object_ids), tag); fbb.Finish(message); return conn_->WriteMessage(MessageType::WaitForDirectActorCallArgsRequest, &fbb); } -ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string &type, - const std::string &error_message, double timestamp) { +Status raylet::RayletClient::PushError(const JobID &job_id, const std::string &type, + const std::string &error_message, + double timestamp) { flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreatePushErrorRequest( + auto message = protocol::CreatePushErrorRequest( fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type), fbb.CreateString(error_message), timestamp); fbb.Finish(message); @@ -331,7 +334,7 @@ ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb); } -ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) { +Status raylet::RayletClient::PushProfileEvents(const ProfileTableData &profile_events) { flatbuffers::FlatBufferBuilder fbb; auto message = fbb.CreateString(profile_events.SerializeAsString()); fbb.Finish(message); @@ -342,13 +345,13 @@ ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_even RAY_LOG(ERROR) << status.ToString() << " [RayletClient] Failed to push profile events."; } - return ray::Status::OK(); + return Status::OK(); } -ray::Status RayletClient::FreeObjects(const std::vector &object_ids, - bool local_only, bool delete_creating_tasks) { +Status raylet::RayletClient::FreeObjects(const std::vector &object_ids, + bool local_only, bool delete_creating_tasks) { flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateFreeObjectsRequest( + auto message = protocol::CreateFreeObjectsRequest( fbb, local_only, delete_creating_tasks, to_flatbuf(fbb, object_ids)); fbb.Finish(message); @@ -356,11 +359,11 @@ ray::Status RayletClient::FreeObjects(const std::vector &object_i return status; } -ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, - ActorCheckpointID &checkpoint_id) { +Status raylet::RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, + ActorCheckpointID &checkpoint_id) { flatbuffers::FlatBufferBuilder fbb; auto message = - ray::protocol::CreatePrepareActorCheckpointRequest(fbb, to_flatbuf(fbb, actor_id)); + protocol::CreatePrepareActorCheckpointRequest(fbb, to_flatbuf(fbb, actor_id)); fbb.Finish(message); std::unique_ptr reply; @@ -369,57 +372,58 @@ ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, MessageType::PrepareActorCheckpointReply, reply, &fbb); if (!status.ok()) return status; auto reply_message = - flatbuffers::GetRoot(reply.get()); + flatbuffers::GetRoot(reply.get()); checkpoint_id = ActorCheckpointID::FromBinary(reply_message->checkpoint_id()->str()); - return ray::Status::OK(); + return Status::OK(); } -ray::Status RayletClient::NotifyActorResumedFromCheckpoint( +Status raylet::RayletClient::NotifyActorResumedFromCheckpoint( const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) { flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateNotifyActorResumedFromCheckpoint( + auto message = protocol::CreateNotifyActorResumedFromCheckpoint( fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, checkpoint_id)); fbb.Finish(message); return conn_->WriteMessage(MessageType::NotifyActorResumedFromCheckpoint, &fbb); } -ray::Status RayletClient::SetResource(const std::string &resource_name, - const double capacity, - const ray::ClientID &client_Id) { +Status raylet::RayletClient::SetResource(const std::string &resource_name, + const double capacity, + const ClientID &client_Id) { flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateSetResourceRequest( - fbb, fbb.CreateString(resource_name), capacity, to_flatbuf(fbb, client_Id)); + auto message = protocol::CreateSetResourceRequest(fbb, fbb.CreateString(resource_name), + capacity, to_flatbuf(fbb, client_Id)); fbb.Finish(message); return conn_->WriteMessage(MessageType::SetResourceRequest, &fbb); } -ray::Status RayletClient::ReportActiveObjectIDs( +Status raylet::RayletClient::ReportActiveObjectIDs( const std::unordered_set &object_ids) { flatbuffers::FlatBufferBuilder fbb; - auto message = - ray::protocol::CreateReportActiveObjectIDs(fbb, to_flatbuf(fbb, object_ids)); + auto message = protocol::CreateReportActiveObjectIDs(fbb, to_flatbuf(fbb, object_ids)); fbb.Finish(message); return conn_->WriteMessage(MessageType::ReportActiveObjectIDs, &fbb); } -ray::Status RayletClient::RequestWorkerLease( - const ray::TaskSpecification &resource_spec, - const ray::rpc::ClientCallback &callback) { - ray::rpc::WorkerLeaseRequest request; +Status raylet::RayletClient::RequestWorkerLease( + const TaskSpecification &resource_spec, + const rpc::ClientCallback &callback) { + rpc::WorkerLeaseRequest request; request.mutable_resource_spec()->CopyFrom(resource_spec.GetMessage()); return grpc_client_->RequestWorkerLease(request, callback); } -ray::Status RayletClient::ReturnWorker(int worker_port, bool disconnect_worker) { - ray::rpc::ReturnWorkerRequest request; +Status raylet::RayletClient::ReturnWorker(int worker_port, bool disconnect_worker) { + rpc::ReturnWorkerRequest request; request.set_worker_port(worker_port); request.set_disconnect_worker(disconnect_worker); return grpc_client_->ReturnWorker( - request, [](const ray::Status &status, const ray::rpc::ReturnWorkerReply &reply) { + request, [](const Status &status, const rpc::ReturnWorkerReply &reply) { if (!status.ok()) { RAY_LOG(INFO) << "Error returning worker: " << status; } }); } + +} // namespace ray diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index f1e48c4de10aa..e45e58abd307a 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -30,6 +30,29 @@ using ResourceMappingType = using Socket = boost::asio::detail::socket_holder; using WaitResultPair = std::pair, std::vector>; +namespace ray { + +/// Interface for leasing workers. Abstract for testing. +class WorkerLeaseInterface { + public: + /// Requests a worker from the raylet. The callback will be sent via gRPC. + /// \param resource_spec Resources that should be allocated for the worker. + /// \return ray::Status + virtual ray::Status RequestWorkerLease( + const ray::TaskSpecification &resource_spec, + const ray::rpc::ClientCallback &callback) = 0; + + /// Returns a worker to the raylet. + /// \param worker_port The local port of the worker on the raylet node. + /// \param disconnect_worker Whether the raylet should disconnect the worker. + /// \return ray::Status + virtual ray::Status ReturnWorker(int worker_port, bool disconnect_worker) = 0; + + virtual ~WorkerLeaseInterface(){}; +}; + +namespace raylet { + class RayletConnection { public: /// Connect to the raylet. @@ -49,9 +72,12 @@ class RayletConnection { /// /// \return ray::Status. ray::Status Disconnect(); + ray::Status ReadMessage(MessageType type, std::unique_ptr &message); + ray::Status WriteMessage(MessageType type, flatbuffers::FlatBufferBuilder *fbb = nullptr); + ray::Status AtomicRequestReply(MessageType request_type, MessageType reply_type, std::unique_ptr &reply_message, flatbuffers::FlatBufferBuilder *fbb = nullptr); @@ -65,25 +91,6 @@ class RayletConnection { std::mutex write_mutex_; }; -/// Interface for leasing workers. Abstract for testing. -class WorkerLeaseInterface { - public: - /// Requests a worker from the raylet. The callback will be sent via gRPC. - /// \param resource_spec Resources that should be allocated for the worker. - /// \return ray::Status - virtual ray::Status RequestWorkerLease( - const ray::TaskSpecification &resource_spec, - const ray::rpc::ClientCallback &callback) = 0; - - /// Returns a worker to the raylet. - /// \param worker_port The local port of the worker on the raylet node. - /// \param disconnect_worker Whether the raylet should disconnect the worker. - /// \return ray::Status - virtual ray::Status ReturnWorker(int worker_port, bool disconnect_worker) = 0; - - virtual ~WorkerLeaseInterface(){}; -}; - class RayletClient : public WorkerLeaseInterface { public: /// Connect to the raylet. @@ -258,4 +265,8 @@ class RayletClient : public WorkerLeaseInterface { std::unique_ptr conn_; }; +} // namespace raylet + +} // namespace ray + #endif