Skip to content

Commit

Permalink
[Core worker] Serialize ActorHandle in core worker. Make ActorHandle …
Browse files Browse the repository at this point in the history
…thread safe. (ray-project#5034)

* Serialize ActorHandle in core worker. Make ActorHandle thread safe.

* Address comments

* Address comments

* Address comments

* Address comments

* lint

* Address comments

* Address comments

* Address comments

* Address comments

* Minor update

* Address comments

* lint
  • Loading branch information
kfstorm authored and jovany-wang committed Jul 2, 2019
1 parent 904dcf0 commit 1cf7728
Show file tree
Hide file tree
Showing 12 changed files with 304 additions and 82 deletions.
11 changes: 11 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ cc_proto_library(
deps = ["object_manager_proto"],
)

proto_library(
name = "core_worker_proto",
srcs = ["src/ray/protobuf/core_worker.proto"],
)

cc_proto_library(
name = "core_worker_cc_proto",
deps = ["core_worker_proto"],
)

# === End of protobuf definitions ===

# === Begin of rpc definitions ===
Expand Down Expand Up @@ -239,6 +249,7 @@ cc_library(
]),
copts = COPTS,
deps = [
":core_worker_cc_proto",
":ray_common",
":ray_util",
":raylet_lib",
Expand Down
16 changes: 16 additions & 0 deletions src/ray/common/id.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,22 @@ const JobID ComputeJobId(const WorkerID &driver_id) {
return JobID(driver_id);
}

const ActorHandleID ComputeNextActorHandleId(const ActorHandleID &actor_handle_id,
int64_t num_forks) {
// Compute hashes.
SHA256_CTX ctx;
sha256_init(&ctx);
sha256_update(&ctx, reinterpret_cast<const BYTE *>(actor_handle_id.Data()),
actor_handle_id.Size());
sha256_update(&ctx, reinterpret_cast<const BYTE *>(&num_forks), sizeof(num_forks));

// Compute the final actor handle ID from the hash.
BYTE buff[DIGEST_SIZE];
sha256_final(&ctx, buff);
RAY_CHECK(DIGEST_SIZE >= ActorHandleID::Size());
return ActorHandleID::FromBinary(std::string(buff, buff + ActorHandleID::Size()));
}

#define ID_OSTREAM_OPERATOR(id_type) \
std::ostream &operator<<(std::ostream &os, const id_type &id) { \
if (id.IsNil()) { \
Expand Down
8 changes: 8 additions & 0 deletions src/ray/common/id.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ std::ostream &operator<<(std::ostream &os, const ObjectID &id);
const TaskID GenerateTaskId(const JobID &job_id, const TaskID &parent_task_id,
int parent_task_counter);

/// Compute the next actor handle ID of a new actor handle during a fork operation.
///
/// \param actor_handle_id The actor handle ID of original actor.
/// \param num_forks The count of forks of original actor.
/// \return The next actor handle ID generated from the given info.
const ActorHandleID ComputeNextActorHandleId(const ActorHandleID &actor_handle_id,
int64_t num_forks);

template <typename T>
BaseID<T>::BaseID() {
// Using const_cast to directly change data is dangerous. The cached
Expand Down
48 changes: 44 additions & 4 deletions src/ray/core_worker/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "ray/common/buffer.h"
#include "ray/common/id.h"
#include "ray/protobuf/gcs.pb.h"
#include "ray/raylet/raylet_client.h"
#include "ray/raylet/task_spec.h"

Expand All @@ -13,13 +14,10 @@ namespace ray {
/// Type of this worker.
enum class WorkerType { WORKER, DRIVER };

/// Language of Ray tasks and workers.
enum class WorkerLanguage { PYTHON, JAVA };

/// Information about a remote function.
struct RayFunction {
/// Language of the remote function.
const WorkerLanguage language;
const ray::rpc::Language language;
/// Function descriptor of the remote function.
const std::vector<std::string> function_descriptor;
};
Expand Down Expand Up @@ -109,6 +107,48 @@ enum class StoreProviderType { PLASMA };

enum class TaskTransportType { RAYLET };

/// Translate from ray::rpc::Language to Language type (required by raylet client).
///
/// \param[in] language Language for a task.
/// \return Translated task language.
inline ::Language ToRayletTaskLanguage(ray::rpc::Language language) {
switch (language) {
case ray::rpc::Language::JAVA:
return ::Language::JAVA;
break;
case ray::rpc::Language::PYTHON:
return ::Language::PYTHON;
break;
case ray::rpc::Language::CPP:
return ::Language::CPP;
break;
default:
RAY_LOG(FATAL) << "Invalid language specified: " << static_cast<int>(language);
break;
}
}

/// Translate from Language to ray::rpc::Language type (required by core worker).
///
/// \param[in] language Language for a task.
/// \return Translated task language.
inline ray::rpc::Language ToRpcTaskLanguage(::Language language) {
switch (language) {
case Language::JAVA:
return ray::rpc::Language::JAVA;
break;
case Language::PYTHON:
return ray::rpc::Language::PYTHON;
break;
case Language::CPP:
return ray::rpc::Language::CPP;
break;
default:
RAY_LOG(FATAL) << "Invalid language specified: " << static_cast<int>(language);
break;
}
}

} // namespace ray

#endif // RAY_CORE_WORKER_COMMON_H
21 changes: 3 additions & 18 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
namespace ray {

CoreWorker::CoreWorker(const enum WorkerType worker_type,
const enum WorkerLanguage language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id)
const ray::rpc::Language language, const std::string &store_socket,
const std::string &raylet_socket, const JobID &job_id)
: worker_type_(worker_type),
language_(language),
store_socket_(store_socket),
Expand All @@ -15,7 +14,7 @@ CoreWorker::CoreWorker(const enum WorkerType worker_type,
raylet_client_(raylet_socket_,
ClientID::FromBinary(worker_context_.GetWorkerID().Binary()),
(worker_type_ == ray::WorkerType::WORKER),
worker_context_.GetCurrentJobID(), ToTaskLanguage(language_)),
worker_context_.GetCurrentJobID(), ToRayletTaskLanguage(language_)),
task_interface_(*this),
object_interface_(*this),
task_execution_interface_(*this) {
Expand All @@ -31,18 +30,4 @@ CoreWorker::CoreWorker(const enum WorkerType worker_type,
}
}

::Language CoreWorker::ToTaskLanguage(WorkerLanguage language) {
switch (language) {
case ray::WorkerLanguage::JAVA:
return ::Language::JAVA;
break;
case ray::WorkerLanguage::PYTHON:
return ::Language::PYTHON;
break;
default:
RAY_LOG(FATAL) << "invalid language specified: " << static_cast<int>(language);
break;
}
}

} // namespace ray
12 changes: 3 additions & 9 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ class CoreWorker {
/// \param[in] langauge Language of this worker.
///
/// NOTE(zhijunfu): the constructor would throw if a failure happens.
CoreWorker(const WorkerType worker_type, const WorkerLanguage language,
CoreWorker(const WorkerType worker_type, const ray::rpc::Language language,
const std::string &store_socket, const std::string &raylet_socket,
const JobID &job_id = JobID::Nil());

/// Type of this worker.
enum WorkerType WorkerType() const { return worker_type_; }

/// Language of this worker.
enum WorkerLanguage Language() const { return language_; }
ray::rpc::Language Language() const { return language_; }

/// Return the `CoreWorkerTaskInterface` that contains the methods related to task
/// submisson.
Expand All @@ -45,17 +45,11 @@ class CoreWorker {
CoreWorkerTaskExecutionInterface &Execution() { return task_execution_interface_; }

private:
/// Translate from WorkLanguage to Language type (required by raylet client).
///
/// \param[in] language Language for a task.
/// \return Translated task language.
::Language ToTaskLanguage(WorkerLanguage language);

/// Type of this worker.
const enum WorkerType worker_type_;

/// Language of this worker.
const enum WorkerLanguage language_;
const ray::rpc::Language language_;

/// Plasma store socket name.
const std::string store_socket_;
Expand Down
53 changes: 43 additions & 10 deletions src/ray/core_worker/core_worker_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class CoreWorkerTest : public ::testing::Test {
void TearDown() {}

void TestNormalTask(const std::unordered_map<std::string, double> &resources) {
CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON,
CoreWorker driver(WorkerType::DRIVER, ray::rpc::Language::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
JobID::FromRandom());

Expand All @@ -134,7 +134,7 @@ class CoreWorkerTest : public ::testing::Test {

auto buffer1 = std::make_shared<LocalMemoryBuffer>(array1, sizeof(array1));

RayFunction func{ray::WorkerLanguage::PYTHON, {}};
RayFunction func{ray::rpc::Language::PYTHON, {}};
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByValue(buffer1));

Expand Down Expand Up @@ -165,7 +165,7 @@ class CoreWorkerTest : public ::testing::Test {
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByReference(object_id));

RayFunction func{ray::WorkerLanguage::PYTHON, {}};
RayFunction func{ray::rpc::Language::PYTHON, {}};
TaskOptions options;

std::vector<ObjectID> return_ids;
Expand All @@ -184,7 +184,7 @@ class CoreWorkerTest : public ::testing::Test {
}

void TestActorTask(const std::unordered_map<std::string, double> &resources) {
CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON,
CoreWorker driver(WorkerType::DRIVER, ray::rpc::Language::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
JobID::FromRandom());

Expand All @@ -195,7 +195,7 @@ class CoreWorkerTest : public ::testing::Test {
uint8_t array[] = {1, 2, 3};
auto buffer = std::make_shared<LocalMemoryBuffer>(array, sizeof(array));

RayFunction func{ray::WorkerLanguage::PYTHON, {}};
RayFunction func{ray::rpc::Language::PYTHON, {}};
std::vector<TaskArg> args;
args.emplace_back(TaskArg::PassByValue(buffer));

Expand Down Expand Up @@ -223,7 +223,7 @@ class CoreWorkerTest : public ::testing::Test {

TaskOptions options{1, resources};
std::vector<ObjectID> return_ids;
RayFunction func{ray::WorkerLanguage::PYTHON, {}};
RayFunction func{ray::rpc::Language::PYTHON, {}};
RAY_CHECK_OK(driver.Tasks().SubmitActorTask(*actor_handle, func, args, options,
&return_ids));
RAY_CHECK(return_ids.size() == 1);
Expand Down Expand Up @@ -302,8 +302,41 @@ TEST_F(ZeroNodeTest, TestWorkerContext) {
ASSERT_EQ(context.GetNextPutIndex(), 3);
}

TEST_F(ZeroNodeTest, TestActorHandle) {
ActorHandle handle1(ActorID::FromRandom(), ActorHandleID::FromRandom(),
ray::rpc::Language::JAVA,
{"org.ray.exampleClass", "exampleMethod", "exampleSignature"});

auto forkedHandle1 = handle1.Fork();
ASSERT_EQ(1, handle1.NumForks());
ASSERT_EQ(handle1.ActorID(), forkedHandle1.ActorID());
ASSERT_NE(handle1.ActorHandleID(), forkedHandle1.ActorHandleID());
ASSERT_EQ(handle1.ActorLanguage(), forkedHandle1.ActorLanguage());
ASSERT_EQ(handle1.ActorCreationTaskFunctionDescriptor(),
forkedHandle1.ActorCreationTaskFunctionDescriptor());
ASSERT_EQ(handle1.ActorCursor(), forkedHandle1.ActorCursor());
ASSERT_EQ(0, forkedHandle1.TaskCounter());
ASSERT_EQ(0, forkedHandle1.NumForks());
auto forkedHandle2 = handle1.Fork();
ASSERT_EQ(2, handle1.NumForks());
ASSERT_EQ(0, forkedHandle2.TaskCounter());
ASSERT_EQ(0, forkedHandle2.NumForks());

std::string buffer;
handle1.Serialize(&buffer);
auto handle2 = ActorHandle::Deserialize(buffer);
ASSERT_EQ(handle1.ActorID(), handle2.ActorID());
ASSERT_EQ(handle1.ActorHandleID(), handle2.ActorHandleID());
ASSERT_EQ(handle1.ActorLanguage(), handle2.ActorLanguage());
ASSERT_EQ(handle1.ActorCreationTaskFunctionDescriptor(),
handle2.ActorCreationTaskFunctionDescriptor());
ASSERT_EQ(handle1.ActorCursor(), handle2.ActorCursor());
ASSERT_EQ(handle1.TaskCounter(), handle2.TaskCounter());
ASSERT_EQ(handle1.NumForks(), handle2.NumForks());
}

TEST_F(SingleNodeTest, TestObjectInterface) {
CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON,
CoreWorker core_worker(WorkerType::DRIVER, ray::rpc::Language::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
JobID::FromRandom());

Expand Down Expand Up @@ -367,11 +400,11 @@ TEST_F(SingleNodeTest, TestObjectInterface) {
}

TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) {
CoreWorker worker1(WorkerType::DRIVER, WorkerLanguage::PYTHON,
CoreWorker worker1(WorkerType::DRIVER, ray::rpc::Language::PYTHON,
raylet_store_socket_names_[0], raylet_socket_names_[0],
JobID::FromRandom());

CoreWorker worker2(WorkerType::DRIVER, WorkerLanguage::PYTHON,
CoreWorker worker2(WorkerType::DRIVER, ray::rpc::Language::PYTHON,
raylet_store_socket_names_[1], raylet_socket_names_[1],
JobID::FromRandom());

Expand Down Expand Up @@ -458,7 +491,7 @@ TEST_F(TwoNodeTest, TestActorTaskCrossNodes) {

TEST_F(SingleNodeTest, TestCoreWorkerConstructorFailure) {
try {
CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, "",
CoreWorker core_worker(WorkerType::DRIVER, ray::rpc::Language::PYTHON, "",
raylet_socket_names_[0], JobID::FromRandom());
} catch (const std::exception &e) {
std::cout << "Caught exception when constructing core worker: " << e.what();
Expand Down
4 changes: 2 additions & 2 deletions src/ray/core_worker/mock_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ namespace ray {
class MockWorker {
public:
MockWorker(const std::string &store_socket, const std::string &raylet_socket)
: worker_(WorkerType::WORKER, WorkerLanguage::PYTHON, store_socket, raylet_socket,
JobID::FromRandom()) {}
: worker_(WorkerType::WORKER, ray::rpc::Language::PYTHON, store_socket,
raylet_socket, JobID::FromRandom()) {}

void Run() {
auto executor_func = [this](const RayFunction &ray_function,
Expand Down
4 changes: 1 addition & 3 deletions src/ray/core_worker/task_execution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ Status CoreWorkerTaskExecutionInterface::Run(const TaskExecutor &executor) {
const auto &spec = task.GetTaskSpecification();
core_worker_.worker_context_.SetCurrentTask(spec);

WorkerLanguage language = (spec.GetLanguage() == ::Language::JAVA)
? WorkerLanguage::JAVA
: WorkerLanguage::PYTHON;
ray::rpc::Language language = ToRpcTaskLanguage(spec.GetLanguage());
RayFunction func{language, spec.FunctionDescriptor()};

std::vector<std::shared_ptr<RayObject>> args;
Expand Down
Loading

0 comments on commit 1cf7728

Please sign in to comment.