Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core worker] Refactor CoreWorker member classes #5062

Merged
6 changes: 2 additions & 4 deletions src/ray/core_worker/common.h
Expand Up @@ -5,6 +5,7 @@

#include "ray/common/buffer.h"
#include "ray/common/id.h"
#include "ray/gcs/format/gcs_generated.h"
#include "ray/raylet/raylet_client.h"
#include "ray/raylet/task_spec.h"

Expand All @@ -13,13 +14,10 @@ namespace ray {
/// Type of this worker.
enum class WorkerType { WORKER, DRIVER };

/// Language of Ray tasks and workers.
enum class WorkerLanguage { PYTHON, JAVA };

/// Information about a remote function.
struct RayFunction {
/// Language of the remote function.
const WorkerLanguage language;
const Language language;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original intent of having a WorkerLanguage structure is that users don't need to look at the generated protobuf file for the Language definition, as that generated .h file is complex so we want to hide that complexity. But yes I agree that using Language directly makes the core worker code simpler, probably that's the right thing to do. Maybe we can consider separating Language definition to a different protobuf file say common.proto ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yes, I think having a single Language type is cleaner. Sure, should I move Language to a new protobuf file now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never mind, I could take care of it in my next PR:)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I'm working on common.proto right now.

/// Function descriptor of the remote function.
const std::vector<std::string> function_descriptor;
};
Expand Down
41 changes: 9 additions & 32 deletions src/ray/core_worker/core_worker.cc
Expand Up @@ -3,46 +3,23 @@

namespace ray {

CoreWorker::CoreWorker(const enum WorkerType worker_type,
const enum WorkerLanguage language,
CoreWorker::CoreWorker(const enum WorkerType worker_type, const ::Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id)
: worker_type_(worker_type),
language_(language),
store_socket_(store_socket),
raylet_socket_(raylet_socket),
worker_context_(worker_type, job_id),
// TODO(zhijunfu): currently RayletClient would crash in its constructor
// if it cannot connect to Raylet after a number of retries, this needs
// to be changed so that the worker (java/python .etc) can retrieve and
// handle the error instead of crashing.
raylet_client_(raylet_socket_,
ClientID::FromBinary(worker_context_.GetWorkerID().Binary()),
(worker_type_ == ray::WorkerType::WORKER),
worker_context_.GetCurrentJobID(), ToTaskLanguage(language_)),
task_interface_(*this),
object_interface_(*this),
task_execution_interface_(*this) {
// TODO(zhijunfu): currently RayletClient would crash in its constructor if it cannot
// connect to Raylet after a number of retries, this needs to be changed
// so that the worker (java/python .etc) can retrieve and handle the error
// instead of crashing.
auto status = store_client_.Connect(store_socket_);
if (!status.ok()) {
RAY_LOG(ERROR) << "Connecting plasma store failed when trying to construct"
<< " core worker: " << status.message();
throw std::runtime_error(status.message());
}
}

::Language CoreWorker::ToTaskLanguage(WorkerLanguage language) {
switch (language) {
case ray::WorkerLanguage::JAVA:
return ::Language::JAVA;
break;
case ray::WorkerLanguage::PYTHON:
return ::Language::PYTHON;
break;
default:
RAY_LOG(FATAL) << "invalid language specified: " << static_cast<int>(language);
break;
}
}
worker_context_.GetCurrentJobID(), language_),
task_interface_(worker_context_, raylet_client_),
object_interface_(worker_context_, raylet_client_, store_socket),
task_execution_interface_(worker_context_, raylet_client_, object_interface_) {}

} // namespace ray
26 changes: 4 additions & 22 deletions src/ray/core_worker/core_worker.h
Expand Up @@ -7,6 +7,7 @@
#include "ray/core_worker/object_interface.h"
#include "ray/core_worker/task_execution.h"
#include "ray/core_worker/task_interface.h"
#include "ray/gcs/format/gcs_generated.h"
#include "ray/raylet/raylet_client.h"

namespace ray {
Expand All @@ -22,15 +23,15 @@ class CoreWorker {
/// \param[in] langauge Language of this worker.
///
/// NOTE(zhijunfu): the constructor would throw if a failure happens.
CoreWorker(const WorkerType worker_type, const WorkerLanguage language,
CoreWorker(const WorkerType worker_type, const ::Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id = JobID::Nil());

/// Type of this worker.
enum WorkerType WorkerType() const { return worker_type_; }

/// Language of this worker.
enum WorkerLanguage Language() const { return language_; }
::Language Language() const { return language_; }

/// Return the `CoreWorkerTaskInterface` that contains the methods related to task
/// submisson.
Expand All @@ -45,33 +46,18 @@ class CoreWorker {
CoreWorkerTaskExecutionInterface &Execution() { return task_execution_interface_; }

private:
/// Translate from WorkLanguage to Language type (required by raylet client).
///
/// \param[in] language Language for a task.
/// \return Translated task language.
::Language ToTaskLanguage(WorkerLanguage language);

/// Type of this worker.
const enum WorkerType worker_type_;

/// Language of this worker.
const enum WorkerLanguage language_;

/// Plasma store socket name.
const std::string store_socket_;
const ::Language language_;

/// raylet socket name.
const std::string raylet_socket_;

/// Worker context.
WorkerContext worker_context_;

/// Plasma store client.
plasma::PlasmaClient store_client_;

/// Mutex to protect store_client_.
std::mutex store_client_mutex_;

/// Raylet client.
RayletClient raylet_client_;

Expand All @@ -83,10 +69,6 @@ class CoreWorker {

/// The `CoreWorkerTaskExecutionInterface` instance.
CoreWorkerTaskExecutionInterface task_execution_interface_;

friend class CoreWorkerTaskInterface;
friend class CoreWorkerObjectInterface;
friend class CoreWorkerTaskExecutionInterface;
};

} // namespace ray
Expand Down
24 changes: 11 additions & 13 deletions src/ray/core_worker/object_interface.cc
@@ -1,23 +1,22 @@
#include "ray/core_worker/object_interface.h"
#include "ray/common/ray_config.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/core_worker.h"
#include "ray/core_worker/store_provider/plasma_store_provider.h"

namespace ray {

CoreWorkerObjectInterface::CoreWorkerObjectInterface(CoreWorker &core_worker)
: core_worker_(core_worker) {
CoreWorkerObjectInterface::CoreWorkerObjectInterface(WorkerContext &worker_context,
RayletClient &raylet_client,
const std::string &store_socket)
: worker_context_(worker_context), raylet_client_(raylet_client) {
store_providers_.emplace(
static_cast<int>(StoreProviderType::PLASMA),
std::unique_ptr<CoreWorkerStoreProvider>(new CoreWorkerPlasmaStoreProvider(
core_worker_.store_client_, core_worker_.store_client_mutex_,
core_worker_.raylet_client_)));
std::unique_ptr<CoreWorkerStoreProvider>(
new CoreWorkerPlasmaStoreProvider(store_socket, raylet_client_)));
}

Status CoreWorkerObjectInterface::Put(const RayObject &object, ObjectID *object_id) {
ObjectID put_id = ObjectID::ForPut(core_worker_.worker_context_.GetCurrentTaskID(),
core_worker_.worker_context_.GetNextPutIndex());
ObjectID put_id = ObjectID::ForPut(worker_context_.GetCurrentTaskID(),
worker_context_.GetNextPutIndex());
*object_id = put_id;
return Put(object, put_id);
}
Expand All @@ -32,17 +31,16 @@ Status CoreWorkerObjectInterface::Get(const std::vector<ObjectID> &ids,
int64_t timeout_ms,
std::vector<std::shared_ptr<RayObject>> *results) {
auto type = static_cast<int>(StoreProviderType::PLASMA);
return store_providers_[type]->Get(
ids, timeout_ms, core_worker_.worker_context_.GetCurrentTaskID(), results);
return store_providers_[type]->Get(ids, timeout_ms, worker_context_.GetCurrentTaskID(),
results);
}

Status CoreWorkerObjectInterface::Wait(const std::vector<ObjectID> &object_ids,
int num_objects, int64_t timeout_ms,
std::vector<bool> *results) {
auto type = static_cast<int>(StoreProviderType::PLASMA);
return store_providers_[type]->Wait(object_ids, num_objects, timeout_ms,
core_worker_.worker_context_.GetCurrentTaskID(),
results);
worker_context_.GetCurrentTaskID(), results);
}

Status CoreWorkerObjectInterface::Delete(const std::vector<ObjectID> &object_ids,
Expand Down
10 changes: 7 additions & 3 deletions src/ray/core_worker/object_interface.h
Expand Up @@ -6,6 +6,7 @@
#include "ray/common/id.h"
#include "ray/common/status.h"
#include "ray/core_worker/common.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/store_provider/store_provider.h"

namespace ray {
Expand All @@ -16,7 +17,8 @@ class CoreWorkerStoreProvider;
/// The interface that contains all `CoreWorker` methods that are related to object store.
class CoreWorkerObjectInterface {
public:
CoreWorkerObjectInterface(CoreWorker &core_worker);
CoreWorkerObjectInterface(WorkerContext &worker_context, RayletClient &raylet_client,
const std::string &store_socket);

/// Put an object into object store.
///
Expand Down Expand Up @@ -62,8 +64,10 @@ class CoreWorkerObjectInterface {
bool delete_creating_tasks);

private:
/// Reference to the parent CoreWorker instance.
CoreWorker &core_worker_;
/// Reference to the parent CoreWorker's context.
WorkerContext &worker_context_;
/// Reference to the parent CoreWorker's raylet client.
RayletClient &raylet_client_;

/// All the store providers supported.
std::unordered_map<int, std::unique_ptr<CoreWorkerStoreProvider>> store_providers_;
Expand Down
14 changes: 9 additions & 5 deletions src/ray/core_worker/store_provider/plasma_store_provider.cc
Expand Up @@ -7,11 +7,15 @@
namespace ray {

CoreWorkerPlasmaStoreProvider::CoreWorkerPlasmaStoreProvider(
plasma::PlasmaClient &store_client, std::mutex &store_client_mutex,
RayletClient &raylet_client)
: store_client_(store_client),
store_client_mutex_(store_client_mutex),
raylet_client_(raylet_client) {}
const std::string &store_socket, RayletClient &raylet_client)
: raylet_client_(raylet_client) {
auto status = store_client_.Connect(store_socket);
if (!status.ok()) {
RAY_LOG(ERROR) << "Connecting plasma store failed when trying to construct"
<< " core worker: " << status.message();
throw std::runtime_error(status.message());
}
}

Status CoreWorkerPlasmaStoreProvider::Put(const RayObject &object,
const ObjectID &object_id) {
Expand Down
7 changes: 3 additions & 4 deletions src/ray/core_worker/store_provider/plasma_store_provider.h
Expand Up @@ -17,8 +17,7 @@ class CoreWorker;
/// local and remote store, remote access is done via raylet.
class CoreWorkerPlasmaStoreProvider : public CoreWorkerStoreProvider {
public:
CoreWorkerPlasmaStoreProvider(plasma::PlasmaClient &store_client,
std::mutex &store_client_mutex,
CoreWorkerPlasmaStoreProvider(const std::string &store_socket,
RayletClient &raylet_client);

/// Put an object with specified ID into object store.
Expand Down Expand Up @@ -62,10 +61,10 @@ class CoreWorkerPlasmaStoreProvider : public CoreWorkerStoreProvider {

private:
/// Plasma store client.
plasma::PlasmaClient &store_client_;
plasma::PlasmaClient store_client_;

/// Mutex to protect store_client_.
std::mutex &store_client_mutex_;
std::mutex store_client_mutex_;

/// Raylet client.
RayletClient &raylet_client_;
Expand Down
21 changes: 9 additions & 12 deletions src/ray/core_worker/task_execution.cc
Expand Up @@ -6,12 +6,12 @@
namespace ray {

CoreWorkerTaskExecutionInterface::CoreWorkerTaskExecutionInterface(
CoreWorker &core_worker)
: core_worker_(core_worker) {
task_receivers.emplace(
static_cast<int>(TaskTransportType::RAYLET),
std::unique_ptr<CoreWorkerRayletTaskReceiver>(
new CoreWorkerRayletTaskReceiver(core_worker_.raylet_client_)));
WorkerContext &worker_context, RayletClient &raylet_client,
CoreWorkerObjectInterface &object_interface)
: worker_context_(worker_context), object_interface_(object_interface) {
task_receivers.emplace(static_cast<int>(TaskTransportType::RAYLET),
std::unique_ptr<CoreWorkerRayletTaskReceiver>(
new CoreWorkerRayletTaskReceiver(raylet_client)));
}

Status CoreWorkerTaskExecutionInterface::Run(const TaskExecutor &executor) {
Expand All @@ -27,12 +27,9 @@ Status CoreWorkerTaskExecutionInterface::Run(const TaskExecutor &executor) {

for (const auto &task : tasks) {
const auto &spec = task.GetTaskSpecification();
core_worker_.worker_context_.SetCurrentTask(spec);
worker_context_.SetCurrentTask(spec);

WorkerLanguage language = (spec.GetLanguage() == ::Language::JAVA)
? WorkerLanguage::JAVA
: WorkerLanguage::PYTHON;
RayFunction func{language, spec.FunctionDescriptor()};
RayFunction func{spec.GetLanguage(), spec.FunctionDescriptor()};

std::vector<std::shared_ptr<RayObject>> args;
RAY_CHECK_OK(BuildArgsForExecutor(spec, &args));
Expand Down Expand Up @@ -92,7 +89,7 @@ Status CoreWorkerTaskExecutionInterface::BuildArgsForExecutor(
}

std::vector<std::shared_ptr<RayObject>> results;
auto status = core_worker_.object_interface_.Get(object_ids_to_fetch, -1, &results);
auto status = object_interface_.Get(object_ids_to_fetch, -1, &results);
if (status.ok()) {
for (size_t i = 0; i < results.size(); i++) {
(*args)[indices[i]] = results[i];
Expand Down
14 changes: 10 additions & 4 deletions src/ray/core_worker/task_execution.h
Expand Up @@ -4,7 +4,8 @@
#include "ray/common/buffer.h"
#include "ray/common/status.h"
#include "ray/core_worker/common.h"
#include "ray/core_worker/store_provider/store_provider.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/object_interface.h"
#include "ray/core_worker/transport/transport.h"

namespace ray {
Expand All @@ -19,7 +20,10 @@ class TaskSpecification;
/// execution.
class CoreWorkerTaskExecutionInterface {
public:
CoreWorkerTaskExecutionInterface(CoreWorker &core_worker);
CoreWorkerTaskExecutionInterface(WorkerContext &worker_context,
RayletClient &raylet_client,
CoreWorkerObjectInterface &object_interface);

/// The callback provided app-language workers that executes tasks.
///
/// \param ray_function[in] Information about the function to execute.
Expand All @@ -46,8 +50,10 @@ class CoreWorkerTaskExecutionInterface {
Status BuildArgsForExecutor(const raylet::TaskSpecification &spec,
std::vector<std::shared_ptr<RayObject>> *args);

/// Reference to the parent CoreWorker instance.
CoreWorker &core_worker_;
/// Reference to the parent CoreWorker's context.
WorkerContext &worker_context_;
/// Reference to the parent CoreWorker's objects interface.
CoreWorkerObjectInterface &object_interface_;

/// All the task task receivers supported.
std::unordered_map<int, std::unique_ptr<CoreWorkerTaskReceiver>> task_receivers;
Expand Down