diff --git a/.gitignore b/.gitignore index 3b2ab07531a10..a380b4fb3c653 100644 --- a/.gitignore +++ b/.gitignore @@ -130,6 +130,7 @@ scripts/nodes.txt # Pytest Cache **/.pytest_cache +**/.cache .benchmarks # Vscode @@ -145,6 +146,9 @@ java/**/.classpath java/**/.project java/runtime/native_dependencies/ +# streaming/python +streaming/python/generated/ + # python virtual env venv diff --git a/.travis.yml b/.travis.yml index d59dbc004c3ad..8a589ad9e90a0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -34,6 +34,21 @@ matrix: - if [ $RAY_CI_JAVA_AFFECTED != "1" ]; then exit; fi - ./java/test.sh + - os: linux + env: BAZEL_PYTHON_VERSION=PY3 PYTHON=3.5 PYTHONWARNINGS=ignore TESTSUITE=streaming + install: + - python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py + - eval `python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py` + - if [ $RAY_CI_STREAMING_PYTHON_AFFECTED != "1" ]; then exit; fi + - ./ci/suppress_output ./ci/travis/install-bazel.sh + - ./ci/suppress_output ./ci/travis/install-dependencies.sh + - export PATH="$HOME/miniconda/bin:$PATH" + - ./ci/suppress_output ./ci/travis/install-ray.sh + script: + # Streaming cpp test. + - if [ $RAY_CI_STREAMING_CPP_AFFECTED == "1" ]; then ./ci/suppress_output bash streaming/src/test/run_streaming_queue_test.sh; fi + - if [ RAY_CI_STREAMING_PYTHON_AFFECTED == "1" ]; then python -m pytest -v --durations=5 --timeout=300 python/ray/streaming/tests/; fi + - os: linux env: LINT=1 PYTHONWARNINGS=ignore before_install: @@ -51,7 +66,7 @@ matrix: - sphinx-build -W -b html -d _build/doctrees source _build/html - cd .. # Run Python linting, ignore dict vs {} (C408), others are defaults - - flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 + - flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,streaming/python/generated,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 - ./ci/travis/format.sh --all # Make sure that the README is formatted properly. - cd python diff --git a/BUILD.bazel b/BUILD.bazel index cb99a93f1d0d0..73e60b334fe10 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -7,6 +7,28 @@ load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") load("@com_github_grpc_grpc//bazel:cython_library.bzl", "pyx_library") load("@rules_proto_grpc//python:defs.bzl", "python_grpc_compile") load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") +load("//bazel:ray.bzl", "if_linux_x86_64") + +config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "macos", + values = { + "apple_platform_type": "macos", + "cpu": "darwin", + }, + visibility = ["//visibility:public"], +) + +config_setting( + name = "linux_x86_64", + values = {"cpu": "k8"}, + visibility = ["//visibility:public"], +) # TODO(mehrdadn): (How to) support dynamic linking? PROPAGATED_WINDOWS_DEFINES = ["RAY_STATIC"] @@ -219,6 +241,7 @@ cc_library( includes = [ "@boost//:asio", ], + visibility = ["//visibility:public"], deps = [ ":common_cc_proto", ":gcs_cc_proto", @@ -327,6 +350,7 @@ cc_library( "-lpthread", ], }), + visibility = ["//streaming:__subpackages__"], deps = [ ":common_cc_proto", ":gcs", @@ -373,6 +397,7 @@ cc_library( "src/ray/core_worker/transport/*.h", ]), copts = COPTS, + visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -659,6 +684,7 @@ cc_library( includes = [ "src", ], + visibility = ["//visibility:public"], deps = [ ":sha256", "@com_github_google_glog//:glog", @@ -782,15 +808,51 @@ pyx_library( name = "_raylet", srcs = glob([ "python/ray/__init__.py", + "python/ray/_raylet.pxd", "python/ray/_raylet.pyx", "python/ray/includes/*.pxd", "python/ray/includes/*.pxi", ]), - copts = COPTS, + # Export ray ABI symbols, which can then be used by _streaming.so. + # We need to dlopen this lib with RTLD_GLOBAL to use ABI in this + # shared lib, see python/ray/__init__.py. + cc_kwargs = { + "linkstatic": 1, + # see https://github.com/tensorflow/tensorflow/blob/r2.1/tensorflow/lite/BUILD#L444 + "linkopts": select({ + "//:macos": [ + "-Wl,-exported_symbols_list,$(location //:src/ray/ray_exported_symbols.lds)", + ], + "//:windows": [], + "//conditions:default": [ + "-Wl,--version-script,$(location //:src/ray/ray_version_script.lds)", + ], + }), + }, + copts = COPTS + if_linux_x86_64(["-fno-gnu-unique"]), deps = [ "//:core_worker_lib", "//:raylet_lib", "//:serialization_cc_proto", + "//:src/ray/ray_exported_symbols.lds", + "//:src/ray/ray_version_script.lds", + ], +) + +pyx_library( + name = "_streaming", + srcs = glob([ + "python/ray/streaming/_streaming.pyx", + "python/ray/__init__.py", + "python/ray/_raylet.pxd", + "python/ray/includes/*.pxd", + "python/ray/includes/*.pxi", + "python/ray/streaming/__init__.pxd", + "python/ray/streaming/includes/*.pxd", + "python/ray/streaming/includes/*.pxi", + ]), + deps = [ + "//streaming:streaming_lib", ], ) @@ -922,6 +984,7 @@ genrule( name = "ray_pkg", srcs = [ "python/ray/_raylet.so", + "python/ray/streaming/_streaming.so", "//:python_sources", "//:all_py_proto", "//:redis-server", @@ -930,12 +993,14 @@ genrule( "//:raylet", "//:raylet_monitor", "@plasma//:plasma_store_server", + "//streaming:copy_streaming_py_proto", ], outs = ["ray_pkg.out"], cmd = """ set -x && WORK_DIR=$$(pwd) && cp -f $(location python/ray/_raylet.so) "$$WORK_DIR/python/ray" && + cp -f $(location python/ray/streaming/_streaming.so) $$WORK_DIR/python/ray/streaming && mkdir -p "$$WORK_DIR/python/ray/core/src/ray/thirdparty/redis/src/" && cp -f $(location //:redis-server) "$$WORK_DIR/python/ray/core/src/ray/thirdparty/redis/src/" && cp -f $(location //:redis-cli) "$$WORK_DIR/python/ray/core/src/ray/thirdparty/redis/src/" && diff --git a/bazel/ray.bzl b/bazel/ray.bzl index 8f41e05aad2dc..738a5f0d3f50a 100644 --- a/bazel/ray.bzl +++ b/bazel/ray.bzl @@ -64,3 +64,9 @@ def define_java_module( "{auto_gen_header}": "", }, ) + +def if_linux_x86_64(a): + return select({ + "//:linux_x86_64": a, + "//conditions:default": [], + }) diff --git a/ci/travis/bazel-format.sh b/ci/travis/bazel-format.sh index f3b4b1a9aad88..71b96357c8031 100755 --- a/ci/travis/bazel-format.sh +++ b/ci/travis/bazel-format.sh @@ -44,6 +44,7 @@ while [[ $# > 0 ]]; do done pushd $ROOT_DIR/../.. -BAZEL_FILES="bazel/BUILD bazel/BUILD.plasma bazel/ray.bzl BUILD.bazel WORKSPACE" +BAZEL_FILES="bazel/BUILD bazel/BUILD.plasma bazel/ray.bzl BUILD.bazel + streaming/BUILD.bazel WORKSPACE" buildifier -mode=$RUN_TYPE -diff_command="diff -u" $BAZEL_FILES popd diff --git a/ci/travis/determine_tests_to_run.py b/ci/travis/determine_tests_to_run.py index 4346302d41e63..40789b0a4c6dd 100644 --- a/ci/travis/determine_tests_to_run.py +++ b/ci/travis/determine_tests_to_run.py @@ -38,6 +38,8 @@ def list_changed_files(commit_range): RAY_CI_PYTHON_AFFECTED = 0 RAY_CI_LINUX_WHEELS_AFFECTED = 0 RAY_CI_MACOS_WHEELS_AFFECTED = 0 + RAY_CI_STREAMING_CPP_AFFECTED = 0 + RAY_CI_STREAMING_PYTHON_AFFECTED = 0 if os.environ["TRAVIS_EVENT_TYPE"] == "pull_request": @@ -71,6 +73,7 @@ def list_changed_files(commit_range): RAY_CI_PYTHON_AFFECTED = 1 RAY_CI_LINUX_WHEELS_AFFECTED = 1 RAY_CI_MACOS_WHEELS_AFFECTED = 1 + RAY_CI_STREAMING_PYTHON_AFFECTED = 1 elif changed_file.startswith("java/"): RAY_CI_JAVA_AFFECTED = 1 elif any( @@ -86,6 +89,13 @@ def list_changed_files(commit_range): RAY_CI_PYTHON_AFFECTED = 1 RAY_CI_LINUX_WHEELS_AFFECTED = 1 RAY_CI_MACOS_WHEELS_AFFECTED = 1 + RAY_CI_STREAMING_CPP_AFFECTED = 1 + RAY_CI_STREAMING_PYTHON_AFFECTED = 1 + elif changed_file.startswith("streaming/src"): + RAY_CI_STREAMING_CPP_AFFECTED = 1 + RAY_CI_STREAMING_PYTHON_AFFECTED = 1 + elif changed_file.startswith("streaming/python"): + RAY_CI_STREAMING_PYTHON_AFFECTED = 1 else: RAY_CI_TUNE_AFFECTED = 1 RAY_CI_RLLIB_AFFECTED = 1 @@ -94,6 +104,7 @@ def list_changed_files(commit_range): RAY_CI_PYTHON_AFFECTED = 1 RAY_CI_LINUX_WHEELS_AFFECTED = 1 RAY_CI_MACOS_WHEELS_AFFECTED = 1 + RAY_CI_STREAMING_CPP_AFFECTED = 1 else: RAY_CI_TUNE_AFFECTED = 1 RAY_CI_RLLIB_AFFECTED = 1 @@ -102,6 +113,7 @@ def list_changed_files(commit_range): RAY_CI_PYTHON_AFFECTED = 1 RAY_CI_LINUX_WHEELS_AFFECTED = 1 RAY_CI_MACOS_WHEELS_AFFECTED = 1 + RAY_CI_STREAMING_CPP_AFFECTED = 1 # Log the modified environment variables visible in console. for output_stream in [sys.stdout, sys.stderr]: @@ -116,3 +128,7 @@ def list_changed_files(commit_range): .format(RAY_CI_LINUX_WHEELS_AFFECTED)) _print("export RAY_CI_MACOS_WHEELS_AFFECTED={}" .format(RAY_CI_MACOS_WHEELS_AFFECTED)) + _print("export RAY_CI_STREAMING_CPP_AFFECTED={}" + .format(RAY_CI_STREAMING_CPP_AFFECTED)) + _print("export RAY_CI_STREAMING_PYTHON_AFFECTED={}" + .format(RAY_CI_STREAMING_PYTHON_AFFECTED)) diff --git a/ci/travis/format.sh b/ci/travis/format.sh index 616aa0d1adda5..19efdbb314c08 100755 --- a/ci/travis/format.sh +++ b/ci/travis/format.sh @@ -79,14 +79,14 @@ format_changed() { yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" if which flake8 >/dev/null; then git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ - flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 + flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,streaming/python/generated,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 fi fi if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then if which flake8 >/dev/null; then git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \ - flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 + flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,streaming/python/generated,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 fi fi diff --git a/python/ray/__init__.py b/python/ray/__init__.py index a23e49ae4459c..3367040252750 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -3,6 +3,7 @@ from __future__ import print_function import os +from os.path import dirname import sys # MUST add pickle5 to the import path because it will be imported by some @@ -19,6 +20,14 @@ os.path.abspath(os.path.dirname(__file__)), "pickle5_files") sys.path.insert(0, pickle5_path) +# Expose ray ABI symbols which may be dependent by other shared +# libraries such as _streaming.so. See BUILD.bazel:_raylet +so_path = os.path.join(dirname(__file__), "_raylet.so") +if os.path.exists(so_path): + import ctypes + from ctypes import CDLL + CDLL(so_path, ctypes.RTLD_GLOBAL) + # MUST import ray._raylet before pyarrow to initialize some global variables. # It seems the library related to memory allocation in pyarrow will destroy the # initialization of grpc if we import pyarrow at first. diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd new file mode 100644 index 0000000000000..85b02de62ea68 --- /dev/null +++ b/python/ray/_raylet.pxd @@ -0,0 +1,70 @@ +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +from libcpp cimport bool as c_bool +from libcpp.string cimport string as c_string +from libcpp.vector cimport vector as c_vector +from libcpp.memory cimport ( + shared_ptr, + unique_ptr +) + +from ray.includes.common cimport ( + CBuffer, + CRayObject +) +from ray.includes.libcoreworker cimport CCoreWorker +from ray.includes.unique_ids cimport ( + CObjectID, + CActorID +) + +cdef class Buffer: + cdef: + shared_ptr[CBuffer] buffer + Py_ssize_t shape + Py_ssize_t strides + + @staticmethod + cdef make(const shared_ptr[CBuffer]& buffer) + +cdef class BaseID: + # To avoid the error of "Python int too large to convert to C ssize_t", + # here `cdef size_t` is required. + cdef size_t hash(self) + +cdef class ObjectID(BaseID): + cdef: + CObjectID data + object buffer_ref + # Flag indicating whether or not this object ID was added to the set + # of active IDs in the core worker so we know whether we should clean + # it up. + c_bool in_core_worker + + cdef CObjectID native(self) + +cdef class ActorID(BaseID): + cdef CActorID data + + cdef CActorID native(self) + + cdef size_t hash(self) + +cdef class CoreWorker: + cdef: + unique_ptr[CCoreWorker] core_worker + object async_thread + object async_event_loop + + cdef _create_put_buffer(self, shared_ptr[CBuffer] &metadata, + size_t data_size, ObjectID object_id, + CObjectID *c_object_id, shared_ptr[CBuffer] *data) + # TODO: handle noreturn better + cdef store_task_outputs( + self, worker, outputs, const c_vector[CObjectID] return_ids, + c_vector[shared_ptr[CRayObject]] *returns) + +cdef c_vector[c_string] string_vector_from_list(list string_list) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 8e9f6541d070e..45c47289ef608 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -41,6 +41,7 @@ from libcpp.vector cimport vector as c_vector from cython.operator import dereference, postincrement from ray.includes.common cimport ( + CBuffer, CAddress, CLanguage, CRayObject, @@ -346,13 +347,29 @@ cdef c_vector[c_string] string_vector_from_list(list string_list): return out +cdef: + c_string pickle_metadata_str = PICKLE_BUFFER_METADATA + shared_ptr[CBuffer] pickle_metadata = dynamic_pointer_cast[ + CBuffer, LocalMemoryBuffer]( + make_shared[LocalMemoryBuffer]( + (pickle_metadata_str.data()), + pickle_metadata_str.size(), True)) + c_string raw_meta_str = RAW_BUFFER_METADATA + shared_ptr[CBuffer] raw_metadata = dynamic_pointer_cast[ + CBuffer, LocalMemoryBuffer]( + make_shared[LocalMemoryBuffer]( + (raw_meta_str.data()), + raw_meta_str.size(), True)) + cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector): cdef: c_string pickled_str - c_string metadata_str = PICKLE_BUFFER_METADATA + const unsigned char[:] buffer + size_t size shared_ptr[CBuffer] arg_data shared_ptr[CBuffer] arg_metadata + # TODO be consistent with store_task_outputs for arg in args: if isinstance(arg, ObjectID): args_vector.push_back( @@ -360,23 +377,25 @@ cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector): elif not ray._raylet.check_simple_value(arg): args_vector.push_back( CTaskArg.PassByReference((ray.put(arg)).native())) + elif type(arg) is bytes: + buffer = arg + size = buffer.nbytes + arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer]( + make_shared[LocalMemoryBuffer]( + (&buffer[0]), size, True)) + args_vector.push_back( + CTaskArg.PassByValue( + make_shared[CRayObject](arg_data, raw_metadata))) else: - pickled_str = pickle.dumps( + buffer = pickle.dumps( arg, protocol=pickle.HIGHEST_PROTOCOL) - # TODO(edoakes): This makes a copy that could be avoided. + size = buffer.nbytes arg_data = dynamic_pointer_cast[CBuffer, LocalMemoryBuffer]( make_shared[LocalMemoryBuffer]( - (pickled_str.data()), - pickled_str.size(), - True)) - arg_metadata = dynamic_pointer_cast[ - CBuffer, LocalMemoryBuffer]( - make_shared[LocalMemoryBuffer]( - ( - metadata_str.data()), metadata_str.size(), True)) + (&buffer[0]), size, True)) args_vector.push_back( CTaskArg.PassByValue( - make_shared[CRayObject](arg_data, arg_metadata))) + make_shared[CRayObject](arg_data, pickle_metadata))) cdef class RayletClient: @@ -738,10 +757,6 @@ cdef write_serialized_object( cdef class CoreWorker: - cdef: - unique_ptr[CCoreWorker] core_worker - object async_thread - object async_event_loop def __cinit__(self, is_driver, store_socket, raylet_socket, JobID job_id, GcsClientOptions gcs_options, log_dir, @@ -1085,7 +1100,6 @@ cdef class CoreWorker: c_vector[shared_ptr[CRayObject]] *returns): cdef: c_vector[size_t] data_sizes - c_string metadata_str c_vector[shared_ptr[CBuffer]] metadatas if return_ids.size() == 0: diff --git a/python/ray/experimental/streaming/batched_queue.py b/python/ray/experimental/streaming/batched_queue.py deleted file mode 100644 index 9ba0c27957817..0000000000000 --- a/python/ray/experimental/streaming/batched_queue.py +++ /dev/null @@ -1,216 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import logging -import numpy as np -import threading -import time - -import ray -from ray.experimental import internal_kv - -logger = logging.getLogger(__name__) -logger.setLevel("INFO") - - -def plasma_prefetch(object_id): - """Tells plasma to prefetch the given object_id.""" - local_sched_client = ray.worker.global_worker.raylet_client - ray_obj_id = ray.ObjectID(object_id) - local_sched_client.fetch_or_reconstruct([ray_obj_id], True) - - -# TODO: doing the timer in Python land is a bit slow -class FlushThread(threading.Thread): - """A thread that flushes periodically to plasma. - - Attributes: - interval: The flush timeout per batch. - flush_fn: The flush function. - """ - - def __init__(self, interval, flush_fn): - threading.Thread.__init__(self) - self.interval = interval # Interval is the max_batch_time - self.flush_fn = flush_fn - self.daemon = True - - def run(self): - while True: - time.sleep(self.interval) # Flushing period - self.flush_fn() - - -class BatchedQueue(object): - """A batched queue for actor to actor communication. - - Attributes: - max_size (int): The maximum size of the queue in number of batches - (if exceeded, backpressure kicks in) - max_batch_size (int): The size of each batch in number of records. - max_batch_time (float): The flush timeout per batch. - prefetch_depth (int): The number of batches to prefetch from plasma. - background_flush (bool): Denotes whether a daemon flush thread should - be used (True) to flush batches to plasma. - base (ndarray): A unique signature for the queue. - read_ack_key (bytes): The signature of the queue in bytes. - prefetch_batch_offset (int): The number of the last read prefetched - batch. - read_batch_offset (int): The number of the last read batch. - read_item_offset (int): The number of the last read record inside a - batch. - write_batch_offset (int): The number of the last written batch. - write_item_offset (int): The numebr of the last written item inside a - batch. - write_buffer (list): The write buffer, i.e. an in-memory batch. - last_flush_time (float): The time the last flushing to plasma took - place. - cached_remote_offset (int): The number of the last read batch as - recorded by the writer after the previous flush. - flush_lock (RLock): A python lock used for flushing batches to plasma. - flush_thread (Threading): The python thread used for flushing batches - to plasma. - """ - - def __init__(self, - max_size=999999, - max_batch_size=99999, - max_batch_time=0.01, - prefetch_depth=10, - background_flush=True): - self.max_size = max_size - self.max_batch_size = max_batch_size - self.max_batch_time = max_batch_time - self.prefetch_depth = prefetch_depth - self.background_flush = background_flush - - # Common queue metadata -- This serves as the unique id of the queue - self.base = np.random.randint(0, 2**32 - 1, size=5, dtype="uint32") - self.base[-2] = 0 - self.base[-1] = 0 - self.read_ack_key = np.ndarray.tobytes(self.base) - - # Reader state - self.prefetch_batch_offset = 0 - self.read_item_offset = 0 - self.read_batch_offset = 0 - self.read_buffer = [] - - # Writer state - self.write_item_offset = 0 - self.write_batch_offset = 0 - self.write_buffer = [] - self.last_flush_time = 0.0 - self.cached_remote_offset = 0 - - self.flush_lock = threading.RLock() - self.flush_thread = FlushThread(self.max_batch_time, - self._flush_writes) - - def __getstate__(self): - state = dict(self.__dict__) - del state["flush_lock"] - del state["flush_thread"] - del state["write_buffer"] - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - # This is to enable writing functionality in - # case the queue is not created by the writer - # The reason is that python locks cannot be serialized - def enable_writes(self): - """Restores the state of the batched queue for writing.""" - self.write_buffer = [] - self.flush_lock = threading.RLock() - self.flush_thread = FlushThread(self.max_batch_time, - self._flush_writes) - - # Batch ids consist of a unique queue id used as prefix along with - # two numbers generated using the batch offset in the queue - def _batch_id(self, batch_offset): - oid = self.base.copy() - oid[-2] = batch_offset // 2**32 - oid[-1] = batch_offset % 2**32 - return np.ndarray.tobytes(oid) - - def _flush_writes(self): - with self.flush_lock: - if not self.write_buffer: - return - batch_id = self._batch_id(self.write_batch_offset) - ray.worker.global_worker.put_object(self.write_buffer, - ray.ObjectID(batch_id)) - logger.debug("[writer] Flush batch {} offset {} size {}".format( - self.write_batch_offset, self.write_item_offset, - len(self.write_buffer))) - self.write_buffer = [] - self.write_batch_offset += 1 - self._wait_for_reader() - self.last_flush_time = time.time() - - def _wait_for_reader(self): - """Checks for backpressure by the downstream reader.""" - if self.max_size <= 0: # Unlimited queue - return - if self.write_item_offset - self.cached_remote_offset <= self.max_size: - return # Hasn't reached max size - remote_offset = internal_kv._internal_kv_get(self.read_ack_key) - if remote_offset is None: - # logger.debug("[writer] Waiting for reader to start...") - while remote_offset is None: - time.sleep(0.01) - remote_offset = internal_kv._internal_kv_get(self.read_ack_key) - remote_offset = int(remote_offset) - if self.write_item_offset - remote_offset > self.max_size: - logger.debug( - "[writer] Waiting for reader to catch up {} to {} - {}".format( - remote_offset, self.write_item_offset, self.max_size)) - while self.write_item_offset - remote_offset > self.max_size: - time.sleep(0.01) - remote_offset = int( - internal_kv._internal_kv_get(self.read_ack_key)) - self.cached_remote_offset = remote_offset - - def _read_next_batch(self): - while (self.prefetch_batch_offset < - self.read_batch_offset + self.prefetch_depth): - plasma_prefetch(self._batch_id(self.prefetch_batch_offset)) - self.prefetch_batch_offset += 1 - self.read_buffer = ray.get( - ray.ObjectID(self._batch_id(self.read_batch_offset))) - self.read_batch_offset += 1 - logger.debug("[reader] Fetched batch {} offset {} size {}".format( - self.read_batch_offset, self.read_item_offset, - len(self.read_buffer))) - self._ack_reads(self.read_item_offset + len(self.read_buffer)) - - # Reader acks the key it reads so that writer knows reader's offset. - # This is to cap queue size and simulate backpressure - def _ack_reads(self, offset): - if self.max_size > 0: - internal_kv._internal_kv_put( - self.read_ack_key, offset, overwrite=True) - - def put_next(self, item): - with self.flush_lock: - if self.background_flush and not self.flush_thread.is_alive(): - logger.debug("[writer] Starting batch flush thread") - self.flush_thread.start() - self.write_buffer.append(item) - self.write_item_offset += 1 - if not self.last_flush_time: - self.last_flush_time = time.time() - delay = time.time() - self.last_flush_time - if (len(self.write_buffer) > self.max_batch_size - or delay > self.max_batch_time): - self._flush_writes() - - def read_next(self): - if not self.read_buffer: - self._read_next_batch() - assert self.read_buffer - self.read_item_offset += 1 - return self.read_buffer.pop(0) diff --git a/python/ray/experimental/streaming/benchmarks/micro/batched_queue_benchmark.py b/python/ray/experimental/streaming/benchmarks/micro/batched_queue_benchmark.py deleted file mode 100644 index 4017f7891cfa8..0000000000000 --- a/python/ray/experimental/streaming/benchmarks/micro/batched_queue_benchmark.py +++ /dev/null @@ -1,182 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import logging -import time - -import ray -from ray.experimental.streaming.batched_queue import BatchedQueue - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - -parser = argparse.ArgumentParser() -parser.add_argument( - "--rounds", default=10, help="the number of experiment rounds") -parser.add_argument( - "--num-queues", default=1, help="the number of queues in the chain") -parser.add_argument( - "--queue-size", default=10000, help="the queue size in number of batches") -parser.add_argument( - "--batch-size", default=1000, help="the batch size in number of elements") -parser.add_argument( - "--flush-timeout", default=0.001, help="the timeout to flush a batch") -parser.add_argument( - "--prefetch-depth", - default=10, - help="the number of batches to prefetch from plasma") -parser.add_argument( - "--background-flush", - default=False, - help="whether to flush in the backrgound or not") -parser.add_argument( - "--max-throughput", - default="inf", - help="maximum read throughput (elements/s)") - - -@ray.remote -class Node(object): - """An actor that reads from an input queue and writes to an output queue. - - Attributes: - id (int): The id of the actor. - queue (BatchedQueue): The input queue. - out_queue (BatchedQueue): The output queue. - max_reads_per_second (int): The max read throughput (default: inf). - num_reads (int): Number of elements read. - num_writes (int): Number of elements written. - """ - - def __init__(self, - id, - in_queue, - out_queue, - max_reads_per_second=float("inf")): - self.id = id - self.queue = in_queue - self.out_queue = out_queue - self.max_reads_per_second = max_reads_per_second - self.num_reads = 0 - self.num_writes = 0 - self.start = time.time() - - def read_write_forever(self): - debug_log = "[actor {}] Reads throttled to {} reads/s" - log = "" - if self.out_queue is not None: - self.out_queue.enable_writes() - log += "[actor {}] Reads/Writes per second {}" - else: # It's just a reader - log += "[actor {}] Reads per second {}" - # Start spinning - expected_value = 0 - while True: - start = time.time() - N = 100000 - for _ in range(N): - x = self.queue.read_next() - assert x == expected_value, (x, expected_value) - expected_value += 1 - self.num_reads += 1 - if self.out_queue is not None: - self.out_queue.put_next(x) - self.num_writes += 1 - while (self.num_reads / (time.time() - self.start) > - self.max_reads_per_second): - logger.debug( - debug_log.format(self.id, self.max_reads_per_second)) - time.sleep(0.1) - logger.info(log.format(self.id, N / (time.time() - start))) - # Flush any remaining elements - if self.out_queue is not None: - self.out_queue._flush_writes() - - -def test_max_throughput(rounds, - max_queue_size, - max_batch_size, - batch_timeout, - prefetch_depth, - background_flush, - num_queues, - max_reads_per_second=float("inf")): - assert num_queues >= 1 - first_queue = BatchedQueue( - max_size=max_queue_size, - max_batch_size=max_batch_size, - max_batch_time=batch_timeout, - prefetch_depth=prefetch_depth, - background_flush=background_flush) - previous_queue = first_queue - for i in range(num_queues): - # Construct the batched queue - in_queue = previous_queue - out_queue = None - if i < num_queues - 1: - out_queue = BatchedQueue( - max_size=max_queue_size, - max_batch_size=max_batch_size, - max_batch_time=batch_timeout, - prefetch_depth=prefetch_depth, - background_flush=background_flush) - - node = Node.remote(i, in_queue, out_queue, max_reads_per_second) - node.read_write_forever.remote() - previous_queue = out_queue - - value = 0 - # Feed the chain - for round in range(rounds): - logger.info("Round {}".format(round)) - N = 100000 - start = time.time() - for i in range(N): - first_queue.put_next(value) - value += 1 - log = "[writer] Puts per second {}" - logger.info(log.format(N / (time.time() - start))) - first_queue._flush_writes() - - -if __name__ == "__main__": - ray.init() - ray.register_custom_serializer(BatchedQueue, use_pickle=True) - - args = parser.parse_args() - - rounds = int(args.rounds) - max_queue_size = int(args.queue_size) - max_batch_size = int(args.batch_size) - batch_timeout = float(args.flush_timeout) - prefetch_depth = int(args.prefetch_depth) - background_flush = bool(args.background_flush) - num_queues = int(args.num_queues) - max_reads_per_second = float(args.max_throughput) - - logger.info("== Parameters ==") - logger.info("Rounds: {}".format(rounds)) - logger.info("Max queue size: {}".format(max_queue_size)) - logger.info("Max batch size: {}".format(max_batch_size)) - logger.info("Batch timeout: {}".format(batch_timeout)) - logger.info("Prefetch depth: {}".format(prefetch_depth)) - logger.info("Background flush: {}".format(background_flush)) - logger.info("Max read throughput: {}".format(max_reads_per_second)) - - # Estimate the ideal throughput - value = 0 - start = time.time() - for round in range(rounds): - N = 100000 - for _ in range(N): - value += 1 - logger.info("Ideal throughput: {}".format(value / (time.time() - start))) - - logger.info("== Testing max throughput ==") - start = time.time() - test_max_throughput(rounds, max_queue_size, max_batch_size, batch_timeout, - prefetch_depth, background_flush, num_queues, - max_reads_per_second) - logger.info("Elapsed time: {}".format(time.time() - start)) diff --git a/python/ray/experimental/streaming/communication.py b/python/ray/experimental/streaming/communication.py deleted file mode 100644 index f626d1b38c5d3..0000000000000 --- a/python/ray/experimental/streaming/communication.py +++ /dev/null @@ -1,359 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import hashlib -import logging -import sys - -from ray.experimental.streaming.operator import PStrategy -from ray.experimental.streaming.batched_queue import BatchedQueue - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - -# Forward and broadcast stream partitioning strategies -forward_broadcast_strategies = [PStrategy.Forward, PStrategy.Broadcast] - - -# Used to choose output channel in case of hash-based shuffling -def _hash(value): - if isinstance(value, int): - return value - try: - return int(hashlib.sha1(value.encode("utf-8")).hexdigest(), 16) - except AttributeError: - return int(hashlib.sha1(value).hexdigest(), 16) - - -# A data channel is a batched queue between two -# operator instances in a streaming environment -class DataChannel(object): - """A data channel for actor-to-actor communication. - - Attributes: - env (Environment): The environment the channel belongs to. - src_operator_id (UUID): The id of the source operator of the channel. - dst_operator_id (UUID): The id of the destination operator of the - channel. - src_instance_id (int): The id of the source instance. - dst_instance_id (int): The id of the destination instance. - queue (BatchedQueue): The batched queue used for data movement. - """ - - def __init__(self, env, src_operator_id, dst_operator_id, src_instance_id, - dst_instance_id): - self.env = env - self.src_operator_id = src_operator_id - self.dst_operator_id = dst_operator_id - self.src_instance_id = src_instance_id - self.dst_instance_id = dst_instance_id - self.queue = BatchedQueue( - max_size=self.env.config.queue_config.max_size, - max_batch_size=self.env.config.queue_config.max_batch_size, - max_batch_time=self.env.config.queue_config.max_batch_time, - prefetch_depth=self.env.config.queue_config.prefetch_depth, - background_flush=self.env.config.queue_config.background_flush) - - def __repr__(self): - return "({},{},{},{})".format( - self.src_operator_id, self.dst_operator_id, self.src_instance_id, - self.dst_instance_id) - - -# Pulls and merges data from multiple input channels -class DataInput(object): - """An input gate of an operator instance. - - The input gate pulls records from all input channels in a round-robin - fashion. - - Attributes: - input_channels (list): The list of input channels. - channel_index (int): The index of the next channel to pull from. - max_index (int): The number of input channels. - closed (list): A list of flags indicating whether an input channel - has been marked as 'closed'. - all_closed (bool): Denotes whether all input channels have been - closed (True) or not (False). - """ - - def __init__(self, channels): - self.input_channels = channels - self.channel_index = 0 - self.max_index = len(channels) - self.closed = [False] * len( - self.input_channels) # Tracks the channels that have been closed - self.all_closed = False - - # Fetches records from input channels in a round-robin fashion - # TODO (john): Make sure the instance is not blocked on any of its input - # channels - # TODO (john): In case of input skew, it might be better to pull from - # the largest queue more often - def _pull(self): - while True: - if self.max_index == 0: - # TODO (john): We should detect this earlier - return None - # Channel to pull from - channel = self.input_channels[self.channel_index] - self.channel_index += 1 - if self.channel_index == self.max_index: # Reset channel index - self.channel_index = 0 - if self.closed[self.channel_index - 1]: - continue # Channel has been 'closed', check next - record = channel.queue.read_next() - logger.debug("Actor ({},{}) pulled '{}'.".format( - channel.src_operator_id, channel.src_instance_id, record)) - if record is None: - # Mark channel as 'closed' and pull from the next open one - self.closed[self.channel_index - 1] = True - self.all_closed = True - for flag in self.closed: - if flag is False: - self.all_closed = False - break - if not self.all_closed: - continue - # Returns 'None' iff all input channels are 'closed' - return record - - -# Selects output channel(s) and pushes data -class DataOutput(object): - """An output gate of an operator instance. - - The output gate pushes records to output channels according to the - user-defined partitioning scheme. - - Attributes: - partitioning_schemes (dict): A mapping from destination operator ids - to partitioning schemes (see: PScheme in operator.py). - forward_channels (list): A list of channels to forward records. - shuffle_channels (list(list)): A list of output channels to shuffle - records grouped by destination operator. - shuffle_key_channels (list(list)): A list of output channels to - shuffle records by a key grouped by destination operator. - shuffle_exists (bool): A flag indicating that there exists at least - one shuffle_channel. - shuffle_key_exists (bool): A flag indicating that there exists at - least one shuffle_key_channel. - """ - - def __init__(self, channels, partitioning_schemes): - self.key_selector = None - self.round_robin_indexes = [0] - self.partitioning_schemes = partitioning_schemes - # Prepare output -- collect channels by type - self.forward_channels = [] # Forward and broadcast channels - slots = sum(1 for scheme in self.partitioning_schemes.values() - if scheme.strategy == PStrategy.RoundRobin) - self.round_robin_channels = [[]] * slots # RoundRobin channels - self.round_robin_indexes = [-1] * slots - slots = sum(1 for scheme in self.partitioning_schemes.values() - if scheme.strategy == PStrategy.Shuffle) - # Flag used to avoid hashing when there is no shuffling - self.shuffle_exists = slots > 0 - self.shuffle_channels = [[]] * slots # Shuffle channels - slots = sum(1 for scheme in self.partitioning_schemes.values() - if scheme.strategy == PStrategy.ShuffleByKey) - # Flag used to avoid hashing when there is no shuffling by key - self.shuffle_key_exists = slots > 0 - self.shuffle_key_channels = [[]] * slots # Shuffle by key channels - # Distinct shuffle destinations - shuffle_destinations = {} - # Distinct shuffle by key destinations - shuffle_by_key_destinations = {} - # Distinct round robin destinations - round_robin_destinations = {} - index_1 = 0 - index_2 = 0 - index_3 = 0 - for channel in channels: - p_scheme = self.partitioning_schemes[channel.dst_operator_id] - strategy = p_scheme.strategy - if strategy in forward_broadcast_strategies: - self.forward_channels.append(channel) - elif strategy == PStrategy.Shuffle: - pos = shuffle_destinations.setdefault(channel.dst_operator_id, - index_1) - self.shuffle_channels[pos].append(channel) - if pos == index_1: - index_1 += 1 - elif strategy == PStrategy.ShuffleByKey: - pos = shuffle_by_key_destinations.setdefault( - channel.dst_operator_id, index_2) - self.shuffle_key_channels[pos].append(channel) - if pos == index_2: - index_2 += 1 - elif strategy == PStrategy.RoundRobin: - pos = round_robin_destinations.setdefault( - channel.dst_operator_id, index_3) - self.round_robin_channels[pos].append(channel) - if pos == index_3: - index_3 += 1 - else: # TODO (john): Add support for other strategies - sys.exit("Unrecognized or unsupported partitioning strategy.") - # A KeyedDataStream can only be shuffled by key - assert not (self.shuffle_exists and self.shuffle_key_exists) - - # Flushes any remaining records in the output channels - # 'close' indicates whether we should also 'close' the channel (True) - # by propagating 'None' - # or just flush the remaining records to plasma (False) - def _flush(self, close=False): - """Flushes remaining output records in the output queues to plasma. - - None is used as special type of record that is propagated from sources - to sink to notify that the end of data in a stream. - - Attributes: - close (bool): A flag denoting whether the channel should be - also marked as 'closed' (True) or not (False) after flushing. - """ - for channel in self.forward_channels: - if close is True: - channel.queue.put_next(None) - channel.queue._flush_writes() - for channels in self.shuffle_channels: - for channel in channels: - if close is True: - channel.queue.put_next(None) - channel.queue._flush_writes() - for channels in self.shuffle_key_channels: - for channel in channels: - if close is True: - channel.queue.put_next(None) - channel.queue._flush_writes() - for channels in self.round_robin_channels: - for channel in channels: - if close is True: - channel.queue.put_next(None) - channel.queue._flush_writes() - # TODO (john): Add more channel types - - # Returns all destination actor ids - def _destination_actor_ids(self): - destinations = [] - for channel in self.forward_channels: - destinations.append((channel.dst_operator_id, - channel.dst_instance_id)) - for channels in self.shuffle_channels: - for channel in channels: - destinations.append((channel.dst_operator_id, - channel.dst_instance_id)) - for channels in self.shuffle_key_channels: - for channel in channels: - destinations.append((channel.dst_operator_id, - channel.dst_instance_id)) - for channels in self.round_robin_channels: - for channel in channels: - destinations.append((channel.dst_operator_id, - channel.dst_instance_id)) - # TODO (john): Add more channel types - return destinations - - # Pushes the record to the output - # Each individual output queue flushes batches to plasma periodically - # based on 'batch_max_size' and 'batch_max_time' - def _push(self, record): - # Forward record - for channel in self.forward_channels: - logger.debug("[writer] Push record '{}' to channel {}".format( - record, channel)) - channel.queue.put_next(record) - # Forward record - index = 0 - for channels in self.round_robin_channels: - self.round_robin_indexes[index] += 1 - if self.round_robin_indexes[index] == len(channels): - self.round_robin_indexes[index] = 0 # Reset index - channel = channels[self.round_robin_indexes[index]] - logger.debug("[writer] Push record '{}' to channel {}".format( - record, channel)) - channel.queue.put_next(record) - index += 1 - # Hash-based shuffling by key - if self.shuffle_key_exists: - key, _ = record - h = _hash(key) - for channels in self.shuffle_key_channels: - num_instances = len(channels) # Downstream instances - channel = channels[h % num_instances] - logger.debug( - "[key_shuffle] Push record '{}' to channel {}".format( - record, channel)) - channel.queue.put_next(record) - elif self.shuffle_exists: # Hash-based shuffling per destination - h = _hash(record) - for channels in self.shuffle_channels: - num_instances = len(channels) # Downstream instances - channel = channels[h % num_instances] - logger.debug("[shuffle] Push record '{}' to channel {}".format( - record, channel)) - channel.queue.put_next(record) - else: # TODO (john): Handle rescaling - pass - - # Pushes a list of records to the output - # Each individual output queue flushes batches to plasma periodically - # based on 'batch_max_size' and 'batch_max_time' - def _push_all(self, records): - # Forward records - for record in records: - for channel in self.forward_channels: - logger.debug("[writer] Push record '{}' to channel {}".format( - record, channel)) - channel.queue.put_next(record) - # Hash-based shuffling by key per destination - if self.shuffle_key_exists: - for record in records: - key, _ = record - h = _hash(key) - for channels in self.shuffle_channels: - num_instances = len(channels) # Downstream instances - channel = channels[h % num_instances] - logger.debug( - "[key_shuffle] Push record '{}' to channel {}".format( - record, channel)) - channel.queue.put_next(record) - elif self.shuffle_exists: # Hash-based shuffling per destination - for record in records: - h = _hash(record) - for channels in self.shuffle_channels: - num_instances = len(channels) # Downstream instances - channel = channels[h % num_instances] - logger.debug( - "[shuffle] Push record '{}' to channel {}".format( - record, channel)) - channel.queue.put_next(record) - else: # TODO (john): Handle rescaling - pass - - -# Batched queue configuration -class QueueConfig(object): - """The configuration of a batched queue. - - Attributes: - max_size (int): The maximum size of the queue in number of batches - (if exceeded, backpressure kicks in). - max_batch_size (int): The size of each batch in number of records. - max_batch_time (float): The flush timeout per batch. - prefetch_depth (int): The number of batches to prefetch from plasma. - background_flush (bool): Denotes whether a daemon flush thread should - be used (True) to flush batches to plasma. - """ - - def __init__(self, - max_size=999999, - max_batch_size=99999, - max_batch_time=0.01, - prefetch_depth=10, - background_flush=False): - self.max_size = max_size - self.max_batch_size = max_batch_size - self.max_batch_time = max_batch_time - self.prefetch_depth = prefetch_depth - self.background_flush = background_flush diff --git a/python/ray/experimental/streaming/operator_instance.py b/python/ray/experimental/streaming/operator_instance.py deleted file mode 100644 index ae36b143a1a10..0000000000000 --- a/python/ray/experimental/streaming/operator_instance.py +++ /dev/null @@ -1,365 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import logging -import sys -import time -import types - -import ray - -logger = logging.getLogger(__name__) -logger.setLevel("DEBUG") - -# -# Each Ray actor corresponds to an operator instance in the physical dataflow -# Actors communicate using batched queues as data channels (no standing TCP -# connections) -# Currently, batched queues are based on Eric's implementation (see: -# batched_queue.py) - - -def _identity(element): - return element - - -# TODO (john): Specify the interface of state keepers -class OperatorInstance(object): - """A streaming operator instance. - - Attributes: - instance_id (UUID): The id of the instance. - input (DataInput): The input gate that manages input channels of - the instance (see: DataInput in communication.py). - input (DataOutput): The output gate that manages output channels of - the instance (see: DataOutput in communication.py). - state_keepers (list): A list of actor handlers to query the state of - the operator instance. - """ - - def __init__(self, instance_id, input_gate, output_gate, - state_keeper=None): - self.key_index = None # Index for key selection - self.key_attribute = None # Attribute name for key selection - self.instance_id = instance_id - self.input = input_gate - self.output = output_gate - # Handle(s) to one or more user-defined actors - # that can retrieve actor's state - self.state_keeper = state_keeper - # Enable writes - for channel in self.output.forward_channels: - channel.queue.enable_writes() - for channels in self.output.shuffle_channels: - for channel in channels: - channel.queue.enable_writes() - for channels in self.output.shuffle_key_channels: - for channel in channels: - channel.queue.enable_writes() - for channels in self.output.round_robin_channels: - for channel in channels: - channel.queue.enable_writes() - # TODO (john): Add more channel types here - - # Registers actor's handle so that the actor can schedule itself - def register_handle(self, actor_handle): - self.this_actor = actor_handle - - # Used for index-based key extraction, e.g. for tuples - def index_based_selector(self, record): - return record[self.key_index] - - # Used for attribute-based key extraction, e.g. for classes - def attribute_based_selector(self, record): - return vars(record)[self.key_attribute] - - # Starts the actor - def start(self): - pass - - -# A source actor that reads a text file line by line -@ray.remote -class ReadTextFile(OperatorInstance): - """A source operator instance that reads a text file line by line. - - Attributes: - filepath (string): The path to the input file. - """ - - def __init__(self, - instance_id, - operator_metadata, - input_gate, - output_gate, - state_keepers=None): - OperatorInstance.__init__(self, instance_id, input_gate, output_gate, - state_keepers) - self.filepath = operator_metadata.other_args - # TODO (john): Handle possible exception here - self.reader = open(self.filepath, "r") - - # Read input file line by line - def start(self): - while True: - record = self.reader.readline() - # Reader returns empty string ('') on EOF - if not record: - # Flush any remaining records to plasma and close the file - self.output._flush(close=True) - self.reader.close() - return - self.output._push( - record[:-1]) # Push after removing newline characters - - -# Map actor -@ray.remote -class Map(OperatorInstance): - """A map operator instance that applies a user-defined - stream transformation. - - A map produces exactly one output record for each record in - the input stream. - - Attributes: - map_fn (function): The user-defined function. - """ - - def __init__(self, instance_id, operator_metadata, input_gate, - output_gate): - OperatorInstance.__init__(self, instance_id, input_gate, output_gate) - self.map_fn = operator_metadata.logic - - # Applies the mapper each record of the input stream(s) - # and pushes resulting records to the output stream(s) - def start(self): - start = time.time() - elements = 0 - while True: - record = self.input._pull() - if record is None: - self.output._flush(close=True) - logger.debug("[map {}] read/writes per second: {}".format( - self.instance_id, elements / (time.time() - start))) - return - self.output._push(self.map_fn(record)) - elements += 1 - - -# Flatmap actor -@ray.remote -class FlatMap(OperatorInstance): - """A map operator instance that applies a user-defined - stream transformation. - - A flatmap produces one or more output records for each record in - the input stream. - - Attributes: - flatmap_fn (function): The user-defined function. - """ - - def __init__(self, instance_id, operator_metadata, input_gate, - output_gate): - OperatorInstance.__init__(self, instance_id, input_gate, output_gate) - self.flatmap_fn = operator_metadata.logic - - # Applies the splitter to the records of the input stream(s) - # and pushes resulting records to the output stream(s) - def start(self): - while True: - record = self.input._pull() - if record is None: - self.output._flush(close=True) - return - self.output._push_all(self.flatmap_fn(record)) - - -# Filter actor -@ray.remote -class Filter(OperatorInstance): - """A filter operator instance that applies a user-defined filter to - each record of the stream. - - Output records are those that pass the filter, i.e. those for which - the filter function returns True. - - Attributes: - filter_fn (function): The user-defined boolean function. - """ - - def __init__(self, instance_id, operator_metadata, input_gate, - output_gate): - OperatorInstance.__init__(self, instance_id, input_gate, output_gate) - self.filter_fn = operator_metadata.logic - - # Applies the filter to the records of the input stream(s) - # and pushes resulting records to the output stream(s) - def start(self): - while True: - record = self.input._pull() - if record is None: # Close channel and return - self.output._flush(close=True) - return - if self.filter_fn(record): - self.output._push(record) - - -# Inspect actor -@ray.remote -class Inspect(OperatorInstance): - """A inspect operator instance that inspects the content of the stream. - - Inspect is useful for printing the records in the stream. - - Attributes: - inspect_fn (function): The user-defined inspect logic. - """ - - def __init__(self, instance_id, operator_metadata, input_gate, - output_gate): - OperatorInstance.__init__(self, instance_id, input_gate, output_gate) - self.inspect_fn = operator_metadata.logic - - # Applies the inspect logic (e.g. print) to the records of - # the input stream(s) - # and leaves stream unaffected by simply pushing the records to - # the output stream(s) - while True: - record = self.input._pull() - if record is None: - self.output._flush(close=True) - return - self.output._push(record) - self.inspect_fn(record) - - -# Reduce actor -@ray.remote -class Reduce(OperatorInstance): - """A reduce operator instance that combines a new value for a key - with the last reduced one according to a user-defined logic. - - Attributes: - reduce_fn (function): The user-defined reduce logic. - value_attribute (int): The index of the value to reduce - (assuming tuple records). - state (dict): A mapping from keys to values. - """ - - def __init__(self, instance_id, operator_metadata, input_gate, - output_gate): - OperatorInstance.__init__(self, instance_id, input_gate, output_gate, - operator_metadata.state_actor) - self.reduce_fn = operator_metadata.logic - # Set the attribute selector - self.attribute_selector = operator_metadata.other_args - if self.attribute_selector is None: - self.attribute_selector = _identity - elif isinstance(self.attribute_selector, int): - self.key_index = self.attribute_selector - self.attribute_selector = self.index_based_selector - elif isinstance(self.attribute_selector, str): - self.key_attribute = self.attribute_selector - self.attribute_selector = self.attribute_based_selector - elif not isinstance(self.attribute_selector, types.FunctionType): - sys.exit("Unrecognized or unsupported key selector.") - self.state = {} # key -> value - - # Combines the input value for a key with the last reduced - # value for that key to produce a new value. - # Outputs the result as (key,new value) - def start(self): - while True: - record = self.input._pull() - if record is None: - self.output._flush(close=True) - del self.state - return - key, rest = record - new_value = self.attribute_selector(rest) - # TODO (john): Is there a way to update state with - # a single dictionary lookup? - try: - old_value = self.state[key] - new_value = self.reduce_fn(old_value, new_value) - self.state[key] = new_value - except KeyError: # Key does not exist in state - self.state.setdefault(key, new_value) - self.output._push((key, new_value)) - - # Returns the state of the actor - def get_state(self): - return self.state - - -@ray.remote -class KeyBy(OperatorInstance): - """A key_by operator instance that physically partitions the - stream based on a key. - - Attributes: - key_attribute (int): The index of the value to reduce - (assuming tuple records). - """ - - def __init__(self, instance_id, operator_metadata, input_gate, - output_gate): - OperatorInstance.__init__(self, instance_id, input_gate, output_gate) - # Set the key selector - self.key_selector = operator_metadata.other_args - if isinstance(self.key_selector, int): - self.key_index = self.key_selector - self.key_selector = self.index_based_selector - elif isinstance(self.key_selector, str): - self.key_attribute = self.key_selector - self.key_selector = self.attribute_based_selector - elif not isinstance(self.key_selector, types.FunctionType): - sys.exit("Unrecognized or unsupported key selector.") - - # The actual partitioning is done by the output gate - def start(self): - while True: - record = self.input._pull() - if record is None: - self.output._flush(close=True) - return - key = self.key_selector(record) - self.output._push((key, record)) - - -# A custom source actor -@ray.remote -class Source(OperatorInstance): - def __init__(self, instance_id, operator_metadata, input_gate, - output_gate): - OperatorInstance.__init__(self, instance_id, input_gate, output_gate) - # The user-defined source with a get_next() method - self.source = operator_metadata.other_args - - # Starts the source by calling get_next() repeatedly - def start(self): - start = time.time() - elements = 0 - while True: - next = self.source.get_next() - if next is None: - self.output._flush(close=True) - logger.debug("[writer {}] puts per second: {}".format( - self.instance_id, elements / (time.time() - start))) - return - self.output._push(next) - elements += 1 - - -# TODO(john): Time window actor (uses system time) -@ray.remote -class TimeWindow(OperatorInstance): - def __init__(self, queue, width): - self.width = width # In milliseconds - - def time_window(self): - while True: - pass diff --git a/python/ray/includes/buffer.pxi b/python/ray/includes/buffer.pxi index ae272b0cdf296..6ac8b1db5e1d8 100644 --- a/python/ray/includes/buffer.pxi +++ b/python/ray/includes/buffer.pxi @@ -15,11 +15,6 @@ cdef class Buffer: See https://docs.python.org/3/c-api/buffer.html for details. """ - cdef: - shared_ptr[CBuffer] buffer - Py_ssize_t shape - Py_ssize_t strides - @staticmethod cdef make(const shared_ptr[CBuffer]& buffer): cdef Buffer self = Buffer.__new__(Buffer) diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 7170745e5ece4..edb280e689d1d 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -19,7 +19,7 @@ from ray.includes.unique_ids cimport ( CObjectID, CTaskID, CUniqueID, - CWorkerID, + CWorkerID ) import ray @@ -40,8 +40,6 @@ cdef extern from "ray/common/constants.h" nogil: cdef class BaseID: - # To avoid the error of "Python int too large to convert to C ssize_t", - # here `cdef size_t` is required. cdef size_t hash(self): pass @@ -129,13 +127,6 @@ cdef class UniqueID(BaseID): cdef class ObjectID(BaseID): - cdef: - CObjectID data - object buffer_ref - # Flag indicating whether or not this object ID was added to the set - # of active IDs in the core worker so we know whether we should clean - # it up. - c_bool in_core_worker def __init__(self, id): check_id(id) @@ -332,8 +323,6 @@ cdef class WorkerID(UniqueID): return self.data cdef class ActorID(BaseID): - cdef CActorID data - def __init__(self, id): check_id(id, CActorID.Size()) self.data = CActorID.FromBinary(id) diff --git a/python/ray/streaming b/python/ray/streaming new file mode 120000 index 0000000000000..c69ca3e077f02 --- /dev/null +++ b/python/ray/streaming @@ -0,0 +1 @@ +../../streaming/python/ \ No newline at end of file diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 7d0ed9fbb5c5b..cf7e7c8475b37 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -233,14 +233,6 @@ py_test( deps = ["//:ray_lib"], ) -py_test( - name = "test_logical_graph", - size = "small", - srcs = ["test_logical_graph.py"], - tags = ["exclusive"], - deps = ["//:ray_lib"], -) - py_test( name = "test_memory_limits", size = "medium", diff --git a/python/setup.py b/python/setup.py index b43b56e0968f2..2f172b766df0b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -177,6 +177,7 @@ def find_version(*filepath): "six >= 1.0.0", "faulthandler;python_version<'3.3'", "protobuf >= 3.8.0", + "cloudpickle", ] setup( diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 9e1461d4ea16b..0af3e0f2e5edf 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -603,11 +603,13 @@ Status CoreWorker::SubmitTask(const RayFunction &function, TaskID::ForNormalTask(worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(), next_task_index); + const std::unordered_map required_resources; // TODO(ekl) offload task building onto a thread pool for performance BuildCommonTaskSpec( builder, worker_context_.GetCurrentJobID(), task_id, worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, - function, args, task_options.num_returns, task_options.resources, {}, + function, args, task_options.num_returns, task_options.resources, + required_resources, task_options.is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET, return_ids); TaskSpecification task_spec = builder.Build(); @@ -681,10 +683,11 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f const TaskID actor_task_id = TaskID::ForActorTask( worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(), next_task_index, actor_handle->GetActorID()); + const std::unordered_map required_resources; BuildCommonTaskSpec(builder, actor_handle->CreationJobID(), actor_task_id, worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), rpc_address_, function, args, num_returns, task_options.resources, - {}, transport_type, return_ids); + required_resources, transport_type, return_ids); const ObjectID new_cursor = return_ids->back(); actor_handle->SetActorTaskSpec(builder, transport_type, new_cursor); diff --git a/src/ray/ray_exported_symbols.lds b/src/ray/ray_exported_symbols.lds new file mode 100644 index 0000000000000..e6bc669f00c07 --- /dev/null +++ b/src/ray/ray_exported_symbols.lds @@ -0,0 +1,27 @@ +# This file defines the C++ symbols that need to be exported (aka ABI, application binary interface). +# These symbols will be used by other libraries (e.g., streaming). +# Note: This file is used for macOS only, and should be kept in sync with `ray_version_script.lds`. +# Ray ABI is not finalized, the exact set of exported (C/C++) APIs is subject to change. +# common +*ray*Language* +*ray*RayObject* +*ray*Status* +*ray*RayFunction* +*ray*TaskArg* +*ray*TaskOptions* +*ray*Buffer* +*ray*LocalMemoryBuffer* +# util +*ray*RayLog* +*ray*RayLogLevel* +# id +*ray*MurmurHash64A* +*ray*JobID* +*ray*TaskID* +*ray*ActorID* +*ray*ObjectID* +# Others +*ray*CoreWorker* +*PyInit* +*init_raylet* +*Java* diff --git a/src/ray/ray_version_script.lds b/src/ray/ray_version_script.lds new file mode 100644 index 0000000000000..9021f7abb1ea9 --- /dev/null +++ b/src/ray/ray_version_script.lds @@ -0,0 +1,31 @@ +# This file defines the C++ symbols that need to be exported (aka ABI, application binary interface). +# These symbols will be used by other libraries (e.g., streaming). +# Note: This file is used for linux only, and should be kept in sync with `ray_exported_symbols.lds`. +# Ray ABI is not finalized, the exact set of exported (C/C++) APIs is subject to change. +VERSION_1.0 { + global: + # common + *ray*Language*; + *ray*RayObject*; + *ray*Status*; + *ray*RayFunction*; + *ray*TaskArg*; + *ray*TaskOptions*; + *ray*Buffer*; + *ray*LocalMemoryBuffer*; + # util + *ray*RayLog*; + *ray*RayLogLevel*; + # id + *ray*MurmurHash64A*; + *ray*JobID*; + *ray*TaskID*; + *ray*ActorID*; + *ray*ObjectID*; + # Others + *ray*CoreWorker*; + *PyInit*; + *init_raylet*; + *Java*; + local: *; +}; diff --git a/streaming/BUILD.bazel b/streaming/BUILD.bazel new file mode 100644 index 0000000000000..876b54da230cf --- /dev/null +++ b/streaming/BUILD.bazel @@ -0,0 +1,235 @@ +# Bazel build +# C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html + +load("@com_github_grpc_grpc//bazel:cython_library.bzl", "pyx_library") +load("@rules_proto_grpc//python:defs.bzl", "python_proto_compile") + +proto_library( + name = "streaming_proto", + srcs = ["src/protobuf/streaming.proto"], + visibility = ["//visibility:public"], +) + +cc_proto_library( + name = "streaming_cc_proto", + deps = [":streaming_proto"], +) + +proto_library( + name = "streaming_queue_proto", + srcs = ["src/protobuf/streaming_queue.proto"], +) + +cc_proto_library( + name = "streaming_queue_cc_proto", + deps = ["streaming_queue_proto"], +) + +# Use `linkshared` to ensure ray related symbols are not packed into streaming libs +# to avoid duplicate symbols. In runtime we expose ray related symbols, which can +# be linked into streaming libs by dynamic linker. See bazel rule `//:_raylet` +cc_binary( + name = "ray_util.so", + linkshared = 1, + visibility = ["//visibility:public"], + deps = ["//:ray_util"], +) + +cc_binary( + name = "ray_common.so", + linkshared = 1, + visibility = ["//visibility:public"], + deps = ["//:ray_common"], +) + +cc_binary( + name = "core_worker_lib.so", + linkshared = 1, + deps = ["//:core_worker_lib"], +) + +cc_library( + name = "streaming_util", + srcs = glob([ + "src/util/*.cc", + ]), + hdrs = glob([ + "src/util/*.h", + ]), + includes = [ + "src", + ], + visibility = ["//visibility:public"], + deps = [ + "ray_util.so", + "@boost//:any", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "streaming_config", + srcs = glob([ + "src/config/*.cc", + ]), + hdrs = glob([ + "src/config/*.h", + ]), + deps = [ + "ray_common.so", + ":streaming_cc_proto", + ":streaming_util", + ], +) + +cc_library( + name = "streaming_message", + srcs = glob([ + "src/message/*.cc", + ]), + hdrs = glob([ + "src/message/*.h", + ]), + deps = [ + "ray_common.so", + ":streaming_config", + ":streaming_util", + ], +) + +cc_library( + name = "streaming_queue", + srcs = glob([ + "src/queue/*.cc", + ]), + hdrs = glob([ + "src/queue/*.h", + ]), + deps = [ + "core_worker_lib.so", + "ray_common.so", + "ray_util.so", + ":streaming_config", + ":streaming_message", + ":streaming_queue_cc_proto", + ":streaming_util", + "@boost//:asio", + "@boost//:thread", + ], +) + +cc_library( + name = "streaming_lib", + srcs = glob([ + "src/*.cc", + ]), + hdrs = glob([ + "src/*.h", + "src/queue/*.h", + "src/test/*.h", + ]), + includes = ["src"], + visibility = ["//visibility:public"], + deps = [ + "ray_common.so", + "ray_util.so", + ":streaming_config", + ":streaming_message", + ":streaming_queue", + ":streaming_util", + "@boost//:circular_buffer", + ], +) + +test_common_deps = [ + ":streaming_lib", + "//:ray_common", + "//:ray_util", + "//:core_worker_lib", +] + +# streaming queue mock actor binary +cc_binary( + name = "streaming_test_worker", + srcs = glob(["src/test/*.h"]) + [ + "src/test/mock_actor.cc", + ], + includes = [ + "streaming/src/test", + ], + deps = test_common_deps, +) + +# use src/test/run_streaming_queue_test.sh to run this test +cc_binary( + name = "streaming_queue_tests", + srcs = glob(["src/test/*.h"]) + [ + "src/test/streaming_queue_tests.cc", + ], + deps = test_common_deps, +) + +cc_test( + name = "streaming_message_ring_buffer_tests", + srcs = [ + "src/test/ring_buffer_tests.cc", + ], + includes = [ + "streaming/src/test", + ], + deps = test_common_deps, +) + +cc_test( + name = "streaming_message_serialization_tests", + srcs = [ + "src/test/message_serialization_tests.cc", + ], + deps = test_common_deps, +) + +cc_test( + name = "streaming_mock_transfer", + srcs = [ + "src/test/mock_transfer_tests.cc", + ], + deps = test_common_deps, +) + +cc_test( + name = "streaming_util_tests", + srcs = [ + "src/test/streaming_util_tests.cc", + ], + deps = test_common_deps, +) + +python_proto_compile( + name = "streaming_py_proto", + deps = ["//streaming:streaming_proto"], +) + +genrule( + name = "copy_streaming_py_proto", + srcs = [ + ":streaming_py_proto", + ], + outs = [ + "copy_streaming_py_proto.out", + ], + cmd = """ + set -e + set -x + WORK_DIR=$$(pwd) + # Copy generated files. + GENERATED_DIR=$$WORK_DIR/streaming/python/generated + rm -rf $$GENERATED_DIR + mkdir -p $$GENERATED_DIR + for f in $(locations //streaming:streaming_py_proto); do + cp $$f $$GENERATED_DIR + done + echo $$(date) > $@ + """, + local = 1, + visibility = ["//visibility:public"], +) diff --git a/streaming/README.md b/streaming/README.md new file mode 100644 index 0000000000000..d7091885fc50e --- /dev/null +++ b/streaming/README.md @@ -0,0 +1,28 @@ +# Ray Streaming + +1. Build streaming java + * build ray + * `sh build.sh -l java` + * `cd java && mvn clean install -Dmaven.test.skip=true` + * build streaming + * `cd ray/streaming/java && bazel build all_modules` + * `mvn clean install -Dmaven.test.skip=true` + +2. Build ray will build ray streaming python. + +3. Run examples +```bash +# c++ test +cd streaming/ && bazel test ... +sh src/test/run_streaming_queue_test.sh +cd .. + +# python test +cd python/ray/streaming/ +pushd examples +python simple.py --input-file toy.txt +popd +pushd tests +pytest . +popd +``` \ No newline at end of file diff --git a/python/ray/experimental/streaming/README.rst b/streaming/python/README.rst similarity index 100% rename from python/ray/experimental/streaming/README.rst rename to streaming/python/README.rst diff --git a/python/ray/experimental/streaming/__init__.py b/streaming/python/__init__.pxd similarity index 100% rename from python/ray/experimental/streaming/__init__.py rename to streaming/python/__init__.pxd diff --git a/streaming/python/__init__.py b/streaming/python/__init__.py new file mode 100644 index 0000000000000..4126425aaf2c3 --- /dev/null +++ b/streaming/python/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa +# Ray should be imported before streaming +import ray diff --git a/streaming/python/_streaming.pyx b/streaming/python/_streaming.pyx new file mode 100644 index 0000000000000..3d845ff341343 --- /dev/null +++ b/streaming/python/_streaming.pyx @@ -0,0 +1,6 @@ +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +include "includes/transfer.pxi" diff --git a/streaming/python/communication.py b/streaming/python/communication.py new file mode 100644 index 0000000000000..a5990165e769f --- /dev/null +++ b/streaming/python/communication.py @@ -0,0 +1,283 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import hashlib +import logging +import pickle +import sys +import time + +import ray +import ray.streaming.runtime.transfer as transfer +from ray.streaming.config import Config +from ray.streaming.operator import PStrategy +from ray.streaming.runtime.transfer import ChannelID + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + +# Forward and broadcast stream partitioning strategies +forward_broadcast_strategies = [PStrategy.Forward, PStrategy.Broadcast] + + +# Used to choose output channel in case of hash-based shuffling +def _hash(value): + if isinstance(value, int): + return value + try: + return int(hashlib.sha1(value.encode("utf-8")).hexdigest(), 16) + except AttributeError: + return int(hashlib.sha1(value).hexdigest(), 16) + + +class DataChannel(object): + """A data channel for actor-to-actor communication. + + Attributes: + env (Environment): The environment the channel belongs to. + src_operator_id (UUID): The id of the source operator of the channel. + src_instance_index (int): The id of the source instance. + dst_operator_id (UUID): The id of the destination operator of the + channel. + dst_instance_index (int): The id of the destination instance. + """ + + def __init__(self, src_operator_id, src_instance_index, dst_operator_id, + dst_instance_index, str_qid): + self.src_operator_id = src_operator_id + self.src_instance_index = src_instance_index + self.dst_operator_id = dst_operator_id + self.dst_instance_index = dst_instance_index + self.str_qid = str_qid + self.qid = ChannelID(str_qid) + + def __repr__(self): + return "(src({},{}),dst({},{}), qid({}))".format( + self.src_operator_id, self.src_instance_index, + self.dst_operator_id, self.dst_instance_index, self.str_qid) + + +_CLOSE_FLAG = b" " + + +# Pulls and merges data from multiple input channels +class DataInput(object): + """An input gate of an operator instance. + + The input gate pulls records from all input channels in a round-robin + fashion. + + Attributes: + input_channels (list): The list of input channels. + channel_index (int): The index of the next channel to pull from. + max_index (int): The number of input channels. + closed (list): A list of flags indicating whether an input channel + has been marked as 'closed'. + all_closed (bool): Denotes whether all input channels have been + closed (True) or not (False). + """ + + def __init__(self, env, channels): + assert len(channels) > 0 + self.env = env + self.reader = None # created in `init` method + self.input_channels = channels + self.channel_index = 0 + self.max_index = len(channels) + # Tracks the channels that have been closed. qid: close status + self.closed = {} + + def init(self): + channels = [c.str_qid for c in self.input_channels] + input_actors = [] + for c in self.input_channels: + actor = self.env.execution_graph.get_actor(c.src_operator_id, + c.src_instance_index) + input_actors.append(actor) + logger.info("DataInput input_actors %s", input_actors) + conf = { + Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context() + .current_driver_id, + Config.CHANNEL_TYPE: self.env.config.channel_type + } + self.reader = transfer.DataReader(channels, input_actors, conf) + + def pull(self): + # pull from channel + item = self.reader.read(100) + while item is None: + time.sleep(0.001) + item = self.reader.read(100) + msg_data = item.body() + if msg_data == _CLOSE_FLAG: + self.closed[item.channel_id] = True + if len(self.closed) == len(self.input_channels): + return None + else: + return self.pull() + else: + return pickle.loads(msg_data) + + def close(self): + self.reader.stop() + + +# Selects output channel(s) and pushes data +class DataOutput(object): + """An output gate of an operator instance. + + The output gate pushes records to output channels according to the + user-defined partitioning scheme. + + Attributes: + partitioning_schemes (dict): A mapping from destination operator ids + to partitioning schemes (see: PScheme in operator.py). + forward_channels (list): A list of channels to forward records. + shuffle_channels (list(list)): A list of output channels to shuffle + records grouped by destination operator. + shuffle_key_channels (list(list)): A list of output channels to + shuffle records by a key grouped by destination operator. + shuffle_exists (bool): A flag indicating that there exists at least + one shuffle_channel. + shuffle_key_exists (bool): A flag indicating that there exists at + least one shuffle_key_channel. + """ + + def __init__(self, env, channels, partitioning_schemes): + assert len(channels) > 0 + self.env = env + self.writer = None # created in `init` method + self.channels = channels + self.key_selector = None + self.round_robin_indexes = [0] + self.partitioning_schemes = partitioning_schemes + # Prepare output -- collect channels by type + self.forward_channels = [] # Forward and broadcast channels + slots = sum(1 for scheme in self.partitioning_schemes.values() + if scheme.strategy == PStrategy.RoundRobin) + self.round_robin_channels = [[]] * slots # RoundRobin channels + self.round_robin_indexes = [-1] * slots + slots = sum(1 for scheme in self.partitioning_schemes.values() + if scheme.strategy == PStrategy.Shuffle) + # Flag used to avoid hashing when there is no shuffling + self.shuffle_exists = slots > 0 + self.shuffle_channels = [[]] * slots # Shuffle channels + slots = sum(1 for scheme in self.partitioning_schemes.values() + if scheme.strategy == PStrategy.ShuffleByKey) + # Flag used to avoid hashing when there is no shuffling by key + self.shuffle_key_exists = slots > 0 + self.shuffle_key_channels = [[]] * slots # Shuffle by key channels + # Distinct shuffle destinations + shuffle_destinations = {} + # Distinct shuffle by key destinations + shuffle_by_key_destinations = {} + # Distinct round robin destinations + round_robin_destinations = {} + index_1 = 0 + index_2 = 0 + index_3 = 0 + for channel in channels: + p_scheme = self.partitioning_schemes[channel.dst_operator_id] + strategy = p_scheme.strategy + if strategy in forward_broadcast_strategies: + self.forward_channels.append(channel) + elif strategy == PStrategy.Shuffle: + pos = shuffle_destinations.setdefault(channel.dst_operator_id, + index_1) + self.shuffle_channels[pos].append(channel) + if pos == index_1: + index_1 += 1 + elif strategy == PStrategy.ShuffleByKey: + pos = shuffle_by_key_destinations.setdefault( + channel.dst_operator_id, index_2) + self.shuffle_key_channels[pos].append(channel) + if pos == index_2: + index_2 += 1 + elif strategy == PStrategy.RoundRobin: + pos = round_robin_destinations.setdefault( + channel.dst_operator_id, index_3) + self.round_robin_channels[pos].append(channel) + if pos == index_3: + index_3 += 1 + else: # TODO (john): Add support for other strategies + sys.exit("Unrecognized or unsupported partitioning strategy.") + # A KeyedDataStream can only be shuffled by key + assert not (self.shuffle_exists and self.shuffle_key_exists) + + def init(self): + """init DataOutput which creates DataWriter""" + channel_ids = [c.str_qid for c in self.channels] + to_actors = [] + for c in self.channels: + actor = self.env.execution_graph.get_actor(c.dst_operator_id, + c.dst_instance_index) + to_actors.append(actor) + logger.info("DataOutput output_actors %s", to_actors) + + conf = { + Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context() + .current_driver_id, + Config.CHANNEL_TYPE: self.env.config.channel_type + } + self.writer = transfer.DataWriter(channel_ids, to_actors, conf) + + def close(self): + """Close the channel (True) by propagating _CLOSE_FLAG + + _CLOSE_FLAG is used as special type of record that is propagated from + sources to sink to notify that the end of data in a stream. + """ + for c in self.channels: + self.writer.write(c.qid, _CLOSE_FLAG) + # must ensure DataWriter send None flag to peer actor + self.writer.stop() + + def push(self, record): + target_channels = [] + # Forward record + for c in self.forward_channels: + logger.debug("[writer] Push record '{}' to channel {}".format( + record, c)) + target_channels.append(c) + # Forward record + index = 0 + for channels in self.round_robin_channels: + self.round_robin_indexes[index] += 1 + if self.round_robin_indexes[index] == len(channels): + self.round_robin_indexes[index] = 0 # Reset index + c = channels[self.round_robin_indexes[index]] + logger.debug("[writer] Push record '{}' to channel {}".format( + record, c)) + target_channels.append(c) + index += 1 + # Hash-based shuffling by key + if self.shuffle_key_exists: + key, _ = record + h = _hash(key) + for channels in self.shuffle_key_channels: + num_instances = len(channels) # Downstream instances + c = channels[h % num_instances] + logger.debug( + "[key_shuffle] Push record '{}' to channel {}".format( + record, c)) + target_channels.append(c) + elif self.shuffle_exists: # Hash-based shuffling per destination + h = _hash(record) + for channels in self.shuffle_channels: + num_instances = len(channels) # Downstream instances + c = channels[h % num_instances] + logger.debug("[shuffle] Push record '{}' to channel {}".format( + record, c)) + target_channels.append(c) + else: # TODO (john): Handle rescaling + pass + + msg_data = pickle.dumps(record) + for c in target_channels: + # send data to channel + self.writer.write(c.qid, msg_data) + + def push_all(self, records): + for record in records: + self.push(record) diff --git a/streaming/python/config.py b/streaming/python/config.py new file mode 100644 index 0000000000000..8f7b5e941508f --- /dev/null +++ b/streaming/python/config.py @@ -0,0 +1,23 @@ +class Config: + STREAMING_JOB_NAME = "streaming.job.name" + STREAMING_OP_NAME = "streaming.op_name" + TASK_JOB_ID = "streaming.task_job_id" + STREAMING_WORKER_NAME = "streaming.worker_name" + # channel + CHANNEL_TYPE = "channel_type" + MEMORY_CHANNEL = "memory_channel" + NATIVE_CHANNEL = "native_channel" + CHANNEL_SIZE = "channel_size" + CHANNEL_SIZE_DEFAULT = 10**8 + IS_RECREATE = "streaming.is_recreate" + # return from StreamingReader.getBundle if only empty message read in this + # interval. + TIMER_INTERVAL_MS = "timer_interval_ms" + + STREAMING_RING_BUFFER_CAPACITY = "streaming.ring_buffer_capacity" + # write an empty message if there is no data to be written in this + # interval. + STREAMING_EMPTY_MESSAGE_INTERVAL = "streaming.empty_message_interval" + + # operator type + OPERATOR_TYPE = "operator_type" diff --git a/python/ray/experimental/streaming/examples/articles.txt b/streaming/python/examples/articles.txt similarity index 100% rename from python/ray/experimental/streaming/examples/articles.txt rename to streaming/python/examples/articles.txt diff --git a/python/ray/experimental/streaming/examples/key_selectors.py b/streaming/python/examples/key_selectors.py similarity index 84% rename from python/ray/experimental/streaming/examples/key_selectors.py rename to streaming/python/examples/key_selectors.py index 762a63cce27c8..4b1d2e7a371be 100644 --- a/python/ray/experimental/streaming/examples/key_selectors.py +++ b/streaming/python/examples/key_selectors.py @@ -7,9 +7,7 @@ import time import ray -from ray.experimental.streaming.streaming import Environment -from ray.experimental.streaming.batched_queue import BatchedQueue -from ray.experimental.streaming.operator import OpType, PStrategy +from ray.streaming.streaming import Environment logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -48,9 +46,6 @@ def as_tuple(record): ray.init() ray.register_custom_serializer(Record, use_dict=True) - ray.register_custom_serializer(BatchedQueue, use_pickle=True) - ray.register_custom_serializer(OpType, use_pickle=True) - ray.register_custom_serializer(PStrategy, use_pickle=True) # A Ray streaming environment with the default configuration env = Environment() diff --git a/python/ray/experimental/streaming/examples/simple.py b/streaming/python/examples/simple.py similarity index 61% rename from python/ray/experimental/streaming/examples/simple.py rename to streaming/python/examples/simple.py index 26272cdc94dc7..0e12317ada9d3 100644 --- a/python/ray/experimental/streaming/examples/simple.py +++ b/streaming/python/examples/simple.py @@ -7,9 +7,8 @@ import time import ray -from ray.experimental.streaming.streaming import Environment -from ray.experimental.streaming.batched_queue import BatchedQueue -from ray.experimental.streaming.operator import OpType, PStrategy +from ray.streaming.config import Config +from ray.streaming.streaming import Environment, Conf logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -33,27 +32,25 @@ def filter_fn(word): args = parser.parse_args() - ray.init() - ray.register_custom_serializer(BatchedQueue, use_pickle=True) - ray.register_custom_serializer(OpType, use_pickle=True) - ray.register_custom_serializer(PStrategy, use_pickle=True) + ray.init(local_mode=False) # A Ray streaming environment with the default configuration - env = Environment() + env = Environment(config=Conf(channel_type=Config.NATIVE_CHANNEL)) # Stream represents the ouput of the filter and # can be forked into other dataflows stream = env.read_text_file(args.input_file) \ - .shuffle() \ - .flat_map(splitter) \ - .set_parallelism(4) \ - .filter(filter_fn) \ - .set_parallelism(2) \ - .inspect(print) # Prints the contents of the + .shuffle() \ + .flat_map(splitter) \ + .set_parallelism(2) \ + .filter(filter_fn) \ + .set_parallelism(2) \ + .inspect(lambda x: print("result", x)) # Prints the contents of the # stream to stdout start = time.time() env_handle = env.execute() ray.get(env_handle) # Stay alive until execution finishes + env.wait_finish() end = time.time() logger.info("Elapsed time: {} secs".format(end - start)) logger.debug("Output stream id: {}".format(stream.id)) diff --git a/python/ray/experimental/streaming/examples/toy.txt b/streaming/python/examples/toy.txt similarity index 100% rename from python/ray/experimental/streaming/examples/toy.txt rename to streaming/python/examples/toy.txt diff --git a/python/ray/experimental/streaming/examples/wordcount.py b/streaming/python/examples/wordcount.py similarity index 90% rename from python/ray/experimental/streaming/examples/wordcount.py rename to streaming/python/examples/wordcount.py index 9cc933ed44dc9..2062f2b7f5f64 100644 --- a/python/ray/experimental/streaming/examples/wordcount.py +++ b/streaming/python/examples/wordcount.py @@ -5,12 +5,10 @@ import argparse import logging import time -import wikipedia import ray -from ray.experimental.streaming.streaming import Environment -from ray.experimental.streaming.batched_queue import BatchedQueue -from ray.experimental.streaming.operator import OpType, PStrategy +import wikipedia +from ray.streaming.streaming import Environment logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -86,9 +84,6 @@ def attribute_selector(tuple): titles_file = str(args.titles_file) ray.init() - ray.register_custom_serializer(BatchedQueue, use_pickle=True) - ray.register_custom_serializer(OpType, use_pickle=True) - ray.register_custom_serializer(PStrategy, use_pickle=True) # A Ray streaming environment with the default configuration env = Environment() @@ -108,6 +103,7 @@ def attribute_selector(tuple): start = time.time() env_handle = env.execute() # Deploys and executes the dataflow ray.get(env_handle) # Stay alive until execution finishes + env.wait_finish() end = time.time() logger.info("Elapsed time: {} secs".format(end - start)) logger.debug("Output stream id: {}".format(stream.id)) diff --git a/streaming/python/includes/__init__.pxd b/streaming/python/includes/__init__.pxd new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/streaming/python/includes/libstreaming.pxd b/streaming/python/includes/libstreaming.pxd new file mode 100644 index 0000000000000..0b1ad27c50de5 --- /dev/null +++ b/streaming/python/includes/libstreaming.pxd @@ -0,0 +1,153 @@ +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 +# flake8: noqa + +from libc.stdint cimport * +from libcpp cimport bool as c_bool +from libcpp.memory cimport shared_ptr +from libcpp.vector cimport vector as c_vector +from libcpp.list cimport list as c_list +from cpython cimport PyObject +cimport cpython + +cdef inline object PyObject_to_object(PyObject* o): + # Cast to "object" increments reference count + cdef object result = o + cpython.Py_DECREF(result) + return result + +from ray.includes.common cimport ( + CLanguage, + CRayObject, + CRayStatus, + CRayFunction +) + +from ray.includes.unique_ids cimport ( + CActorID, + CJobID, + CTaskID, + CObjectID, +) +from ray.includes.libcoreworker cimport CCoreWorker + +cdef extern from "status.h" namespace "ray::streaming" nogil: + cdef cppclass CStreamingStatus "ray::streaming::StreamingStatus": + pass + cdef CStreamingStatus StatusOK "ray::streaming::StreamingStatus::OK" + cdef CStreamingStatus StatusReconstructTimeOut "ray::streaming::StreamingStatus::ReconstructTimeOut" + cdef CStreamingStatus StatusQueueIdNotFound "ray::streaming::StreamingStatus::QueueIdNotFound" + cdef CStreamingStatus StatusResubscribeFailed "ray::streaming::StreamingStatus::ResubscribeFailed" + cdef CStreamingStatus StatusEmptyRingBuffer "ray::streaming::StreamingStatus::EmptyRingBuffer" + cdef CStreamingStatus StatusFullChannel "ray::streaming::StreamingStatus::FullChannel" + cdef CStreamingStatus StatusNoSuchItem "ray::streaming::StreamingStatus::NoSuchItem" + cdef CStreamingStatus StatusInitQueueFailed "ray::streaming::StreamingStatus::InitQueueFailed" + cdef CStreamingStatus StatusGetBundleTimeOut "ray::streaming::StreamingStatus::GetBundleTimeOut" + cdef CStreamingStatus StatusSkipSendEmptyMessage "ray::streaming::StreamingStatus::SkipSendEmptyMessage" + cdef CStreamingStatus StatusInterrupted "ray::streaming::StreamingStatus::Interrupted" + cdef CStreamingStatus StatusWaitQueueTimeOut "ray::streaming::StreamingStatus::WaitQueueTimeOut" + cdef CStreamingStatus StatusOutOfMemory "ray::streaming::StreamingStatus::OutOfMemory" + cdef CStreamingStatus StatusInvalid "ray::streaming::StreamingStatus::Invalid" + cdef CStreamingStatus StatusUnknownError "ray::streaming::StreamingStatus::UnknownError" + cdef CStreamingStatus StatusTailStatus "ray::streaming::StreamingStatus::TailStatus" + + cdef cppclass CStreamingCommon "ray::streaming::StreamingCommon": + void SetConfig(const uint8_t *, uint32_t size) + + +cdef extern from "runtime_context.h" namespace "ray::streaming" nogil: + cdef cppclass CRuntimeContext "ray::streaming::RuntimeContext": + CRuntimeContext() + void SetConfig(const uint8_t *data, uint32_t size) + inline void MarkMockTest() + inline c_bool IsMockTest() + +cdef extern from "message/message.h" namespace "ray::streaming" nogil: + cdef cppclass CStreamingMessageType "ray::streaming::StreamingMessageType": + pass + cdef CStreamingMessageType MessageTypeBarrier "ray::streaming::StreamingMessageType::Barrier" + cdef CStreamingMessageType MessageTypeMessage "ray::streaming::StreamingMessageType::Message" + cdef cppclass CStreamingMessage "ray::streaming::StreamingMessage": + inline uint8_t *RawData() const + inline uint32_t GetDataSize() const + inline CStreamingMessageType GetMessageType() const + inline uint64_t GetMessageSeqId() const + +cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil: + cdef cppclass CStreamingMessageBundleType "ray::streaming::StreamingMessageBundleType": + pass + cdef CStreamingMessageBundleType BundleTypeEmpty "ray::streaming::StreamingMessageBundleType::Empty" + cdef CStreamingMessageBundleType BundleTypeBarrier "ray::streaming::StreamingMessageBundleType::Barrier" + cdef CStreamingMessageBundleType BundleTypeBundle "ray::streaming::StreamingMessageBundleType::Bundle" + + cdef cppclass CStreamingMessageBundleMeta "ray::streaming::StreamingMessageBundleMeta": + CStreamingMessageBundleMeta() + inline uint64_t GetMessageBundleTs() const + inline uint64_t GetLastMessageId() const + inline uint32_t GetMessageListSize() const + inline CStreamingMessageBundleType GetBundleType() const + inline c_bool IsBarrier() + inline c_bool IsBundle() + + ctypedef shared_ptr[CStreamingMessageBundleMeta] CStreamingMessageBundleMetaPtr + uint32_t kMessageBundleHeaderSize "ray::streaming::kMessageBundleHeaderSize" + cdef cppclass CStreamingMessageBundle "ray::streaming::StreamingMessageBundle"(CStreamingMessageBundleMeta): + @staticmethod + void GetMessageListFromRawData(const uint8_t *data, uint32_t size, uint32_t msg_nums, + c_list[shared_ptr[CStreamingMessage]] &msg_list); + +cdef extern from "queue/queue_client.h" namespace "ray::streaming" nogil: + cdef cppclass CReaderClient "ray::streaming::ReaderClient": + CReaderClient(CCoreWorker *core_worker, + CRayFunction &async_func, + CRayFunction &sync_func) + void OnReaderMessage(shared_ptr[CLocalMemoryBuffer] buffer); + shared_ptr[CLocalMemoryBuffer] OnReaderMessageSync(shared_ptr[CLocalMemoryBuffer] buffer); + + cdef cppclass CWriterClient "ray::streaming::WriterClient": + CWriterClient(CCoreWorker *core_worker, + CRayFunction &async_func, + CRayFunction &sync_func) + void OnWriterMessage(shared_ptr[CLocalMemoryBuffer] buffer); + shared_ptr[CLocalMemoryBuffer] OnWriterMessageSync(shared_ptr[CLocalMemoryBuffer] buffer); + + +cdef extern from "data_reader.h" namespace "ray::streaming" nogil: + cdef cppclass CDataBundle "ray::streaming::DataBundle": + uint8_t *data + uint32_t data_size + CObjectID c_from "from" + uint64_t seq_id + CStreamingMessageBundleMetaPtr meta + + cdef cppclass CDataReader "ray::streaming::DataReader"(CStreamingCommon): + CDataReader(shared_ptr[CRuntimeContext] &runtime_context) + void Init(const c_vector[CObjectID] &input_ids, + const c_vector[CActorID] &actor_ids, + const c_vector[uint64_t] &seq_ids, + const c_vector[uint64_t] &msg_ids, + int64_t timer_interval); + CStreamingStatus GetBundle(const uint32_t timeout_ms, + shared_ptr[CDataBundle] &message) + void Stop() + + +cdef extern from "data_writer.h" namespace "ray::streaming" nogil: + cdef cppclass CDataWriter "ray::streaming::DataWriter"(CStreamingCommon): + CDataWriter(shared_ptr[CRuntimeContext] &runtime_context) + CStreamingStatus Init(const c_vector[CObjectID] &channel_ids, + const c_vector[CActorID] &actor_ids, + const c_vector[uint64_t] &message_ids, + const c_vector[uint64_t] &queue_size_vec); + long WriteMessageToBufferRing( + const CObjectID &q_id, uint8_t *data, uint32_t data_size) + void Run() + void Stop() + + +cdef extern from "ray/common/buffer.h" nogil: + cdef cppclass CLocalMemoryBuffer "ray::LocalMemoryBuffer": + uint8_t *Data() const + size_t Size() const diff --git a/streaming/python/includes/transfer.pxi b/streaming/python/includes/transfer.pxi new file mode 100644 index 0000000000000..7830cf8abde94 --- /dev/null +++ b/streaming/python/includes/transfer.pxi @@ -0,0 +1,323 @@ +# flake8: noqa + +from libc.stdint cimport * +from libcpp cimport bool as c_bool +from libcpp.memory cimport shared_ptr, make_shared, dynamic_pointer_cast +from libcpp.string cimport string as c_string +from libcpp.vector cimport vector as c_vector +from libcpp.list cimport list as c_list + +from ray.includes.common cimport ( + CRayFunction, + LANGUAGE_PYTHON, + CBuffer +) + +from ray.includes.unique_ids cimport ( + CActorID, + CObjectID +) +from ray._raylet cimport ( + Buffer, + CoreWorker, + ActorID, + ObjectID, + string_vector_from_list +) + +from ray.includes.libcoreworker cimport CCoreWorker + +cimport ray.streaming.includes.libstreaming as libstreaming +from ray.streaming.includes.libstreaming cimport ( + CStreamingStatus, + CStreamingMessage, + CStreamingMessageBundle, + CRuntimeContext, + CDataBundle, + CDataWriter, + CDataReader, + CReaderClient, + CWriterClient, + CLocalMemoryBuffer, +) + +import logging +from ray.function_manager import FunctionDescriptor + + +channel_logger = logging.getLogger(__name__) + + +cdef class ReaderClient: + cdef: + CReaderClient *client + + def __cinit__(self, + CoreWorker worker, + async_func: FunctionDescriptor, + sync_func: FunctionDescriptor): + cdef: + CCoreWorker *core_worker = worker.core_worker.get() + CRayFunction async_native_func + CRayFunction sync_native_func + async_native_func = CRayFunction( + LANGUAGE_PYTHON, string_vector_from_list(async_func.get_function_descriptor_list())) + sync_native_func = CRayFunction( + LANGUAGE_PYTHON, string_vector_from_list(sync_func.get_function_descriptor_list())) + self.client = new CReaderClient(core_worker, async_native_func, sync_native_func) + + def __dealloc__(self): + del self.client + self.client = NULL + + def on_reader_message(self, const unsigned char[:] value): + cdef: + size_t size = value.nbytes + shared_ptr[CLocalMemoryBuffer] local_buf = \ + make_shared[CLocalMemoryBuffer]((&value[0]), size, True) + with nogil: + self.client.OnReaderMessage(local_buf) + + def on_reader_message_sync(self, const unsigned char[:] value): + cdef: + size_t size = value.nbytes + shared_ptr[CLocalMemoryBuffer] local_buf = \ + make_shared[CLocalMemoryBuffer]((&value[0]), size, True) + shared_ptr[CLocalMemoryBuffer] result_buffer + with nogil: + result_buffer = self.client.OnReaderMessageSync(local_buf) + return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer)) + + +cdef class WriterClient: + cdef: + CWriterClient * client + + def __cinit__(self, + CoreWorker worker, + async_func: FunctionDescriptor, + sync_func: FunctionDescriptor): + cdef: + CCoreWorker *core_worker = worker.core_worker.get() + CRayFunction async_native_func + CRayFunction sync_native_func + async_native_func = CRayFunction( + LANGUAGE_PYTHON, string_vector_from_list(async_func.get_function_descriptor_list())) + sync_native_func = CRayFunction( + LANGUAGE_PYTHON, string_vector_from_list(sync_func.get_function_descriptor_list())) + self.client = new CWriterClient(core_worker, async_native_func, sync_native_func) + + def __dealloc__(self): + del self.client + self.client = NULL + + def on_writer_message(self, const unsigned char[:] value): + cdef: + size_t size = value.nbytes + shared_ptr[CLocalMemoryBuffer] local_buf = \ + make_shared[CLocalMemoryBuffer]((&value[0]), size, True) + with nogil: + self.client.OnWriterMessage(local_buf) + + def on_writer_message_sync(self, const unsigned char[:] value): + cdef: + size_t size = value.nbytes + shared_ptr[CLocalMemoryBuffer] local_buf = \ + make_shared[CLocalMemoryBuffer]((&value[0]), size, True) + shared_ptr[CLocalMemoryBuffer] result_buffer + with nogil: + result_buffer = self.client.OnWriterMessageSync(local_buf) + return Buffer.make(dynamic_pointer_cast[CBuffer, CLocalMemoryBuffer](result_buffer)) + + +cdef class DataWriter: + cdef: + CDataWriter *writer + + def __init__(self): + raise Exception("use create() to create DataWriter") + + @staticmethod + def create(list py_output_channels, + list output_actor_ids: list[ActorID], + uint64_t queue_size, + list py_msg_ids, + bytes config_bytes, + c_bool is_mock): + cdef: + c_vector[CObjectID] channel_ids = bytes_list_to_qid_vec(py_output_channels) + c_vector[CActorID] actor_ids + c_vector[uint64_t] msg_ids + CDataWriter *c_writer + cdef const unsigned char[:] config_data + for actor_id in output_actor_ids: + actor_ids.push_back((actor_id).data) + for py_msg_id in py_msg_ids: + msg_ids.push_back(py_msg_id) + + cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]() + if is_mock: + ctx.get().MarkMockTest() + if config_bytes: + config_data = config_bytes + channel_logger.info("load config, config bytes size: %s", config_data.nbytes) + ctx.get().SetConfig((&config_data[0]), config_data.nbytes) + c_writer = new CDataWriter(ctx) + cdef: + c_vector[CObjectID] remain_id_vec + c_vector[uint64_t] queue_size_vec + for i in range(channel_ids.size()): + queue_size_vec.push_back(queue_size) + cdef CStreamingStatus status = c_writer.Init(channel_ids, actor_ids, msg_ids, queue_size_vec) + if remain_id_vec.size() != 0: + channel_logger.warning("failed queue amounts => %s", remain_id_vec.size()) + if status != libstreaming.StatusOK: + msg = "initialize writer failed, status={}".format(status) + channel_logger.error(msg) + del c_writer + import ray.streaming.runtime.transfer as transfer + raise transfer.ChannelInitException(msg, qid_vector_to_list(remain_id_vec)) + + c_writer.Run() + channel_logger.info("create native writer succeed") + cdef DataWriter writer = DataWriter.__new__(DataWriter) + writer.writer = c_writer + return writer + + def __dealloc__(self): + if self.writer != NULL: + del self.writer + channel_logger.info("deleted DataWriter") + self.writer = NULL + + def write(self, ObjectID qid, const unsigned char[:] value): + """support zero-copy bytes, bytearray, array of unsigned char""" + cdef: + CObjectID native_id = qid.data + uint64_t msg_id + uint8_t *data = (&value[0]) + uint32_t size = value.nbytes + with nogil: + msg_id = self.writer.WriteMessageToBufferRing(native_id, data, size) + return msg_id + + def stop(self): + self.writer.Stop() + channel_logger.info("stopped DataWriter") + + +cdef class DataReader: + cdef: + CDataReader *reader + readonly bytes meta + readonly bytes data + + def __init__(self): + raise Exception("use create() to create DataReader") + + @staticmethod + def create(list py_input_queues, + list input_actor_ids: list[ActorID], + list py_seq_ids, + list py_msg_ids, + int64_t timer_interval, + c_bool is_recreate, + bytes config_bytes, + c_bool is_mock): + cdef: + c_vector[CObjectID] queue_id_vec = bytes_list_to_qid_vec(py_input_queues) + c_vector[CActorID] actor_ids + c_vector[uint64_t] seq_ids + c_vector[uint64_t] msg_ids + CDataReader *c_reader + cdef const unsigned char[:] config_data + for actor_id in input_actor_ids: + actor_ids.push_back((actor_id).data) + for py_seq_id in py_seq_ids: + seq_ids.push_back(py_seq_id) + for py_msg_id in py_msg_ids: + msg_ids.push_back(py_msg_id) + cdef shared_ptr[CRuntimeContext] ctx = make_shared[CRuntimeContext]() + if config_bytes: + config_data = config_bytes + channel_logger.info("load config, config bytes size: %s", config_data.nbytes) + ctx.get().SetConfig((&(config_data[0])), config_data.nbytes) + if is_mock: + ctx.get().MarkMockTest() + c_reader = new CDataReader(ctx) + c_reader.Init(queue_id_vec, actor_ids, seq_ids, msg_ids, timer_interval) + channel_logger.info("create native reader succeed") + cdef DataReader reader = DataReader.__new__(DataReader) + reader.reader = c_reader + return reader + + def __dealloc__(self): + if self.reader != NULL: + del self.reader + channel_logger.info("deleted DataReader") + self.reader = NULL + + def read(self, uint32_t timeout_millis): + cdef: + shared_ptr[CDataBundle] bundle + CStreamingStatus status + with nogil: + status = self.reader.GetBundle(timeout_millis, bundle) + cdef uint32_t bundle_type = (bundle.get().meta.get().GetBundleType()) + if status != libstreaming.StatusOK: + if status == libstreaming.StatusInterrupted: + # avoid cyclic import + import ray.streaming.runtime.transfer as transfer + raise transfer.ChannelInterruptException("reader interrupted") + elif status == libstreaming.StatusInitQueueFailed: + raise Exception("init channel failed") + elif status == libstreaming.StatusWaitQueueTimeOut: + raise Exception("wait channel object timeout") + cdef: + uint32_t msg_nums + CObjectID queue_id + c_list[shared_ptr[CStreamingMessage]] msg_list + list msgs = [] + uint64_t timestamp + uint64_t msg_id + if bundle_type == libstreaming.BundleTypeBundle: + msg_nums = bundle.get().meta.get().GetMessageListSize() + CStreamingMessageBundle.GetMessageListFromRawData( + bundle.get().data + libstreaming.kMessageBundleHeaderSize, + bundle.get().data_size - libstreaming.kMessageBundleHeaderSize, + msg_nums, + msg_list) + timestamp = bundle.get().meta.get().GetMessageBundleTs() + for msg in msg_list: + msg_bytes = msg.get().RawData()[:msg.get().GetDataSize()] + qid_bytes = queue_id.Binary() + msg_id = msg.get().GetMessageSeqId() + msgs.append((msg_bytes, msg_id, timestamp, qid_bytes)) + return msgs + elif bundle_type == libstreaming.BundleTypeEmpty: + return [] + else: + raise Exception("Unsupported bundle type {}".format(bundle_type)) + + def stop(self): + self.reader.Stop() + channel_logger.info("stopped DataReader") + + +cdef c_vector[CObjectID] bytes_list_to_qid_vec(list py_queue_ids) except *: + assert len(py_queue_ids) > 0 + cdef: + c_vector[CObjectID] queue_id_vec + c_string q_id_data + for q_id in py_queue_ids: + q_id_data = q_id + assert q_id_data.size() == CObjectID.Size() + obj_id = CObjectID.FromBinary(q_id_data) + queue_id_vec.push_back(obj_id) + return queue_id_vec + +cdef c_vector[c_string] qid_vector_to_list(c_vector[CObjectID] queue_id_vec): + queues = [] + for obj_id in queue_id_vec: + queues.append(obj_id.Binary()) + return queues \ No newline at end of file diff --git a/streaming/python/jobworker.py b/streaming/python/jobworker.py new file mode 100644 index 0000000000000..03bfa1f97dfef --- /dev/null +++ b/streaming/python/jobworker.py @@ -0,0 +1,124 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import pickle +import threading + +import ray +import ray.streaming._streaming as _streaming +from ray.streaming.config import Config +from ray.function_manager import FunctionDescriptor +from ray.streaming.communication import DataInput, DataOutput + +logger = logging.getLogger(__name__) + + +@ray.remote +class JobWorker(object): + """A streaming job worker. + + Attributes: + worker_id: The id of the instance. + input_channels: The input gate that manages input channels of + the instance (see: DataInput in communication.py). + output_channels (DataOutput): The output gate that manages output + channels of the instance (see: DataOutput in communication.py). + the operator instance. + """ + + def __init__(self, worker_id, operator, input_channels, output_channels): + self.env = None + self.worker_id = worker_id + self.operator = operator + processor_name = operator.processor_class.__name__ + processor_instance = operator.processor_class(operator) + self.processor_name = processor_name + self.processor_instance = processor_instance + self.input_channels = input_channels + self.output_channels = output_channels + self.input_gate = None + self.output_gate = None + self.reader_client = None + self.writer_client = None + + def init(self, env): + """init streaming actor""" + env = pickle.loads(env) + self.env = env + logger.info("init operator instance %s", self.processor_name) + + if env.config.channel_type == Config.NATIVE_CHANNEL: + core_worker = ray.worker.global_worker.core_worker + reader_async_func = FunctionDescriptor( + __name__, self.on_reader_message.__name__, + self.__class__.__name__) + reader_sync_func = FunctionDescriptor( + __name__, self.on_reader_message_sync.__name__, + self.__class__.__name__) + self.reader_client = _streaming.ReaderClient( + core_worker, reader_async_func, reader_sync_func) + writer_async_func = FunctionDescriptor( + __name__, self.on_writer_message.__name__, + self.__class__.__name__) + writer_sync_func = FunctionDescriptor( + __name__, self.on_writer_message_sync.__name__, + self.__class__.__name__) + self.writer_client = _streaming.WriterClient( + core_worker, writer_async_func, writer_sync_func) + if len(self.input_channels) > 0: + self.input_gate = DataInput(env, self.input_channels) + self.input_gate.init() + if len(self.output_channels) > 0: + self.output_gate = DataOutput( + env, self.output_channels, + self.operator.partitioning_strategies) + self.output_gate.init() + logger.info("init operator instance %s succeed", self.processor_name) + return True + + # Starts the actor + def start(self): + self.t = threading.Thread(target=self.run, daemon=True) + self.t.start() + actor_id = ray.worker.global_worker.actor_id + logger.info("%s %s started, actor id %s", self.__class__.__name__, + self.processor_name, actor_id) + + def run(self): + logger.info("%s start running", self.processor_name) + self.processor_instance.run(self.input_gate, self.output_gate) + logger.info("%s finished running", self.processor_name) + self.close() + + def close(self): + if self.input_gate: + self.input_gate.close() + if self.output_gate: + self.output_gate.close() + + def is_finished(self): + return not self.t.is_alive() + + def on_reader_message(self, buffer: bytes): + """used in direct call mode""" + self.reader_client.on_reader_message(buffer) + + def on_reader_message_sync(self, buffer: bytes): + """used in direct call mode""" + if self.reader_client is None: + return b" " * 4 # special flag to indicate this actor not ready + result = self.reader_client.on_reader_message_sync(buffer) + return result.to_pybytes() + + def on_writer_message(self, buffer: bytes): + """used in direct call mode""" + self.writer_client.on_writer_message(buffer) + + def on_writer_message_sync(self, buffer: bytes): + """used in direct call mode""" + if self.writer_client is None: + return b" " * 4 # special flag to indicate this actor not ready + result = self.writer_client.on_writer_message_sync(buffer) + return result.to_pybytes() diff --git a/python/ray/experimental/streaming/operator.py b/streaming/python/operator.py similarity index 87% rename from python/ray/experimental/streaming/operator.py rename to streaming/python/operator.py index 9f70a6450c029..cb698ff9a1959 100644 --- a/python/ray/experimental/streaming/operator.py +++ b/streaming/python/operator.py @@ -5,6 +5,8 @@ import enum import logging +import cloudpickle + logger = logging.getLogger(__name__) logger.setLevel("DEBUG") @@ -52,16 +54,18 @@ class OpType(enum.Enum): class Operator(object): def __init__(self, id, - type, + op_type, + processor_class, name="", logic=None, num_instances=1, other=None, state_actor=None): self.id = id - self.type = type + self.type = op_type + self.processor_class = processor_class self.name = name - self.logic = logic # The operator's logic + self._logic = cloudpickle.dumps(logic) # The operator's logic self.num_instances = num_instances # One partitioning strategy per downstream operator (default: forward) self.partitioning_strategies = {} @@ -96,10 +100,14 @@ def _clean(self): self.partitioning_strategies = strategies def print(self): - log = "Operator<\nID = {}\nName = {}\nType = {}\n" + log = "Operator<\nID = {}\nName = {}\nprocessor_class = {}\n" log += "Logic = {}\nNumber_of_Instances = {}\n" log += "Partitioning_Scheme = {}\nOther_Args = {}>\n" logger.debug( - log.format(self.id, self.name, self.type, self.logic, + log.format(self.id, self.name, self.processor_class, self.logic, self.num_instances, self.partitioning_strategies, self.other_args)) + + @property + def logic(self): + return cloudpickle.loads(self._logic) diff --git a/streaming/python/processor.py b/streaming/python/processor.py new file mode 100644 index 0000000000000..4ffcf2cab772f --- /dev/null +++ b/streaming/python/processor.py @@ -0,0 +1,226 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import sys +import time +import types + +logger = logging.getLogger(__name__) +logger.setLevel("INFO") + + +def _identity(element): + return element + + +class ReadTextFile: + """A source operator instance that reads a text file line by line. + + Attributes: + filepath (string): The path to the input file. + """ + + def __init__(self, operator): + self.filepath = operator.other_args + # TODO (john): Handle possible exception here + self.reader = open(self.filepath, "r") + + # Read input file line by line + def run(self, input_gate, output_gate): + while True: + record = self.reader.readline() + # Reader returns empty string ('') on EOF + if not record: + self.reader.close() + return + output_gate.push( + record[:-1]) # Push after removing newline characters + + +class Map: + """A map operator instance that applies a user-defined + stream transformation. + + A map produces exactly one output record for each record in + the input stream. + + """ + + def __init__(self, operator): + self.map_fn = operator.logic + + # Applies the mapper each record of the input stream(s) + # and pushes resulting records to the output stream(s) + def run(self, input_gate, output_gate): + elements = 0 + while True: + record = input_gate.pull() + if record is None: + return + output_gate.push(self.map_fn(record)) + elements += 1 + + +class FlatMap: + """A map operator instance that applies a user-defined + stream transformation. + + A flatmap produces one or more output records for each record in + the input stream. + + Attributes: + flatmap_fn (function): The user-defined function. + """ + + def __init__(self, operator): + self.flatmap_fn = operator.logic + + # Applies the splitter to the records of the input stream(s) + # and pushes resulting records to the output stream(s) + def run(self, input_gate, output_gate): + while True: + record = input_gate.pull() + if record is None: + return + output_gate.push_all(self.flatmap_fn(record)) + + +class Filter: + """A filter operator instance that applies a user-defined filter to + each record of the stream. + + Output records are those that pass the filter, i.e. those for which + the filter function returns True. + + Attributes: + filter_fn (function): The user-defined boolean function. + """ + + def __init__(self, operator): + self.filter_fn = operator.logic + + # Applies the filter to the records of the input stream(s) + # and pushes resulting records to the output stream(s) + def run(self, input_gate, output_gate): + while True: + record = input_gate.pull() + if record is None: + return + if self.filter_fn(record): + output_gate.push(record) + + +class Inspect: + """A inspect operator instance that inspects the content of the stream. + Inspect is useful for printing the records in the stream. + """ + + def __init__(self, operator): + self.inspect_fn = operator.logic + + def run(self, input_gate, output_gate): + # Applies the inspect logic (e.g. print) to the records of + # the input stream(s) + # and leaves stream unaffected by simply pushing the records to + # the output stream(s) + while True: + record = input_gate.pull() + if record is None: + return + if output_gate: + output_gate.push(record) + self.inspect_fn(record) + + +class Reduce: + """A reduce operator instance that combines a new value for a key + with the last reduced one according to a user-defined logic. + """ + + def __init__(self, operator): + self.reduce_fn = operator.logic + # Set the attribute selector + self.attribute_selector = operator.other_args + if self.attribute_selector is None: + self.attribute_selector = _identity + elif isinstance(self.attribute_selector, int): + self.key_index = self.attribute_selector + self.attribute_selector =\ + lambda record: record[self.attribute_selector] + elif isinstance(self.attribute_selector, str): + self.attribute_selector =\ + lambda record: vars(record)[self.attribute_selector] + elif not isinstance(self.attribute_selector, types.FunctionType): + sys.exit("Unrecognized or unsupported key selector.") + self.state = {} # key -> value + + # Combines the input value for a key with the last reduced + # value for that key to produce a new value. + # Outputs the result as (key,new value) + def run(self, input_gate, output_gate): + while True: + record = input_gate.pull() + if record is None: + return + key, rest = record + new_value = self.attribute_selector(rest) + # TODO (john): Is there a way to update state with + # a single dictionary lookup? + try: + old_value = self.state[key] + new_value = self.reduce_fn(old_value, new_value) + self.state[key] = new_value + except KeyError: # Key does not exist in state + self.state.setdefault(key, new_value) + output_gate.push((key, new_value)) + + # Returns the state of the actor + def get_state(self): + return self.state + + +class KeyBy: + """A key_by operator instance that physically partitions the + stream based on a key. + """ + + def __init__(self, operator): + # Set the key selector + self.key_selector = operator.other_args + if isinstance(self.key_selector, int): + self.key_selector = lambda r: r[self.key_selector] + elif isinstance(self.key_selector, str): + self.key_selector = lambda record: vars(record)[self.key_selector] + elif not isinstance(self.key_selector, types.FunctionType): + sys.exit("Unrecognized or unsupported key selector.") + + # The actual partitioning is done by the output gate + def run(self, input_gate, output_gate): + while True: + record = input_gate.pull() + if record is None: + return + key = self.key_selector(record) + output_gate.push((key, record)) + + +# A custom source actor +class Source: + def __init__(self, operator): + # The user-defined source with a get_next() method + self.source = operator.logic + + # Starts the source by calling get_next() repeatedly + def run(self, input_gate, output_gate): + start = time.time() + elements = 0 + while True: + record = self.source.get_next() + if not record: + logger.debug("[writer] puts per second: {}".format( + elements / (time.time() - start))) + return + output_gate.push(record) + elements += 1 diff --git a/streaming/python/runtime/__init__.py b/streaming/python/runtime/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/streaming/python/runtime/transfer.py b/streaming/python/runtime/transfer.py new file mode 100644 index 0000000000000..4bb9b5f1d65fe --- /dev/null +++ b/streaming/python/runtime/transfer.py @@ -0,0 +1,291 @@ +import logging +import random +from queue import Queue +from typing import List + +import ray +import ray.streaming._streaming as _streaming +import ray.streaming.generated.streaming_pb2 as streaming_pb +from ray.actor import ActorHandle, ActorID +from ray.streaming.config import Config + +CHANNEL_ID_LEN = 20 + + +class ChannelID: + """ + ChannelID is used to identify a transfer channel between + a upstream worker and downstream worker. + """ + + def __init__(self, channel_id_str: str): + """ + Args: + channel_id_str: string representation of channel id + """ + self.channel_id_str = channel_id_str + self.object_qid = ray.ObjectID(channel_id_str_to_bytes(channel_id_str)) + + def __eq__(self, other): + if other is None: + return False + if type(other) is ChannelID: + return self.channel_id_str == other.channel_id_str + else: + return False + + def __hash__(self): + return hash(self.channel_id_str) + + def __repr__(self): + return self.channel_id_str + + @staticmethod + def gen_random_id(): + """Generate a random channel id string + """ + res = "" + for i in range(CHANNEL_ID_LEN * 2): + res += str(chr(random.randint(0, 5) + ord("A"))) + return res + + @staticmethod + def gen_id(from_index, to_index, ts): + """Generate channel id, which is 20 character""" + channel_id = bytearray(20) + for i in range(11, 7, -1): + channel_id[i] = ts & 0xff + ts >>= 8 + channel_id[16] = (from_index & 0xffff) >> 8 + channel_id[17] = (from_index & 0xff) + channel_id[18] = (to_index & 0xffff) >> 8 + channel_id[19] = (to_index & 0xff) + return channel_bytes_to_str(bytes(channel_id)) + + +def channel_id_str_to_bytes(channel_id_str): + """ + Args: + channel_id_str: string representation of channel id + + Returns: + bytes representation of channel id + """ + assert type(channel_id_str) in [str, bytes] + if isinstance(channel_id_str, bytes): + return channel_id_str + qid_bytes = bytes.fromhex(channel_id_str) + assert len(qid_bytes) == CHANNEL_ID_LEN + return qid_bytes + + +def channel_bytes_to_str(id_bytes): + """ + Args: + id_bytes: bytes representation of channel id + + Returns: + string representation of channel id + """ + assert type(id_bytes) in [str, bytes] + if isinstance(id_bytes, str): + return id_bytes + return bytes.hex(id_bytes) + + +class DataMessage: + """ + DataMessage represents data between upstream and downstream operator + """ + + def __init__(self, + body, + timestamp, + channel_id, + message_id_, + is_empty_message=False): + self.__body = body + self.__timestamp = timestamp + self.__channel_id = channel_id + self.__message_id = message_id_ + self.__is_empty_message = is_empty_message + + def __len__(self): + return len(self.__body) + + def body(self): + """Message data""" + return self.__body + + def timestamp(self): + """Get timestamp when item is written by upstream DataWriter + """ + return self.__timestamp + + def channel_id(self): + """Get string id of channel where data is coming from + """ + return self.__channel_id + + def is_empty_message(self): + """Whether this message is an empty message. + Upstream DataWriter will send an empty message when this is no data + in specified interval. + """ + return self.__is_empty_message + + @property + def message_id(self): + return self.__message_id + + +logger = logging.getLogger(__name__) + + +class DataWriter: + """Data Writer is a wrapper of streaming c++ DataWriter, which sends data + to downstream workers + """ + + def __init__(self, output_channels, to_actors: List[ActorHandle], + conf: dict): + """Get DataWriter of output channels + Args: + output_channels: output channels ids + to_actors: downstream output actors + Returns: + DataWriter + """ + assert len(output_channels) > 0 + py_output_channels = [ + channel_id_str_to_bytes(qid_str) for qid_str in output_channels + ] + output_actor_ids: List[ActorID] = [ + handle._ray_actor_id for handle in to_actors + ] + channel_size = conf.get(Config.CHANNEL_SIZE, + Config.CHANNEL_SIZE_DEFAULT) + py_msg_ids = [0 for _ in range(len(output_channels))] + config_bytes = _to_native_conf(conf) + is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL + self.writer = _streaming.DataWriter.create( + py_output_channels, output_actor_ids, channel_size, py_msg_ids, + config_bytes, is_mock) + + logger.info("create DataWriter succeed") + + def write(self, channel_id: ChannelID, item: bytes): + """Write data into native channel + Args: + channel_id: channel id + item: bytes data + Returns: + msg_id + """ + assert type(item) == bytes + msg_id = self.writer.write(channel_id.object_qid, item) + return msg_id + + def stop(self): + logger.info("stopping channel writer.") + self.writer.stop() + # destruct DataWriter + self.writer = None + + def close(self): + logger.info("closing channel writer.") + + +class DataReader: + """Data Reader is wrapper of streaming c++ DataReader, which read data + from channels of upstream workers + """ + + def __init__(self, input_channels: List, from_actors: List[ActorHandle], + conf: dict): + """Get DataReader of input channels + Args: + input_channels: input channels + from_actors: upstream input actors + Returns: + DataReader + """ + assert len(input_channels) > 0 + py_input_channels = [ + channel_id_str_to_bytes(qid_str) for qid_str in input_channels + ] + input_actor_ids: List[ActorID] = [ + handle._ray_actor_id for handle in from_actors + ] + py_seq_ids = [0 for _ in range(len(input_channels))] + py_msg_ids = [0 for _ in range(len(input_channels))] + timer_interval = int(conf.get(Config.TIMER_INTERVAL_MS, -1)) + is_recreate = bool(conf.get(Config.IS_RECREATE, False)) + config_bytes = _to_native_conf(conf) + self.__queue = Queue(10000) + is_mock = conf[Config.CHANNEL_TYPE] == Config.MEMORY_CHANNEL + self.reader = _streaming.DataReader.create( + py_input_channels, input_actor_ids, py_seq_ids, py_msg_ids, + timer_interval, is_recreate, config_bytes, is_mock) + logger.info("create DataReader succeed") + + def read(self, timeout_millis): + """Read data from channel + Args: + timeout_millis: timeout millis when there is no data in channel + for this duration + Returns: + channel item + """ + if self.__queue.empty(): + msgs = self.reader.read(timeout_millis) + for msg in msgs: + msg_bytes, msg_id, timestamp, qid_bytes = msg + data_msg = DataMessage(msg_bytes, timestamp, + channel_bytes_to_str(qid_bytes), msg_id) + self.__queue.put(data_msg) + if self.__queue.empty(): + return None + return self.__queue.get() + + def stop(self): + logger.info("stopping Data Reader.") + self.reader.stop() + # destruct DataReader + self.reader = None + + def close(self): + logger.info("closing Data Reader.") + + +def _to_native_conf(conf): + config = streaming_pb.StreamingConfig() + if Config.STREAMING_JOB_NAME in conf: + config.job_name = conf[Config.STREAMING_JOB_NAME] + if Config.TASK_JOB_ID in conf: + job_id = conf[Config.TASK_JOB_ID] + config.task_job_id = job_id.hex() + if Config.STREAMING_WORKER_NAME in conf: + config.worker_name = conf[Config.STREAMING_WORKER_NAME] + if Config.STREAMING_OP_NAME in conf: + config.op_name = conf[Config.STREAMING_OP_NAME] + # TODO set operator type + if Config.STREAMING_RING_BUFFER_CAPACITY in conf: + config.ring_buffer_capacity = \ + conf[Config.STREAMING_RING_BUFFER_CAPACITY] + if Config.STREAMING_EMPTY_MESSAGE_INTERVAL in conf: + config.empty_message_interval = \ + conf[Config.STREAMING_EMPTY_MESSAGE_INTERVAL] + logger.info("conf: %s", str(config)) + return config.SerializeToString() + + +class ChannelInitException(Exception): + def __init__(self, msg, abnormal_channels): + self.abnormal_channels = abnormal_channels + self.msg = msg + + +class ChannelInterruptException(Exception): + def __init__(self, msg=None): + self.msg = msg diff --git a/python/ray/experimental/streaming/streaming.py b/streaming/python/streaming.py similarity index 70% rename from python/ray/experimental/streaming/streaming.py rename to streaming/python/streaming.py index f9e4241fe9a1a..6f93322bf69d3 100644 --- a/python/ray/experimental/streaming/streaming.py +++ b/streaming/python/streaming.py @@ -3,26 +3,24 @@ from __future__ import print_function import logging +import pickle import sys -import uuid +import time import networkx as nx - -from ray.experimental.streaming.communication import DataChannel, DataInput -from ray.experimental.streaming.communication import DataOutput, QueueConfig -from ray.experimental.streaming.operator import Operator, OpType -from ray.experimental.streaming.operator import PScheme, PStrategy -import ray.experimental.streaming.operator_instance as operator_instance +import ray +import ray.streaming.processor as processor +import ray.streaming.runtime.transfer as transfer +from ray.streaming.communication import DataChannel +from ray.streaming.config import Config +from ray.streaming.jobworker import JobWorker +from ray.streaming.operator import Operator, OpType +from ray.streaming.operator import PScheme, PStrategy logger = logging.getLogger(__name__) logger.setLevel("INFO") -# Generates UUIDs -def _generate_uuid(): - return uuid.uuid4() - - # Rolling sum's logic def _sum(value_1, value_2): return value_1 + value_2 @@ -36,137 +34,55 @@ def _sum(value_1, value_2): # Environment configuration -class Config(object): +class Conf(object): """Environment configuration. This class includes all information about the configuration of the streaming environment. - - Attributes: - queue_config (QueueConfig): Batched Queue configuration - (see: communication.py) - A batched queue configuration includes the max queue size, - the size of each batch (in number of elements), the batch flush - timeout, and the number of batches to prefetch from plasma - parallelism (int): The number of isntances (actors) for each logical - dataflow operator (default: 1) """ - def __init__(self, parallelism=1): - self.queue_config = QueueConfig() + def __init__(self, parallelism=1, channel_type=Config.MEMORY_CHANNEL): self.parallelism = parallelism + self.channel_type = channel_type # ... -# The execution environment for a streaming job -class Environment(object): - """A streaming environment. - - This class is responsible for constructing the logical and the - physical dataflow. - - Attributes: - logical_topo (DiGraph): The user-defined logical topology in - NetworkX DiGRaph format. - (See: https://networkx.github.io) - physical_topo (DiGraph): The physical topology in NetworkX - DiGRaph format. The physical dataflow is constructed by the - environment based on logical_topo. - operators (dict): A mapping from operator ids to operator metadata - (See: Operator in operator.py). - config (Config): The environment's configuration. - topo_cleaned (bool): A flag that indicates whether the logical - topology is garbage collected (True) or not (False). - actor_handles (list): A list of all Ray actor handles that execute - the streaming dataflow. - """ - - def __init__(self, config=Config()): - self.logical_topo = nx.DiGraph() # DAG +class ExecutionGraph: + def __init__(self, env): + self.env = env self.physical_topo = nx.DiGraph() # DAG - self.operators = {} # operator id --> operator object - self.config = config # Environment's configuration - self.topo_cleaned = False # Handles to all actors in the physical dataflow self.actor_handles = [] + # (op_id, op_instance_index) -> ActorID + self.actors_map = {} + # execution graph build time: milliseconds since epoch + self.build_time = 0 + self.task_id_counter = 0 + self.task_ids = {} + self.input_channels = {} # operator id -> input channels + self.output_channels = {} # operator id -> output channels # Constructs and deploys a Ray actor of a specific type # TODO (john): Actor placement information should be specified in # the environment's configuration - def __generate_actor(self, instance_id, operator, input, output): + def __generate_actor(self, instance_index, operator, input_channels, + output_channels): """Generates an actor that will execute a particular instance of the logical operator Attributes: - instance_id (UUID): The id of the instance the actor will execute. - operator (Operator): The metadata of the logical operator. - input (DataInput): The input gate that manages input channels of - the instance (see: DataInput in communication.py). - input (DataOutput): The output gate that manages output channels - of the instance (see: DataOutput in communication.py). + instance_index: The index of the instance the actor will execute. + operator: The metadata of the logical operator. + input_channels: The input channels of the instance. + output_channels The output channels of the instance. """ - actor_id = (operator.id, instance_id) + worker_id = (operator.id, instance_index) # Record the physical dataflow graph (for debugging purposes) - self.__add_channel(actor_id, input, output) - # Select actor to construct - if operator.type == OpType.Source: - source = operator_instance.Source.remote(actor_id, operator, input, - output) - source.register_handle.remote(source) - return source.start.remote() - elif operator.type == OpType.Map: - map = operator_instance.Map.remote(actor_id, operator, input, - output) - map.register_handle.remote(map) - return map.start.remote() - elif operator.type == OpType.FlatMap: - flatmap = operator_instance.FlatMap.remote(actor_id, operator, - input, output) - flatmap.register_handle.remote(flatmap) - return flatmap.start.remote() - elif operator.type == OpType.Filter: - filter = operator_instance.Filter.remote(actor_id, operator, input, - output) - filter.register_handle.remote(filter) - return filter.start.remote() - elif operator.type == OpType.Reduce: - reduce = operator_instance.Reduce.remote(actor_id, operator, input, - output) - reduce.register_handle.remote(reduce) - return reduce.start.remote() - elif operator.type == OpType.TimeWindow: - pass - elif operator.type == OpType.KeyBy: - keyby = operator_instance.KeyBy.remote(actor_id, operator, input, - output) - keyby.register_handle.remote(keyby) - return keyby.start.remote() - elif operator.type == OpType.Sum: - sum = operator_instance.Reduce.remote(actor_id, operator, input, - output) - # Register target handle at state actor - state_actor = operator.state_actor - if state_actor is not None: - state_actor.register_target.remote(sum) - # Register own handle - sum.register_handle.remote(sum) - return sum.start.remote() - elif operator.type == OpType.Sink: - pass - elif operator.type == OpType.Inspect: - inspect = operator_instance.Inspect.remote(actor_id, operator, - input, output) - inspect.register_handle.remote(inspect) - return inspect.start.remote() - elif operator.type == OpType.ReadTextFile: - # TODO (john): Colocate the source with the input file - read = operator_instance.ReadTextFile.remote( - actor_id, operator, input, output) - read.register_handle.remote(read) - return read.start.remote() - else: # TODO (john): Add support for other types of operators - sys.exit("Unrecognized or unsupported {} operator type.".format( - operator.type)) + self.__add_channel(worker_id, output_channels) + # Note direct_call only support pass by value + return JobWorker._remote( + args=[worker_id, operator, input_channels, output_channels], + is_direct_call=True) # Constructs and deploys a Ray actor for each instance of # the given operator @@ -185,33 +101,24 @@ def __generate_actors(self, operator, upstream_channels, num_instances = operator.num_instances logger.info("Generating {} actors of type {}...".format( num_instances, operator.type)) - in_channels = upstream_channels.pop( - operator.id) if upstream_channels else [] handles = [] for i in range(num_instances): # Collect input and output channels for the particular instance - ip = [ - channel for channel in in_channels - if channel.dst_instance_id == i - ] if in_channels else [] - op = [ - channel for channels_list in downstream_channels.values() - for channel in channels_list if channel.src_instance_id == i - ] + ip = [c for c in upstream_channels if c.dst_instance_index == i] + op = [c for c in downstream_channels if c.src_instance_index == i] log = "Constructed {} input and {} output channels " log += "for the {}-th instance of the {} operator." logger.debug(log.format(len(ip), len(op), i, operator.type)) - input_gate = DataInput(ip) - output_gate = DataOutput(op, operator.partitioning_strategies) - handle = self.__generate_actor(i, operator, input_gate, - output_gate) + handle = self.__generate_actor(i, operator, ip, op) if handle: handles.append(handle) + self.actors_map[(operator.id, i)] = handle return handles # Adds a channel/edge to the physical dataflow graph - def __add_channel(self, actor_id, input, output): - for dest_actor_id in output._destination_actor_ids(): + def __add_channel(self, actor_id, output_channels): + for c in output_channels: + dest_actor_id = (c.dst_operator_id, c.dst_instance_index) self.physical_topo.add_edge(actor_id, dest_actor_id) # Generates all required data channels between an operator @@ -232,26 +139,155 @@ def _generate_channels(self, operator): channels = {} # destination operator id -> channels strategies = operator.partitioning_strategies for dst_operator, p_scheme in strategies.items(): - num_dest_instances = self.operators[dst_operator].num_instances + num_dest_instances = self.env.operators[dst_operator].num_instances entry = channels.setdefault(dst_operator, []) if p_scheme.strategy == PStrategy.Forward: for i in range(operator.num_instances): # ID of destination instance to connect id = i % num_dest_instances - channel = DataChannel(self, operator.id, dst_operator, i, - id) - entry.append(channel) + qid = self._gen_str_qid(operator.id, i, dst_operator, id) + c = DataChannel(operator.id, i, dst_operator, id, qid) + entry.append(c) elif p_scheme.strategy in all_to_all_strategies: for i in range(operator.num_instances): for j in range(num_dest_instances): - channel = DataChannel(self, operator.id, dst_operator, - i, j) - entry.append(channel) + qid = self._gen_str_qid(operator.id, i, dst_operator, + j) + c = DataChannel(operator.id, i, dst_operator, j, qid) + entry.append(c) else: # TODO (john): Add support for other partitioning strategies sys.exit("Unrecognized or unsupported partitioning strategy.") return channels + def _gen_str_qid(self, src_operator_id, src_instance_index, + dst_operator_id, dst_instance_index): + from_task_id = self.env.execution_graph.get_task_id( + src_operator_id, src_instance_index) + to_task_id = self.env.execution_graph.get_task_id( + dst_operator_id, dst_instance_index) + return transfer.ChannelID.gen_id(from_task_id, to_task_id, + self.build_time) + + def _gen_task_id(self): + task_id = self.task_id_counter + self.task_id_counter += 1 + return task_id + + def get_task_id(self, op_id, op_instance_id): + return self.task_ids[(op_id, op_instance_id)] + + def get_actor(self, op_id, op_instance_id): + return self.actors_map[(op_id, op_instance_id)] + + # Prints the physical dataflow graph + def print_physical_graph(self): + logger.info("===================================") + logger.info("======Physical Dataflow Graph======") + logger.info("===================================") + # Print all data channels between operator instances + log = "(Source Operator ID,Source Operator Name,Source Instance ID)" + log += " --> " + log += "(Destination Operator ID,Destination Operator Name," + log += "Destination Instance ID)" + logger.info(log) + for src_actor_id, dst_actor_id in self.physical_topo.edges: + src_operator_id, src_instance_index = src_actor_id + dst_operator_id, dst_instance_index = dst_actor_id + logger.info("({},{},{}) --> ({},{},{})".format( + src_operator_id, self.env.operators[src_operator_id].name, + src_instance_index, dst_operator_id, + self.env.operators[dst_operator_id].name, dst_instance_index)) + + def build_graph(self): + self.build_channels() + + # to support cyclic reference serialization + try: + ray.register_custom_serializer(Environment, use_pickle=True) + ray.register_custom_serializer(ExecutionGraph, use_pickle=True) + ray.register_custom_serializer(OpType, use_pickle=True) + ray.register_custom_serializer(PStrategy, use_pickle=True) + except Exception: + # local mode can't use pickle + pass + + # Each operator instance is implemented as a Ray actor + # Actors are deployed in topological order, as we traverse the + # logical dataflow from sources to sinks. + for node in nx.topological_sort(self.env.logical_topo): + operator = self.env.operators[node] + # Instantiate Ray actors + handles = self.__generate_actors( + operator, self.input_channels.get(node, []), + self.output_channels.get(node, [])) + if handles: + self.actor_handles.extend(handles) + + def build_channels(self): + self.build_time = int(time.time() * 1000) + # gen auto-incremented unique task id for every operator instance + for node in nx.topological_sort(self.env.logical_topo): + operator = self.env.operators[node] + for i in range(operator.num_instances): + operator_instance_id = (operator.id, i) + self.task_ids[operator_instance_id] = self._gen_task_id() + channels = {} + for node in nx.topological_sort(self.env.logical_topo): + operator = self.env.operators[node] + # Generate downstream data channels + downstream_channels = self._generate_channels(operator) + channels[node] = downstream_channels + # op_id -> channels + input_channels = {} + output_channels = {} + for op_id, all_downstream_channels in channels.items(): + for dst_op_channels in all_downstream_channels.values(): + for c in dst_op_channels: + dst = input_channels.setdefault(c.dst_operator_id, []) + dst.append(c) + src = output_channels.setdefault(c.src_operator_id, []) + src.append(c) + self.input_channels = input_channels + self.output_channels = output_channels + + +# The execution environment for a streaming job +class Environment(object): + """A streaming environment. + + This class is responsible for constructing the logical and the + physical dataflow. + + Attributes: + logical_topo (DiGraph): The user-defined logical topology in + NetworkX DiGRaph format. + (See: https://networkx.github.io) + physical_topo (DiGraph): The physical topology in NetworkX + DiGRaph format. The physical dataflow is constructed by the + environment based on logical_topo. + operators (dict): A mapping from operator ids to operator metadata + (See: Operator in operator.py). + config (Config): The environment's configuration. + topo_cleaned (bool): A flag that indicates whether the logical + topology is garbage collected (True) or not (False). + actor_handles (list): A list of all Ray actor handles that execute + the streaming dataflow. + """ + + def __init__(self, config=Conf()): + self.logical_topo = nx.DiGraph() # DAG + self.operators = {} # operator id --> operator object + self.config = config # Environment's configuration + self.topo_cleaned = False + self.operator_id_counter = 0 + self.execution_graph = None # set when executed + + def gen_operator_id(self): + op_id = self.operator_id_counter + self.operator_id_counter += 1 + return op_id + # An edge denotes a flow of data between logical operators # and may correspond to multiple data channels in the physical dataflow def _add_edge(self, source, destination): @@ -275,19 +311,15 @@ def _set_parallelism(self, operator_id, level_of_parallelism): def set_parallelism(self, parallelism): self.config.parallelism = parallelism - # Sets batched queue configuration for the environment - def set_queue_config(self, queue_config): - self.config.queue_config = queue_config - # Creates and registers a user-defined data source # TODO (john): There should be different types of sources, e.g. sources # reading from Kafka, text files, etc. # TODO (john): Handle case where environment parallelism is set def source(self, source): - source_id = _generate_uuid() + source_id = self.gen_operator_id() source_stream = DataStream(self, source_id) self.operators[source_id] = Operator( - source_id, OpType.Source, "Source", other=source) + source_id, OpType.Source, processor.Source, "Source", logic=source) return source_stream # Creates and registers a new data source that reads a @@ -296,10 +328,14 @@ def source(self, source): # e.g. sources reading from Kafka, text files, etc. # TODO (john): Handle case where environment parallelism is set def read_text_file(self, filepath): - source_id = _generate_uuid() + source_id = self.gen_operator_id() source_stream = DataStream(self, source_id) self.operators[source_id] = Operator( - source_id, OpType.ReadTextFile, "Read Text File", other=filepath) + source_id, + OpType.ReadTextFile, + processor.ReadTextFile, + "Read Text File", + other=filepath) return source_stream # Constructs and deploys the physical dataflow @@ -312,24 +348,27 @@ def execute(self): # upstream instances, some of the downstream instances will not be # used at all - # Each operator instance is implemented as a Ray actor - # Actors are deployed in topological order, as we traverse the - # logical dataflow from sources to sinks. At each step, data - # producers wait for acknowledge from consumers before starting - # generating data. - upstream_channels = {} - for node in nx.topological_sort(self.logical_topo): - operator = self.operators[node] - # Generate downstream data channels - downstream_channels = self._generate_channels(operator) - # Instantiate Ray actors - handles = self.__generate_actors(operator, upstream_channels, - downstream_channels) - if handles: - self.actor_handles.extend(handles) - upstream_channels.update(downstream_channels) - logger.debug("Running...") - return self.actor_handles + self.execution_graph = ExecutionGraph(self) + self.execution_graph.build_graph() + logger.info("init...") + # init + init_waits = [] + for actor_handle in self.execution_graph.actor_handles: + init_waits.append(actor_handle.init.remote(pickle.dumps(self))) + for wait in init_waits: + assert ray.get(wait) is True + logger.info("running...") + # start + exec_handles = [] + for actor_handle in self.execution_graph.actor_handles: + exec_handles.append(actor_handle.start.remote()) + + return exec_handles + + def wait_finish(self): + for actor_handle in self.execution_graph.actor_handles: + while not ray.get(actor_handle.is_finished.remote()): + time.sleep(1) # Prints the logical dataflow graph def print_logical_graph(self): @@ -349,25 +388,6 @@ def print_logical_graph(self): for downstream_node in downstream_neighbors: self.operators[downstream_node].print() - # Prints the physical dataflow graph - def print_physical_graph(self): - logger.info("===================================") - logger.info("======Physical Dataflow Graph======") - logger.info("===================================") - # Print all data channels between operator instances - log = "(Source Operator ID,Source Operator Name,Source Instance ID)" - log += " --> " - log += "(Destination Operator ID,Destination Operator Name," - log += "Destination Instance ID)" - logger.info(log) - for src_actor_id, dst_actor_id in self.physical_topo.edges: - src_operator_id, src_instance_id = src_actor_id - dst_operator_id, dst_instance_id = dst_actor_id - logger.info("({},{},{}) --> ({},{},{})".format( - src_operator_id, self.operators[src_operator_id].name, - src_instance_id, dst_operator_id, - self.operators[dst_operator_id].name, dst_instance_id)) - # TODO (john): We also need KeyedDataStream and WindowedDataStream as # subclasses of DataStream to prevent ill-defined logical dataflows @@ -389,14 +409,16 @@ class DataStream(object): is_partitioned (bool): Denotes if there is a partitioning strategy (e.g. shuffle) for the stream or not (default stategy: Forward). """ + stream_id_counter = 0 def __init__(self, environment, source_id=None, dest_id=None, is_partitioned=False): - self.id = _generate_uuid() self.env = environment + self.id = DataStream.stream_id_counter + DataStream.stream_id_counter += 1 self.src_operator_id = source_id self.dst_operator_id = dest_id # True if a partitioning strategy for this stream exists, @@ -448,17 +470,17 @@ def __register(self, operator): src_operator = self.env.operators[self.src_operator_id] if self.is_partitioned is True: partitioning, _ = src_operator._get_partition_strategy(self.id) - src_operator._set_partition_strategy(_generate_uuid(), - partitioning, operator.id) + src_operator._set_partition_strategy(self.id, partitioning, + operator.id) elif src_operator.type == OpType.KeyBy: # Set the output partitioning strategy to shuffle by key partitioning = PScheme(PStrategy.ShuffleByKey) - src_operator._set_partition_strategy(_generate_uuid(), - partitioning, operator.id) + src_operator._set_partition_strategy(self.id, partitioning, + operator.id) else: # No partitioning strategy has been defined - set default partitioning = PScheme(PStrategy.Forward) - src_operator._set_partition_strategy(_generate_uuid(), - partitioning, operator.id) + src_operator._set_partition_strategy(self.id, partitioning, + operator.id) return self.__expand() # Sets the level of parallelism for an operator, i.e. its total @@ -525,8 +547,9 @@ def map(self, map_fn, name="Map"): map_fn (function): The user-defined logic of the map. """ op = Operator( - _generate_uuid(), + self.env.gen_operator_id(), OpType.Map, + processor.Map, name, map_fn, num_instances=self.env.config.parallelism) @@ -541,8 +564,9 @@ def flat_map(self, flatmap_fn): (e.g. split()). """ op = Operator( - _generate_uuid(), + self.env.gen_operator_id(), OpType.FlatMap, + processor.FlatMap, "FlatMap", flatmap_fn, num_instances=self.env.config.parallelism) @@ -558,8 +582,9 @@ def key_by(self, key_selector): (assuming tuple records). """ op = Operator( - _generate_uuid(), + self.env.gen_operator_id(), OpType.KeyBy, + processor.KeyBy, "KeyBy", other=key_selector, num_instances=self.env.config.parallelism) @@ -574,8 +599,9 @@ def reduce(self, reduce_fn): (assuming tuple records). """ op = Operator( - _generate_uuid(), + self.env.gen_operator_id(), OpType.Reduce, + processor.Reduce, "Sum", reduce_fn, num_instances=self.env.config.parallelism) @@ -590,8 +616,9 @@ def sum(self, attribute_selector, state_keeper=None): (assuming tuple records). """ op = Operator( - _generate_uuid(), + self.env.gen_operator_id(), OpType.Sum, + processor.Reduce, "Sum", _sum, other=attribute_selector, @@ -608,13 +635,7 @@ def time_window(self, window_width_ms): Attributes: window_width_ms (int): The length of the window in ms. """ - op = Operator( - _generate_uuid(), - OpType.TimeWindow, - "TimeWindow", - num_instances=self.env.config.parallelism, - other=window_width_ms) - return self.__register(op) + raise Exception("time_window is unsupported") # Registers filter operator to the environment def filter(self, filter_fn): @@ -624,8 +645,9 @@ def filter(self, filter_fn): filter_fn (function): The user-defined filter function. """ op = Operator( - _generate_uuid(), + self.env.gen_operator_id(), OpType.Filter, + processor.Filter, "Filter", filter_fn, num_instances=self.env.config.parallelism) @@ -634,8 +656,9 @@ def filter(self, filter_fn): # TODO (john): Registers window join operator to the environment def window_join(self, other_stream, join_attribute, window_width): op = Operator( - _generate_uuid(), + self.env.gen_operator_id(), OpType.WindowJoin, + processor.WindowJoin, "WindowJoin", num_instances=self.env.config.parallelism) return self.__register(op) @@ -648,8 +671,9 @@ def inspect(self, inspect_logic): inspect_logic (function): The user-defined inspect function. """ op = Operator( - _generate_uuid(), + self.env.gen_operator_id(), OpType.Inspect, + processor.Inspect, "Inspect", inspect_logic, num_instances=self.env.config.parallelism) @@ -661,8 +685,9 @@ def inspect(self, inspect_logic): def sink(self): """Closes the stream with a sink operator.""" op = Operator( - _generate_uuid(), + self.env.gen_operator_id(), OpType.Sink, + processor.Sink, "Sink", num_instances=self.env.config.parallelism) return self.__register(op) diff --git a/streaming/python/tests/__init__.py b/streaming/python/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/streaming/python/tests/test_direct_transfer.py b/streaming/python/tests/test_direct_transfer.py new file mode 100644 index 0000000000000..42321769f2aad --- /dev/null +++ b/streaming/python/tests/test_direct_transfer.py @@ -0,0 +1,127 @@ +import pickle +import threading +import time + +import ray +import ray.streaming._streaming as _streaming +import ray.streaming.runtime.transfer as transfer +from ray.function_manager import FunctionDescriptor +from ray.streaming.config import Config + + +@ray.remote +class Worker: + def __init__(self): + core_worker = ray.worker.global_worker.core_worker + writer_async_func = FunctionDescriptor( + __name__, self.on_writer_message.__name__, self.__class__.__name__) + writer_sync_func = FunctionDescriptor( + __name__, self.on_writer_message_sync.__name__, + self.__class__.__name__) + self.writer_client = _streaming.WriterClient( + core_worker, writer_async_func, writer_sync_func) + reader_async_func = FunctionDescriptor( + __name__, self.on_reader_message.__name__, self.__class__.__name__) + reader_sync_func = FunctionDescriptor( + __name__, self.on_reader_message_sync.__name__, + self.__class__.__name__) + self.reader_client = _streaming.ReaderClient( + core_worker, reader_async_func, reader_sync_func) + self.writer = None + self.output_channel_id = None + self.reader = None + + def init_writer(self, output_channel, reader_actor): + conf = { + Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context() + .current_driver_id, + Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL + } + self.writer = transfer.DataWriter([output_channel], + [pickle.loads(reader_actor)], conf) + self.output_channel_id = transfer.ChannelID(output_channel) + + def init_reader(self, input_channel, writer_actor): + conf = { + Config.TASK_JOB_ID: ray.runtime_context._get_runtime_context() + .current_driver_id, + Config.CHANNEL_TYPE: Config.NATIVE_CHANNEL + } + self.reader = transfer.DataReader([input_channel], + [pickle.loads(writer_actor)], conf) + + def start_write(self, msg_nums): + self.t = threading.Thread( + target=self.run_writer, args=[msg_nums], daemon=True) + self.t.start() + + def run_writer(self, msg_nums): + for i in range(msg_nums): + self.writer.write(self.output_channel_id, pickle.dumps(i)) + print("WriterWorker done.") + + def start_read(self, msg_nums): + self.t = threading.Thread( + target=self.run_reader, args=[msg_nums], daemon=True) + self.t.start() + + def run_reader(self, msg_nums): + count = 0 + msg = None + while count != msg_nums: + item = self.reader.read(100) + if item is None: + time.sleep(0.01) + else: + msg = pickle.loads(item.body()) + count += 1 + assert msg == msg_nums - 1 + print("ReaderWorker done.") + + def is_finished(self): + return not self.t.is_alive() + + def on_reader_message(self, buffer: bytes): + """used in direct call mode""" + self.reader_client.on_reader_message(buffer) + + def on_reader_message_sync(self, buffer: bytes): + """used in direct call mode""" + if self.reader_client is None: + return b" " * 4 # special flag to indicate this actor not ready + result = self.reader_client.on_reader_message_sync(buffer) + return result.to_pybytes() + + def on_writer_message(self, buffer: bytes): + """used in direct call mode""" + self.writer_client.on_writer_message(buffer) + + def on_writer_message_sync(self, buffer: bytes): + """used in direct call mode""" + if self.writer_client is None: + return b" " * 4 # special flag to indicate this actor not ready + result = self.writer_client.on_writer_message_sync(buffer) + return result.to_pybytes() + + +def test_queue(): + ray.init() + writer = Worker._remote(is_direct_call=True) + reader = Worker._remote(is_direct_call=True) + channel_id_str = transfer.ChannelID.gen_random_id() + inits = [ + writer.init_writer.remote(channel_id_str, pickle.dumps(reader)), + reader.init_reader.remote(channel_id_str, pickle.dumps(writer)) + ] + ray.get(inits) + msg_nums = 1000 + print("start read/write") + reader.start_read.remote(msg_nums) + writer.start_write.remote(msg_nums) + while not ray.get(reader.is_finished.remote()): + time.sleep(0.1) + ray.shutdown() + + +if __name__ == "__main__": + test_queue() diff --git a/python/ray/tests/test_logical_graph.py b/streaming/python/tests/test_logical_graph.py similarity index 93% rename from python/ray/tests/test_logical_graph.py rename to streaming/python/tests/test_logical_graph.py index 1cfe3f2323674..863cfc6daf3d6 100644 --- a/python/ray/tests/test_logical_graph.py +++ b/streaming/python/tests/test_logical_graph.py @@ -2,8 +2,8 @@ from __future__ import division from __future__ import print_function -from ray.experimental.streaming.streaming import Environment -from ray.experimental.streaming.operator import OpType, PStrategy +from ray.streaming.streaming import Environment, ExecutionGraph +from ray.streaming.operator import OpType, PStrategy def test_parallelism(): @@ -169,18 +169,20 @@ def _test_channels(environment, expected_channels): if operator.type == OpType.Map: map_id = id # Collect channels + environment.execution_graph = ExecutionGraph(environment) + environment.execution_graph.build_channels() channels_per_destination = [] for operator in environment.operators.values(): channels_per_destination.append( - environment._generate_channels(operator)) + environment.execution_graph._generate_channels(operator)) # Check actual connectivity actual = [] for destination in channels_per_destination: for channels in destination.values(): for channel in channels: - src_instance_id = channel.src_instance_id - dst_instance_id = channel.dst_instance_id - connection = (src_instance_id, dst_instance_id) + src_instance_index = channel.src_instance_index + dst_instance_index = channel.dst_instance_index + connection = (src_instance_index, dst_instance_index) assert channel.dst_operator_id == map_id, ( channel.dst_operator_id, map_id) actual.append(connection) @@ -205,6 +207,4 @@ def test_wordcount(): if __name__ == "__main__": - import pytest - import sys - sys.exit(pytest.main(["-v", __file__])) + test_channel_generation() diff --git a/streaming/python/tests/test_word_count.py b/streaming/python/tests/test_word_count.py new file mode 100644 index 0000000000000..9fc2f2e11b6fd --- /dev/null +++ b/streaming/python/tests/test_word_count.py @@ -0,0 +1,20 @@ +import ray +from ray.streaming.config import Config +from ray.streaming.streaming import Environment, Conf + + +def test_word_count(): + ray.init() + env = Environment(config=Conf(channel_type=Config.NATIVE_CHANNEL)) + env.read_text_file(__file__) \ + .set_parallelism(1) \ + .filter(lambda x: "word" in x) \ + .inspect(lambda x: print("result", x)) + env_handle = env.execute() + ray.get(env_handle) # Stay alive until execution finishes + env.wait_finish() + ray.shutdown() + + +if __name__ == "__main__": + test_word_count() diff --git a/streaming/src/channel.cc b/streaming/src/channel.cc new file mode 100644 index 0000000000000..de7c99f8e0da8 --- /dev/null +++ b/streaming/src/channel.cc @@ -0,0 +1,274 @@ +#include "channel.h" +#include +namespace ray { +namespace streaming { + +ProducerChannel::ProducerChannel(std::shared_ptr &transfer_config, + ProducerChannelInfo &p_channel_info) + : transfer_config_(transfer_config), channel_info(p_channel_info) {} + +ConsumerChannel::ConsumerChannel(std::shared_ptr &transfer_config, + ConsumerChannelInfo &c_channel_info) + : transfer_config_(transfer_config), channel_info(c_channel_info) {} + +StreamingQueueProducer::StreamingQueueProducer(std::shared_ptr &transfer_config, + ProducerChannelInfo &p_channel_info) + : ProducerChannel(transfer_config, p_channel_info) { + STREAMING_LOG(INFO) << "Producer Init"; +} + +StreamingQueueProducer::~StreamingQueueProducer() { + STREAMING_LOG(INFO) << "Producer Destory"; +} + +StreamingStatus StreamingQueueProducer::CreateTransferChannel() { + CreateQueue(); + + uint64_t queue_last_seq_id = 0; + uint64_t last_message_id_in_queue = 0; + + if (!last_message_id_in_queue) { + if (last_message_id_in_queue < channel_info.current_message_id) { + STREAMING_LOG(WARNING) << "last message id in queue : " << last_message_id_in_queue + << " is less than message checkpoint loaded id : " + << channel_info.current_message_id + << ", an old queue object " << channel_info.channel_id + << " was fond in store"; + } + last_message_id_in_queue = channel_info.current_message_id; + } + if (queue_last_seq_id == static_cast(-1)) { + queue_last_seq_id = 0; + } + channel_info.current_seq_id = queue_last_seq_id; + + STREAMING_LOG(WARNING) << "existing last message id => " << last_message_id_in_queue + << ", message id in channel => " + << channel_info.current_message_id << ", queue last seq id => " + << queue_last_seq_id; + + channel_info.message_last_commit_id = last_message_id_in_queue; + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueProducer::CreateQueue() { + STREAMING_LOG(INFO) << "CreateQueue qid: " << channel_info.channel_id + << " data_size: " << channel_info.queue_size; + auto upstream_handler = ray::streaming::UpstreamQueueMessageHandler::GetService(); + if (upstream_handler->UpstreamQueueExists(channel_info.channel_id)) { + RAY_LOG(INFO) << "StreamingQueueWriter::CreateQueue duplicate!!!"; + return StreamingStatus::OK; + } + + upstream_handler->SetPeerActorID(channel_info.channel_id, channel_info.actor_id); + queue_ = upstream_handler->CreateUpstreamQueue( + channel_info.channel_id, channel_info.actor_id, channel_info.queue_size); + STREAMING_CHECK(queue_ != nullptr); + + std::vector queue_ids, failed_queues; + queue_ids.push_back(channel_info.channel_id); + upstream_handler->WaitQueues(queue_ids, 10 * 1000, failed_queues); + + STREAMING_LOG(INFO) << "q id => " << channel_info.channel_id << ", queue size => " + << channel_info.queue_size; + + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueProducer::DestroyTransferChannel() { + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueProducer::ClearTransferCheckpoint( + uint64_t checkpoint_id, uint64_t checkpoint_offset) { + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueProducer::NotifyChannelConsumed(uint64_t channel_offset) { + queue_->SetQueueEvictionLimit(channel_offset); + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueProducer::ProduceItemToChannel(uint8_t *data, + uint32_t data_size) { + Status status = + PushQueueItem(channel_info.current_seq_id + 1, data, data_size, current_time_ms()); + + if (status.code() != StatusCode::OK) { + STREAMING_LOG(DEBUG) << channel_info.channel_id << " => Queue is full" + << " meesage => " << status.message(); + + // Assume that only status OutOfMemory and OK are acceptable. + // OutOfMemory means queue is full at that moment. + STREAMING_CHECK(status.code() == StatusCode::OutOfMemory) + << "status => " << status.message() + << ", perhaps data block is so large that it can't be stored in" + << ", data block size => " << data_size; + + return StreamingStatus::FullChannel; + } + return StreamingStatus::OK; +} + +Status StreamingQueueProducer::PushQueueItem(uint64_t seq_id, uint8_t *data, + uint32_t data_size, uint64_t timestamp) { + STREAMING_LOG(INFO) << "StreamingQueueProducer::PushQueueItem:" + << " qid: " << channel_info.channel_id << " seq_id: " << seq_id + << " data_size: " << data_size; + Status status = queue_->Push(seq_id, data, data_size, timestamp, false); + if (status.IsOutOfMemory()) { + status = queue_->TryEvictItems(); + if (!status.ok()) { + STREAMING_LOG(INFO) << "Evict fail."; + return status; + } + + status = queue_->Push(seq_id, data, data_size, timestamp, false); + } + + queue_->Send(); + return status; +} + +StreamingQueueConsumer::StreamingQueueConsumer(std::shared_ptr &transfer_config, + ConsumerChannelInfo &c_channel_info) + : ConsumerChannel(transfer_config, c_channel_info) { + STREAMING_LOG(INFO) << "Consumer Init"; +} + +StreamingQueueConsumer::~StreamingQueueConsumer() { + STREAMING_LOG(INFO) << "Consumer Destroy"; +} + +StreamingStatus StreamingQueueConsumer::CreateTransferChannel() { + auto downstream_handler = ray::streaming::DownstreamQueueMessageHandler::GetService(); + STREAMING_LOG(INFO) << "GetQueue qid: " << channel_info.channel_id + << " start_seq_id: " << channel_info.current_seq_id + 1; + if (downstream_handler->DownstreamQueueExists(channel_info.channel_id)) { + RAY_LOG(INFO) << "StreamingQueueReader::GetQueue duplicate!!!"; + return StreamingStatus::OK; + } + + downstream_handler->SetPeerActorID(channel_info.channel_id, channel_info.actor_id); + STREAMING_LOG(INFO) << "Create ReaderQueue " << channel_info.channel_id + << " pull from start_seq_id: " << channel_info.current_seq_id + 1; + queue_ = downstream_handler->CreateDownstreamQueue(channel_info.channel_id, + channel_info.actor_id); + + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueConsumer::DestroyTransferChannel() { + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueConsumer::ClearTransferCheckpoint( + uint64_t checkpoint_id, uint64_t checkpoint_offset) { + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueConsumer::ConsumeItemFromChannel(uint64_t &offset_id, + uint8_t *&data, + uint32_t &data_size, + uint32_t timeout) { + STREAMING_LOG(INFO) << "GetQueueItem qid: " << channel_info.channel_id; + STREAMING_CHECK(queue_ != nullptr); + QueueItem item = queue_->PopPendingBlockTimeout(timeout * 1000); + if (item.SeqId() == QUEUE_INVALID_SEQ_ID) { + STREAMING_LOG(INFO) << "GetQueueItem timeout."; + data = nullptr; + data_size = 0; + offset_id = QUEUE_INVALID_SEQ_ID; + return StreamingStatus::OK; + } + + data = item.Buffer()->Data(); + offset_id = item.SeqId(); + data_size = item.Buffer()->Size(); + + STREAMING_LOG(DEBUG) << "GetQueueItem qid: " << channel_info.channel_id + << " seq_id: " << offset_id << " msg_id: " << item.MaxMsgId() + << " data_size: " << data_size; + return StreamingStatus::OK; +} + +StreamingStatus StreamingQueueConsumer::NotifyChannelConsumed(uint64_t offset_id) { + STREAMING_CHECK(queue_ != nullptr); + queue_->OnConsumed(offset_id); + return StreamingStatus::OK; +} + +// For mock queue transfer +struct MockQueueItem { + uint64_t seq_id; + uint32_t data_size; + std::shared_ptr data; +}; + +struct MockQueue { + std::unordered_map>> + message_buffer_; + std::unordered_map>> + consumed_buffer_; +}; +static MockQueue mock_queue; + +StreamingStatus MockProducer::CreateTransferChannel() { + mock_queue.message_buffer_[channel_info.channel_id] = + std::make_shared>(500); + mock_queue.consumed_buffer_[channel_info.channel_id] = + std::make_shared>(500); + return StreamingStatus::OK; +} + +StreamingStatus MockProducer::DestroyTransferChannel() { + mock_queue.message_buffer_.erase(channel_info.channel_id); + mock_queue.consumed_buffer_.erase(channel_info.channel_id); + return StreamingStatus::OK; +} + +StreamingStatus MockProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_size) { + auto &ring_buffer = mock_queue.message_buffer_[channel_info.channel_id]; + if (ring_buffer->Full()) { + return StreamingStatus::OutOfMemory; + } + MockQueueItem item; + item.seq_id = channel_info.current_seq_id + 1; + item.data.reset(new uint8_t[data_size]); + item.data_size = data_size; + std::memcpy(item.data.get(), data, data_size); + ring_buffer->Push(item); + return StreamingStatus::OK; +} + +StreamingStatus MockConsumer::ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, + uint32_t &data_size, + uint32_t timeout) { + auto &channel_id = channel_info.channel_id; + if (mock_queue.message_buffer_.find(channel_id) == mock_queue.message_buffer_.end()) { + return StreamingStatus::NoSuchItem; + } + + if (mock_queue.message_buffer_[channel_id]->Empty()) { + return StreamingStatus::NoSuchItem; + } + MockQueueItem item = mock_queue.message_buffer_[channel_id]->Front(); + mock_queue.message_buffer_[channel_id]->Pop(); + mock_queue.consumed_buffer_[channel_id]->Push(item); + offset_id = item.seq_id; + data = item.data.get(); + data_size = item.data_size; + return StreamingStatus::OK; +} + +StreamingStatus MockConsumer::NotifyChannelConsumed(uint64_t offset_id) { + auto &channel_id = channel_info.channel_id; + auto &ring_buffer = mock_queue.consumed_buffer_[channel_id]; + while (!ring_buffer->Empty() && ring_buffer->Front().seq_id <= offset_id) { + ring_buffer->Pop(); + } + return StreamingStatus::OK; +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/channel.h b/streaming/src/channel.h new file mode 100644 index 0000000000000..507002481a1b7 --- /dev/null +++ b/streaming/src/channel.h @@ -0,0 +1,176 @@ +#ifndef RAY_CHANNEL_H +#define RAY_CHANNEL_H + +#include "config/streaming_config.h" +#include "queue/queue_handler.h" +#include "ring_buffer.h" +#include "status.h" +#include "util/streaming_util.h" + +namespace ray { +namespace streaming { + +struct StreamingQueueInfo { + uint64_t first_seq_id = 0; + uint64_t last_seq_id = 0; + uint64_t target_seq_id = 0; + uint64_t consumed_seq_id = 0; +}; + +/// PrducerChannelinfo and ConsumerChannelInfo contains channel information and +/// its metrics that help us to debug or show important messages in logging. +struct ProducerChannelInfo { + ObjectID channel_id; + StreamingRingBufferPtr writer_ring_buffer; + uint64_t current_message_id; + uint64_t current_seq_id; + uint64_t message_last_commit_id; + StreamingQueueInfo queue_info; + uint32_t queue_size; + int64_t message_pass_by_ts; + ActorID actor_id; +}; + +struct ConsumerChannelInfo { + ObjectID channel_id; + uint64_t current_message_id; + uint64_t current_seq_id; + uint64_t barrier_id; + uint64_t partial_barrier_id; + + StreamingQueueInfo queue_info; + + uint64_t last_queue_item_delay; + uint64_t last_queue_item_latency; + uint64_t last_queue_target_diff; + uint64_t get_queue_item_times; + ActorID actor_id; +}; + +/// Two types of channel are presented: +/// * ProducerChannel is supporting all writing operations for upperlevel. +/// * ConsumerChannel is for all reader operations. +/// They share similar interfaces: +/// * ClearTransferCheckpoint(it's empty and unsupported now, we will add +/// implementation in next PR) +/// * NotifychannelConsumed (notify owner of channel which range data should +// be release to avoid out of memory) +/// but some differences in read/write function.(named ProduceItemTochannel and +/// ConsumeItemFrom channel) +class ProducerChannel { + public: + explicit ProducerChannel(std::shared_ptr &transfer_config, + ProducerChannelInfo &p_channel_info); + virtual ~ProducerChannel() = default; + virtual StreamingStatus CreateTransferChannel() = 0; + virtual StreamingStatus DestroyTransferChannel() = 0; + virtual StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, + uint64_t checkpoint_offset) = 0; + virtual StreamingStatus ProduceItemToChannel(uint8_t *data, uint32_t data_size) = 0; + virtual StreamingStatus NotifyChannelConsumed(uint64_t channel_offset) = 0; + + protected: + std::shared_ptr transfer_config_; + ProducerChannelInfo &channel_info; +}; + +class ConsumerChannel { + public: + explicit ConsumerChannel(std::shared_ptr &transfer_config, + ConsumerChannelInfo &c_channel_info); + virtual ~ConsumerChannel() = default; + virtual StreamingStatus CreateTransferChannel() = 0; + virtual StreamingStatus DestroyTransferChannel() = 0; + virtual StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, + uint64_t checkpoint_offset) = 0; + virtual StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, + uint32_t &data_size, + uint32_t timeout) = 0; + virtual StreamingStatus NotifyChannelConsumed(uint64_t offset_id) = 0; + + protected: + std::shared_ptr transfer_config_; + ConsumerChannelInfo &channel_info; +}; + +class StreamingQueueProducer : public ProducerChannel { + public: + explicit StreamingQueueProducer(std::shared_ptr &transfer_config, + ProducerChannelInfo &p_channel_info); + ~StreamingQueueProducer() override; + StreamingStatus CreateTransferChannel() override; + StreamingStatus DestroyTransferChannel() override; + StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, + uint64_t checkpoint_offset) override; + StreamingStatus ProduceItemToChannel(uint8_t *data, uint32_t data_size) override; + StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override; + + private: + StreamingStatus CreateQueue(); + Status PushQueueItem(uint64_t seq_id, uint8_t *data, uint32_t data_size, + uint64_t timestamp); + + private: + std::shared_ptr queue_; +}; + +class StreamingQueueConsumer : public ConsumerChannel { + public: + explicit StreamingQueueConsumer(std::shared_ptr &transfer_config, + ConsumerChannelInfo &c_channel_info); + ~StreamingQueueConsumer() override; + StreamingStatus CreateTransferChannel() override; + StreamingStatus DestroyTransferChannel() override; + StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, + uint64_t checkpoint_offset) override; + StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, + uint32_t &data_size, uint32_t timeout) override; + StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override; + + private: + std::shared_ptr queue_; +}; + +/// MockProducer and Mockconsumer are independent implementation of channels that +/// conduct a very simple memory channel for unit tests or intergation test. +class MockProducer : public ProducerChannel { + public: + explicit MockProducer(std::shared_ptr &transfer_config, + ProducerChannelInfo &channel_info) + : ProducerChannel(transfer_config, channel_info){}; + StreamingStatus CreateTransferChannel() override; + + StreamingStatus DestroyTransferChannel() override; + + StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, + uint64_t checkpoint_offset) override { + return StreamingStatus::OK; + } + + StreamingStatus ProduceItemToChannel(uint8_t *data, uint32_t data_size) override; + + StreamingStatus NotifyChannelConsumed(uint64_t channel_offset) override { + return StreamingStatus::OK; + } +}; + +class MockConsumer : public ConsumerChannel { + public: + explicit MockConsumer(std::shared_ptr &transfer_config, + ConsumerChannelInfo &c_channel_info) + : ConsumerChannel(transfer_config, c_channel_info){}; + StreamingStatus CreateTransferChannel() override { return StreamingStatus::OK; } + StreamingStatus DestroyTransferChannel() override { return StreamingStatus::OK; } + StreamingStatus ClearTransferCheckpoint(uint64_t checkpoint_id, + uint64_t checkpoint_offset) override { + return StreamingStatus::OK; + } + StreamingStatus ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, + uint32_t &data_size, uint32_t timeout) override; + StreamingStatus NotifyChannelConsumed(uint64_t offset_id) override; +}; + +} // namespace streaming +} // namespace ray + +#endif // RAY_CHANNEL_H diff --git a/streaming/src/config/streaming_config.cc b/streaming/src/config/streaming_config.cc new file mode 100644 index 0000000000000..094463ddf4c6c --- /dev/null +++ b/streaming/src/config/streaming_config.cc @@ -0,0 +1,89 @@ +#include + +#include "streaming_config.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +uint64_t StreamingConfig::TIME_WAIT_UINT = 1; +uint32_t StreamingConfig::DEFAULT_RING_BUFFER_CAPACITY = 500; +uint32_t StreamingConfig::DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL = 20; +// Time to force clean if barrier in queue, default 0ms +const uint32_t StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE = 2048; + +void StreamingConfig::FromProto(const uint8_t *data, uint32_t size) { + proto::StreamingConfig config; + STREAMING_CHECK(config.ParseFromArray(data, size)) << "Parse streaming conf failed"; + if (!config.job_name().empty()) { + SetJobName(config.job_name()); + } + if (!config.task_job_id().empty()) { + STREAMING_CHECK(config.task_job_id().size() == 2 * JobID::Size()); + SetTaskJobId(config.task_job_id()); + } + if (!config.worker_name().empty()) { + SetWorkerName(config.worker_name()); + } + if (!config.op_name().empty()) { + SetOpName(config.op_name()); + } + if (config.role() != proto::OperatorType::UNKNOWN) { + SetOperatorType(config.role()); + } + if (config.ring_buffer_capacity() != 0) { + SetRingBufferCapacity(config.ring_buffer_capacity()); + } + if (config.empty_message_interval() != 0) { + SetEmptyMessageTimeInterval(config.empty_message_interval()); + } +} + +uint32_t StreamingConfig::GetRingBufferCapacity() const { return ring_buffer_capacity_; } + +void StreamingConfig::SetRingBufferCapacity(uint32_t ring_buffer_capacity) { + StreamingConfig::ring_buffer_capacity_ = + std::min(ring_buffer_capacity, StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE); +} + +uint32_t StreamingConfig::GetEmptyMessageTimeInterval() const { + return empty_message_time_interval_; +} + +void StreamingConfig::SetEmptyMessageTimeInterval(uint32_t empty_message_time_interval) { + StreamingConfig::empty_message_time_interval_ = empty_message_time_interval; +} + +streaming::proto::OperatorType StreamingConfig::GetOperatorType() const { + return operator_type_; +} + +void StreamingConfig::SetOperatorType(streaming::proto::OperatorType type) { + StreamingConfig::operator_type_ = type; +} + +const std::string &StreamingConfig::GetJobName() const { return job_name_; } + +void StreamingConfig::SetJobName(const std::string &job_name) { + StreamingConfig::job_name_ = job_name; +} + +const std::string &StreamingConfig::GetOpName() const { return op_name_; } + +void StreamingConfig::SetOpName(const std::string &op_name) { + StreamingConfig::op_name_ = op_name; +} + +const std::string &StreamingConfig::GetWorkerName() const { return worker_name_; } +void StreamingConfig::SetWorkerName(const std::string &worker_name) { + StreamingConfig::worker_name_ = worker_name; +} + +const std::string &StreamingConfig::GetTaskJobId() const { return task_job_id_; } + +void StreamingConfig::SetTaskJobId(const std::string &task_job_id) { + StreamingConfig::task_job_id_ = task_job_id; +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/config/streaming_config.h b/streaming/src/config/streaming_config.h new file mode 100644 index 0000000000000..cd38b562bcc0c --- /dev/null +++ b/streaming/src/config/streaming_config.h @@ -0,0 +1,69 @@ +#ifndef RAY_STREAMING_CONFIG_H +#define RAY_STREAMING_CONFIG_H + +#include +#include + +#include "protobuf/streaming.pb.h" +#include "ray/common/id.h" + +namespace ray { +namespace streaming { + +class StreamingConfig { + public: + static uint64_t TIME_WAIT_UINT; + static uint32_t DEFAULT_RING_BUFFER_CAPACITY; + static uint32_t DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL; + static const uint32_t MESSAGE_BUNDLE_MAX_SIZE; + + private: + uint32_t ring_buffer_capacity_ = DEFAULT_RING_BUFFER_CAPACITY; + + uint32_t empty_message_time_interval_ = DEFAULT_EMPTY_MESSAGE_TIME_INTERVAL; + + streaming::proto::OperatorType operator_type_ = + streaming::proto::OperatorType::TRANSFORM; + + std::string job_name_ = "DEFAULT_JOB_NAME"; + + std::string op_name_ = "DEFAULT_OP_NAME"; + + std::string worker_name_ = "DEFAULT_WORKER_NAME"; + + std::string task_job_id_ = "ffffffff"; + + public: + void FromProto(const uint8_t *, uint32_t size); + + const std::string &GetTaskJobId() const; + + void SetTaskJobId(const std::string &task_job_id); + + const std::string &GetWorkerName() const; + + void SetWorkerName(const std::string &worker_name); + + const std::string &GetOpName() const; + + void SetOpName(const std::string &op_name); + + uint32_t GetEmptyMessageTimeInterval() const; + + void SetEmptyMessageTimeInterval(uint32_t empty_message_time_interval); + + uint32_t GetRingBufferCapacity() const; + + void SetRingBufferCapacity(uint32_t ring_buffer_capacity); + + streaming::proto::OperatorType GetOperatorType() const; + + void SetOperatorType(streaming::proto::OperatorType type); + + const std::string &GetJobName() const; + + void SetJobName(const std::string &job_name); +}; +} // namespace streaming +} // namespace ray +#endif // RAY_STREAMING_CONFIG_H diff --git a/streaming/src/data_reader.cc b/streaming/src/data_reader.cc new file mode 100644 index 0000000000000..71afdcc8b354c --- /dev/null +++ b/streaming/src/data_reader.cc @@ -0,0 +1,297 @@ +#include +#include +#include +#include +#include +#include + +#include "ray/util/logging.h" +#include "ray/util/util.h" + +#include "data_reader.h" +#include "message/message_bundle.h" + +namespace ray { +namespace streaming { + +const uint32_t DataReader::kReadItemTimeout = 1000; + +void DataReader::Init(const std::vector &input_ids, + const std::vector &actor_ids, + const std::vector &queue_seq_ids, + const std::vector &streaming_msg_ids, + int64_t timer_interval) { + Init(input_ids, actor_ids, timer_interval); + for (size_t i = 0; i < input_ids.size(); ++i) { + auto &q_id = input_ids[i]; + channel_info_map_[q_id].current_seq_id = queue_seq_ids[i]; + channel_info_map_[q_id].current_message_id = streaming_msg_ids[i]; + } +} + +void DataReader::Init(const std::vector &input_ids, + const std::vector &actor_ids, int64_t timer_interval) { + STREAMING_LOG(INFO) << input_ids.size() << " queue to init."; + + transfer_config_->Set(ConfigEnum::QUEUE_ID_VECTOR, input_ids); + + last_fetched_queue_item_ = nullptr; + timer_interval_ = timer_interval; + last_message_ts_ = 0; + input_queue_ids_ = input_ids; + last_message_latency_ = 0; + last_bundle_unit_ = 0; + + for (size_t i = 0; i < input_ids.size(); ++i) { + ObjectID q_id = input_ids[i]; + STREAMING_LOG(INFO) << "[Reader] Init queue id: " << q_id; + auto &channel_info = channel_info_map_[q_id]; + channel_info.channel_id = q_id; + channel_info.actor_id = actor_ids[i]; + channel_info.last_queue_item_delay = 0; + channel_info.last_queue_item_latency = 0; + channel_info.last_queue_target_diff = 0; + channel_info.get_queue_item_times = 0; + } + + /// Make the input id location stable. + sort(input_queue_ids_.begin(), input_queue_ids_.end(), + [](const ObjectID &a, const ObjectID &b) { return a.Hash() < b.Hash(); }); + std::copy(input_ids.begin(), input_ids.end(), std::back_inserter(unready_queue_ids_)); + InitChannel(); +} + +StreamingStatus DataReader::InitChannel() { + STREAMING_LOG(INFO) << "[Reader] Getting queues. total queue num " + << input_queue_ids_.size() << ", unready queue num => " + << unready_queue_ids_.size(); + + for (const auto &input_channel : unready_queue_ids_) { + auto &channel_info = channel_info_map_[input_channel]; + std::shared_ptr channel; + if (runtime_context_->IsMockTest()) { + channel = std::make_shared(transfer_config_, channel_info); + } else { + channel = std::make_shared(transfer_config_, channel_info); + } + + channel_map_.emplace(input_channel, channel); + StreamingStatus status = channel->CreateTransferChannel(); + if (StreamingStatus::OK != status) { + STREAMING_LOG(ERROR) << "Initialize queue failed, id => " << input_channel; + } + } + runtime_context_->SetRuntimeStatus(RuntimeStatus::Running); + STREAMING_LOG(INFO) << "[Reader] Reader construction done!"; + return StreamingStatus::OK; +} + +StreamingStatus DataReader::InitChannelMerger() { + STREAMING_LOG(INFO) << "[Reader] Initializing queue merger."; + // Init reader merger by given comparator when it's first created. + StreamingReaderMsgPtrComparator comparator; + if (!reader_merger_) { + reader_merger_.reset( + new PriorityQueue, StreamingReaderMsgPtrComparator>( + comparator)); + } + + // An old item in merger vector must be evicted before new queue item has been + // pushed. + if (!unready_queue_ids_.empty() && last_fetched_queue_item_) { + STREAMING_LOG(INFO) << "pop old item from => " << last_fetched_queue_item_->from; + RETURN_IF_NOT_OK(StashNextMessage(last_fetched_queue_item_)) + last_fetched_queue_item_.reset(); + } + // Create initial heap for priority queue. + for (auto &input_queue : unready_queue_ids_) { + std::shared_ptr msg = std::make_shared(); + RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info_map_[input_queue], msg)) + channel_info_map_[msg->from].current_seq_id = msg->seq_id; + channel_info_map_[msg->from].current_message_id = msg->meta->GetLastMessageId(); + reader_merger_->push(msg); + } + STREAMING_LOG(INFO) << "[Reader] Initializing merger done."; + return StreamingStatus::OK; +} + +StreamingStatus DataReader::GetMessageFromChannel(ConsumerChannelInfo &channel_info, + std::shared_ptr &message) { + auto &qid = channel_info.channel_id; + last_read_q_id_ = qid; + STREAMING_LOG(DEBUG) << "[Reader] send get request queue seq id => " << qid; + while (RuntimeStatus::Running == runtime_context_->GetRuntimeStatus() && + !message->data) { + auto status = channel_map_[channel_info.channel_id]->ConsumeItemFromChannel( + message->seq_id, message->data, message->data_size, kReadItemTimeout); + channel_info.get_queue_item_times++; + if (!message->data) { + STREAMING_LOG(DEBUG) << "[Reader] Queue " << qid << " status " << status + << " get item timeout, resend notify " + << channel_info.current_seq_id; + // TODO(lingxuan.zlx): notify consumed when it's timeout. + } + } + if (RuntimeStatus::Interrupted == runtime_context_->GetRuntimeStatus()) { + return StreamingStatus::Interrupted; + } + STREAMING_LOG(DEBUG) << "[Reader] recevied queue seq id => " << message->seq_id + << ", queue id => " << qid; + + message->from = qid; + message->meta = StreamingMessageBundleMeta::FromBytes(message->data); + return StreamingStatus::OK; +} + +StreamingStatus DataReader::StashNextMessage(std::shared_ptr &message) { + // Push new message into priority queue and record the channel metrics in + // channel info. + std::shared_ptr new_msg = std::make_shared(); + auto &channel_info = channel_info_map_[message->from]; + reader_merger_->pop(); + int64_t cur_time = current_time_ms(); + RETURN_IF_NOT_OK(GetMessageFromChannel(channel_info, new_msg)) + reader_merger_->push(new_msg); + channel_info.last_queue_item_delay = + new_msg->meta->GetMessageBundleTs() - message->meta->GetMessageBundleTs(); + channel_info.last_queue_item_latency = current_time_ms() - cur_time; + return StreamingStatus::OK; +} + +StreamingStatus DataReader::GetMergedMessageBundle(std::shared_ptr &message, + bool &is_valid_break) { + int64_t cur_time = current_time_ms(); + if (last_fetched_queue_item_) { + RETURN_IF_NOT_OK(StashNextMessage(last_fetched_queue_item_)) + } + message = reader_merger_->top(); + last_fetched_queue_item_ = message; + auto &offset_info = channel_info_map_[message->from]; + + uint64_t cur_queue_previous_msg_id = offset_info.current_message_id; + STREAMING_LOG(DEBUG) << "[Reader] [Bundle] from q_id =>" << message->from << "cur => " + << cur_queue_previous_msg_id << ", message list size" + << message->meta->GetMessageListSize() << ", lst message id =>" + << message->meta->GetLastMessageId() << ", q seq id => " + << message->seq_id << ", last barrier id => " << message->data_size + << ", " << message->meta->GetMessageBundleTs(); + + if (message->meta->IsBundle()) { + last_message_ts_ = cur_time; + is_valid_break = true; + } else if (timer_interval_ != -1 && cur_time - last_message_ts_ > timer_interval_) { + // Throw empty message when reaching timer_interval. + last_message_ts_ = cur_time; + is_valid_break = true; + } + + offset_info.current_message_id = message->meta->GetLastMessageId(); + offset_info.current_seq_id = message->seq_id; + last_bundle_ts_ = message->meta->GetMessageBundleTs(); + + STREAMING_LOG(DEBUG) << "[Reader] [Bundle] message type =>" + << static_cast(message->meta->GetBundleType()) + << " from id => " << message->from << ", queue seq id =>" + << message->seq_id << ", message id => " + << message->meta->GetLastMessageId(); + return StreamingStatus::OK; +} + +StreamingStatus DataReader::GetBundle(const uint32_t timeout_ms, + std::shared_ptr &message) { + // Notify consumed every item in this mode. + if (last_fetched_queue_item_) { + NotifyConsumedItem(channel_info_map_[last_fetched_queue_item_->from], + last_fetched_queue_item_->seq_id); + } + + /// DataBundle will be returned to the upper layer in the following cases: + /// a batch of data is returned when the real data is read, or an empty message + /// is returned to the upper layer when the given timeout period is reached to + /// avoid blocking for too long. + auto start_time = current_time_ms(); + bool is_valid_break = false; + uint32_t empty_bundle_cnt = 0; + while (!is_valid_break) { + if (RuntimeStatus::Interrupted == runtime_context_->GetRuntimeStatus()) { + return StreamingStatus::Interrupted; + } + auto cur_time = current_time_ms(); + auto dur = cur_time - start_time; + if (dur > timeout_ms) { + return StreamingStatus::GetBundleTimeOut; + } + if (!unready_queue_ids_.empty()) { + StreamingStatus status = InitChannel(); + switch (status) { + case StreamingStatus::InitQueueFailed: + break; + case StreamingStatus::WaitQueueTimeOut: + STREAMING_LOG(ERROR) + << "Wait upstream queue timeout, maybe some actors in deadlock"; + break; + default: + STREAMING_LOG(INFO) << "Init reader queue in GetBundle"; + } + if (StreamingStatus::OK != status) { + return status; + } + RETURN_IF_NOT_OK(InitChannelMerger()) + unready_queue_ids_.clear(); + auto &merge_vec = reader_merger_->getRawVector(); + for (auto &bundle : merge_vec) { + STREAMING_LOG(INFO) << "merger vector item => " << bundle->from; + } + } + RETURN_IF_NOT_OK(GetMergedMessageBundle(message, is_valid_break)); + if (!is_valid_break) { + empty_bundle_cnt++; + NotifyConsumedItem(channel_info_map_[message->from], message->seq_id); + } + } + last_message_latency_ += current_time_ms() - start_time; + if (message->meta->GetMessageListSize() > 0) { + last_bundle_unit_ = message->data_size * 1.0 / message->meta->GetMessageListSize(); + } + return StreamingStatus::OK; +} + +void DataReader::GetOffsetInfo( + std::unordered_map *&offset_map) { + offset_map = &channel_info_map_; + for (auto &offset_info : channel_info_map_) { + STREAMING_LOG(INFO) << "[Reader] [GetOffsetInfo], q id " << offset_info.first + << ", seq id => " << offset_info.second.current_seq_id + << ", message id => " << offset_info.second.current_message_id; + } +} + +void DataReader::NotifyConsumedItem(ConsumerChannelInfo &channel_info, uint64_t offset) { + channel_map_[channel_info.channel_id]->NotifyChannelConsumed(offset); + if (offset == channel_info.queue_info.last_seq_id) { + STREAMING_LOG(DEBUG) << "notify seq id equal to last seq id => " << offset; + } +} + +DataReader::DataReader(std::shared_ptr &runtime_context) + : transfer_config_(new Config()), runtime_context_(runtime_context) {} + +DataReader::~DataReader() { STREAMING_LOG(INFO) << "Streaming reader deconstruct."; } + +void DataReader::Stop() { + runtime_context_->SetRuntimeStatus(RuntimeStatus::Interrupted); +} + +bool StreamingReaderMsgPtrComparator::operator()(const std::shared_ptr &a, + const std::shared_ptr &b) { + STREAMING_CHECK(a->meta); + // We use hash value of id for stability of message in sorting. + if (a->meta->GetMessageBundleTs() == b->meta->GetMessageBundleTs()) { + return a->from.Hash() > b->from.Hash(); + } + return a->meta->GetMessageBundleTs() > b->meta->GetMessageBundleTs(); +} + +} // namespace streaming + +} // namespace ray diff --git a/streaming/src/data_reader.h b/streaming/src/data_reader.h new file mode 100644 index 0000000000000..79793c41cfad6 --- /dev/null +++ b/streaming/src/data_reader.h @@ -0,0 +1,127 @@ +#ifndef RAY_DATA_READER_H +#define RAY_DATA_READER_H + +#include +#include +#include +#include +#include +#include + +#include "channel.h" +#include "message/message_bundle.h" +#include "message/priority_queue.h" +#include "runtime_context.h" + +namespace ray { +namespace streaming { + +/// Databundle is super-bundle that contains channel information (upstream +/// channel id & bundle meta data) and raw buffer pointer. +struct DataBundle { + uint8_t *data = nullptr; + uint32_t data_size; + ObjectID from; + uint64_t seq_id; + StreamingMessageBundleMetaPtr meta; +}; + +/// This is implementation of merger policy in StreamingReaderMsgPtrComparator. +struct StreamingReaderMsgPtrComparator { + StreamingReaderMsgPtrComparator() = default; + bool operator()(const std::shared_ptr &a, + const std::shared_ptr &b); +}; + +/// DataReader will fetch data bundles from channels of upstream workers, once +/// invoked by user thread. Firstly put them into a priority queue ordered by bundle +/// comparator that's related meta-data, then pop out the top bunlde to user +/// thread every time, so that the order of the message can be guranteed, which +/// will also facilitate our future implementation of fault tolerance. Finally +/// user thread can extract messages from the bundle and process one by one. +class DataReader { + private: + std::vector input_queue_ids_; + + std::vector unready_queue_ids_; + + std::unique_ptr< + PriorityQueue, StreamingReaderMsgPtrComparator>> + reader_merger_; + + std::shared_ptr last_fetched_queue_item_; + + int64_t timer_interval_; + int64_t last_bundle_ts_; + int64_t last_message_ts_; + int64_t last_message_latency_; + int64_t last_bundle_unit_; + + ObjectID last_read_q_id_; + + static const uint32_t kReadItemTimeout; + + protected: + std::unordered_map channel_info_map_; + std::unordered_map> channel_map_; + std::shared_ptr transfer_config_; + std::shared_ptr runtime_context_; + + public: + explicit DataReader(std::shared_ptr &runtime_context); + virtual ~DataReader(); + + /// During initialization, only the channel parameters and necessary member properties + /// are assigned. All channels will be connected in the first reading operation. + /// \param input_ids + /// \param actor_ids + /// \param channel_seq_ids + /// \param msg_ids + /// \param timer_interval + void Init(const std::vector &input_ids, const std::vector &actor_ids, + const std::vector &channel_seq_ids, + const std::vector &msg_ids, int64_t timer_interval); + + void Init(const std::vector &input_ids, const std::vector &actor_ids, + int64_t timer_interval); + + /// Get latest message from input queues. + /// \param timeout_ms + /// \param message, return the latest message + StreamingStatus GetBundle(uint32_t timeout_ms, std::shared_ptr &message); + + /// Get offset information about channels for checkpoint. + /// \param offset_map (return value) + void GetOffsetInfo(std::unordered_map *&offset_map); + + void Stop(); + + /// Notify input queues to clear data whose seq id is equal or less than offset. + /// It's used when checkpoint is done. + /// \param channel_info + /// \param offset + /// + void NotifyConsumedItem(ConsumerChannelInfo &channel_info, uint64_t offset); + + private: + /// Create channels and connect to all upstream. + StreamingStatus InitChannel(); + + /// One item from every channel will be popped out, then collecting + /// them to a merged queue. High prioprity items will be fetched one by one. + /// When item pop from one channel where must produce new item for placeholder + /// in merged queue. + StreamingStatus InitChannelMerger(); + + StreamingStatus StashNextMessage(std::shared_ptr &message); + + StreamingStatus GetMessageFromChannel(ConsumerChannelInfo &channel_info, + std::shared_ptr &message); + + /// Get top item from prioprity queue. + StreamingStatus GetMergedMessageBundle(std::shared_ptr &message, + bool &is_valid_break); +}; +} // namespace streaming +} // namespace ray +#endif // RAY_DATA_READER_H diff --git a/streaming/src/data_writer.cc b/streaming/src/data_writer.cc new file mode 100644 index 0000000000000..e9a0c889fd195 --- /dev/null +++ b/streaming/src/data_writer.cc @@ -0,0 +1,310 @@ +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include "data_writer.h" +#include "util/streaming_util.h" + +namespace ray { +namespace streaming { + +void DataWriter::WriterLoopForward() { + STREAMING_CHECK(RuntimeStatus::Running == runtime_context_->GetRuntimeStatus()); + while (true) { + int64_t min_passby_message_ts = std::numeric_limits::max(); + uint32_t empty_messge_send_count = 0; + + for (auto &output_queue : output_queue_ids_) { + if (RuntimeStatus::Running != runtime_context_->GetRuntimeStatus()) { + return; + } + ProducerChannelInfo &channel_info = channel_info_map_[output_queue]; + bool is_push_empty_message = false; + StreamingStatus write_status = + WriteChannelProcess(channel_info, &is_push_empty_message); + int64_t current_ts = current_time_ms(); + if (StreamingStatus::OK == write_status) { + channel_info.message_pass_by_ts = current_ts; + if (is_push_empty_message) { + min_passby_message_ts = + std::min(channel_info.message_pass_by_ts, min_passby_message_ts); + empty_messge_send_count++; + } + } else if (StreamingStatus::FullChannel == write_status) { + } else { + if (StreamingStatus::EmptyRingBuffer != write_status) { + STREAMING_LOG(DEBUG) << "write buffer status => " + << static_cast(write_status) + << ", is push empty message => " << is_push_empty_message; + } + } + } + + if (empty_messge_send_count == output_queue_ids_.size()) { + // Sleep if empty message was sent in all channel. + uint64_t sleep_time_ = current_time_ms() - min_passby_message_ts; + // Sleep_time can be bigger than time interval because of network jitter. + if (sleep_time_ <= runtime_context_->GetConfig().GetEmptyMessageTimeInterval()) { + std::this_thread::sleep_for(std::chrono::milliseconds( + runtime_context_->GetConfig().GetEmptyMessageTimeInterval() - sleep_time_)); + } + } + } +} + +StreamingStatus DataWriter::WriteChannelProcess(ProducerChannelInfo &channel_info, + bool *is_empty_message) { + // No message in buffer, empty message will be sent to downstream queue. + uint64_t buffer_remain = 0; + StreamingStatus write_queue_flag = WriteBufferToChannel(channel_info, buffer_remain); + int64_t current_ts = current_time_ms(); + if (write_queue_flag == StreamingStatus::EmptyRingBuffer && + current_ts - channel_info.message_pass_by_ts >= + runtime_context_->GetConfig().GetEmptyMessageTimeInterval()) { + write_queue_flag = WriteEmptyMessage(channel_info); + *is_empty_message = true; + STREAMING_LOG(DEBUG) << "send empty message bundle in q_id =>" + << channel_info.channel_id; + } + return write_queue_flag; +} + +StreamingStatus DataWriter::WriteBufferToChannel(ProducerChannelInfo &channel_info, + uint64_t &buffer_remain) { + StreamingRingBufferPtr &buffer_ptr = channel_info.writer_ring_buffer; + if (!IsMessageAvailableInBuffer(channel_info)) { + return StreamingStatus::EmptyRingBuffer; + } + + // Flush transient buffer to queue first. + if (buffer_ptr->IsTransientAvaliable()) { + return WriteTransientBufferToChannel(channel_info); + } + + STREAMING_CHECK(CollectFromRingBuffer(channel_info, buffer_remain)) + << "empty data in ringbuffer, q id => " << channel_info.channel_id; + + return WriteTransientBufferToChannel(channel_info); +} + +void DataWriter::Run() { + STREAMING_LOG(INFO) << "WriterLoopForward start"; + loop_thread_ = std::make_shared(&DataWriter::WriterLoopForward, this); +} + +/// Since every memory ring buffer's size is limited, when the writing buffer is +/// full, the user thread will be blocked, which will cause backpressure +/// naturally. +uint64_t DataWriter::WriteMessageToBufferRing(const ObjectID &q_id, uint8_t *data, + uint32_t data_size, + StreamingMessageType message_type) { + STREAMING_LOG(DEBUG) << "WriteMessageToBufferRing q_id: " << q_id + << " data_size: " << data_size; + // TODO(lingxuan.zlx): currently, unsafe in multithreads + ProducerChannelInfo &channel_info = channel_info_map_[q_id]; + // Write message id stands for current lastest message id and differs from + // channel.current_message_id if it's barrier message. + uint64_t &write_message_id = channel_info.current_message_id; + write_message_id++; + auto &ring_buffer_ptr = channel_info.writer_ring_buffer; + while (ring_buffer_ptr->IsFull() && + runtime_context_->GetRuntimeStatus() == RuntimeStatus::Running) { + std::this_thread::sleep_for( + std::chrono::milliseconds(StreamingConfig::TIME_WAIT_UINT)); + } + if (runtime_context_->GetRuntimeStatus() != RuntimeStatus::Running) { + STREAMING_LOG(WARNING) << "stop in write message to ringbuffer"; + return 0; + } + ring_buffer_ptr->Push(std::make_shared( + data, data_size, write_message_id, message_type)); + + return write_message_id; +} + +StreamingStatus DataWriter::InitChannel(const ObjectID &q_id, const ActorID &actor_id, + uint64_t channel_message_id, + uint64_t queue_size) { + ProducerChannelInfo &channel_info = channel_info_map_[q_id]; + channel_info.current_message_id = channel_message_id; + channel_info.channel_id = q_id; + channel_info.actor_id = actor_id; + channel_info.queue_size = queue_size; + STREAMING_LOG(WARNING) << " Init queue [" << q_id << "]"; + channel_info.writer_ring_buffer = std::make_shared( + runtime_context_->GetConfig().GetRingBufferCapacity(), + StreamingRingBufferType::SPSC); + channel_info.message_pass_by_ts = current_time_ms(); + std::shared_ptr channel; + + if (runtime_context_->IsMockTest()) { + channel = std::make_shared(transfer_config_, channel_info); + } else { + channel = std::make_shared(transfer_config_, channel_info); + } + + channel_map_.emplace(q_id, channel); + RETURN_IF_NOT_OK(channel->CreateTransferChannel()) + return StreamingStatus::OK; +} + +StreamingStatus DataWriter::Init(const std::vector &queue_id_vec, + const std::vector &actor_ids, + const std::vector &channel_message_id_vec, + const std::vector &queue_size_vec) { + STREAMING_CHECK(!queue_id_vec.empty() && !channel_message_id_vec.empty()); + + ray::JobID job_id = + JobID::FromBinary(Util::Hexqid2str(runtime_context_->GetConfig().GetTaskJobId())); + + STREAMING_LOG(INFO) << "Job name => " << runtime_context_->GetConfig().GetJobName() + << ", job id => " << job_id; + + output_queue_ids_ = queue_id_vec; + transfer_config_->Set(ConfigEnum::QUEUE_ID_VECTOR, queue_id_vec); + + for (size_t i = 0; i < queue_id_vec.size(); ++i) { + StreamingStatus status = InitChannel(queue_id_vec[i], actor_ids[i], + channel_message_id_vec[i], queue_size_vec[i]); + if (status != StreamingStatus::OK) { + return status; + } + } + runtime_context_->SetRuntimeStatus(RuntimeStatus::Running); + return StreamingStatus::OK; +} + +DataWriter::DataWriter(std::shared_ptr &runtime_context) + : transfer_config_(new Config()), runtime_context_(runtime_context) {} + +DataWriter::~DataWriter() { + // Return if fail to init streaming writer + if (runtime_context_->GetRuntimeStatus() == RuntimeStatus::Init) { + return; + } + runtime_context_->SetRuntimeStatus(RuntimeStatus::Interrupted); + if (loop_thread_->joinable()) { + STREAMING_LOG(INFO) << "Writer loop thread waiting for join"; + loop_thread_->join(); + } + STREAMING_LOG(INFO) << "Writer client queue disconnect."; +} + +bool DataWriter::IsMessageAvailableInBuffer(ProducerChannelInfo &channel_info) { + return channel_info.writer_ring_buffer->IsTransientAvaliable() || + !channel_info.writer_ring_buffer->IsEmpty(); +} + +StreamingStatus DataWriter::WriteEmptyMessage(ProducerChannelInfo &channel_info) { + auto &q_id = channel_info.channel_id; + if (channel_info.message_last_commit_id < channel_info.current_message_id) { + // Abort to send empty message if ring buffer is not empty now. + STREAMING_LOG(DEBUG) << "q_id =>" << q_id << " abort to send empty, last commit id =>" + << channel_info.message_last_commit_id << ", channel max id => " + << channel_info.current_message_id; + return StreamingStatus::SkipSendEmptyMessage; + } + + // Make an empty bundle, use old ts from reloaded meta if it's not nullptr. + StreamingMessageBundlePtr bundle_ptr = std::make_shared( + channel_info.current_message_id, current_time_ms()); + auto &q_ringbuffer = channel_info.writer_ring_buffer; + q_ringbuffer->ReallocTransientBuffer(bundle_ptr->ClassBytesSize()); + bundle_ptr->ToBytes(q_ringbuffer->GetTransientBufferMutable()); + + StreamingStatus status = channel_map_[q_id]->ProduceItemToChannel( + const_cast(q_ringbuffer->GetTransientBuffer()), + q_ringbuffer->GetTransientBufferSize()); + STREAMING_LOG(DEBUG) << "q_id =>" << q_id << " send empty message, meta info =>" + << bundle_ptr->ToString(); + + q_ringbuffer->FreeTransientBuffer(); + RETURN_IF_NOT_OK(status) + channel_info.current_seq_id++; + channel_info.message_pass_by_ts = current_time_ms(); + return StreamingStatus::OK; +} + +StreamingStatus DataWriter::WriteTransientBufferToChannel( + ProducerChannelInfo &channel_info) { + StreamingRingBufferPtr &buffer_ptr = channel_info.writer_ring_buffer; + StreamingStatus status = channel_map_[channel_info.channel_id]->ProduceItemToChannel( + buffer_ptr->GetTransientBufferMutable(), buffer_ptr->GetTransientBufferSize()); + RETURN_IF_NOT_OK(status) + channel_info.current_seq_id++; + auto transient_bundle_meta = + StreamingMessageBundleMeta::FromBytes(buffer_ptr->GetTransientBuffer()); + bool is_barrier_bundle = transient_bundle_meta->IsBarrier(); + // Force delete to avoid super block memory isn't released so long + // if it's barrier bundle. + buffer_ptr->FreeTransientBuffer(is_barrier_bundle); + channel_info.message_last_commit_id = transient_bundle_meta->GetLastMessageId(); + return StreamingStatus::OK; +} + +bool DataWriter::CollectFromRingBuffer(ProducerChannelInfo &channel_info, + uint64_t &buffer_remain) { + StreamingRingBufferPtr &buffer_ptr = channel_info.writer_ring_buffer; + auto &q_id = channel_info.channel_id; + + std::list message_list; + uint64_t bundle_buffer_size = 0; + const uint32_t max_queue_item_size = channel_info.queue_size; + while (message_list.size() < runtime_context_->GetConfig().GetRingBufferCapacity() && + !buffer_ptr->IsEmpty()) { + StreamingMessagePtr &message_ptr = buffer_ptr->Front(); + uint32_t message_total_size = message_ptr->ClassBytesSize(); + if (!message_list.empty() && + bundle_buffer_size + message_total_size >= max_queue_item_size) { + STREAMING_LOG(DEBUG) << "message total size " << message_total_size + << " max queue item size => " << max_queue_item_size; + break; + } + if (!message_list.empty() && + message_list.back()->GetMessageType() != message_ptr->GetMessageType()) { + break; + } + // ClassBytesSize = DataSize + MetaDataSize + // bundle_buffer_size += message_ptr->GetDataSize(); + bundle_buffer_size += message_total_size; + message_list.push_back(message_ptr); + buffer_ptr->Pop(); + buffer_remain = buffer_ptr->Size(); + } + + if (bundle_buffer_size >= channel_info.queue_size) { + STREAMING_LOG(ERROR) << "bundle buffer is too large to store q id => " << q_id + << ", bundle size => " << bundle_buffer_size + << ", queue size => " << channel_info.queue_size; + } + + StreamingMessageBundlePtr bundle_ptr; + bundle_ptr = std::make_shared( + std::move(message_list), current_time_ms(), message_list.back()->GetMessageSeqId(), + StreamingMessageBundleType::Bundle, bundle_buffer_size); + buffer_ptr->ReallocTransientBuffer(bundle_ptr->ClassBytesSize()); + bundle_ptr->ToBytes(buffer_ptr->GetTransientBufferMutable()); + + STREAMING_CHECK(bundle_ptr->ClassBytesSize() == buffer_ptr->GetTransientBufferSize()); + return true; +} + +void DataWriter::Stop() { + for (auto &output_queue : output_queue_ids_) { + ProducerChannelInfo &channel_info = channel_info_map_[output_queue]; + while (!channel_info.writer_ring_buffer->IsEmpty()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + runtime_context_->SetRuntimeStatus(RuntimeStatus::Interrupted); +} +} // namespace streaming +} // namespace ray diff --git a/streaming/src/data_writer.h b/streaming/src/data_writer.h new file mode 100644 index 0000000000000..92b53b9b418e7 --- /dev/null +++ b/streaming/src/data_writer.h @@ -0,0 +1,115 @@ +#ifndef RAY_DATA_WRITER_H +#define RAY_DATA_WRITER_H + +#include +#include +#include +#include +#include + +#include "channel.h" +#include "config/streaming_config.h" +#include "message/message_bundle.h" +#include "runtime_context.h" + +namespace ray { +namespace streaming { + +/// DataWriter is designed for data transporting between upstream and downstream. +/// After the user sends the data, it does not immediately send the data to +/// downstream, but caches it in the corresponding memory ring buffer. There is +/// a spearate transfer thread (setup in WriterLoopForward function) to collect +/// the messages from all the ringbuffers, and write them to the corresponding +/// transmission channels, which is backed by StreamingQueue. Actually, the +/// advantage is that the user thread will not be affected by the transmission +/// speed during the data transfer. And also the transfer thread can automatically +/// batch the catched data from memory buffer into a data bundle to reduce +/// transmission overhead. In addtion, when there is no data in the ringbuffer, +/// it will also send an empty bundle, so downstream can know that and process +/// accordingly. It will sleep for a short interval to save cpu if all ring +/// buffers have no data in that moment. +class DataWriter { + private: + std::shared_ptr loop_thread_; + // One channel have unique identity. + std::vector output_queue_ids_; + + protected: + // ProducerTransfer is middle broker for data transporting. + std::unordered_map channel_info_map_; + std::unordered_map> channel_map_; + std::shared_ptr transfer_config_; + std::shared_ptr runtime_context_; + + private: + bool IsMessageAvailableInBuffer(ProducerChannelInfo &channel_info); + + /// This function handles two scenarios. When there is data in the transient + /// buffer, the existing data is written into the channel first, otherwise a + /// certain amount of message is first collected from the buffer and serialized + /// into the transient buffer, and finally written to the channel. + /// \\param channel_info + /// \\param buffer_remain + StreamingStatus WriteBufferToChannel(ProducerChannelInfo &channel_info, + uint64_t &buffer_remain); + + /// Start the loop forward thread for collecting messages from all channels. + /// Invoking stack: + /// WriterLoopForward + /// -- WriteChannelProcess + /// -- WriteBufferToChannel + /// -- CollectFromRingBuffer + /// -- WriteTransientBufferToChannel + /// -- WriteEmptyMessage(if WriteChannelProcess return empty state) + void WriterLoopForward(); + + /// Push empty message when no valid message or bundle was produced each time + /// interval. + /// \param channel_info + StreamingStatus WriteEmptyMessage(ProducerChannelInfo &channel_info); + + /// Flush all data from transient buffer to channel for transporting. + /// \param channel_info + StreamingStatus WriteTransientBufferToChannel(ProducerChannelInfo &channel_info); + + bool CollectFromRingBuffer(ProducerChannelInfo &channel_info, uint64_t &buffer_remain); + + StreamingStatus WriteChannelProcess(ProducerChannelInfo &channel_info, + bool *is_empty_message); + + StreamingStatus InitChannel(const ObjectID &q_id, const ActorID &actor_id, + uint64_t channel_message_id, uint64_t queue_size); + + public: + explicit DataWriter(std::shared_ptr &runtime_context); + virtual ~DataWriter(); + + /// Streaming writer client initialization. + /// \param queue_id_vec queue id vector + /// \param channel_message_id_vec channel seq id is related with message checkpoint + /// \param queue_size queue size (memory size not length) + StreamingStatus Init(const std::vector &channel_ids, + const std::vector &actor_ids, + const std::vector &channel_message_id_vec, + const std::vector &queue_size_vec); + + /// To increase throughout, we employed an output buffer for message transformation, + /// which means we merge a lot of message to a message bundle and no message will be + /// pushed into queue directly util daemon thread does this action. + /// Additionally, writing will block when buffer ring is full intentionly. + /// \param q_id + /// \param data + /// \param data_size + /// \param message_type + /// \return message seq iq + uint64_t WriteMessageToBufferRing( + const ObjectID &q_id, uint8_t *data, uint32_t data_size, + StreamingMessageType message_type = StreamingMessageType::Message); + + void Run(); + + void Stop(); +}; +} // namespace streaming +} // namespace ray +#endif // RAY_DATA_WRITER_H diff --git a/streaming/src/message/message.cc b/streaming/src/message/message.cc new file mode 100644 index 0000000000000..ca0de652bc4d9 --- /dev/null +++ b/streaming/src/message/message.cc @@ -0,0 +1,90 @@ +#include + +#include +#include + +#include "message.h" +#include "ray/common/status.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +StreamingMessage::StreamingMessage(std::shared_ptr &data, uint32_t data_size, + uint64_t seq_id, StreamingMessageType message_type) + : message_data_(data), + data_size_(data_size), + message_type_(message_type), + message_id_(seq_id) {} + +StreamingMessage::StreamingMessage(std::shared_ptr &&data, uint32_t data_size, + uint64_t seq_id, StreamingMessageType message_type) + : message_data_(data), + data_size_(data_size), + message_type_(message_type), + message_id_(seq_id) {} + +StreamingMessage::StreamingMessage(const uint8_t *data, uint32_t data_size, + uint64_t seq_id, StreamingMessageType message_type) + : data_size_(data_size), message_type_(message_type), message_id_(seq_id) { + message_data_.reset(new uint8_t[data_size], std::default_delete()); + std::memcpy(message_data_.get(), data, data_size_); +} + +StreamingMessage::StreamingMessage(const StreamingMessage &msg) { + data_size_ = msg.data_size_; + message_data_ = msg.message_data_; + message_id_ = msg.message_id_; + message_type_ = msg.message_type_; +} + +StreamingMessagePtr StreamingMessage::FromBytes(const uint8_t *bytes, + bool verifer_check) { + uint32_t byte_offset = 0; + uint32_t data_size = *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(data_size); + + uint64_t seq_id = *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(seq_id); + + StreamingMessageType msg_type = + *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(msg_type); + + auto buf = new uint8_t[data_size]; + std::memcpy(buf, bytes + byte_offset, data_size); + auto data_ptr = std::shared_ptr(buf, std::default_delete()); + return std::make_shared(data_ptr, data_size, seq_id, msg_type); +} + +void StreamingMessage::ToBytes(uint8_t *serlizable_data) { + uint32_t byte_offset = 0; + std::memcpy(serlizable_data + byte_offset, reinterpret_cast(&data_size_), + sizeof(data_size_)); + byte_offset += sizeof(data_size_); + + std::memcpy(serlizable_data + byte_offset, reinterpret_cast(&message_id_), + sizeof(message_id_)); + byte_offset += sizeof(message_id_); + + std::memcpy(serlizable_data + byte_offset, reinterpret_cast(&message_type_), + sizeof(message_type_)); + byte_offset += sizeof(message_type_); + + std::memcpy(serlizable_data + byte_offset, + reinterpret_cast(message_data_.get()), data_size_); + + byte_offset += data_size_; + + STREAMING_CHECK(byte_offset == this->ClassBytesSize()); +} + +bool StreamingMessage::operator==(const StreamingMessage &message) const { + return GetDataSize() == message.GetDataSize() && + GetMessageSeqId() == message.GetMessageSeqId() && + GetMessageType() == message.GetMessageType() && + !std::memcmp(RawData(), message.RawData(), data_size_); +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/message/message.h b/streaming/src/message/message.h new file mode 100644 index 0000000000000..086faeb258080 --- /dev/null +++ b/streaming/src/message/message.h @@ -0,0 +1,93 @@ +#ifndef RAY_MESSAGE_H +#define RAY_MESSAGE_H + +#include + +namespace ray { +namespace streaming { + +class StreamingMessage; + +typedef std::shared_ptr StreamingMessagePtr; + +enum class StreamingMessageType : uint32_t { + Barrier = 1, + Message = 2, + MIN = Barrier, + MAX = Message +}; + +constexpr uint32_t kMessageHeaderSize = + sizeof(uint32_t) + sizeof(uint64_t) + sizeof(StreamingMessageType); + +/// All messages should be wrapped by this protocol. +// DataSize means length of raw data, message id is increasing from [1, +INF]. +// MessageType will be used for barrier transporting and checkpoint. +/// +----------------+ +/// | DataSize=U32 | +/// +----------------+ +/// | MessageId=U64 | +/// +----------------+ +/// | MessageType=U32| +/// +----------------+ +/// | Data=var | +/// +----------------+ + +class StreamingMessage { + private: + std::shared_ptr message_data_; + uint32_t data_size_; + StreamingMessageType message_type_; + uint64_t message_id_; + + public: + /// Copy raw data from outside shared buffer. + /// \param data raw data from user buffer + /// \param data_size raw data size + /// \param seq_id message id + /// \param message_type + StreamingMessage(std::shared_ptr &data, uint32_t data_size, uint64_t seq_id, + StreamingMessageType message_type); + + /// Move outsite raw data to message data. + /// \param data raw data from user buffer + /// \param data_size raw data size + /// \param seq_id message id + /// \param message_type + StreamingMessage(std::shared_ptr &&data, uint32_t data_size, uint64_t seq_id, + StreamingMessageType message_type); + + /// Copy raw data from outside buffer. + /// \param data raw data from user buffer + /// \param data_size raw data size + /// \param seq_id message id + /// \param message_type + StreamingMessage(const uint8_t *data, uint32_t data_size, uint64_t seq_id, + StreamingMessageType message_type); + + StreamingMessage(const StreamingMessage &); + + StreamingMessage operator=(const StreamingMessage &) = delete; + + virtual ~StreamingMessage() = default; + + inline uint8_t *RawData() const { return message_data_.get(); } + + inline uint32_t GetDataSize() const { return data_size_; } + inline StreamingMessageType GetMessageType() const { return message_type_; } + inline uint64_t GetMessageSeqId() const { return message_id_; } + inline bool IsMessage() { return StreamingMessageType::Message == message_type_; } + inline bool IsBarrier() { return StreamingMessageType::Barrier == message_type_; } + + bool operator==(const StreamingMessage &) const; + + virtual void ToBytes(uint8_t *data); + static StreamingMessagePtr FromBytes(const uint8_t *data, bool verifer_check = true); + + inline virtual uint32_t ClassBytesSize() { return kMessageHeaderSize + data_size_; } +}; + +} // namespace streaming +} // namespace ray + +#endif // RAY_MESSAGE_H diff --git a/streaming/src/message/message_bundle.cc b/streaming/src/message/message_bundle.cc new file mode 100644 index 0000000000000..af2f8882f1154 --- /dev/null +++ b/streaming/src/message/message_bundle.cc @@ -0,0 +1,236 @@ +#include +#include + +#include "ray/common/status.h" + +#include "config/streaming_config.h" +#include "message_bundle.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { +StreamingMessageBundle::StreamingMessageBundle(uint64_t last_offset_seq_id, + uint64_t message_bundle_ts) + : StreamingMessageBundleMeta(message_bundle_ts, last_offset_seq_id, 0, + StreamingMessageBundleType::Empty) { + this->raw_bundle_size_ = 0; +} + +StreamingMessageBundleMeta::StreamingMessageBundleMeta( + uint64_t message_bundle_ts, uint64_t last_offset_seq_id, uint32_t message_list_size, + StreamingMessageBundleType bundle_type) + : message_bundle_ts_(message_bundle_ts), + last_message_id_(last_offset_seq_id), + message_list_size_(message_list_size), + bundle_type_(bundle_type) { + STREAMING_CHECK(message_list_size <= StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE); +} + +void StreamingMessageBundleMeta::ToBytes(uint8_t *bytes) { + uint32_t byte_offset = 0; + + uint32_t magicNum = StreamingMessageBundleMeta::StreamingMessageBundleMagicNum; + std::memcpy(bytes + byte_offset, reinterpret_cast(&magicNum), + sizeof(uint32_t)); + byte_offset += sizeof(uint32_t); + + std::memcpy(bytes + byte_offset, reinterpret_cast(&message_bundle_ts_), + sizeof(uint64_t)); + byte_offset += sizeof(uint64_t); + + std::memcpy(bytes + byte_offset, reinterpret_cast(&last_message_id_), + sizeof(uint64_t)); + byte_offset += sizeof(uint64_t); + + std::memcpy(bytes + byte_offset, reinterpret_cast(&message_list_size_), + sizeof(uint32_t)); + byte_offset += sizeof(uint32_t); + + std::memcpy(bytes + byte_offset, reinterpret_cast(&bundle_type_), + sizeof(StreamingMessageBundleType)); + byte_offset += sizeof(StreamingMessageBundleType); +} + +StreamingMessageBundleMetaPtr StreamingMessageBundleMeta::FromBytes(const uint8_t *bytes, + bool check) { + STREAMING_CHECK(bytes); + + uint32_t byte_offset = 0; + const uint32_t magic_num = *reinterpret_cast(bytes + byte_offset); + + if (magic_num != StreamingMessageBundleMagicNum) { + STREAMING_LOG(INFO) << "Magic Number => " << magic_num; + } + + STREAMING_CHECK(magic_num == StreamingMessageBundleMagicNum); + byte_offset += sizeof(uint32_t); + + uint64_t message_bundle_ts = *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(uint64_t); + + uint64_t last_message_id = *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(uint64_t); + + uint32_t messageListSize = *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(uint32_t); + STREAMING_LOG(DEBUG) << "ts => " << message_bundle_ts << " last message id => " + << last_message_id << " message size => " << messageListSize; + + STREAMING_CHECK(messageListSize <= StreamingConfig::MESSAGE_BUNDLE_MAX_SIZE); + + StreamingMessageBundleType messageBundleType = + *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(StreamingMessageBundleType); + + auto result = std::make_shared( + message_bundle_ts, last_message_id, messageListSize, messageBundleType); + STREAMING_CHECK(byte_offset == result->ClassBytesSize()); + return result; +} + +bool StreamingMessageBundleMeta::operator==(StreamingMessageBundleMeta &meta) const { + return this->message_list_size_ == meta.GetMessageListSize() && + this->message_bundle_ts_ == meta.GetMessageBundleTs() && + this->bundle_type_ == meta.GetBundleType() && + this->last_message_id_ == meta.GetLastMessageId(); +} + +bool StreamingMessageBundleMeta::operator==(StreamingMessageBundleMeta *meta) const { + return operator==(*meta); +} + +StreamingMessageBundleMeta::StreamingMessageBundleMeta( + StreamingMessageBundleMeta *meta_ptr) { + bundle_type_ = meta_ptr->bundle_type_; + last_message_id_ = meta_ptr->last_message_id_; + message_bundle_ts_ = meta_ptr->message_bundle_ts_; + message_list_size_ = meta_ptr->message_list_size_; +} + +StreamingMessageBundleMeta::StreamingMessageBundleMeta() + : bundle_type_(StreamingMessageBundleType::Empty) {} + +StreamingMessageBundle::StreamingMessageBundle( + std::list &&message_list, uint64_t message_ts, + uint64_t last_offset_seq_id, StreamingMessageBundleType bundle_type, + uint32_t raw_data_size) + : StreamingMessageBundleMeta(message_ts, last_offset_seq_id, message_list.size(), + bundle_type), + raw_bundle_size_(raw_data_size), + message_list_(message_list) { + if (bundle_type_ != StreamingMessageBundleType::Empty) { + if (!raw_bundle_size_) { + raw_bundle_size_ = std::accumulate( + message_list_.begin(), message_list_.end(), 0, + [](uint32_t x, StreamingMessagePtr &y) { return x + y->ClassBytesSize(); }); + } + } +} + +StreamingMessageBundle::StreamingMessageBundle( + std::list &message_list, uint64_t message_ts, + uint64_t last_offset_seq_id, StreamingMessageBundleType bundle_type, + uint32_t raw_data_size) + : StreamingMessageBundle(std::list(message_list), message_ts, + last_offset_seq_id, bundle_type, raw_data_size) {} + +StreamingMessageBundle::StreamingMessageBundle(StreamingMessageBundle &bundle) { + message_bundle_ts_ = bundle.message_bundle_ts_; + message_list_size_ = bundle.message_list_size_; + raw_bundle_size_ = bundle.raw_bundle_size_; + bundle_type_ = bundle.bundle_type_; + last_message_id_ = bundle.last_message_id_; + message_list_ = bundle.message_list_; +} + +void StreamingMessageBundle::ToBytes(uint8_t *bytes) { + uint32_t byte_offset = 0; + StreamingMessageBundleMeta::ToBytes(bytes + byte_offset); + + byte_offset += StreamingMessageBundleMeta::ClassBytesSize(); + + std::memcpy(bytes + byte_offset, reinterpret_cast(&raw_bundle_size_), + sizeof(uint32_t)); + byte_offset += sizeof(uint32_t); + + if (raw_bundle_size_ > 0) { + ConvertMessageListToRawData(message_list_, raw_bundle_size_, bytes + byte_offset); + } +} + +StreamingMessageBundlePtr StreamingMessageBundle::FromBytes(const uint8_t *bytes, + bool verifer_check) { + uint32_t byte_offset = 0; + StreamingMessageBundleMetaPtr meta_ptr = + StreamingMessageBundleMeta::FromBytes(bytes + byte_offset); + byte_offset += meta_ptr->ClassBytesSize(); + + uint32_t raw_data_size = *reinterpret_cast(bytes + byte_offset); + byte_offset += sizeof(uint32_t); + + std::list message_list; + // only message bundle own raw data + if (meta_ptr->GetBundleType() != StreamingMessageBundleType::Empty) { + GetMessageListFromRawData(bytes + byte_offset, raw_data_size, + meta_ptr->GetMessageListSize(), message_list); + byte_offset += raw_data_size; + } + auto result = std::make_shared( + message_list, meta_ptr->GetMessageBundleTs(), meta_ptr->GetLastMessageId(), + meta_ptr->GetBundleType()); + STREAMING_CHECK(byte_offset == result->ClassBytesSize()); + return result; +} + +void StreamingMessageBundle::GetMessageListFromRawData( + const uint8_t *bytes, uint32_t byte_size, uint32_t message_list_size, + std::list &message_list) { + uint32_t byte_offset = 0; + // only message bundle own raw data + for (size_t i = 0; i < message_list_size; ++i) { + StreamingMessagePtr item = StreamingMessage::FromBytes(bytes + byte_offset); + message_list.push_back(item); + byte_offset += item->ClassBytesSize(); + } + STREAMING_CHECK(byte_offset == byte_size); +} + +void StreamingMessageBundle::GetMessageList( + std::list &message_list) { + message_list = message_list_; +} + +void StreamingMessageBundle::ConvertMessageListToRawData( + const std::list &message_list, uint32_t raw_data_size, + uint8_t *raw_data) { + uint32_t byte_offset = 0; + for (auto &message : message_list) { + message->ToBytes(raw_data + byte_offset); + byte_offset += message->ClassBytesSize(); + } + STREAMING_CHECK(byte_offset == raw_data_size); +} + +bool StreamingMessageBundle::operator==(StreamingMessageBundle &bundle) const { + if (!(StreamingMessageBundleMeta::operator==(&bundle) && + this->GetRawBundleSize() == bundle.GetRawBundleSize() && + this->GetMessageListSize() == bundle.GetMessageListSize())) { + return false; + } + auto it1 = message_list_.begin(); + auto it2 = bundle.message_list_.begin(); + while (it1 != message_list_.end() && it2 != bundle.message_list_.end()) { + if (!((*it1).get()->operator==(*(*it2).get()))) { + return false; + } + it1++; + it2++; + } + return true; +} + +bool StreamingMessageBundle::operator==(StreamingMessageBundle *bundle) const { + return this->operator==(*bundle); +} +} // namespace streaming +} // namespace ray diff --git a/streaming/src/message/message_bundle.h b/streaming/src/message/message_bundle.h new file mode 100644 index 0000000000000..323197c0b66d0 --- /dev/null +++ b/streaming/src/message/message_bundle.h @@ -0,0 +1,164 @@ +#ifndef RAY_MESSAGE_BUNDLE_H +#define RAY_MESSAGE_BUNDLE_H + +#include +#include +#include + +#include "message.h" + +namespace ray { +namespace streaming { + +enum class StreamingMessageBundleType : uint32_t { + Empty = 1, + Barrier = 2, + Bundle = 3, + MIN = Empty, + MAX = Bundle +}; + +class StreamingMessageBundleMeta; +class StreamingMessageBundle; + +typedef std::shared_ptr StreamingMessageBundlePtr; +typedef std::shared_ptr StreamingMessageBundleMetaPtr; + +constexpr uint32_t kMessageBundleMetaHeaderSize = sizeof(uint32_t) + sizeof(uint32_t) + + sizeof(uint64_t) + sizeof(uint64_t) + + sizeof(StreamingMessageBundleType); + +constexpr uint32_t kMessageBundleHeaderSize = + kMessageBundleMetaHeaderSize + sizeof(uint32_t); + +class StreamingMessageBundleMeta { + public: + static const uint32_t StreamingMessageBundleMagicNum = 0xCAFEBABA; + + protected: + uint64_t message_bundle_ts_; + + uint64_t last_message_id_; + + uint32_t message_list_size_; + + StreamingMessageBundleType bundle_type_; + + public: + explicit StreamingMessageBundleMeta(uint64_t, uint64_t, uint32_t, + StreamingMessageBundleType); + + explicit StreamingMessageBundleMeta(StreamingMessageBundleMeta *); + + explicit StreamingMessageBundleMeta(); + + virtual ~StreamingMessageBundleMeta(){}; + + bool operator==(StreamingMessageBundleMeta &) const; + + bool operator==(StreamingMessageBundleMeta *) const; + + inline uint64_t GetMessageBundleTs() const { return message_bundle_ts_; } + + inline uint64_t GetLastMessageId() const { return last_message_id_; } + + inline uint32_t GetMessageListSize() const { return message_list_size_; } + + inline StreamingMessageBundleType GetBundleType() const { return bundle_type_; } + + inline bool IsBarrier() { return StreamingMessageBundleType::Barrier == bundle_type_; } + inline bool IsBundle() { return StreamingMessageBundleType::Bundle == bundle_type_; } + + virtual void ToBytes(uint8_t *data); + static StreamingMessageBundleMetaPtr FromBytes(const uint8_t *data, + bool verifer_check = true); + inline virtual uint32_t ClassBytesSize() { return kMessageBundleMetaHeaderSize; } + + std::string ToString() { + return std::to_string(last_message_id_) + "," + std::to_string(message_list_size_) + + "," + std::to_string(message_bundle_ts_) + "," + + std::to_string(static_cast(bundle_type_)); + } +}; + +/// StreamingMessageBundle inherits from metadata class (StreamingMessageBundleMeta) with +/// the following protocol: +/// MagicNum = 0xcafebaba +/// Timestamp 64bits timestamp (milliseconds from 1970) +/// LastMessageId( the last id of bundle) (0,INF] +/// MessageListSize(bundle len of message) +/// BundleType(a. bundle = 3 , b. barrier =2, c. empty = 1) +/// RawBundleSize(binary length of data) +/// RawData ( binary data) +/// +/// +--------------------+ +/// | MagicNum=U32 | +/// +--------------------+ +/// | BundleTs=U64 | +/// +--------------------+ +/// | LastMessageId=U64 | +/// +--------------------+ +/// | MessageListSize=U32| +/// +--------------------+ +/// | BundleType=U32 | +/// +--------------------+ +/// | RawBundleSize=U32 | +/// +--------------------+ +/// | RawData=var(N*Msg) | +/// +--------------------+ +/// It should be noted that StreamingMessageBundle and StreamingMessageBundleMeta share +/// almost same protocol but the last two fields (RawBundleSize and RawData). +class StreamingMessageBundle : public StreamingMessageBundleMeta { + private: + uint32_t raw_bundle_size_; + + // Lazy serlization/deserlization. + std::list message_list_; + + public: + explicit StreamingMessageBundle(std::list &&message_list, + uint64_t bundle_ts, uint64_t offset, + StreamingMessageBundleType bundle_type, + uint32_t raw_data_size = 0); + + // Duplicated copy if left reference in constructor. + explicit StreamingMessageBundle(std::list &message_list, + uint64_t bundle_ts, uint64_t offset, + StreamingMessageBundleType bundle_type, + uint32_t raw_data_size = 0); + + // New a empty bundle by passing last message id and timestamp. + explicit StreamingMessageBundle(uint64_t, uint64_t); + + explicit StreamingMessageBundle(StreamingMessageBundle &bundle); + + virtual ~StreamingMessageBundle() = default; + + inline uint32_t GetRawBundleSize() const { return raw_bundle_size_; } + + bool operator==(StreamingMessageBundle &bundle) const; + + bool operator==(StreamingMessageBundle *bundle_ptr) const; + + void GetMessageList(std::list &message_list); + + const std::list &GetMessageList() const { return message_list_; } + + virtual void ToBytes(uint8_t *data); + static StreamingMessageBundlePtr FromBytes(const uint8_t *data, + bool verifer_check = true); + inline virtual uint32_t ClassBytesSize() { + return kMessageBundleHeaderSize + raw_bundle_size_; + }; + + static void GetMessageListFromRawData(const uint8_t *bytes, uint32_t bytes_size, + uint32_t message_list_size, + std::list &message_list); + static void ConvertMessageListToRawData( + const std::list &message_list, uint32_t raw_data_size, + uint8_t *raw_data); +}; +} // namespace streaming +} // namespace ray + +#endif // RAY_MESSAGE_BUNDLE_H diff --git a/streaming/src/message/priority_queue.h b/streaming/src/message/priority_queue.h new file mode 100644 index 0000000000000..f49fad7c34677 --- /dev/null +++ b/streaming/src/message/priority_queue.h @@ -0,0 +1,53 @@ +#ifndef RAY_PRIORITY_QUEUE_H +#define RAY_PRIORITY_QUEUE_H + +#include +#include +#include +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +template + +class PriorityQueue { + private: + std::vector merge_vec_; + C comparator_; + + public: + PriorityQueue(C &comparator) : comparator_(comparator){}; + + inline void push(T &&item) { + merge_vec_.push_back(std::forward(item)); + std::push_heap(merge_vec_.begin(), merge_vec_.end(), comparator_); + } + + inline void push(const T &item) { + merge_vec_.push_back(item); + std::push_heap(merge_vec_.begin(), merge_vec_.end(), comparator_); + } + + inline void pop() { + STREAMING_CHECK(!isEmpty()); + std::pop_heap(merge_vec_.begin(), merge_vec_.end(), comparator_); + merge_vec_.pop_back(); + } + + inline void makeHeap() { + std::make_heap(merge_vec_.begin(), merge_vec_.end(), comparator_); + } + + inline T &top() { return merge_vec_.front(); } + + inline uint32_t size() { return merge_vec_.size(); } + + inline bool isEmpty() { return merge_vec_.empty(); } + + std::vector &getRawVector() { return merge_vec_; } +}; +} // namespace streaming +} // namespace ray + +#endif // RAY_PRIORITY_QUEUE_H diff --git a/streaming/src/protobuf/streaming.proto b/streaming/src/protobuf/streaming.proto new file mode 100644 index 0000000000000..2b4a9a4cd02af --- /dev/null +++ b/streaming/src/protobuf/streaming.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; + +package ray.streaming.proto; + +option java_package = "org.ray.streaming.runtime.generated"; + +enum OperatorType { + UNKNOWN = 0; + TRANSFORM = 1; + SOURCE = 2; + SINK = 3; +} + +// all string in this message is ASCII string +message StreamingConfig { + string job_name = 1; + string task_job_id = 2; + string worker_name = 3; + string op_name = 4; + OperatorType role = 5; + uint32 ring_buffer_capacity = 6; + uint32 empty_message_interval = 7; +} diff --git a/streaming/src/protobuf/streaming_queue.proto b/streaming/src/protobuf/streaming_queue.proto new file mode 100644 index 0000000000000..d0eea2c2c9bce --- /dev/null +++ b/streaming/src/protobuf/streaming_queue.proto @@ -0,0 +1,70 @@ +syntax = "proto3"; + +package ray.streaming.queue.protobuf; + +enum StreamingQueueMessageType { + StreamingQueueDataMsgType = 0; + StreamingQueueCheckMsgType = 1; + StreamingQueueCheckRspMsgType = 2; + StreamingQueueNotificationMsgType = 3; + StreamingQueueTestInitMsgType = 4; + StreamingQueueTestCheckStatusRspMsgType = 5; +} + +enum StreamingQueueError { + OK = 0; + QUEUE_NOT_EXIST = 1; + NO_VALID_DATA_TO_PULL = 2; +} + +message StreamingQueueDataMsg { + bytes src_actor_id = 1; + bytes dst_actor_id = 2; + bytes queue_id = 3; + uint64 seq_id = 4; + uint64 length = 5; + bool raw = 6; +} + +message StreamingQueueCheckMsg { + bytes src_actor_id = 1; + bytes dst_actor_id = 2; + bytes queue_id = 3; +} + +message StreamingQueueCheckRspMsg { + bytes src_actor_id = 1; + bytes dst_actor_id = 2; + bytes queue_id = 3; + StreamingQueueError err_code = 4; +} + +message StreamingQueueNotificationMsg { + bytes src_actor_id = 1; + bytes dst_actor_id = 2; + bytes queue_id = 3; + uint64 seq_id = 4; +} + +// for test +enum StreamingQueueTestRole { + WRITER = 0; + READER = 1; +} + +message StreamingQueueTestInitMsg { + StreamingQueueTestRole role = 1; + bytes src_actor_id = 2; + bytes dst_actor_id = 3; + bytes actor_handle = 4; + repeated bytes queue_ids = 5; + repeated bytes rescale_queue_ids = 6; + string test_suite_name = 7; + string test_name = 8; + uint64 param = 9; +} + +message StreamingQueueTestCheckStatusRspMsg { + string test_name = 1; + bool status = 2; +} \ No newline at end of file diff --git a/streaming/src/queue/message.cc b/streaming/src/queue/message.cc new file mode 100644 index 0000000000000..9c3f4fa4b61f0 --- /dev/null +++ b/streaming/src/queue/message.cc @@ -0,0 +1,240 @@ +#include "message.h" + +namespace ray { +namespace streaming { +const uint32_t Message::MagicNum = 0xBABA0510; + +std::unique_ptr Message::ToBytes() { + uint8_t *bytes = nullptr; + + std::string pboutput; + ToProtobuf(&pboutput); + int64_t fbs_length = pboutput.length(); + + queue::protobuf::StreamingQueueMessageType type = Type(); + size_t total_len = + sizeof(Message::MagicNum) + sizeof(type) + sizeof(fbs_length) + fbs_length; + if (buffer_ != nullptr) { + total_len += buffer_->Size(); + } + bytes = new uint8_t[total_len]; + STREAMING_CHECK(bytes != nullptr) << "allocate bytes fail."; + + uint8_t *p_cur = bytes; + memcpy(p_cur, &Message::MagicNum, sizeof(Message::MagicNum)); + + p_cur += sizeof(Message::MagicNum); + memcpy(p_cur, &type, sizeof(type)); + + p_cur += sizeof(type); + memcpy(p_cur, &fbs_length, sizeof(fbs_length)); + + p_cur += sizeof(fbs_length); + uint8_t *fbs_bytes = (uint8_t *)pboutput.data(); + memcpy(p_cur, fbs_bytes, fbs_length); + p_cur += fbs_length; + + if (buffer_ != nullptr) { + memcpy(p_cur, buffer_->Data(), buffer_->Size()); + } + + // COPY + std::unique_ptr buffer = + std::unique_ptr(new LocalMemoryBuffer(bytes, total_len, true)); + delete bytes; + return buffer; +} + +void DataMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueDataMsg msg; + msg.set_src_actor_id(actor_id_.Binary()); + msg.set_dst_actor_id(peer_actor_id_.Binary()); + msg.set_queue_id(queue_id_.Binary()); + msg.set_seq_id(seq_id_); + msg.set_length(buffer_->Size()); + msg.set_raw(raw_); + msg.SerializeToString(output); +} + +std::shared_ptr DataMessage::FromBytes(uint8_t *bytes) { + bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType); + uint64_t *fbs_length = (uint64_t *)bytes; + bytes += sizeof(uint64_t); + + std::string inputpb(reinterpret_cast(bytes), *fbs_length); + queue::protobuf::StreamingQueueDataMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.queue_id()); + uint64_t seq_id = message.seq_id(); + uint64_t length = message.length(); + bool raw = message.raw(); + bytes += *fbs_length; + + /// Copy data and create a new buffer for streaming queue. + std::shared_ptr buffer = + std::make_shared(bytes, (size_t)length, true); + std::shared_ptr data_msg = std::make_shared( + src_actor_id, dst_actor_id, queue_id, seq_id, buffer, raw); + + return data_msg; +} + +void NotificationMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueNotificationMsg msg; + msg.set_src_actor_id(actor_id_.Binary()); + msg.set_dst_actor_id(peer_actor_id_.Binary()); + msg.set_queue_id(queue_id_.Binary()); + msg.set_seq_id(seq_id_); + msg.SerializeToString(output); +} + +std::shared_ptr NotificationMessage::FromBytes(uint8_t *bytes) { + bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType); + uint64_t *length = (uint64_t *)bytes; + bytes += sizeof(uint64_t); + + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueueNotificationMsg message; + message.ParseFromString(inputpb); + STREAMING_LOG(INFO) << "message.src_actor_id: " << message.src_actor_id(); + ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.queue_id()); + uint64_t seq_id = message.seq_id(); + + std::shared_ptr notify_msg = + std::make_shared(src_actor_id, dst_actor_id, queue_id, seq_id); + + return notify_msg; +} + +void CheckMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueCheckMsg msg; + msg.set_src_actor_id(actor_id_.Binary()); + msg.set_dst_actor_id(peer_actor_id_.Binary()); + msg.set_queue_id(queue_id_.Binary()); + msg.SerializeToString(output); +} + +std::shared_ptr CheckMessage::FromBytes(uint8_t *bytes) { + bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType); + uint64_t *length = (uint64_t *)bytes; + bytes += sizeof(uint64_t); + + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueueCheckMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.queue_id()); + + std::shared_ptr check_msg = + std::make_shared(src_actor_id, dst_actor_id, queue_id); + + return check_msg; +} + +void CheckRspMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueCheckRspMsg msg; + msg.set_src_actor_id(actor_id_.Binary()); + msg.set_dst_actor_id(peer_actor_id_.Binary()); + msg.set_queue_id(queue_id_.Binary()); + msg.set_err_code(err_code_); + msg.SerializeToString(output); +} + +std::shared_ptr CheckRspMessage::FromBytes(uint8_t *bytes) { + bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType); + uint64_t *length = (uint64_t *)bytes; + bytes += sizeof(uint64_t); + + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueueCheckRspMsg message; + message.ParseFromString(inputpb); + ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id()); + ObjectID queue_id = ObjectID::FromBinary(message.queue_id()); + queue::protobuf::StreamingQueueError err_code = message.err_code(); + + std::shared_ptr check_rsp_msg = + std::make_shared(src_actor_id, dst_actor_id, queue_id, err_code); + + return check_rsp_msg; +} + +void TestInitMessage::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueTestInitMsg msg; + msg.set_role(role_); + msg.set_src_actor_id(actor_id_.Binary()); + msg.set_dst_actor_id(peer_actor_id_.Binary()); + msg.set_actor_handle(actor_handle_serialized_); + for (auto &queue_id : queue_ids_) { + msg.add_queue_ids(queue_id.Binary()); + } + for (auto &queue_id : rescale_queue_ids_) { + msg.add_rescale_queue_ids(queue_id.Binary()); + } + msg.set_test_suite_name(test_suite_name_); + msg.set_test_name(test_name_); + msg.set_param(param_); + msg.SerializeToString(output); +} + +std::shared_ptr TestInitMessage::FromBytes(uint8_t *bytes) { + bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType); + uint64_t *length = (uint64_t *)bytes; + bytes += sizeof(uint64_t); + + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueueTestInitMsg message; + message.ParseFromString(inputpb); + queue::protobuf::StreamingQueueTestRole role = message.role(); + ActorID src_actor_id = ActorID::FromBinary(message.src_actor_id()); + ActorID dst_actor_id = ActorID::FromBinary(message.dst_actor_id()); + std::string actor_handle_serialized = message.actor_handle(); + std::vector queue_ids; + for (int i = 0; i < message.queue_ids_size(); i++) { + queue_ids.push_back(ObjectID::FromBinary(message.queue_ids(i))); + } + std::vector rescale_queue_ids; + for (int i = 0; i < message.rescale_queue_ids_size(); i++) { + rescale_queue_ids.push_back(ObjectID::FromBinary(message.rescale_queue_ids(i))); + } + std::string test_suite_name = message.test_suite_name(); + std::string test_name = message.test_name(); + uint64_t param = message.param(); + + std::shared_ptr test_init_msg = std::make_shared( + role, src_actor_id, dst_actor_id, actor_handle_serialized, queue_ids, + rescale_queue_ids, test_suite_name, test_name, param); + + return test_init_msg; +} + +void TestCheckStatusRspMsg::ToProtobuf(std::string *output) { + queue::protobuf::StreamingQueueTestCheckStatusRspMsg msg; + msg.set_test_name(test_name_); + msg.set_status(status_); + msg.SerializeToString(output); +} + +std::shared_ptr TestCheckStatusRspMsg::FromBytes(uint8_t *bytes) { + bytes += sizeof(uint32_t) + sizeof(queue::protobuf::StreamingQueueMessageType); + uint64_t *length = (uint64_t *)bytes; + bytes += sizeof(uint64_t); + + std::string inputpb(reinterpret_cast(bytes), *length); + queue::protobuf::StreamingQueueTestCheckStatusRspMsg message; + message.ParseFromString(inputpb); + std::string test_name = message.test_name(); + bool status = message.status(); + + std::shared_ptr test_check_msg = + std::make_shared(test_name, status); + + return test_check_msg; +} +} // namespace streaming +} // namespace ray \ No newline at end of file diff --git a/streaming/src/queue/message.h b/streaming/src/queue/message.h new file mode 100644 index 0000000000000..795d2393e50a6 --- /dev/null +++ b/streaming/src/queue/message.h @@ -0,0 +1,235 @@ +#ifndef _STREAMING_QUEUE_MESSAGE_H_ +#define _STREAMING_QUEUE_MESSAGE_H_ + +#include "protobuf/streaming_queue.pb.h" +#include "ray/common/buffer.h" +#include "ray/common/id.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +/// Base class of all message classes. +/// All payloads transferred through direct actor call are packed into a unified package, +/// consisting of protobuf-formatted metadata and data, including data and control +/// messages. These message classes wrap the package defined in +/// protobuf/streaming_queue.proto respectively. +class Message { + public: + /// Construct a Message instance. + /// \param[in] actor_id ActorID of message sender. + /// \param[in] peer_actor_id ActorID of message receiver. + /// \param[in] queue_id queue id to identify which queue the message is sent to. + /// \param[in] buffer an optional param, a chunk of data to send. + Message(const ActorID &actor_id, const ActorID &peer_actor_id, const ObjectID &queue_id, + std::shared_ptr buffer = nullptr) + : actor_id_(actor_id), + peer_actor_id_(peer_actor_id), + queue_id_(queue_id), + buffer_(buffer) {} + Message() {} + virtual ~Message() {} + ActorID ActorId() { return actor_id_; } + ActorID PeerActorId() { return peer_actor_id_; } + ObjectID QueueId() { return queue_id_; } + std::shared_ptr Buffer() { return buffer_; } + + /// Serialize all meta data and data to a LocalMemoryBuffer, which can be sent through + /// direct actor call. \return serialized buffer . + std::unique_ptr ToBytes(); + + /// Get message type. + /// \return message type. + virtual queue::protobuf::StreamingQueueMessageType Type() = 0; + + /// All subclasses should implement `ToProtobuf` to serialize its own protobuf data. + virtual void ToProtobuf(std::string *output) = 0; + + protected: + ActorID actor_id_; + ActorID peer_actor_id_; + ObjectID queue_id_; + std::shared_ptr buffer_; + + public: + /// A magic number to identify a valid message. + static const uint32_t MagicNum; +}; + +/// Wrap StreamingQueueDataMsg in streaming_queue.proto. +/// DataMessage encapsulates the memory buffer of QueueItem, a one-to-one relationship +/// exists between DataMessage and QueueItem. +class DataMessage : public Message { + public: + DataMessage(const ActorID &actor_id, const ActorID &peer_actor_id, ObjectID queue_id, + uint64_t seq_id, std::shared_ptr buffer, bool raw) + : Message(actor_id, peer_actor_id, queue_id, buffer), seq_id_(seq_id), raw_(raw) {} + virtual ~DataMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + uint64_t SeqId() { return seq_id_; } + bool IsRaw() { return raw_; } + queue::protobuf::StreamingQueueMessageType Type() { return type_; } + + private: + uint64_t seq_id_; + bool raw_; + + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueDataMsgType; +}; + +/// Wrap StreamingQueueNotificationMsg in streaming_queue.proto. +/// NotificationMessage, downstream queues sends to upstream queues, for the data reader +/// to inform the data writer of the consumed offset. +class NotificationMessage : public Message { + public: + NotificationMessage(const ActorID &actor_id, const ActorID &peer_actor_id, + const ObjectID &queue_id, uint64_t seq_id) + : Message(actor_id, peer_actor_id, queue_id), seq_id_(seq_id) {} + + virtual ~NotificationMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + + uint64_t SeqId() { return seq_id_; } + queue::protobuf::StreamingQueueMessageType Type() { return type_; } + + private: + uint64_t seq_id_; + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType; +}; + +/// Wrap StreamingQueueCheckMsg in streaming_queue.proto. +/// CheckMessage, upstream queues sends to downstream queues, fot the data writer to check +/// whether the corresponded downstream queue is read or not. +class CheckMessage : public Message { + public: + CheckMessage(const ActorID &actor_id, const ActorID &peer_actor_id, + const ObjectID &queue_id) + : Message(actor_id, peer_actor_id, queue_id) {} + virtual ~CheckMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + + queue::protobuf::StreamingQueueMessageType Type() { return type_; } + + private: + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckMsgType; +}; + +/// Wrap StreamingQueueCheckRspMsg in streaming_queue.proto. +/// CheckRspMessage, downstream queues sends to upstream queues, the response message to +/// CheckMessage to indicate whether downstream queue is ready or not. +class CheckRspMessage : public Message { + public: + CheckRspMessage(const ActorID &actor_id, const ActorID &peer_actor_id, + const ObjectID &queue_id, queue::protobuf::StreamingQueueError err_code) + : Message(actor_id, peer_actor_id, queue_id), err_code_(err_code) {} + virtual ~CheckRspMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + queue::protobuf::StreamingQueueMessageType Type() { return type_; } + queue::protobuf::StreamingQueueError Error() { return err_code_; } + + private: + queue::protobuf::StreamingQueueError err_code_; + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType; +}; + +/// Wrap StreamingQueueTestInitMsg in streaming_queue.proto. +/// TestInitMessage, used for test, driver sends to test workers to init test suite. +class TestInitMessage : public Message { + public: + TestInitMessage(const queue::protobuf::StreamingQueueTestRole role, + const ActorID &actor_id, const ActorID &peer_actor_id, + const std::string actor_handle_serialized, + const std::vector &queue_ids, + const std::vector &rescale_queue_ids, + std::string test_suite_name, std::string test_name, uint64_t param) + : Message(actor_id, peer_actor_id, queue_ids[0]), + actor_handle_serialized_(actor_handle_serialized), + queue_ids_(queue_ids), + rescale_queue_ids_(rescale_queue_ids), + role_(role), + test_suite_name_(test_suite_name), + test_name_(test_name), + param_(param) {} + virtual ~TestInitMessage() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + queue::protobuf::StreamingQueueMessageType Type() { return type_; } + std::string ActorHandleSerialized() { return actor_handle_serialized_; } + queue::protobuf::StreamingQueueTestRole Role() { return role_; } + std::vector QueueIds() { return queue_ids_; } + std::vector RescaleQueueIds() { return rescale_queue_ids_; } + std::string TestSuiteName() { return test_suite_name_; } + std::string TestName() { return test_name_; } + uint64_t Param() { return param_; } + + std::string ToString() { + std::ostringstream os; + os << "actor_handle_serialized: " << actor_handle_serialized_; + os << " actor_id: " << ActorId(); + os << " peer_actor_id: " << PeerActorId(); + os << " queue_ids:["; + for (auto &qid : queue_ids_) { + os << qid << ","; + } + os << "], rescale_queue_ids:["; + for (auto &qid : rescale_queue_ids_) { + os << qid << ","; + } + os << "],"; + os << " role:" << queue::protobuf::StreamingQueueTestRole_Name(role_); + os << " suite_name: " << test_suite_name_; + os << " test_name: " << test_name_; + os << " param: " << param_; + return os.str(); + } + + private: + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueTestInitMsgType; + std::string actor_handle_serialized_; + std::vector queue_ids_; + std::vector rescale_queue_ids_; + queue::protobuf::StreamingQueueTestRole role_; + std::string test_suite_name_; + std::string test_name_; + uint64_t param_; +}; + +/// Wrap StreamingQueueTestCheckStatusRspMsg in streaming_queue.proto. +/// TestCheckStatusRspMsg, used for test, driver sends to test workers to check +/// whether test has completed or failed. +class TestCheckStatusRspMsg : public Message { + public: + TestCheckStatusRspMsg(const std::string test_name, bool status) + : test_name_(test_name), status_(status) {} + virtual ~TestCheckStatusRspMsg() {} + + static std::shared_ptr FromBytes(uint8_t *bytes); + virtual void ToProtobuf(std::string *output); + queue::protobuf::StreamingQueueMessageType Type() { return type_; } + std::string TestName() { return test_name_; } + bool Status() { return status_; } + + private: + const queue::protobuf::StreamingQueueMessageType type_ = + queue::protobuf::StreamingQueueMessageType::StreamingQueueTestCheckStatusRspMsgType; + std::string test_name_; + bool status_; +}; + +} // namespace streaming +} // namespace ray +#endif diff --git a/streaming/src/queue/queue.cc b/streaming/src/queue/queue.cc new file mode 100644 index 0000000000000..d2a9814eb4231 --- /dev/null +++ b/streaming/src/queue/queue.cc @@ -0,0 +1,211 @@ +#include "queue.h" +#include +#include +#include "queue_handler.h" +#include "util/streaming_util.h" + +namespace ray { +namespace streaming { + +bool Queue::Push(QueueItem item) { + std::unique_lock lock(mutex_); + if (max_data_size_ < item.DataSize() + data_size_) return false; + + buffer_queue_.push_back(item); + data_size_ += item.DataSize(); + readable_cv_.notify_one(); + return true; +} + +QueueItem Queue::FrontProcessed() { + std::unique_lock lock(mutex_); + STREAMING_CHECK(buffer_queue_.size() != 0) << "WriterQueue Pop fail"; + + if (watershed_iter_ == buffer_queue_.begin()) { + return InvalidQueueItem(); + } + + QueueItem item = buffer_queue_.front(); + return item; +} + +QueueItem Queue::PopProcessed() { + std::unique_lock lock(mutex_); + STREAMING_CHECK(buffer_queue_.size() != 0) << "WriterQueue Pop fail"; + + if (watershed_iter_ == buffer_queue_.begin()) { + return InvalidQueueItem(); + } + + QueueItem item = buffer_queue_.front(); + buffer_queue_.pop_front(); + data_size_ -= item.DataSize(); + data_size_sent_ -= item.DataSize(); + return item; +} + +QueueItem Queue::PopPending() { + std::unique_lock lock(mutex_); + auto it = std::next(watershed_iter_); + QueueItem item = *it; + data_size_sent_ += it->DataSize(); + buffer_queue_.splice(watershed_iter_, buffer_queue_, it, std::next(it)); + return item; +} + +QueueItem Queue::PopPendingBlockTimeout(uint64_t timeout_us) { + std::unique_lock lock(mutex_); + std::chrono::system_clock::time_point point = + std::chrono::system_clock::now() + std::chrono::microseconds(timeout_us); + if (readable_cv_.wait_until(lock, point, [this] { + return std::next(watershed_iter_) != buffer_queue_.end(); + })) { + auto it = std::next(watershed_iter_); + QueueItem item = *it; + data_size_sent_ += it->DataSize(); + buffer_queue_.splice(watershed_iter_, buffer_queue_, it, std::next(it)); + return item; + + } else { + uint8_t data[1]; + return QueueItem(QUEUE_INVALID_SEQ_ID, data, 1, 0, true); + } +} + +QueueItem Queue::BackPending() { + std::unique_lock lock(mutex_); + if (std::next(watershed_iter_) == buffer_queue_.end()) { + uint8_t data[1]; + return QueueItem(QUEUE_INVALID_SEQ_ID, data, 1, 0, true); + } + return buffer_queue_.back(); +} + +bool Queue::IsPendingEmpty() { + std::unique_lock lock(mutex_); + return std::next(watershed_iter_) == buffer_queue_.end(); +} + +bool Queue::IsPendingFull(uint64_t data_size) { + std::unique_lock lock(mutex_); + return max_data_size_ < data_size + data_size_; +} + +size_t Queue::ProcessedCount() { + std::unique_lock lock(mutex_); + if (watershed_iter_ == buffer_queue_.begin()) return 0; + + auto begin = buffer_queue_.begin(); + auto end = std::prev(watershed_iter_); + + return end->SeqId() + 1 - begin->SeqId(); +} + +size_t Queue::PendingCount() { + std::unique_lock lock(mutex_); + if (std::next(watershed_iter_) == buffer_queue_.end()) return 0; + + auto begin = std::next(watershed_iter_); + auto end = std::prev(buffer_queue_.end()); + + return begin->SeqId() - end->SeqId() + 1; +} + +Status WriterQueue::Push(uint64_t seq_id, uint8_t *data, uint32_t data_size, + uint64_t timestamp, bool raw) { + if (IsPendingFull(data_size)) { + return Status::OutOfMemory("Queue Push OutOfMemory"); + } + + while (is_pulling_) { + STREAMING_LOG(INFO) << "This queue is sending pull data, wait."; + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + QueueItem item(seq_id, data, data_size, timestamp, raw); + Queue::Push(item); + return Status::OK(); +} + +void WriterQueue::Send() { + while (!IsPendingEmpty()) { + // FIXME: front -> send -> pop + QueueItem item = PopPending(); + DataMessage msg(actor_id_, peer_actor_id_, queue_id_, item.SeqId(), item.Buffer(), + item.IsRaw()); + std::unique_ptr buffer = msg.ToBytes(); + STREAMING_CHECK(transport_ != nullptr); + transport_->Send(std::move(buffer), + DownstreamQueueMessageHandler::peer_async_function_); + } +} + +Status WriterQueue::TryEvictItems() { + STREAMING_LOG(INFO) << "TryEvictItems"; + QueueItem item = FrontProcessed(); + uint64_t first_seq_id = item.SeqId(); + STREAMING_LOG(INFO) << "TryEvictItems first_seq_id: " << first_seq_id + << " min_consumed_id_: " << min_consumed_id_ + << " eviction_limit_: " << eviction_limit_; + if (min_consumed_id_ == QUEUE_INVALID_SEQ_ID || first_seq_id > min_consumed_id_) { + return Status::OutOfMemory("The queue is full and some reader doesn't consume"); + } + + if (eviction_limit_ == QUEUE_INVALID_SEQ_ID || first_seq_id > eviction_limit_) { + return Status::OutOfMemory("The queue is full and eviction limit block evict"); + } + + uint64_t evict_target_seq_id = std::min(min_consumed_id_, eviction_limit_); + + while (item.SeqId() <= evict_target_seq_id) { + PopProcessed(); + STREAMING_LOG(INFO) << "TryEvictItems directly " << item.SeqId(); + item = FrontProcessed(); + } + return Status::OK(); +} + +void WriterQueue::OnNotify(std::shared_ptr notify_msg) { + STREAMING_LOG(INFO) << "OnNotify target seq_id: " << notify_msg->SeqId(); + min_consumed_id_ = notify_msg->SeqId(); +} + +void ReaderQueue::OnConsumed(uint64_t seq_id) { + STREAMING_LOG(INFO) << "OnConsumed: " << seq_id; + QueueItem item = FrontProcessed(); + while (item.SeqId() <= seq_id) { + PopProcessed(); + item = FrontProcessed(); + } + Notify(seq_id); +} + +void ReaderQueue::Notify(uint64_t seq_id) { + std::vector task_args; + CreateNotifyTask(seq_id, task_args); + // SubmitActorTask + + NotificationMessage msg(actor_id_, peer_actor_id_, queue_id_, seq_id); + std::unique_ptr buffer = msg.ToBytes(); + + transport_->Send(std::move(buffer), UpstreamQueueMessageHandler::peer_async_function_); +} + +void ReaderQueue::CreateNotifyTask(uint64_t seq_id, std::vector &task_args) {} + +void ReaderQueue::OnData(QueueItem &item) { + if (item.SeqId() != expect_seq_id_) { + STREAMING_LOG(WARNING) << "OnData ignore seq_id: " << item.SeqId() + << " expect_seq_id_: " << expect_seq_id_; + return; + } + + last_recv_seq_id_ = item.SeqId(); + STREAMING_LOG(DEBUG) << "ReaderQueue::OnData seq_id: " << last_recv_seq_id_; + + Push(item); + expect_seq_id_++; +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/queue/queue.h b/streaming/src/queue/queue.h new file mode 100644 index 0000000000000..e6b09771146e7 --- /dev/null +++ b/streaming/src/queue/queue.h @@ -0,0 +1,213 @@ +#ifndef _STREAMING_QUEUE_H_ +#define _STREAMING_QUEUE_H_ +#include +#include +#include + +#include "ray/common/id.h" +#include "ray/util/util.h" + +#include "queue_item.h" +#include "transport.h" +#include "util/streaming_logging.h" +#include "utils.h" + +namespace ray { +namespace streaming { + +using ray::ObjectID; + +enum QueueType { UPSTREAM = 0, DOWNSTREAM }; + +/// A queue-like data structure, which does not delete its items after poped. +/// The lifecycle of each item is: +/// - Pending, an item is pushed into a queue, but has not been processed (sent out or +/// consumed), +/// - Processed, has been handled by the user, but should not be deleted. +/// - Evicted, useless to the user, should be poped and destroyed. +/// At present, this data structure is implemented with one std::list, +/// using a watershed iterator to divided. +class Queue { + public: + /// \param[in] queue_id the unique identification of a pair of queues (upstream and + /// downstream). \param[in] size max size of the queue in bytes. \param[in] transport + /// transport to send items to peer. + Queue(ObjectID queue_id, uint64_t size, std::shared_ptr transport) + : queue_id_(queue_id), max_data_size_(size), data_size_(0), data_size_sent_(0) { + buffer_queue_.push_back(InvalidQueueItem()); + watershed_iter_ = buffer_queue_.begin(); + } + + virtual ~Queue() {} + + /// Push an item into the queue. + /// \param[in] item the QueueItem object to be send to peer. + /// \return false if the queue is full. + bool Push(QueueItem item); + + /// Get the front of item which in processed state. + QueueItem FrontProcessed(); + + /// Pop the front of item which in processed state. + QueueItem PopProcessed(); + + /// Pop the front of item which in pending state, the item + /// will not be evicted at this moment, its state turn to + /// processed. + QueueItem PopPending(); + + /// PopPending with timeout in microseconds. + QueueItem PopPendingBlockTimeout(uint64_t timeout_us); + + /// Return the last item in pending state. + QueueItem BackPending(); + + bool IsPendingEmpty(); + bool IsPendingFull(uint64_t data_size = 0); + + /// Return the size in bytes of all items in queue. + uint64_t QueueSize() { return data_size_; } + + /// Return the size in bytes of all items in pending state. + uint64_t PendingDataSize() { return data_size_ - data_size_sent_; } + + /// Return the size in bytes of all items in processed state. + uint64_t ProcessedDataSize() { return data_size_sent_; } + + /// Return item count of the queue. + size_t Count() { return buffer_queue_.size(); } + + /// Return item count in pending state. + size_t PendingCount(); + + /// Return item count in processed state. + size_t ProcessedCount(); + + protected: + ObjectID queue_id_; + std::list buffer_queue_; + std::list::iterator watershed_iter_; + + /// max data size in bytes + uint64_t max_data_size_; + uint64_t data_size_; + uint64_t data_size_sent_; + + std::mutex mutex_; + std::condition_variable readable_cv_; +}; + +/// Queue in upstream. +class WriterQueue : public Queue { + public: + /// \param queue_id, the unique ObjectID to identify a queue + /// \param actor_id, the actor id of upstream worker + /// \param peer_actor_id, the actor id of downstream worker + /// \param size, max data size in bytes + /// \param transport, transport + WriterQueue(const ObjectID &queue_id, const ActorID &actor_id, + const ActorID &peer_actor_id, uint64_t size, + std::shared_ptr transport) + : Queue(queue_id, size, transport), + actor_id_(actor_id), + peer_actor_id_(peer_actor_id), + eviction_limit_(QUEUE_INVALID_SEQ_ID), + min_consumed_id_(QUEUE_INVALID_SEQ_ID), + peer_last_msg_id_(0), + peer_last_seq_id_(QUEUE_INVALID_SEQ_ID), + transport_(transport), + is_pulling_(false) {} + + /// Push a continuous buffer into queue. + /// NOTE: the buffer should be copied. + Status Push(uint64_t seq_id, uint8_t *data, uint32_t data_size, uint64_t timestamp, + bool raw = false); + + /// Callback function, will be called when downstream queue notifies + /// it has consumed some items. + /// NOTE: this callback function is called in queue thread. + void OnNotify(std::shared_ptr notify_msg); + + /// Send items through direct call. + void Send(); + + /// Called when user pushs item into queue. The count of items + /// can be evicted, determined by eviction_limit_ and min_consumed_id_. + Status TryEvictItems(); + + void SetQueueEvictionLimit(uint64_t eviction_limit) { + eviction_limit_ = eviction_limit; + } + + uint64_t EvictionLimit() { return eviction_limit_; } + + uint64_t GetMinConsumedSeqID() { return min_consumed_id_; } + + void SetPeerLastIds(uint64_t msg_id, uint64_t seq_id) { + peer_last_msg_id_ = msg_id; + peer_last_seq_id_ = seq_id; + } + + uint64_t GetPeerLastMsgId() { return peer_last_msg_id_; } + + uint64_t GetPeerLastSeqId() { return peer_last_seq_id_; } + + private: + ActorID actor_id_; + ActorID peer_actor_id_; + uint64_t eviction_limit_; + uint64_t min_consumed_id_; + uint64_t peer_last_msg_id_; + uint64_t peer_last_seq_id_; + std::shared_ptr transport_; + + std::atomic is_pulling_; +}; + +/// Queue in downstream. +class ReaderQueue : public Queue { + public: + /// \param queue_id, the unique ObjectID to identify a queue + /// \param actor_id, the actor id of upstream worker + /// \param peer_actor_id, the actor id of downstream worker + /// \param transport, transport + /// NOTE: we do not restrict queue size of ReaderQueue + ReaderQueue(const ObjectID &queue_id, const ActorID &actor_id, + const ActorID &peer_actor_id, std::shared_ptr transport) + : Queue(queue_id, std::numeric_limits::max(), transport), + actor_id_(actor_id), + peer_actor_id_(peer_actor_id), + min_consumed_id_(QUEUE_INVALID_SEQ_ID), + last_recv_seq_id_(QUEUE_INVALID_SEQ_ID), + expect_seq_id_(1), + transport_(transport) {} + + /// Delete processed items whose seq id <= seq_id, + /// then notify upstream queue. + void OnConsumed(uint64_t seq_id); + + void OnData(QueueItem &item); + + uint64_t GetMinConsumedSeqID() { return min_consumed_id_; } + + uint64_t GetLastRecvSeqId() { return last_recv_seq_id_; } + + void SetExpectSeqId(uint64_t expect) { expect_seq_id_ = expect; } + + private: + void Notify(uint64_t seq_id); + void CreateNotifyTask(uint64_t seq_id, std::vector &task_args); + + private: + ActorID actor_id_; + ActorID peer_actor_id_; + uint64_t min_consumed_id_; + uint64_t last_recv_seq_id_; + uint64_t expect_seq_id_; + std::shared_ptr promise_for_pull_; + std::shared_ptr transport_; +}; + +} // namespace streaming +} // namespace ray +#endif diff --git a/streaming/src/queue/queue_client.cc b/streaming/src/queue/queue_client.cc new file mode 100644 index 0000000000000..b0518066a0882 --- /dev/null +++ b/streaming/src/queue/queue_client.cc @@ -0,0 +1,25 @@ +#include "queue_client.h" + +namespace ray { +namespace streaming { + +void WriterClient::OnWriterMessage(std::shared_ptr buffer) { + upstream_handler_->DispatchMessageAsync(buffer); +} + +std::shared_ptr WriterClient::OnWriterMessageSync( + std::shared_ptr buffer) { + return upstream_handler_->DispatchMessageSync(buffer); +} + +void ReaderClient::OnReaderMessage(std::shared_ptr buffer) { + downstream_handler_->DispatchMessageAsync(buffer); +} + +std::shared_ptr ReaderClient::OnReaderMessageSync( + std::shared_ptr buffer) { + return downstream_handler_->DispatchMessageSync(buffer); +} + +} // namespace streaming +} // namespace ray \ No newline at end of file diff --git a/streaming/src/queue/queue_client.h b/streaming/src/queue/queue_client.h new file mode 100644 index 0000000000000..a7d5171ca5c20 --- /dev/null +++ b/streaming/src/queue/queue_client.h @@ -0,0 +1,62 @@ +#ifndef _STREAMING_QUEUE_CLIENT_H_ +#define _STREAMING_QUEUE_CLIENT_H_ +#include "queue_handler.h" +#include "transport.h" + +namespace ray { +namespace streaming { + +/// The interface of the streaming queue for DataReader. +/// A ReaderClient should be created before DataReader created in Cython/Jni, and hold by +/// Jobworker. When DataReader receive a buffer from upstream DataWriter (DataReader's +/// raycall function is called), it calls `OnReaderMessage` to pass the buffer to its own +/// downstream queue, or `OnReaderMessageSync` to wait for handle result. +class ReaderClient { + public: + /// Construct a ReaderClient object. + /// \param[in] core_worker CoreWorker C++ pointer of current actor + /// \param[in] async_func DataReader's raycall function descriptor to be called by + /// DataWriter, asynchronous semantics \param[in] sync_func DataReader's raycall + /// function descriptor to be called by DataWriter, synchronous semantics + ReaderClient(CoreWorker *core_worker, RayFunction &async_func, RayFunction &sync_func) + : core_worker_(core_worker) { + DownstreamQueueMessageHandler::peer_async_function_ = async_func; + DownstreamQueueMessageHandler::peer_sync_function_ = sync_func; + downstream_handler_ = ray::streaming::DownstreamQueueMessageHandler::CreateService( + core_worker_, core_worker_->GetWorkerContext().GetCurrentActorID()); + } + + /// Post buffer to downstream queue service, asynchronously. + void OnReaderMessage(std::shared_ptr buffer); + /// Post buffer to downstream queue service, synchronously. + /// \return handle result. + std::shared_ptr OnReaderMessageSync( + std::shared_ptr buffer); + + private: + CoreWorker *core_worker_; + std::shared_ptr downstream_handler_; +}; + +/// Interface of streaming queue for DataWriter. Similar to ReaderClient. +class WriterClient { + public: + WriterClient(CoreWorker *core_worker, RayFunction &async_func, RayFunction &sync_func) + : core_worker_(core_worker) { + UpstreamQueueMessageHandler::peer_async_function_ = async_func; + UpstreamQueueMessageHandler::peer_sync_function_ = sync_func; + upstream_handler_ = ray::streaming::UpstreamQueueMessageHandler::CreateService( + core_worker, core_worker_->GetWorkerContext().GetCurrentActorID()); + } + + void OnWriterMessage(std::shared_ptr buffer); + std::shared_ptr OnWriterMessageSync( + std::shared_ptr buffer); + + private: + CoreWorker *core_worker_; + std::shared_ptr upstream_handler_; +}; +} // namespace streaming +} // namespace ray +#endif \ No newline at end of file diff --git a/streaming/src/queue/queue_handler.cc b/streaming/src/queue/queue_handler.cc new file mode 100644 index 0000000000000..d3e6cdbb7baa9 --- /dev/null +++ b/streaming/src/queue/queue_handler.cc @@ -0,0 +1,358 @@ +#include "queue_handler.h" +#include "util/streaming_util.h" +#include "utils.h" + +namespace ray { +namespace streaming { + +constexpr uint64_t COMMON_SYNC_CALL_TIMEOUTT_MS = 5 * 1000; + +std::shared_ptr + UpstreamQueueMessageHandler::upstream_handler_ = nullptr; +std::shared_ptr + DownstreamQueueMessageHandler::downstream_handler_ = nullptr; + +RayFunction UpstreamQueueMessageHandler::peer_sync_function_; +RayFunction UpstreamQueueMessageHandler::peer_async_function_; +RayFunction DownstreamQueueMessageHandler::peer_sync_function_; +RayFunction DownstreamQueueMessageHandler::peer_async_function_; + +std::shared_ptr QueueMessageHandler::ParseMessage( + std::shared_ptr buffer) { + uint8_t *bytes = buffer->Data(); + uint8_t *p_cur = bytes; + uint32_t *magic_num = (uint32_t *)p_cur; + STREAMING_CHECK(*magic_num == Message::MagicNum) + << *magic_num << " " << Message::MagicNum; + + p_cur += sizeof(Message::MagicNum); + queue::protobuf::StreamingQueueMessageType *type = + (queue::protobuf::StreamingQueueMessageType *)p_cur; + + std::shared_ptr message = nullptr; + switch (*type) { + case queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType: + message = NotificationMessage::FromBytes(bytes); + break; + case queue::protobuf::StreamingQueueMessageType::StreamingQueueDataMsgType: + message = DataMessage::FromBytes(bytes); + break; + case queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckMsgType: + message = CheckMessage::FromBytes(bytes); + break; + case queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType: + message = CheckRspMessage::FromBytes(bytes); + break; + default: + STREAMING_CHECK(false) << "nonsupport message type: " + << queue::protobuf::StreamingQueueMessageType_Name(*type); + break; + } + + return message; +} + +void QueueMessageHandler::DispatchMessageAsync( + std::shared_ptr buffer) { + queue_service_.post( + boost::bind(&QueueMessageHandler::DispatchMessageInternal, this, buffer, nullptr)); +} + +std::shared_ptr QueueMessageHandler::DispatchMessageSync( + std::shared_ptr buffer) { + std::shared_ptr result = nullptr; + std::shared_ptr promise = std::make_shared(); + queue_service_.post( + boost::bind(&QueueMessageHandler::DispatchMessageInternal, this, buffer, + [&promise, &result](std::shared_ptr rst) { + result = rst; + promise->Notify(ray::Status::OK()); + })); + Status st = promise->Wait(); + STREAMING_CHECK(st.ok()); + + return result; +} + +std::shared_ptr QueueMessageHandler::GetOutTransport( + const ObjectID &queue_id) { + auto it = out_transports_.find(queue_id); + if (it == out_transports_.end()) return nullptr; + + return it->second; +} + +void QueueMessageHandler::SetPeerActorID(const ObjectID &queue_id, + const ActorID &actor_id) { + actors_.emplace(queue_id, actor_id); + out_transports_.emplace( + queue_id, std::make_shared(core_worker_, actor_id)); +} + +ActorID QueueMessageHandler::GetPeerActorID(const ObjectID &queue_id) { + auto it = actors_.find(queue_id); + STREAMING_CHECK(it != actors_.end()); + return it->second; +} + +void QueueMessageHandler::Release() { + actors_.clear(); + out_transports_.clear(); +} + +void QueueMessageHandler::Start() { + queue_thread_ = std::thread(&QueueMessageHandler::QueueThreadCallback, this); +} + +void QueueMessageHandler::Stop() { + STREAMING_LOG(INFO) << "QueueMessageHandler Stop."; + queue_service_.stop(); + if (queue_thread_.joinable()) { + queue_thread_.join(); + } +} + +std::shared_ptr UpstreamQueueMessageHandler::CreateService( + CoreWorker *core_worker, const ActorID &actor_id) { + if (nullptr == upstream_handler_) { + upstream_handler_ = + std::make_shared(core_worker, actor_id); + } + return upstream_handler_; +} + +std::shared_ptr UpstreamQueueMessageHandler::GetService() { + return upstream_handler_; +} + +std::shared_ptr UpstreamQueueMessageHandler::CreateUpstreamQueue( + const ObjectID &queue_id, const ActorID &peer_actor_id, uint64_t size) { + STREAMING_LOG(INFO) << "CreateUpstreamQueue: " << queue_id << " " << actor_id_ << "->" + << peer_actor_id; + std::shared_ptr queue = GetUpQueue(queue_id); + if (queue != nullptr) { + STREAMING_LOG(WARNING) << "Duplicate to create up queue." << queue_id; + return queue; + } + + queue = std::unique_ptr(new streaming::WriterQueue( + queue_id, actor_id_, peer_actor_id, size, GetOutTransport(queue_id))); + upstream_queues_[queue_id] = queue; + + return queue; +} + +bool UpstreamQueueMessageHandler::UpstreamQueueExists(const ObjectID &queue_id) { + return nullptr != GetUpQueue(queue_id); +} + +std::shared_ptr UpstreamQueueMessageHandler::GetUpQueue( + const ObjectID &queue_id) { + auto it = upstream_queues_.find(queue_id); + if (it == upstream_queues_.end()) return nullptr; + + return it->second; +} + +bool UpstreamQueueMessageHandler::CheckQueueSync(const ObjectID &queue_id) { + ActorID peer_actor_id = GetPeerActorID(queue_id); + STREAMING_LOG(INFO) << "CheckQueueSync queue_id: " << queue_id + << " peer_actor_id: " << peer_actor_id; + + CheckMessage msg(actor_id_, peer_actor_id, queue_id); + std::unique_ptr buffer = msg.ToBytes(); + + auto transport_it = GetOutTransport(queue_id); + STREAMING_CHECK(transport_it != nullptr); + std::shared_ptr result_buffer = transport_it->SendForResultWithRetry( + std::move(buffer), DownstreamQueueMessageHandler::peer_sync_function_, 10, + COMMON_SYNC_CALL_TIMEOUTT_MS); + if (result_buffer == nullptr) { + return false; + } + + std::shared_ptr result_msg = ParseMessage(result_buffer); + STREAMING_CHECK( + result_msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType); + std::shared_ptr check_rsp_msg = + std::dynamic_pointer_cast(result_msg); + STREAMING_LOG(INFO) << "CheckQueueSync return queue_id: " << check_rsp_msg->QueueId(); + STREAMING_CHECK(check_rsp_msg->PeerActorId() == actor_id_); + + return queue::protobuf::StreamingQueueError::OK == check_rsp_msg->Error(); +} + +void UpstreamQueueMessageHandler::WaitQueues(const std::vector &queue_ids, + int64_t timeout_ms, + std::vector &failed_queues) { + failed_queues.insert(failed_queues.begin(), queue_ids.begin(), queue_ids.end()); + uint64_t start_time_us = current_time_ms(); + uint64_t current_time_us = start_time_us; + while (!failed_queues.empty() && current_time_us < start_time_us + timeout_ms * 1000) { + for (auto it = failed_queues.begin(); it != failed_queues.end();) { + if (CheckQueueSync(*it)) { + STREAMING_LOG(INFO) << "Check queue: " << *it << " return, ready."; + it = failed_queues.erase(it); + } else { + STREAMING_LOG(INFO) << "Check queue: " << *it << " return, not ready."; + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + it++; + } + } + current_time_us = current_time_ms(); + } +} + +void UpstreamQueueMessageHandler::DispatchMessageInternal( + std::shared_ptr buffer, + std::function)> callback) { + std::shared_ptr msg = ParseMessage(buffer); + STREAMING_LOG(DEBUG) << "QueueMessageHandler::DispatchMessageInternal: " + << " qid: " << msg->QueueId() << " actorid " << msg->ActorId() + << " peer actorid: " << msg->PeerActorId() << " type: " + << queue::protobuf::StreamingQueueMessageType_Name(msg->Type()); + + if (msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueueNotificationMsgType) { + OnNotify(std::dynamic_pointer_cast(msg)); + } else if (msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckRspMsgType) { + STREAMING_CHECK(false) << "Should not receive StreamingQueueCheckRspMsg"; + } else { + STREAMING_CHECK(false) << "message type should be added: " + << queue::protobuf::StreamingQueueMessageType_Name( + msg->Type()); + } +} + +void UpstreamQueueMessageHandler::OnNotify( + std::shared_ptr notify_msg) { + auto queue = GetUpQueue(notify_msg->QueueId()); + if (queue == nullptr) { + STREAMING_LOG(WARNING) << "Can not find queue for " + << queue::protobuf::StreamingQueueMessageType_Name( + notify_msg->Type()) + << ", maybe queue has been destroyed, ignore it." + << " seq id: " << notify_msg->SeqId(); + return; + } + queue->OnNotify(notify_msg); +} + +void UpstreamQueueMessageHandler::ReleaseAllUpQueues() { + STREAMING_LOG(INFO) << "ReleaseAllUpQueues"; + upstream_queues_.clear(); + Release(); +} + +std::shared_ptr +DownstreamQueueMessageHandler::CreateService(CoreWorker *core_worker, + const ActorID &actor_id) { + if (nullptr == downstream_handler_) { + downstream_handler_ = + std::make_shared(core_worker, actor_id); + } + return downstream_handler_; +} + +std::shared_ptr +DownstreamQueueMessageHandler::GetService() { + return downstream_handler_; +} + +bool DownstreamQueueMessageHandler::DownstreamQueueExists(const ObjectID &queue_id) { + return nullptr != GetDownQueue(queue_id); +} + +std::shared_ptr DownstreamQueueMessageHandler::CreateDownstreamQueue( + const ObjectID &queue_id, const ActorID &peer_actor_id) { + STREAMING_LOG(INFO) << "CreateDownstreamQueue: " << queue_id << " " << peer_actor_id + << "->" << actor_id_; + auto it = downstream_queues_.find(queue_id); + if (it != downstream_queues_.end()) { + STREAMING_LOG(WARNING) << "Duplicate to create down queue!!!! " << queue_id; + return it->second; + } + + std::shared_ptr queue = + std::unique_ptr(new streaming::ReaderQueue( + queue_id, actor_id_, peer_actor_id, GetOutTransport(queue_id))); + downstream_queues_[queue_id] = queue; + return queue; +} + +std::shared_ptr DownstreamQueueMessageHandler::GetDownQueue( + const ObjectID &queue_id) { + auto it = downstream_queues_.find(queue_id); + if (it == downstream_queues_.end()) return nullptr; + + return it->second; +} + +std::shared_ptr DownstreamQueueMessageHandler::OnCheckQueue( + std::shared_ptr check_msg) { + queue::protobuf::StreamingQueueError err_code = + queue::protobuf::StreamingQueueError::OK; + + auto down_queue = downstream_queues_.find(check_msg->QueueId()); + if (down_queue == downstream_queues_.end()) { + STREAMING_LOG(WARNING) << "OnCheckQueue " << check_msg->QueueId() << " not found."; + err_code = queue::protobuf::StreamingQueueError::QUEUE_NOT_EXIST; + } + + CheckRspMessage msg(check_msg->PeerActorId(), check_msg->ActorId(), + check_msg->QueueId(), err_code); + std::shared_ptr buffer = msg.ToBytes(); + + return buffer; +} + +void DownstreamQueueMessageHandler::ReleaseAllDownQueues() { + STREAMING_LOG(INFO) << "ReleaseAllDownQueues size: " << downstream_queues_.size(); + downstream_queues_.clear(); + Release(); +} + +void DownstreamQueueMessageHandler::DispatchMessageInternal( + std::shared_ptr buffer, + std::function)> callback) { + std::shared_ptr msg = ParseMessage(buffer); + STREAMING_LOG(DEBUG) << "QueueMessageHandler::DispatchMessageInternal: " + << " qid: " << msg->QueueId() << " actorid " << msg->ActorId() + << " peer actorid: " << msg->PeerActorId() << " type: " + << queue::protobuf::StreamingQueueMessageType_Name(msg->Type()); + + if (msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueueDataMsgType) { + OnData(std::dynamic_pointer_cast(msg)); + } else if (msg->Type() == + queue::protobuf::StreamingQueueMessageType::StreamingQueueCheckMsgType) { + std::shared_ptr check_result = + this->OnCheckQueue(std::dynamic_pointer_cast(msg)); + if (callback != nullptr) { + callback(check_result); + } + } else { + STREAMING_CHECK(false) << "message type should be added: " + << queue::protobuf::StreamingQueueMessageType_Name( + msg->Type()); + } +} + +void DownstreamQueueMessageHandler::OnData(std::shared_ptr msg) { + auto queue = GetDownQueue(msg->QueueId()); + if (queue == nullptr) { + STREAMING_LOG(WARNING) << "Can not find queue for " + << queue::protobuf::StreamingQueueMessageType_Name(msg->Type()) + << ", maybe queue has been destroyed, ignore it." + << " seq id: " << msg->SeqId(); + return; + } + + QueueItem item(msg); + queue->OnData(item); +} + +} // namespace streaming +} // namespace ray \ No newline at end of file diff --git a/streaming/src/queue/queue_handler.h b/streaming/src/queue/queue_handler.h new file mode 100644 index 0000000000000..0563b564b4b0d --- /dev/null +++ b/streaming/src/queue/queue_handler.h @@ -0,0 +1,194 @@ +#ifndef _QUEUE_SERVICE_H_ +#define _QUEUE_SERVICE_H_ + +#include +#include +#include +#include + +#include "queue.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +/// Base class of UpstreamQueueMessageHandler and DownstreamQueueMessageHandler. +/// A queue service manages a group of queues, upstream queues or downstream queues of +/// the current actor. Each queue service holds a boost.asio io_service, to handle +/// messages asynchronously. When a message received by Writer/Reader in ray call thread, +/// the message was delivered to +/// UpstreamQueueMessageHandler/DownstreamQueueMessageHandler, then the ray call thread +/// returns immediately. The queue service parses meta infomation from the message, +/// including queue_id actor_id, etc, and dispatchs message to queue according to +/// queue_id. +class QueueMessageHandler { + public: + /// Construct a QueueMessageHandler instance. + /// \param[in] core_worker CoreWorker C++ pointer of current actor, used to call Core + /// Worker's api. + /// For Python worker, the pointer can be obtained from + /// ray.worker.global_worker.core_worker; For Java worker, obtained from + /// RayNativeRuntime object through java reflection. + /// \param[in] actor_id actor id of current actor. + QueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id) + : core_worker_(core_worker), + actor_id_(actor_id), + queue_dummy_work_(queue_service_) { + Start(); + } + + virtual ~QueueMessageHandler() { Stop(); } + + /// Dispatch message buffer to asio service. + /// \param[in] buffer serialized message received from peer actor. + void DispatchMessageAsync(std::shared_ptr buffer); + + /// Dispatch message buffer to asio service synchronously, and wait for handle result. + /// \param[in] buffer serialized message received from peer actor. + /// \return handle result. + std::shared_ptr DispatchMessageSync( + std::shared_ptr buffer); + + /// Get transport to a peer actor specified by actor_id. + /// \param[in] actor_id actor id of peer actor + /// \return transport + std::shared_ptr GetOutTransport(const ObjectID &actor_id); + + /// The actual function where message being dispatched, called by DispatchMessageAsync + /// and DispatchMessageSync. + /// \param[in] buffer serialized message received from peer actor. + /// \param[in] callback the callback function used by DispatchMessageSync, called + /// after message processed complete. The std::shared_ptr + /// parameter is the return value. + virtual void DispatchMessageInternal( + std::shared_ptr buffer, + std::function)> callback) = 0; + + /// Save actor_id of the peer actor specified by queue_id. For a upstream queue, the + /// peer actor refer specifically to the actor in current ray cluster who has a + /// downstream queue with same queue_id, and vice versa. + /// \param[in] queue_id queue id of current queue. + /// \param[in] actor_id actor_id actor id of corresponded peer actor. + void SetPeerActorID(const ObjectID &queue_id, const ActorID &actor_id); + + /// Obtain the actor id of the peer actor specified by queue_id. + /// \return actor id + ActorID GetPeerActorID(const ObjectID &queue_id); + + /// Release all queues in current queue service. + void Release(); + + private: + /// Start asio service + void Start(); + /// Stop asio service + void Stop(); + /// The callback function of internal thread. + void QueueThreadCallback() { queue_service_.run(); } + + protected: + /// CoreWorker C++ pointer of current actor + CoreWorker *core_worker_; + /// actor_id actor id of current actor + ActorID actor_id_; + /// Helper function, parse message buffer to Message object. + std::shared_ptr ParseMessage(std::shared_ptr buffer); + + private: + /// Map from queue id to a actor id of the queue's peer actor. + std::unordered_map actors_; + /// Map from queue id to a transport of the queue's peer actor. + std::unordered_map> out_transports_; + /// The internal thread which asio service run with. + std::thread queue_thread_; + /// The internal asio service. + boost::asio::io_service queue_service_; + /// The asio work which keeps queue_service_ alive. + boost::asio::io_service::work queue_dummy_work_; +}; + +/// UpstreamQueueMessageHandler holds and manages all upstream queues of current actor. +class UpstreamQueueMessageHandler : public QueueMessageHandler { + public: + /// Construct a UpstreamQueueMessageHandler instance. + UpstreamQueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id) + : QueueMessageHandler(core_worker, actor_id) {} + /// Create a upstream queue. + /// \param[in] queue_id queue id of the queue to be created. + /// \param[in] peer_actor_id actor id of peer actor. + /// \param[in] size the max memory size of the queue. + std::shared_ptr CreateUpstreamQueue(const ObjectID &queue_id, + const ActorID &peer_actor_id, + uint64_t size); + /// Check whether the upstream queue specified by queue_id exists or not. + bool UpstreamQueueExists(const ObjectID &queue_id); + /// Wait all queues in queue_ids vector ready, until timeout. + /// \param[in] queue_ids a group of queues. + /// \param[in] timeout_ms max timeout time interval for wait all queues. + /// \param[out] failed_queues a group of queues which are not ready when timeout. + void WaitQueues(const std::vector &queue_ids, int64_t timeout_ms, + std::vector &failed_queues); + /// Handle notify message from corresponded downstream queue. + void OnNotify(std::shared_ptr notify_msg); + /// Obtain upstream queue specified by queue_id. + std::shared_ptr GetUpQueue(const ObjectID &queue_id); + /// Release all upstream queues + void ReleaseAllUpQueues(); + + virtual void DispatchMessageInternal( + std::shared_ptr buffer, + std::function)> callback) override; + + static std::shared_ptr CreateService( + CoreWorker *core_worker, const ActorID &actor_id); + static std::shared_ptr GetService(); + + static RayFunction peer_sync_function_; + static RayFunction peer_async_function_; + + private: + bool CheckQueueSync(const ObjectID &queue_ids); + + private: + std::unordered_map> upstream_queues_; + static std::shared_ptr upstream_handler_; +}; + +/// UpstreamQueueMessageHandler holds and manages all downstream queues of current actor. +class DownstreamQueueMessageHandler : public QueueMessageHandler { + public: + DownstreamQueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id) + : QueueMessageHandler(core_worker, actor_id) {} + std::shared_ptr CreateDownstreamQueue(const ObjectID &queue_id, + const ActorID &peer_actor_id); + bool DownstreamQueueExists(const ObjectID &queue_id); + + void UpdateDownActor(const ObjectID &queue_id, const ActorID &actor_id); + + std::shared_ptr OnCheckQueue( + std::shared_ptr check_msg); + + std::shared_ptr GetDownQueue(const ObjectID &queue_id); + + void ReleaseAllDownQueues(); + + void OnData(std::shared_ptr msg); + virtual void DispatchMessageInternal( + std::shared_ptr buffer, + std::function)> callback); + + static std::shared_ptr CreateService( + CoreWorker *core_worker, const ActorID &actor_id); + static std::shared_ptr GetService(); + static RayFunction peer_sync_function_; + static RayFunction peer_async_function_; + + private: + std::unordered_map> + downstream_queues_; + static std::shared_ptr downstream_handler_; +}; + +} // namespace streaming +} // namespace ray +#endif diff --git a/streaming/src/queue/queue_item.h b/streaming/src/queue/queue_item.h new file mode 100644 index 0000000000000..cfff3c7e57ac0 --- /dev/null +++ b/streaming/src/queue/queue_item.h @@ -0,0 +1,109 @@ +#ifndef _STREAMING_QUEUE_ITEM_H_ +#define _STREAMING_QUEUE_ITEM_H_ +#include +#include +#include +#include + +#include "ray/common/id.h" + +#include "message.h" +#include "message/message_bundle.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +using ray::ObjectID; +const uint64_t QUEUE_INVALID_SEQ_ID = std::numeric_limits::max(); + +/// QueueItem is the element stored in `Queue`. Actually, when DataWriter pushes a message +/// bundle into a queue, the bundle is packed into one QueueItem, so a one-to-one +/// relationship exists between message bundle and QueueItem. Meanwhile, the QueueItem is +/// also the minimum unit to send through direct actor call. Each QueueItem holds a +/// LocalMemoryBuffer shared_ptr, which will be sent out by Transport. +class QueueItem { + public: + /// Construct a QueueItem object. + /// \param[in] seq_id the sequential id assigned by DataWriter for a message bundle and + /// QueueItem. + /// \param[in] data the data buffer to be stored in this QueueItem. + /// \param[in] data_size the data size in bytes. + /// \param[in] timestamp the time when this QueueItem created. + /// \param[in] raw whether the data content is raw bytes, only used in some tests. + QueueItem(uint64_t seq_id, uint8_t *data, uint32_t data_size, uint64_t timestamp, + bool raw = false) + : seq_id_(seq_id), + timestamp_(timestamp), + raw_(raw), + /*COPY*/ buffer_(std::make_shared(data, data_size, true)) {} + + QueueItem(uint64_t seq_id, std::shared_ptr buffer, + uint64_t timestamp, bool raw = false) + : seq_id_(seq_id), timestamp_(timestamp), raw_(raw), buffer_(buffer) {} + + QueueItem(std::shared_ptr data_msg) + : seq_id_(data_msg->SeqId()), + raw_(data_msg->IsRaw()), + buffer_(data_msg->Buffer()) {} + + QueueItem(const QueueItem &&item) { + buffer_ = item.buffer_; + seq_id_ = item.seq_id_; + timestamp_ = item.timestamp_; + raw_ = item.raw_; + } + + QueueItem(const QueueItem &item) { + buffer_ = item.buffer_; + seq_id_ = item.seq_id_; + timestamp_ = item.timestamp_; + raw_ = item.raw_; + } + + QueueItem &operator=(const QueueItem &item) { + buffer_ = item.buffer_; + seq_id_ = item.seq_id_; + timestamp_ = item.timestamp_; + raw_ = item.raw_; + return *this; + } + + virtual ~QueueItem() = default; + + uint64_t SeqId() { return seq_id_; } + bool IsRaw() { return raw_; } + uint64_t TimeStamp() { return timestamp_; } + size_t DataSize() { return buffer_->Size(); } + std::shared_ptr Buffer() { return buffer_; } + + /// Get max message id in this item. + /// \return max message id. + uint64_t MaxMsgId() { + if (raw_) { + return 0; + } + auto message_bundle = StreamingMessageBundleMeta::FromBytes(buffer_->Data()); + return message_bundle->GetLastMessageId(); + } + + protected: + uint64_t seq_id_; + uint64_t timestamp_; + bool raw_; + + std::shared_ptr buffer_; +}; + +class InvalidQueueItem : public QueueItem { + public: + InvalidQueueItem() : QueueItem(QUEUE_INVALID_SEQ_ID, data_, 1, 0) {} + + private: + uint8_t data_[1]; +}; +typedef std::shared_ptr QueueItemPtr; + +} // namespace streaming +} // namespace ray +#endif diff --git a/streaming/src/queue/transport.cc b/streaming/src/queue/transport.cc new file mode 100644 index 0000000000000..18c1cd39595d0 --- /dev/null +++ b/streaming/src/queue/transport.cc @@ -0,0 +1,94 @@ +#include "transport.h" +#include "utils.h" + +namespace ray { +namespace streaming { + +static constexpr int TASK_OPTION_RETURN_NUM_0 = 0; +static constexpr int TASK_OPTION_RETURN_NUM_1 = 1; + +void Transport::SendInternal(std::shared_ptr buffer, + RayFunction &function, int return_num, + std::vector &return_ids) { + std::unordered_map resources; + TaskOptions options{return_num, true, resources}; + + char meta_data[3] = {'R', 'A', 'W'}; + std::shared_ptr meta = + std::make_shared((uint8_t *)meta_data, 3, true); + + std::vector args; + if (function.GetLanguage() == Language::PYTHON) { + auto dummy = "__RAY_DUMMY__"; + std::shared_ptr dummyBuffer = + std::make_shared((uint8_t *)dummy, 13, true); + args.emplace_back(TaskArg::PassByValue( + std::make_shared(std::move(dummyBuffer), meta, true))); + } + args.emplace_back( + TaskArg::PassByValue(std::make_shared(std::move(buffer), meta, true))); + + STREAMING_CHECK(core_worker_ != nullptr); + std::vector> results; + ray::Status st = + core_worker_->SubmitActorTask(peer_actor_id_, function, args, options, &return_ids); + if (!st.ok()) { + STREAMING_LOG(ERROR) << "SubmitActorTask failed. " << st; + } +} + +void Transport::Send(std::shared_ptr buffer, RayFunction &function) { + STREAMING_LOG(INFO) << "Transport::Send buffer size: " << buffer->Size(); + std::vector return_ids; + SendInternal(std::move(buffer), function, TASK_OPTION_RETURN_NUM_0, return_ids); +} + +std::shared_ptr Transport::SendForResult( + std::shared_ptr buffer, RayFunction &function, + int64_t timeout_ms) { + std::vector return_ids; + SendInternal(buffer, function, TASK_OPTION_RETURN_NUM_1, return_ids); + + std::vector> results; + Status get_st = core_worker_->Get(return_ids, timeout_ms, &results); + if (!get_st.ok()) { + STREAMING_LOG(ERROR) << "Get fail."; + return nullptr; + } + STREAMING_CHECK(results.size() >= 1); + if (results[0]->IsException()) { + STREAMING_LOG(ERROR) << "peer actor may has exceptions, should retry."; + return nullptr; + } + STREAMING_CHECK(results[0]->HasData()); + if (results[0]->GetData()->Size() == 4) { + STREAMING_LOG(WARNING) << "peer actor may not ready yet, should retry."; + return nullptr; + } + + std::shared_ptr result_buffer = results[0]->GetData(); + std::shared_ptr return_buffer = std::make_shared( + result_buffer->Data(), result_buffer->Size(), true); + return return_buffer; +} + +std::shared_ptr Transport::SendForResultWithRetry( + std::shared_ptr buffer, RayFunction &function, int retry_cnt, + int64_t timeout_ms) { + STREAMING_LOG(INFO) << "SendForResultWithRetry retry_cnt: " << retry_cnt + << " timeout_ms: " << timeout_ms + << " function: " << function.GetFunctionDescriptor()[0]; + std::shared_ptr buffer_shared = std::move(buffer); + for (int cnt = 0; cnt < retry_cnt; cnt++) { + auto result = SendForResult(buffer_shared, function, timeout_ms); + if (result != nullptr) { + return result; + } + } + + STREAMING_LOG(WARNING) << "SendForResultWithRetry fail after retry."; + return nullptr; +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/queue/transport.h b/streaming/src/queue/transport.h new file mode 100644 index 0000000000000..3f26754a48a7d --- /dev/null +++ b/streaming/src/queue/transport.h @@ -0,0 +1,63 @@ +#ifndef _STREAMING_QUEUE_TRANSPORT_H_ +#define _STREAMING_QUEUE_TRANSPORT_H_ + +#include "ray/common/id.h" +#include "ray/core_worker/core_worker.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +/// Transport is the transfer endpoint to a specific actor, buffers can be sent to peer +/// through direct actor call. +class Transport { + public: + /// Construct a Transport object. + /// \param[in] core_worker CoreWorker C++ pointer of current actor, which we call direct + /// actor call interface with. + /// \param[in] peer_actor_id actor id of peer actor. + Transport(CoreWorker *core_worker, const ActorID &peer_actor_id) + : core_worker_(core_worker), peer_actor_id_(peer_actor_id) {} + virtual ~Transport() = default; + + /// Send buffer asynchronously, peer's `function` will be called. + /// \param[in] buffer buffer to be sent. + /// \param[in] function the function descriptor of peer's function. + virtual void Send(std::shared_ptr buffer, RayFunction &function); + /// Send buffer synchronously, peer's `function` will be called, and return the peer + /// function's return value. + /// \param[in] buffer buffer to be sent. + /// \param[in] function the function descriptor of peer's function. + /// \param[in] timeout_ms max time to wait for result. + /// \return peer function's result. + virtual std::shared_ptr SendForResult( + std::shared_ptr buffer, RayFunction &function, + int64_t timeout_ms); + /// Send buffer and get result with retry. + /// return value. + /// \param[in] buffer buffer to be sent. + /// \param[in] function the function descriptor of peer's function. + /// \param[in] max retry count + /// \param[in] timeout_ms max time to wait for result. + /// \return peer function's result. + std::shared_ptr SendForResultWithRetry( + std::shared_ptr buffer, RayFunction &function, int retry_cnt, + int64_t timeout_ms); + + private: + /// Send buffer internal + /// \param[in] buffer buffer to be sent. + /// \param[in] function the function descriptor of peer's function. + /// \param[in] return_num return value number of the call. + /// \param[out] return_ids return ids from SubmitActorTask. + virtual void SendInternal(std::shared_ptr buffer, + RayFunction &function, int return_num, + std::vector &return_ids); + + private: + CoreWorker *core_worker_; + ActorID peer_actor_id_; +}; +} // namespace streaming +} // namespace ray +#endif diff --git a/streaming/src/queue/utils.h b/streaming/src/queue/utils.h new file mode 100644 index 0000000000000..38021faef43fc --- /dev/null +++ b/streaming/src/queue/utils.h @@ -0,0 +1,50 @@ +#ifndef _STREAMING_QUEUE_UTILS_H_ +#define _STREAMING_QUEUE_UTILS_H_ +#include +#include +#include +#include "ray/util/util.h" + +namespace ray { +namespace streaming { + +/// Helper class encapulate std::future to help multithread async wait. +class PromiseWrapper { + public: + Status Wait() { + std::future fut = promise_.get_future(); + fut.get(); + return status_; + } + + Status WaitFor(uint64_t timeout_ms) { + std::future fut = promise_.get_future(); + std::future_status status; + do { + status = fut.wait_for(std::chrono::milliseconds(timeout_ms)); + if (status == std::future_status::deferred) { + } else if (status == std::future_status::timeout) { + return Status::Invalid("timeout"); + } else if (status == std::future_status::ready) { + return status_; + } + } while (status == std::future_status::deferred); + + return status_; + } + + void Notify(Status status) { + status_ = status; + promise_.set_value(true); + } + + Status GetResultStatus() { return status_; } + + private: + std::promise promise_; + Status status_; +}; + +} // namespace streaming +} // namespace ray +#endif diff --git a/streaming/src/ring_buffer.cc b/streaming/src/ring_buffer.cc new file mode 100644 index 0000000000000..e2b179177d696 --- /dev/null +++ b/streaming/src/ring_buffer.cc @@ -0,0 +1,82 @@ +#include "ring_buffer.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +StreamingRingBuffer::StreamingRingBuffer(size_t buf_size, + StreamingRingBufferType buffer_type) { + switch (buffer_type) { + case StreamingRingBufferType::SPSC: + message_buffer_ = + std::make_shared>(buf_size); + break; + case StreamingRingBufferType::SPSC_LOCK: + default: + message_buffer_ = + std::make_shared>(buf_size); + } +} + +bool StreamingRingBuffer::Push(const StreamingMessagePtr &msg) { + message_buffer_->Push(msg); + return true; +} + +bool StreamingRingBuffer::Push(StreamingMessagePtr &&msg) { + message_buffer_->Push(std::forward(msg)); + return true; +} + +StreamingMessagePtr &StreamingRingBuffer::Front() { + STREAMING_CHECK(!message_buffer_->Empty()); + return message_buffer_->Front(); +} + +void StreamingRingBuffer::Pop() { + STREAMING_CHECK(!message_buffer_->Empty()); + message_buffer_->Pop(); +} + +bool StreamingRingBuffer::IsFull() { return message_buffer_->Full(); } + +bool StreamingRingBuffer::IsEmpty() { return message_buffer_->Empty(); } + +size_t StreamingRingBuffer::Size() { return message_buffer_->Size(); }; + +size_t StreamingRingBuffer::Capacity() const { return message_buffer_->Capacity(); } + +size_t StreamingRingBuffer::GetTransientBufferSize() { + return transient_buffer_.GetTransientBufferSize(); +}; + +void StreamingRingBuffer::SetTransientBufferSize(uint32_t new_transient_buffer_size) { + return transient_buffer_.SetTransientBufferSize(new_transient_buffer_size); +} + +size_t StreamingRingBuffer::GetMaxTransientBufferSize() const { + return transient_buffer_.GetMaxTransientBufferSize(); +} + +const uint8_t *StreamingRingBuffer::GetTransientBuffer() const { + return transient_buffer_.GetTransientBuffer(); +} + +uint8_t *StreamingRingBuffer::GetTransientBufferMutable() const { + return transient_buffer_.GetTransientBufferMutable(); +} + +void StreamingRingBuffer::ReallocTransientBuffer(uint32_t size) { + transient_buffer_.ReallocTransientBuffer(size); +} + +bool StreamingRingBuffer::IsTransientAvaliable() { + return transient_buffer_.IsTransientAvaliable(); +} + +void StreamingRingBuffer::FreeTransientBuffer(bool is_force) { + transient_buffer_.FreeTransientBuffer(is_force); +} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/ring_buffer.h b/streaming/src/ring_buffer.h new file mode 100644 index 0000000000000..64fdb4eef7f1c --- /dev/null +++ b/streaming/src/ring_buffer.h @@ -0,0 +1,233 @@ +#ifndef RAY_RING_BUFFER_H +#define RAY_RING_BUFFER_H + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "message/message.h" +#include "ray/common/status.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +/// Because the data cannot be successfully written to the channel every time, in +/// order not to serialize the message repeatedly, we designed a temporary buffer +/// area so that when the downstream is backpressured or the channel is blocked +/// due to memory limitations, it can be cached first and waited for the next use. +class StreamingTransientBuffer { + private: + std::shared_ptr transient_buffer_; + // BufferSize is length of last serialization data. + uint32_t transient_buffer_size_ = 0; + uint32_t max_transient_buffer_size_ = 0; + bool transient_flag_ = false; + + public: + inline size_t GetTransientBufferSize() const { return transient_buffer_size_; } + + inline void SetTransientBufferSize(uint32_t new_transient_buffer_size) { + transient_buffer_size_ = new_transient_buffer_size; + } + + inline size_t GetMaxTransientBufferSize() const { return max_transient_buffer_size_; } + + inline const uint8_t *GetTransientBuffer() const { return transient_buffer_.get(); } + + inline uint8_t *GetTransientBufferMutable() const { return transient_buffer_.get(); } + + /// To reuse transient buffer, we will realloc buffer memory if size of needed + /// message bundle raw data is greater-than original buffer size. + /// \param size buffer size + /// + inline void ReallocTransientBuffer(uint32_t size) { + transient_buffer_size_ = size; + transient_flag_ = true; + if (max_transient_buffer_size_ > size) { + return; + } + max_transient_buffer_size_ = size; + transient_buffer_.reset(new uint8_t[size], std::default_delete()); + } + + inline bool IsTransientAvaliable() { return transient_flag_; } + + inline void FreeTransientBuffer(bool is_force = false) { + transient_buffer_size_ = 0; + transient_flag_ = false; + + // Transient buffer always holds max size buffer among all messages, which is + // wasteful. So expiration time is considerable idea to release large buffer if this + // transient buffer pointer hold it in long time. + + if (is_force) { + max_transient_buffer_size_ = 0; + transient_buffer_.reset(); + } + } + + virtual ~StreamingTransientBuffer() = default; +}; + +template +class AbstractRingBufferImpl { + public: + virtual void Push(T &&) = 0; + virtual void Push(const T &) = 0; + virtual void Pop() = 0; + virtual T &Front() = 0; + virtual bool Empty() = 0; + virtual bool Full() = 0; + virtual size_t Size() = 0; + virtual size_t Capacity() = 0; +}; + +template +class RingBufferImplThreadSafe : public AbstractRingBufferImpl { + private: + boost::shared_mutex ring_buffer_mutex_; + boost::circular_buffer buffer_; + + public: + RingBufferImplThreadSafe(size_t size) : buffer_(size) {} + virtual ~RingBufferImplThreadSafe() = default; + void Push(T &&t) { + boost::unique_lock lock(ring_buffer_mutex_); + buffer_.push_back(t); + } + void Push(const T &t) { + boost::unique_lock lock(ring_buffer_mutex_); + buffer_.push_back(t); + } + void Pop() { + boost::unique_lock lock(ring_buffer_mutex_); + buffer_.pop_front(); + } + T &Front() { + boost::shared_lock lock(ring_buffer_mutex_); + return buffer_.front(); + } + bool Empty() { + boost::shared_lock lock(ring_buffer_mutex_); + return buffer_.empty(); + } + bool Full() { + boost::shared_lock lock(ring_buffer_mutex_); + return buffer_.full(); + } + size_t Size() { + boost::shared_lock lock(ring_buffer_mutex_); + return buffer_.size(); + } + size_t Capacity() { return buffer_.capacity(); } +}; + +template +class RingBufferImplLockFree : public AbstractRingBufferImpl { + private: + std::vector buffer_; + std::atomic capacity_; + std::atomic read_index_; + std::atomic write_index_; + + public: + RingBufferImplLockFree(size_t size) + : buffer_(size, nullptr), capacity_(size), read_index_(0), write_index_(0) {} + virtual ~RingBufferImplLockFree() = default; + + void Push(T &&t) { + STREAMING_CHECK(!Full()); + buffer_[write_index_] = t; + write_index_ = IncreaseIndex(write_index_); + } + + void Push(const T &t) { + STREAMING_CHECK(!Full()); + buffer_[write_index_] = t; + write_index_ = IncreaseIndex(write_index_); + } + + void Pop() { + STREAMING_CHECK(!Empty()); + read_index_ = IncreaseIndex(read_index_); + } + + T &Front() { + STREAMING_CHECK(!Empty()); + return buffer_[read_index_]; + } + + bool Empty() { return write_index_ == read_index_; } + + bool Full() { return IncreaseIndex(write_index_) == read_index_; } + + size_t Size() { return (write_index_ + capacity_ - read_index_) % capacity_; } + + size_t Capacity() { return capacity_; } + + private: + size_t IncreaseIndex(size_t index) const { return (index + 1) % capacity_; } +}; + +enum class StreamingRingBufferType : uint8_t { SPSC_LOCK, SPSC }; + +/// StreamingRinggBuffer is factory to generate two different buffers. In data +/// writer, we use lock-free single producer single consumer (SPSC) ring buffer +/// to hold messages from user thread because SPSC has much better performance +/// than lock style. Since the SPSC_LOCK is useful to our event-driver model( +/// we will use that buffer to optimize our thread model in the future), so +/// it cann't be removed currently. +class StreamingRingBuffer { + private: + std::shared_ptr> message_buffer_; + + StreamingTransientBuffer transient_buffer_; + + public: + explicit StreamingRingBuffer(size_t buf_size, StreamingRingBufferType buffer_type = + StreamingRingBufferType::SPSC_LOCK); + + bool Push(StreamingMessagePtr &&msg); + + bool Push(const StreamingMessagePtr &msg); + + StreamingMessagePtr &Front(); + + void Pop(); + + bool IsFull(); + + bool IsEmpty(); + + size_t Size(); + + size_t Capacity() const; + + size_t GetTransientBufferSize(); + + void SetTransientBufferSize(uint32_t new_transient_buffer_size); + + size_t GetMaxTransientBufferSize() const; + + const uint8_t *GetTransientBuffer() const; + + uint8_t *GetTransientBufferMutable() const; + + void ReallocTransientBuffer(uint32_t size); + + bool IsTransientAvaliable(); + + void FreeTransientBuffer(bool is_force = false); +}; + +typedef std::shared_ptr StreamingRingBufferPtr; +} // namespace streaming +} // namespace ray + +#endif // RAY_RING_BUFFER_H diff --git a/streaming/src/runtime_context.cc b/streaming/src/runtime_context.cc new file mode 100644 index 0000000000000..38f1bd0d9cfa9 --- /dev/null +++ b/streaming/src/runtime_context.cc @@ -0,0 +1,32 @@ +#include "ray/common/id.h" +#include "ray/protobuf/common.pb.h" +#include "ray/util/util.h" + +#include "runtime_context.h" +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +void RuntimeContext::SetConfig(const StreamingConfig &streaming_config) { + STREAMING_CHECK(runtime_status_ == RuntimeStatus::Init) + << "set config must be at beginning"; + config_ = streaming_config; +} + +void RuntimeContext::SetConfig(const uint8_t *data, uint32_t size) { + STREAMING_CHECK(runtime_status_ == RuntimeStatus::Init) + << "set config must be at beginning"; + if (!data) { + STREAMING_LOG(WARNING) << "buffer pointer is null, but len is => " << size; + return; + } + config_.FromProto(data, size); +} + +RuntimeContext::~RuntimeContext() {} + +RuntimeContext::RuntimeContext() : runtime_status_(RuntimeStatus::Init) {} + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/runtime_context.h b/streaming/src/runtime_context.h new file mode 100644 index 0000000000000..fa6075ba768c0 --- /dev/null +++ b/streaming/src/runtime_context.h @@ -0,0 +1,42 @@ +#ifndef RAY_STREAMING_H +#define RAY_STREAMING_H +#include + +#include "config/streaming_config.h" +#include "status.h" + +namespace ray { +namespace streaming { + +enum class RuntimeStatus : uint8_t { Init = 0, Running = 1, Interrupted = 2 }; + +#define RETURN_IF_NOT_OK(STATUS_EXP) \ + { \ + StreamingStatus state = STATUS_EXP; \ + if (StreamingStatus::OK != state) { \ + return state; \ + } \ + } + +class RuntimeContext { + public: + RuntimeContext(); + virtual ~RuntimeContext(); + inline const StreamingConfig &GetConfig() const { return config_; }; + void SetConfig(const StreamingConfig &config); + void SetConfig(const uint8_t *data, uint32_t buffer_len); + inline RuntimeStatus GetRuntimeStatus() { return runtime_status_; } + inline void SetRuntimeStatus(RuntimeStatus status) { runtime_status_ = status; } + inline void MarkMockTest() { is_mock_test_ = true; } + inline bool IsMockTest() { return is_mock_test_; } + + private: + StreamingConfig config_; + RuntimeStatus runtime_status_; + bool is_mock_test_ = false; +}; + +} // namespace streaming +} // namespace ray + +#endif // RAY_STREAMING_H diff --git a/streaming/src/status.h b/streaming/src/status.h new file mode 100644 index 0000000000000..10095d6438042 --- /dev/null +++ b/streaming/src/status.h @@ -0,0 +1,47 @@ +#ifndef RAY_STREAMING_STATUS_H +#define RAY_STREAMING_STATUS_H +#include +#include +#include + +namespace ray { +namespace streaming { + +enum class StreamingStatus : uint32_t { + OK = 0, + ReconstructTimeOut = 1, + QueueIdNotFound = 3, + ResubscribeFailed = 4, + EmptyRingBuffer = 5, + FullChannel = 6, + NoSuchItem = 7, + InitQueueFailed = 8, + GetBundleTimeOut = 9, + SkipSendEmptyMessage = 10, + Interrupted = 11, + WaitQueueTimeOut = 12, + OutOfMemory = 13, + Invalid = 14, + UnknownError = 15, + TailStatus = 999, + MIN = OK, + MAX = TailStatus +}; + +static inline std::ostream &operator<<(std::ostream &os, const StreamingStatus &status) { + os << static_cast::type>(status); + return os; +} + +#define RETURN_IF_NOT_OK(STATUS_EXP) \ + { \ + StreamingStatus state = STATUS_EXP; \ + if (StreamingStatus::OK != state) { \ + return state; \ + } \ + } + +} // namespace streaming +} // namespace ray + +#endif // RAY_STREAMING_STATUS_H diff --git a/streaming/src/test/message_serialization_tests.cc b/streaming/src/test/message_serialization_tests.cc new file mode 100644 index 0000000000000..349a44155e31e --- /dev/null +++ b/streaming/src/test/message_serialization_tests.cc @@ -0,0 +1,176 @@ +#include +#include + +#include "gtest/gtest.h" + +#include "message/message.h" +#include "message/message_bundle.h" + +using namespace ray; +using namespace ray::streaming; + +TEST(StreamingSerializationTest, streaming_message_serialization_test) { + uint8_t data[] = {9, 1, 3}; + StreamingMessagePtr message = + std::make_shared(data, 3, 7, StreamingMessageType::Message); + uint32_t message_length = message->ClassBytesSize(); + uint8_t *bytes = new uint8_t[message_length]; + message->ToBytes(bytes); + StreamingMessagePtr new_message = StreamingMessage::FromBytes(bytes); + EXPECT_EQ(std::memcmp(new_message->RawData(), data, 3), 0); + delete[] bytes; +} + +TEST(StreamingSerializationTest, streaming_message_empty_bundle_serialization_test) { + for (int i = 0; i < 10; ++i) { + StreamingMessageBundle bundle(i, i); + uint64_t bundle_size = bundle.ClassBytesSize(); + uint8_t *bundle_bytes = new uint8_t[bundle_size]; + bundle.ToBytes(bundle_bytes); + StreamingMessageBundlePtr bundle_ptr = + StreamingMessageBundle::FromBytes(bundle_bytes); + + EXPECT_EQ(bundle.ClassBytesSize(), bundle_ptr->ClassBytesSize()); + EXPECT_EQ(bundle.GetMessageListSize(), bundle_ptr->GetMessageListSize()); + EXPECT_EQ(bundle.GetBundleType(), bundle_ptr->GetBundleType()); + EXPECT_EQ(bundle.GetLastMessageId(), bundle_ptr->GetLastMessageId()); + std::list s_message_list; + bundle_ptr->GetMessageList(s_message_list); + std::list b_message_list; + bundle.GetMessageList(b_message_list); + EXPECT_EQ(b_message_list.size(), 0); + EXPECT_EQ(s_message_list.size(), 0); + + delete[] bundle_bytes; + } +} +TEST(StreamingSerializationTest, streaming_message_barrier_bundle_serialization_test) { + for (int i = 0; i < 10; ++i) { + uint8_t data[] = {1, 2, 3, 4}; + uint32_t data_size = 4; + uint32_t head_size = sizeof(uint64_t); + uint64_t checkpoint_id = 777; + std::shared_ptr ptr(new uint8_t[data_size + head_size], + std::default_delete()); + // move checkpint_id in head of barrier data + std::memcpy(ptr.get(), &checkpoint_id, head_size); + std::memcpy(ptr.get() + head_size, data, data_size); + StreamingMessagePtr message = std::make_shared( + data, head_size + data_size, i, StreamingMessageType::Barrier); + std::list message_list; + message_list.push_back(message); + // message list will be moved to bundle member + std::list message_list_cpy(message_list); + + StreamingMessageBundle bundle(message_list_cpy, i, i, + StreamingMessageBundleType::Barrier); + uint64_t bundle_size = bundle.ClassBytesSize(); + uint8_t *bundle_bytes = new uint8_t[bundle_size]; + bundle.ToBytes(bundle_bytes); + StreamingMessageBundlePtr bundle_ptr = + StreamingMessageBundle::FromBytes(bundle_bytes); + + EXPECT_TRUE(bundle.ClassBytesSize() == bundle_ptr->ClassBytesSize()); + EXPECT_TRUE(bundle.GetMessageListSize() == bundle_ptr->GetMessageListSize()); + EXPECT_TRUE(bundle.GetBundleType() == bundle_ptr->GetBundleType()); + EXPECT_TRUE(bundle.GetLastMessageId() == bundle_ptr->GetLastMessageId()); + std::list s_message_list; + bundle_ptr->GetMessageList(s_message_list); + EXPECT_TRUE(s_message_list.size() == message_list.size()); + auto m_item = message_list.back(); + auto s_item = s_message_list.back(); + EXPECT_TRUE(s_item->ClassBytesSize() == m_item->ClassBytesSize()); + EXPECT_TRUE(s_item->GetMessageType() == m_item->GetMessageType()); + EXPECT_TRUE(s_item->GetMessageSeqId() == m_item->GetMessageSeqId()); + EXPECT_TRUE(s_item->GetDataSize() == m_item->GetDataSize()); + EXPECT_TRUE( + std::memcmp(s_item->RawData(), m_item->RawData(), m_item->GetDataSize()) == 0); + EXPECT_TRUE(*(s_item.get()) == (*(m_item.get()))); + + delete[] bundle_bytes; + } +} + +TEST(StreamingSerializationTest, streaming_message_bundle_serialization_test) { + for (int k = 0; k <= 1000; k++) { + std::list message_list; + + for (int i = 0; i < 100; ++i) { + uint8_t *data = new uint8_t[i + 1]; + data[0] = i; + StreamingMessagePtr message = std::make_shared( + data, i + 1, i + 1, StreamingMessageType::Message); + message_list.push_back(message); + delete[] data; + } + StreamingMessageBundle messageBundle(message_list, 0, 1, + StreamingMessageBundleType::Bundle); + size_t message_length = messageBundle.ClassBytesSize(); + uint8_t *bytes = new uint8_t[message_length]; + messageBundle.ToBytes(bytes); + + StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(bytes); + EXPECT_EQ(bundle_ptr->ClassBytesSize(), message_length); + std::list s_message_list; + bundle_ptr->GetMessageList(s_message_list); + EXPECT_TRUE(bundle_ptr->operator==(messageBundle)); + StreamingMessageBundleMetaPtr bundle_meta_ptr = + StreamingMessageBundleMeta::FromBytes(bytes); + + EXPECT_EQ(bundle_meta_ptr->GetBundleType(), bundle_ptr->GetBundleType()); + EXPECT_EQ(bundle_meta_ptr->GetLastMessageId(), bundle_ptr->GetLastMessageId()); + EXPECT_EQ(bundle_meta_ptr->GetMessageBundleTs(), bundle_ptr->GetMessageBundleTs()); + EXPECT_EQ(bundle_meta_ptr->GetMessageListSize(), bundle_ptr->GetMessageListSize()); + delete[] bytes; + } +} + +TEST(StreamingSerializationTest, streaming_message_bundle_equal_test) { + std::list message_list; + std::list message_list_same; + std::list message_list_cpy; + for (int i = 0; i < 100; ++i) { + uint8_t *data = new uint8_t[i + 1]; + for (int j = 0; j < i + 1; ++j) { + data[j] = i; + } + StreamingMessagePtr message = std::make_shared( + data, i + 1, i + 1, StreamingMessageType::Message); + message_list.push_back(message); + message_list_cpy.push_front(message); + delete[] data; + } + for (int i = 0; i < 100; ++i) { + uint8_t *data = new uint8_t[i + 1]; + for (int j = 0; j < i + 1; ++j) { + data[j] = i; + } + StreamingMessagePtr message = std::make_shared( + data, i + 1, i + 1, StreamingMessageType::Message); + message_list_same.push_back(message); + delete[] data; + } + StreamingMessageBundle message_bundle(message_list, 0, 1, + StreamingMessageBundleType::Bundle); + StreamingMessageBundle message_bundle_same(message_list_same, 0, 1, + StreamingMessageBundleType::Bundle); + StreamingMessageBundle message_bundle_reverse(message_list_cpy, 0, 1, + StreamingMessageBundleType::Bundle); + EXPECT_TRUE(message_bundle_same == message_bundle); + EXPECT_FALSE(message_bundle_reverse == message_bundle); + size_t message_length = message_bundle.ClassBytesSize(); + uint8_t *bytes = new uint8_t[message_length]; + message_bundle.ToBytes(bytes); + + StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(bytes); + EXPECT_EQ(bundle_ptr->ClassBytesSize(), message_length); + std::list s_message_list; + bundle_ptr->GetMessageList(s_message_list); + EXPECT_TRUE(bundle_ptr->operator==(message_bundle)); + delete[] bytes; +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc new file mode 100644 index 0000000000000..0effacf835426 --- /dev/null +++ b/streaming/src/test/mock_actor.cc @@ -0,0 +1,439 @@ +#define BOOST_BIND_NO_PLACEHOLDERS +#include "ray/core_worker/context.h" +#include "ray/core_worker/core_worker.h" +#include "src/ray/util/test_util.h" + +#include "data_reader.h" +#include "data_writer.h" +#include "message/message.h" +#include "message/message_bundle.h" +#include "queue/queue_client.h" +#include "ring_buffer.h" +#include "status.h" + +#include "gtest/gtest.h" +using namespace std::placeholders; + +const uint32_t MESSAGE_BOUND_SIZE = 10000; +const uint32_t DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE = 1000; + +namespace ray { +namespace streaming { + +class StreamingQueueTestSuite { + public: + StreamingQueueTestSuite(std::shared_ptr core_worker, ActorID &peer_actor_id, + std::vector queue_ids, + std::vector rescale_queue_ids) + : core_worker_(core_worker), + peer_actor_id_(peer_actor_id), + queue_ids_(queue_ids), + rescale_queue_ids_(rescale_queue_ids) {} + + virtual void ExecuteTest(std::string test_name) { + auto it = test_func_map_.find(test_name); + STREAMING_CHECK(it != test_func_map_.end()); + current_test_ = test_name; + status_ = false; + auto func = it->second; + executor_thread_ = std::make_shared(func); + executor_thread_->detach(); + } + + virtual std::shared_ptr CheckCurTestStatus() { + TestCheckStatusRspMsg msg(current_test_, status_); + return msg.ToBytes(); + } + + virtual bool TestDone() { return status_; } + + virtual ~StreamingQueueTestSuite() {} + + protected: + std::unordered_map> test_func_map_; + std::string current_test_; + bool status_; + std::shared_ptr executor_thread_; + std::shared_ptr core_worker_; + ActorID peer_actor_id_; + std::vector queue_ids_; + std::vector rescale_queue_ids_; +}; + +class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite { + public: + StreamingQueueWriterTestSuite(std::shared_ptr core_worker, + ActorID &peer_actor_id, std::vector queue_ids, + std::vector rescale_queue_ids) + : StreamingQueueTestSuite(core_worker, peer_actor_id, queue_ids, + rescale_queue_ids) { + test_func_map_ = { + {"streaming_writer_exactly_once_test", + std::bind(&StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest, + this)}}; + } + + private: + void TestWriteMessageToBufferRing(std::shared_ptr writer_client, + std::vector &q_list) { + // const uint8_t temp_data[] = {1, 2, 4, 5}; + + uint32_t i = 1; + while (i <= MESSAGE_BOUND_SIZE) { + for (auto &q_id : q_list) { + uint64_t buffer_len = (i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE); + uint8_t *data = new uint8_t[buffer_len]; + for (uint32_t j = 0; j < buffer_len; ++j) { + data[j] = j % 128; + } + + writer_client->WriteMessageToBufferRing(q_id, data, buffer_len, + StreamingMessageType::Message); + } + ++i; + } + + // Wait a while + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + } + + void StreamingWriterStrategyTest(StreamingConfig &config) { + for (auto &queue_id : queue_ids_) { + STREAMING_LOG(INFO) << "queue_id: " << queue_id; + } + std::vector actor_ids(queue_ids_.size(), peer_actor_id_); + STREAMING_LOG(INFO) << "writer actor_ids size: " << actor_ids.size() + << " actor_id: " << peer_actor_id_; + + std::shared_ptr runtime_context(new RuntimeContext()); + runtime_context->SetConfig(config); + + std::shared_ptr streaming_writer_client(new DataWriter(runtime_context)); + uint64_t queue_size = 10 * 1000 * 1000; + std::vector channel_seq_id_vec(queue_ids_.size(), 0); + streaming_writer_client->Init(queue_ids_, actor_ids, channel_seq_id_vec, + std::vector(queue_ids_.size(), queue_size)); + STREAMING_LOG(INFO) << "streaming_writer_client Init done"; + + streaming_writer_client->Run(); + std::thread test_loop_thread( + &StreamingQueueWriterTestSuite::TestWriteMessageToBufferRing, this, + streaming_writer_client, std::ref(queue_ids_)); + // test_loop_thread.detach(); + if (test_loop_thread.joinable()) { + test_loop_thread.join(); + } + } + + void StreamingWriterExactlyOnceTest() { + StreamingConfig config; + StreamingWriterStrategyTest(config); + + STREAMING_LOG(INFO) + << "StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest"; + status_ = true; + } +}; + +class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite { + public: + StreamingQueueReaderTestSuite(std::shared_ptr core_worker, + ActorID peer_actor_id, std::vector queue_ids, + std::vector rescale_queue_ids) + : StreamingQueueTestSuite(core_worker, peer_actor_id, queue_ids, + rescale_queue_ids) { + test_func_map_ = { + {"streaming_writer_exactly_once_test", + std::bind(&StreamingQueueReaderTestSuite::StreamingWriterExactlyOnceTest, + this)}}; + } + + private: + void ReaderLoopForward(std::shared_ptr reader_client, + std::shared_ptr writer_client, + std::vector &queue_id_vec) { + uint64_t recevied_message_cnt = 0; + std::unordered_map queue_last_cp_id; + + for (auto &q_id : queue_id_vec) { + queue_last_cp_id[q_id] = 0; + } + STREAMING_LOG(INFO) << "Start read message bundle"; + while (true) { + std::shared_ptr msg; + StreamingStatus st = reader_client->GetBundle(100, msg); + + if (st != StreamingStatus::OK || !msg->data) { + STREAMING_LOG(DEBUG) << "read bundle timeout, status = " << (int)st; + continue; + } + + STREAMING_CHECK(msg.get() && msg->meta.get()) + << "read null pointer message, queue id => " << msg->from.Hex(); + + if (msg->meta->GetBundleType() == StreamingMessageBundleType::Barrier) { + STREAMING_LOG(DEBUG) << "barrier message recevied => " + << msg->meta->GetMessageBundleTs(); + std::unordered_map *offset_map; + reader_client->GetOffsetInfo(offset_map); + + for (auto &q_id : queue_id_vec) { + reader_client->NotifyConsumedItem((*offset_map)[q_id], + (*offset_map)[q_id].current_seq_id); + } + // writer_client->ClearCheckpoint(msg->last_barrier_id); + + continue; + } else if (msg->meta->GetBundleType() == StreamingMessageBundleType::Empty) { + STREAMING_LOG(DEBUG) << "empty message recevied => " + << msg->meta->GetMessageBundleTs(); + continue; + } + + StreamingMessageBundlePtr bundlePtr; + bundlePtr = StreamingMessageBundle::FromBytes(msg->data); + std::list message_list; + bundlePtr->GetMessageList(message_list); + STREAMING_LOG(INFO) << "message size => " << message_list.size() + << " from queue id => " << msg->from.Hex() + << " last message id => " << msg->meta->GetLastMessageId(); + + recevied_message_cnt += message_list.size(); + for (auto &item : message_list) { + uint64_t i = item->GetMessageSeqId(); + + uint32_t buff_len = i % DEFAULT_STREAMING_MESSAGE_BUFFER_SIZE; + if (i > MESSAGE_BOUND_SIZE) break; + + EXPECT_EQ(buff_len, item->GetDataSize()); + uint8_t *compared_data = new uint8_t[buff_len]; + for (uint32_t j = 0; j < item->GetDataSize(); ++j) { + compared_data[j] = j % 128; + } + EXPECT_EQ(std::memcmp(compared_data, item->RawData(), item->GetDataSize()), 0); + delete[] compared_data; + } + STREAMING_LOG(DEBUG) << "Received message count => " << recevied_message_cnt; + if (recevied_message_cnt == queue_id_vec.size() * MESSAGE_BOUND_SIZE) { + STREAMING_LOG(INFO) << "recevied message count => " << recevied_message_cnt + << ", break"; + break; + } + } + } + + void StreamingReaderStrategyTest(StreamingConfig &config) { + std::vector actor_ids(queue_ids_.size(), peer_actor_id_); + STREAMING_LOG(INFO) << "reader actor_ids size: " << actor_ids.size() + << " actor_id: " << peer_actor_id_; + std::shared_ptr runtime_context(new RuntimeContext()); + runtime_context->SetConfig(config); + std::shared_ptr reader(new DataReader(runtime_context)); + + reader->Init(queue_ids_, actor_ids, -1); + ReaderLoopForward(reader, nullptr, queue_ids_); + + STREAMING_LOG(INFO) << "Reader exit"; + } + + void StreamingWriterExactlyOnceTest() { + STREAMING_LOG(INFO) + << "StreamingQueueReaderTestSuite::StreamingWriterExactlyOnceTest"; + StreamingConfig config; + + StreamingReaderStrategyTest(config); + status_ = true; + } +}; + +class TestSuiteFactory { + public: + static std::shared_ptr CreateTestSuite( + std::shared_ptr worker, std::shared_ptr message) { + std::shared_ptr test_suite = nullptr; + std::string suite_name = message->TestSuiteName(); + queue::protobuf::StreamingQueueTestRole role = message->Role(); + const std::vector &queue_ids = message->QueueIds(); + const std::vector &rescale_queue_ids = message->RescaleQueueIds(); + ActorID peer_actor_id = message->PeerActorId(); + + if (role == queue::protobuf::StreamingQueueTestRole::WRITER) { + if (suite_name == "StreamingWriterTest") { + test_suite = std::make_shared( + worker, peer_actor_id, queue_ids, rescale_queue_ids); + } else { + STREAMING_CHECK(false) << "unsurported suite_name: " << suite_name; + } + } else { + if (suite_name == "StreamingWriterTest") { + test_suite = std::make_shared( + worker, peer_actor_id, queue_ids, rescale_queue_ids); + } else { + STREAMING_CHECK(false) << "unsupported suite_name: " << suite_name; + } + } + + return test_suite; + } +}; + +class StreamingWorker { + public: + StreamingWorker(const std::string &store_socket, const std::string &raylet_socket, + int node_manager_port, const gcs::GcsClientOptions &gcs_options) + : test_suite_(nullptr), peer_actor_handle_(nullptr) { + worker_ = std::make_shared( + WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket, + JobID::FromInt(1), gcs_options, "", "127.0.0.1", node_manager_port, + std::bind(&StreamingWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7)); + + RayFunction reader_async_call_func{ray::Language::PYTHON, {"reader_async_call_func"}}; + RayFunction reader_sync_call_func{ray::Language::PYTHON, {"reader_sync_call_func"}}; + RayFunction writer_async_call_func{ray::Language::PYTHON, {"writer_async_call_func"}}; + RayFunction writer_sync_call_func{ray::Language::PYTHON, {"writer_sync_call_func"}}; + + reader_client_ = std::make_shared(worker_.get(), reader_async_call_func, + reader_sync_call_func); + writer_client_ = std::make_shared(worker_.get(), writer_async_call_func, + writer_sync_call_func); + STREAMING_LOG(INFO) << "StreamingWorker constructor"; + } + + void StartExecutingTasks() { + // Start executing tasks. + worker_->StartExecutingTasks(); + } + + private: + Status ExecuteTask(TaskType task_type, const RayFunction &ray_function, + const std::unordered_map &required_resources, + const std::vector> &args, + const std::vector &arg_reference_ids, + const std::vector &return_ids, + std::vector> *results) { + // Only one arg param used in streaming. + STREAMING_CHECK(args.size() >= 1) << "args.size() = " << args.size(); + + std::vector function_descriptor = ray_function.GetFunctionDescriptor(); + STREAMING_LOG(INFO) << "StreamingWorker::ExecuteTask " << function_descriptor[0]; + + std::string func_name = function_descriptor[0]; + if (func_name == "init") { + std::shared_ptr local_buffer = + std::make_shared(args[0]->GetData()->Data(), + args[0]->GetData()->Size(), true); + HandleInitTask(local_buffer); + } else if (func_name == "execute_test") { + STREAMING_LOG(INFO) << "Test name: " << function_descriptor[1]; + test_suite_->ExecuteTest(function_descriptor[1]); + } else if (func_name == "check_current_test_status") { + results->push_back( + std::make_shared(test_suite_->CheckCurTestStatus(), nullptr)); + } else if (func_name == "reader_sync_call_func") { + if (test_suite_->TestDone()) { + STREAMING_LOG(WARNING) << "Test has done!!"; + return Status::OK(); + } + std::shared_ptr local_buffer = + std::make_shared(args[1]->GetData()->Data(), + args[1]->GetData()->Size(), true); + auto result_buffer = reader_client_->OnReaderMessageSync(local_buffer); + results->push_back(std::make_shared(result_buffer, nullptr)); + } else if (func_name == "reader_async_call_func") { + if (test_suite_->TestDone()) { + STREAMING_LOG(WARNING) << "Test has done!!"; + return Status::OK(); + } + std::shared_ptr local_buffer = + std::make_shared(args[1]->GetData()->Data(), + args[1]->GetData()->Size(), true); + reader_client_->OnReaderMessage(local_buffer); + } else if (func_name == "writer_sync_call_func") { + if (test_suite_->TestDone()) { + STREAMING_LOG(WARNING) << "Test has done!!"; + return Status::OK(); + } + std::shared_ptr local_buffer = + std::make_shared(args[1]->GetData()->Data(), + args[1]->GetData()->Size(), true); + auto result_buffer = writer_client_->OnWriterMessageSync(local_buffer); + results->push_back(std::make_shared(result_buffer, nullptr)); + } else if (func_name == "writer_async_call_func") { + if (test_suite_->TestDone()) { + STREAMING_LOG(WARNING) << "Test has done!!"; + return Status::OK(); + } + std::shared_ptr local_buffer = + std::make_shared(args[1]->GetData()->Data(), + args[1]->GetData()->Size(), true); + writer_client_->OnWriterMessage(local_buffer); + } else { + STREAMING_LOG(WARNING) << "Invalid function name " << func_name; + } + + return Status::OK(); + } + + private: + void HandleInitTask(std::shared_ptr buffer) { + uint8_t *bytes = buffer->Data(); + uint8_t *p_cur = bytes; + uint32_t *magic_num = (uint32_t *)p_cur; + STREAMING_CHECK(*magic_num == Message::MagicNum); + + p_cur += sizeof(Message::MagicNum); + queue::protobuf::StreamingQueueMessageType *type = + (queue::protobuf::StreamingQueueMessageType *)p_cur; + STREAMING_CHECK( + *type == + queue::protobuf::StreamingQueueMessageType::StreamingQueueTestInitMsgType); + std::shared_ptr message = TestInitMessage::FromBytes(bytes); + + STREAMING_LOG(INFO) << "Init message: " << message->ToString(); + std::string actor_handle_serialized = message->ActorHandleSerialized(); + worker_->DeserializeAndRegisterActorHandle(actor_handle_serialized); + std::shared_ptr actor_handle(new ActorHandle(actor_handle_serialized)); + STREAMING_CHECK(actor_handle != nullptr); + STREAMING_LOG(INFO) << " actor id from handle: " << actor_handle->GetActorID(); + ; + + // STREAMING_LOG(INFO) << "actor_handle_serialized: " << actor_handle_serialized; + // peer_actor_handle_ = + // std::make_shared(actor_handle_serialized); + + STREAMING_LOG(INFO) << "HandleInitTask queues:"; + for (auto qid : message->QueueIds()) { + STREAMING_LOG(INFO) << "queue: " << qid; + } + for (auto qid : message->RescaleQueueIds()) { + STREAMING_LOG(INFO) << "rescale queue: " << qid; + } + + test_suite_ = TestSuiteFactory::CreateTestSuite(worker_, message); + STREAMING_CHECK(test_suite_ != nullptr); + } + + private: + std::shared_ptr worker_; + std::shared_ptr reader_client_; + std::shared_ptr writer_client_; + std::shared_ptr test_thread_; + std::shared_ptr test_suite_; + std::shared_ptr peer_actor_handle_; +}; + +} // namespace streaming +} // namespace ray + +int main(int argc, char **argv) { + RAY_CHECK(argc == 4); + auto store_socket = std::string(argv[1]); + auto raylet_socket = std::string(argv[2]); + auto node_manager_port = std::stoi(std::string(argv[3])); + + ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, ""); + ray::streaming::StreamingWorker worker(store_socket, raylet_socket, node_manager_port, + gcs_options); + worker.StartExecutingTasks(); + return 0; +} diff --git a/streaming/src/test/mock_transfer_tests.cc b/streaming/src/test/mock_transfer_tests.cc new file mode 100644 index 0000000000000..0dea14e2a2dff --- /dev/null +++ b/streaming/src/test/mock_transfer_tests.cc @@ -0,0 +1,136 @@ +#include "data_reader.h" +#include "data_writer.h" +#include "gtest/gtest.h" + +using namespace ray; +using namespace ray::streaming; + +TEST(StreamingMockTransfer, mock_produce_consume) { + std::shared_ptr transfer_config; + ObjectID channel_id = ObjectID::FromRandom(); + ProducerChannelInfo producer_channel_info; + producer_channel_info.channel_id = channel_id; + producer_channel_info.current_seq_id = 0; + MockProducer producer(transfer_config, producer_channel_info); + + ConsumerChannelInfo consumer_channel_info; + consumer_channel_info.channel_id = channel_id; + MockConsumer consumer(transfer_config, consumer_channel_info); + + producer.CreateTransferChannel(); + uint8_t data[3] = {1, 2, 3}; + producer.ProduceItemToChannel(data, 3); + uint8_t *data_consumed; + uint32_t data_size_consumed; + uint64_t data_seq_id; + consumer.ConsumeItemFromChannel(data_seq_id, data_consumed, data_size_consumed, -1); + EXPECT_EQ(data_size_consumed, 3); + EXPECT_EQ(data_seq_id, 1); + EXPECT_EQ(std::memcmp(data_consumed, data, 3), 0); + consumer.NotifyChannelConsumed(1); + + auto status = + consumer.ConsumeItemFromChannel(data_seq_id, data_consumed, data_size_consumed, -1); + EXPECT_EQ(status, StreamingStatus::NoSuchItem); +} + +class StreamingTransferTest : public ::testing::Test { + public: + StreamingTransferTest() { + std::shared_ptr runtime_context(new RuntimeContext()); + runtime_context->MarkMockTest(); + writer = std::make_shared(runtime_context); + reader = std::make_shared(runtime_context); + } + virtual ~StreamingTransferTest() = default; + void InitTransfer(int channel_num = 1) { + for (int i = 0; i < channel_num; ++i) { + queue_vec.push_back(ObjectID::FromRandom()); + } + std::vector channel_id_vec(queue_vec.size(), 0); + std::vector queue_size_vec(queue_vec.size(), 10000); + // actor ids are not used in this test, so we can just use Nil. + std::vector actor_id_vec(queue_vec.size(), + ActorID::NilFromJob(JobID::FromInt(0))); + writer->Init(queue_vec, actor_id_vec, channel_id_vec, queue_size_vec); + reader->Init(queue_vec, actor_id_vec, channel_id_vec, queue_size_vec, -1); + } + void DestroyTransfer() { + writer.reset(); + reader.reset(); + } + + protected: + std::shared_ptr writer; + std::shared_ptr reader; + std::vector queue_vec; +}; + +TEST_F(StreamingTransferTest, exchange_single_channel_test) { + InitTransfer(); + writer->Run(); + uint8_t data[4] = {1, 2, 3, 0xff}; + uint32_t data_size = 4; + writer->WriteMessageToBufferRing(queue_vec[0], data, data_size); + std::shared_ptr msg; + reader->GetBundle(5000, msg); + StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data); + auto &message_list = bundle_ptr->GetMessageList(); + auto &message = message_list.front(); + EXPECT_EQ(std::memcmp(message->RawData(), data, data_size), 0); +} + +TEST_F(StreamingTransferTest, exchange_multichannel_test) { + int channel_num = 4; + InitTransfer(4); + writer->Run(); + for (int i = 0; i < channel_num; ++i) { + uint8_t data[4] = {1, 2, 3, (uint8_t)i}; + uint32_t data_size = 4; + writer->WriteMessageToBufferRing(queue_vec[i], data, data_size); + std::shared_ptr msg; + reader->GetBundle(5000, msg); + EXPECT_EQ(msg->from, queue_vec[i]); + StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data); + auto &message_list = bundle_ptr->GetMessageList(); + auto &message = message_list.front(); + EXPECT_EQ(std::memcmp(message->RawData(), data, data_size), 0); + } +} + +TEST_F(StreamingTransferTest, exchange_consumed_test) { + InitTransfer(); + writer->Run(); + uint32_t data_size = 8196; + std::shared_ptr data(new uint8_t[data_size]); + auto func = [data, data_size](int index) { std::fill_n(data.get(), data_size, index); }; + + int num = 10000; + std::thread write_thread([this, data, data_size, &func, num]() { + for (uint32_t i = 0; i < num; ++i) { + func(i); + writer->WriteMessageToBufferRing(queue_vec[0], data.get(), data_size); + } + }); + + std::list read_message_list; + while (read_message_list.size() < num) { + std::shared_ptr msg; + reader->GetBundle(5000, msg); + StreamingMessageBundlePtr bundle_ptr = StreamingMessageBundle::FromBytes(msg->data); + auto &message_list = bundle_ptr->GetMessageList(); + std::copy(message_list.begin(), message_list.end(), + std::back_inserter(read_message_list)); + } + int index = 0; + for (auto &message : read_message_list) { + func(index++); + EXPECT_EQ(std::memcmp(message->RawData(), data.get(), data_size), 0); + } + write_thread.join(); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h new file mode 100644 index 0000000000000..4d70af01cbd73 --- /dev/null +++ b/streaming/src/test/queue_tests_base.h @@ -0,0 +1,313 @@ +namespace ray { +namespace streaming { + +ray::ObjectID RandomObjectID() { return ObjectID::FromRandom(); } + +static void flushall_redis(void) { + redisContext *context = redisConnect("127.0.0.1", 6379); + freeReplyObject(redisCommand(context, "FLUSHALL")); + freeReplyObject(redisCommand(context, "SET NumRedisShards 1")); + freeReplyObject(redisCommand(context, "LPUSH RedisShards 127.0.0.1:6380")); + redisFree(context); +} +/// Base class for real-world tests with streaming queue +class StreamingQueueTestBase : public ::testing::TestWithParam { + public: + StreamingQueueTestBase(int num_nodes, std::string raylet_exe, std::string store_exe, + int port, std::string actor_exe) + : gcs_options_("127.0.0.1", 6379, ""), + raylet_executable_(raylet_exe), + store_executable_(store_exe), + actor_executable_(actor_exe), + node_manager_port_(port) { + // flush redis first. + flushall_redis(); + + RAY_CHECK(num_nodes >= 0); + if (num_nodes > 0) { + raylet_socket_names_.resize(num_nodes); + raylet_store_socket_names_.resize(num_nodes); + } + + // start plasma store. + for (auto &store_socket : raylet_store_socket_names_) { + store_socket = StartStore(); + } + + // start raylet on each node. Assign each node with different resources so that + // a task can be scheduled to the desired node. + for (int i = 0; i < num_nodes; i++) { + raylet_socket_names_[i] = + StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", node_manager_port_ + i, + "127.0.0.1", "\"CPU,4.0,resource" + std::to_string(i) + ",10\""); + } + } + + ~StreamingQueueTestBase() { + STREAMING_LOG(INFO) << "Stop raylet store and actors"; + for (const auto &raylet_socket : raylet_socket_names_) { + StopRaylet(raylet_socket); + } + + for (const auto &store_socket : raylet_store_socket_names_) { + StopStore(store_socket); + } + } + + JobID NextJobId() const { + static uint32_t job_counter = 1; + return JobID::FromInt(job_counter++); + } + + std::string StartStore() { + std::string store_socket_name = "/tmp/store" + RandomObjectID().Hex(); + std::string store_pid = store_socket_name + ".pid"; + std::string plasma_command = store_executable_ + " -m 10000000 -s " + + store_socket_name + + " 1> /dev/null 2> /dev/null & echo $! > " + store_pid; + RAY_LOG(DEBUG) << plasma_command; + RAY_CHECK(system(plasma_command.c_str()) == 0); + usleep(200 * 1000); + return store_socket_name; + } + + void StopStore(std::string store_socket_name) { + std::string store_pid = store_socket_name + ".pid"; + std::string kill_9 = "kill -9 `cat " + store_pid + "`"; + RAY_LOG(DEBUG) << kill_9; + ASSERT_EQ(system(kill_9.c_str()), 0); + ASSERT_EQ(system(("rm -rf " + store_socket_name).c_str()), 0); + ASSERT_EQ(system(("rm -rf " + store_socket_name + ".pid").c_str()), 0); + } + + std::string StartRaylet(std::string store_socket_name, std::string node_ip_address, + int port, std::string redis_address, std::string resource) { + std::string raylet_socket_name = "/tmp/raylet" + RandomObjectID().Hex(); + std::string ray_start_cmd = raylet_executable_; + ray_start_cmd.append(" --raylet_socket_name=" + raylet_socket_name) + .append(" --store_socket_name=" + store_socket_name) + .append(" --object_manager_port=0 --node_manager_port=" + std::to_string(port)) + .append(" --node_ip_address=" + node_ip_address) + .append(" --redis_address=" + redis_address) + .append(" --redis_port=6379") + .append(" --num_initial_workers=1") + .append(" --maximum_startup_concurrency=10") + .append(" --static_resource_list=" + resource) + .append(" --python_worker_command=\"" + actor_executable_ + " " + + store_socket_name + " " + raylet_socket_name + " " + + std::to_string(port) + "\"") + .append(" --config_list=initial_reconstruction_timeout_milliseconds,2000") + .append(" & echo $! > " + raylet_socket_name + ".pid"); + + RAY_LOG(DEBUG) << "Ray Start command: " << ray_start_cmd; + RAY_CHECK(system(ray_start_cmd.c_str()) == 0); + usleep(200 * 1000); + return raylet_socket_name; + } + + void StopRaylet(std::string raylet_socket_name) { + std::string raylet_pid = raylet_socket_name + ".pid"; + std::string kill_9 = "kill -9 `cat " + raylet_pid + "`"; + RAY_LOG(DEBUG) << kill_9; + ASSERT_TRUE(system(kill_9.c_str()) == 0); + ASSERT_TRUE(system(("rm -rf " + raylet_socket_name).c_str()) == 0); + ASSERT_TRUE(system(("rm -rf " + raylet_socket_name + ".pid").c_str()) == 0); + } + + void InitWorker(CoreWorker &driver, ActorID &self_actor_id, ActorID &peer_actor_id, + const queue::protobuf::StreamingQueueTestRole role, + const std::vector &queue_ids, + const std::vector &rescale_queue_ids, std::string suite_name, + std::string test_name, uint64_t param) { + std::string forked_serialized_str; + Status st = driver.SerializeActorHandle(peer_actor_id, &forked_serialized_str); + STREAMING_CHECK(st.ok()); + STREAMING_LOG(INFO) << "forked_serialized_str: " << forked_serialized_str; + TestInitMessage msg(role, self_actor_id, peer_actor_id, forked_serialized_str, + queue_ids, rescale_queue_ids, suite_name, test_name, param); + + std::vector args; + args.emplace_back( + TaskArg::PassByValue(std::make_shared(msg.ToBytes(), nullptr, true))); + std::unordered_map resources; + TaskOptions options{0, true, resources}; + std::vector return_ids; + RayFunction func{ray::Language::PYTHON, {"init"}}; + + RAY_CHECK_OK(driver.SubmitActorTask(self_actor_id, func, args, options, &return_ids)); + } + + void SubmitTestToActor(CoreWorker &driver, ActorID &actor_id, const std::string test) { + uint8_t data[8]; + auto buffer = std::make_shared(data, 8, true); + std::vector args; + args.emplace_back( + TaskArg::PassByValue(std::make_shared(buffer, nullptr, true))); + std::unordered_map resources; + TaskOptions options{0, true, resources}; + std::vector return_ids; + RayFunction func{ray::Language::PYTHON, {"execute_test", test}}; + + RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids)); + } + + bool CheckCurTest(CoreWorker &driver, ActorID &actor_id, const std::string test_name) { + uint8_t data[8]; + auto buffer = std::make_shared(data, 8, true); + std::vector args; + args.emplace_back( + TaskArg::PassByValue(std::make_shared(buffer, nullptr, true))); + std::unordered_map resources; + TaskOptions options{1, true, resources}; + std::vector return_ids; + RayFunction func{ray::Language::PYTHON, {"check_current_test_status"}}; + + RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids)); + + std::vector wait_results; + std::vector> results; + Status wait_st = driver.Wait(return_ids, 1, 5 * 1000, &wait_results); + if (!wait_st.ok()) { + STREAMING_LOG(ERROR) << "Wait fail."; + return false; + } + STREAMING_CHECK(wait_results.size() >= 1); + if (!wait_results[0]) { + STREAMING_LOG(WARNING) << "Wait direct call fail."; + return false; + } + + Status get_st = driver.Get(return_ids, -1, &results); + if (!get_st.ok()) { + STREAMING_LOG(ERROR) << "Get fail."; + return false; + } + STREAMING_CHECK(results.size() >= 1); + if (results[0]->IsException()) { + STREAMING_LOG(INFO) << "peer actor may has exceptions."; + return false; + } + STREAMING_CHECK(results[0]->HasData()); + STREAMING_LOG(DEBUG) << "SendForResult result[0] DataSize: " << results[0]->GetSize(); + + const std::shared_ptr result_buffer = results[0]->GetData(); + std::shared_ptr return_buffer = + std::make_shared(result_buffer->Data(), result_buffer->Size(), + true); + + uint8_t *bytes = result_buffer->Data(); + uint8_t *p_cur = bytes; + uint32_t *magic_num = (uint32_t *)p_cur; + STREAMING_CHECK(*magic_num == Message::MagicNum); + + p_cur += sizeof(Message::MagicNum); + queue::protobuf::StreamingQueueMessageType *type = + (queue::protobuf::StreamingQueueMessageType *)p_cur; + STREAMING_CHECK(*type == queue::protobuf::StreamingQueueMessageType:: + StreamingQueueTestCheckStatusRspMsgType); + std::shared_ptr message = + TestCheckStatusRspMsg::FromBytes(bytes); + STREAMING_CHECK(message->TestName() == test_name); + return message->Status(); + } + + ActorID CreateActorHelper(CoreWorker &worker, + const std::unordered_map &resources, + bool is_direct_call, uint64_t max_reconstructions) { + std::unique_ptr actor_handle; + + // Test creating actor. + uint8_t array[] = {1, 2, 3}; + auto buffer = std::make_shared(array, sizeof(array)); + + RayFunction func{ray::Language::PYTHON, {"actor creation task"}}; + std::vector args; + args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr))); + + ActorCreationOptions actor_options{ + max_reconstructions, is_direct_call, + /*max_concurrency*/ 1, resources, resources, {}, + /*is_detached*/ false, /*is_asyncio*/ false}; + + // Create an actor. + ActorID actor_id; + RAY_CHECK_OK(worker.CreateActor(func, args, actor_options, &actor_id)); + return actor_id; + } + + void SubmitTest(uint32_t queue_num, std::string suite_name, std::string test_name, + uint64_t timeout_ms) { + std::vector queue_id_vec; + std::vector rescale_queue_id_vec; + for (uint32_t i = 0; i < queue_num; ++i) { + ObjectID queue_id = ray::ObjectID::FromRandom(); + queue_id_vec.emplace_back(queue_id); + } + + // One scale id + ObjectID rescale_queue_id = ray::ObjectID::FromRandom(); + rescale_queue_id_vec.emplace_back(rescale_queue_id); + + std::vector channel_seq_id_vec(queue_num, 0); + + for (size_t i = 0; i < queue_id_vec.size(); ++i) { + STREAMING_LOG(INFO) << " qid hex => " << queue_id_vec[i].Hex(); + } + for (auto &qid : rescale_queue_id_vec) { + STREAMING_LOG(INFO) << " rescale qid hex => " << qid.Hex(); + } + STREAMING_LOG(INFO) << "Sub process: writer."; + + CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], + raylet_socket_names_[0], NextJobId(), gcs_options_, "", "", + node_manager_port_, nullptr); + + // Create writer and reader actors + std::unordered_map resources; + auto actor_id_writer = CreateActorHelper(driver, resources, true, 0); + auto actor_id_reader = CreateActorHelper(driver, resources, true, 0); + + InitWorker(driver, actor_id_writer, actor_id_reader, + queue::protobuf::StreamingQueueTestRole::WRITER, queue_id_vec, + rescale_queue_id_vec, suite_name, test_name, GetParam()); + InitWorker(driver, actor_id_reader, actor_id_writer, + queue::protobuf::StreamingQueueTestRole::READER, queue_id_vec, + rescale_queue_id_vec, suite_name, test_name, GetParam()); + + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + + SubmitTestToActor(driver, actor_id_writer, test_name); + SubmitTestToActor(driver, actor_id_reader, test_name); + + uint64_t slept_time_ms = 0; + while (slept_time_ms < timeout_ms) { + std::this_thread::sleep_for(std::chrono::milliseconds(5 * 1000)); + STREAMING_LOG(INFO) << "Check test status."; + if (CheckCurTest(driver, actor_id_writer, test_name) && + CheckCurTest(driver, actor_id_reader, test_name)) { + STREAMING_LOG(INFO) << "Test Success, Exit."; + return; + } + slept_time_ms += 5 * 1000; + } + + EXPECT_TRUE(false); + STREAMING_LOG(INFO) << "Test Timeout, Exit."; + } + + void SetUp() {} + + void TearDown() {} + + protected: + std::vector raylet_socket_names_; + std::vector raylet_store_socket_names_; + gcs::GcsClientOptions gcs_options_; + std::string raylet_executable_; + std::string store_executable_; + std::string actor_executable_; + int node_manager_port_; +}; + +} // namespace streaming +} // namespace ray diff --git a/streaming/src/test/ring_buffer_tests.cc b/streaming/src/test/ring_buffer_tests.cc new file mode 100644 index 0000000000000..c419e4f3d73eb --- /dev/null +++ b/streaming/src/test/ring_buffer_tests.cc @@ -0,0 +1,93 @@ +#include "gtest/gtest.h" +#include "ray/util/logging.h" + +#include +#include +#include +#include +#include "message/message.h" +#include "ring_buffer.h" + +using namespace ray; +using namespace ray::streaming; + +size_t data_n = 1000000; +TEST(StreamingRingBufferTest, streaming_message_ring_buffer_test) { + for (int k = 0; k < 10000; ++k) { + StreamingRingBuffer ring_buffer(3, StreamingRingBufferType::SPSC_LOCK); + for (int i = 0; i < 5; ++i) { + uint8_t data[] = {1, 1, 3}; + data[0] = i; + StreamingMessagePtr message = + std::make_shared(data, 3, i, StreamingMessageType::Message); + EXPECT_EQ(ring_buffer.Push(message), true); + size_t ith = i >= 3 ? 3 : (i + 1); + EXPECT_EQ(ring_buffer.Size(), ith); + } + int th = 2; + + while (!ring_buffer.IsEmpty()) { + StreamingMessagePtr message_ptr = ring_buffer.Front(); + ring_buffer.Pop(); + EXPECT_EQ(message_ptr->GetDataSize(), 3); + EXPECT_EQ(*(message_ptr->RawData()), th++); + } + } +} + +TEST(StreamingRingBufferTest, spsc_test) { + size_t m_num = 1000; + StreamingRingBuffer ring_buffer(m_num, StreamingRingBufferType::SPSC); + std::thread thread([&ring_buffer]() { + for (size_t j = 0; j < data_n; ++j) { + StreamingMessagePtr message = std::make_shared( + reinterpret_cast(&j), sizeof(size_t), j, + StreamingMessageType::Message); + while (ring_buffer.IsFull()) { + } + ring_buffer.Push(message); + } + }); + size_t count = 0; + while (count < data_n) { + while (ring_buffer.IsEmpty()) { + } + auto &msg = ring_buffer.Front(); + EXPECT_EQ(std::memcmp(msg->RawData(), &count, sizeof(size_t)), 0); + ring_buffer.Pop(); + count++; + } + thread.join(); + EXPECT_EQ(count, data_n); +} + +TEST(StreamingRingBufferTest, mutex_test) { + size_t m_num = data_n; + StreamingRingBuffer ring_buffer(m_num, StreamingRingBufferType::SPSC_LOCK); + std::thread thread([&ring_buffer]() { + for (size_t j = 0; j < data_n; ++j) { + StreamingMessagePtr message = std::make_shared( + reinterpret_cast(&j), sizeof(size_t), j, + StreamingMessageType::Message); + while (ring_buffer.IsFull()) { + } + ring_buffer.Push(message); + } + }); + size_t count = 0; + while (count < data_n) { + while (ring_buffer.IsEmpty()) { + } + auto msg = ring_buffer.Front(); + EXPECT_EQ(std::memcmp(msg->RawData(), &count, sizeof(size_t)), 0); + ring_buffer.Pop(); + count++; + } + thread.join(); + EXPECT_EQ(count, data_n); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/run_streaming_queue_test.sh b/streaming/src/test/run_streaming_queue_test.sh new file mode 100644 index 0000000000000..a2771e83039d0 --- /dev/null +++ b/streaming/src/test/run_streaming_queue_test.sh @@ -0,0 +1,69 @@ +#!/usr/bin/env bash + +# Run all streaming c++ tests using streaming queue, instead of plasma queue +# This needs to be run in the root directory. + +# Try to find an unused port for raylet to use. +PORTS="2000 2001 2002 2003 2004 2005 2006 2007 2008 2009" +RAYLET_PORT=0 +for port in $PORTS; do + nc -z localhost $port + if [[ $? != 0 ]]; then + RAYLET_PORT=$port + break + fi +done + +if [[ $RAYLET_PORT == 0 ]]; then + echo "WARNING: Could not find unused port for raylet to use. Exiting without running tests." + exit +fi + +# Cause the script to exit if a single command fails. +set -e +set -x +export STREAMING_METRICS_MODE=DEV + +# Get the directory in which this script is executing. +SCRIPT_DIR="`dirname \"$0\"`" + +# Get the directory in which this script is executing. +SCRIPT_DIR="`dirname \"$0\"`" +RAY_ROOT="$SCRIPT_DIR/../../.." +# Makes $RAY_ROOT an absolute path. +RAY_ROOT="`( cd \"$RAY_ROOT\" && pwd )`" +if [ -z "$RAY_ROOT" ] ; then + exit 1 +fi + +bazel build "//:core_worker_test" "//:mock_worker" "//:raylet" "//:libray_redis_module.so" "@plasma//:plasma_store_server" +bazel build //streaming:streaming_test_worker +bazel build //streaming:streaming_queue_tests + +# Ensure we're in the right directory. +if [ ! -d "$RAY_ROOT/python" ]; then + echo "Unable to find root Ray directory. Has this script moved?" + exit 1 +fi + +REDIS_MODULE="./bazel-bin/libray_redis_module.so" +LOAD_MODULE_ARGS="--loadmodule ${REDIS_MODULE}" +STORE_EXEC="./bazel-bin/external/plasma/plasma_store_server" +RAYLET_EXEC="./bazel-bin/raylet" +STREAMING_TEST_WORKER_EXEC="./bazel-bin/streaming/streaming_test_worker" + +# Allow cleanup commands to fail. +bazel run //:redis-cli -- -p 6379 shutdown || true +sleep 1s +bazel run //:redis-cli -- -p 6380 shutdown || true +sleep 1s +bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6379 & +sleep 2s +bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6380 & +sleep 2s +# Run tests. +./bazel-bin/streaming/streaming_queue_tests $STORE_EXEC $RAYLET_EXEC $RAYLET_PORT $STREAMING_TEST_WORKER_EXEC +sleep 1s +bazel run //:redis-cli -- -p 6379 shutdown +bazel run //:redis-cli -- -p 6380 shutdown +sleep 1s diff --git a/streaming/src/test/streaming_queue_tests.cc b/streaming/src/test/streaming_queue_tests.cc new file mode 100644 index 0000000000000..e5eb87b13a410 --- /dev/null +++ b/streaming/src/test/streaming_queue_tests.cc @@ -0,0 +1,65 @@ +#define BOOST_BIND_NO_PLACEHOLDERS +#include +#include "gtest/gtest.h" +#include "queue/queue_client.h" +#include "ray/core_worker/core_worker.h" + +#include "data_reader.h" +#include "data_writer.h" +#include "message/message.h" +#include "message/message_bundle.h" +#include "ring_buffer.h" + +#include "queue_tests_base.h" + +using namespace std::placeholders; +namespace ray { +namespace streaming { + +static std::string store_executable; +static std::string raylet_executable; +static std::string actor_executable; +static int node_manager_port; + +class StreamingWriterTest : public StreamingQueueTestBase { + public: + StreamingWriterTest() + : StreamingQueueTestBase(1, raylet_executable, store_executable, node_manager_port, + actor_executable) {} +}; + +class StreamingExactlySameTest : public StreamingQueueTestBase { + public: + StreamingExactlySameTest() + : StreamingQueueTestBase(1, raylet_executable, store_executable, node_manager_port, + actor_executable) {} +}; + +TEST_P(StreamingWriterTest, streaming_writer_exactly_once_test) { + STREAMING_LOG(INFO) << "StreamingWriterTest.streaming_writer_exactly_once_test"; + + uint32_t queue_num = 1; + + STREAMING_LOG(INFO) << "Streaming Strategy => EXACTLY ONCE"; + SubmitTest(queue_num, "StreamingWriterTest", "streaming_writer_exactly_once_test", + 60 * 1000); +} + +INSTANTIATE_TEST_CASE_P(StreamingTest, StreamingWriterTest, testing::Values(0)); + +INSTANTIATE_TEST_CASE_P(StreamingTest, StreamingExactlySameTest, + testing::Values(0, 1, 5, 9)); + +} // namespace streaming +} // namespace ray + +int main(int argc, char **argv) { + // set_streaming_log_config("streaming_writer_test", StreamingLogLevel::INFO, 0); + ::testing::InitGoogleTest(&argc, argv); + RAY_CHECK(argc == 5); + ray::streaming::store_executable = std::string(argv[1]); + ray::streaming::raylet_executable = std::string(argv[2]); + ray::streaming::node_manager_port = std::stoi(std::string(argv[3])); + ray::streaming::actor_executable = std::string(argv[4]); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/test/streaming_util_tests.cc b/streaming/src/test/streaming_util_tests.cc new file mode 100644 index 0000000000000..a633065e51e07 --- /dev/null +++ b/streaming/src/test/streaming_util_tests.cc @@ -0,0 +1,24 @@ +#include "gtest/gtest.h" + +#include "util/streaming_util.h" + +using namespace ray; +using namespace ray::streaming; + +TEST(StreamingUtilTest, test_Byte2hex) { + const uint8_t data[2] = {0x11, 0x07}; + EXPECT_TRUE(Util::Byte2hex(data, 2) == "1107"); + EXPECT_TRUE(Util::Byte2hex(data, 2) != "1108"); +} + +TEST(StreamingUtilTest, test_Hex2str) { + const uint8_t data[2] = {0x11, 0x07}; + EXPECT_TRUE(std::memcmp(Util::Hexqid2str("1107").c_str(), data, 2) == 0); + const uint8_t data2[2] = {0x10, 0x0f}; + EXPECT_TRUE(std::memcmp(Util::Hexqid2str("100f").c_str(), data2, 2) == 0); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/streaming/src/util/streaming_logging.cc b/streaming/src/util/streaming_logging.cc new file mode 100644 index 0000000000000..c4c9c6c49e6f0 --- /dev/null +++ b/streaming/src/util/streaming_logging.cc @@ -0,0 +1,12 @@ +#include +#include +#include + +#include "glog/log_severity.h" +#include "glog/logging.h" + +#include "streaming_logging.h" + +namespace ray { +namespace streaming {} // namespace streaming +} // namespace ray diff --git a/streaming/src/util/streaming_logging.h b/streaming/src/util/streaming_logging.h new file mode 100644 index 0000000000000..ba6b0b45d69b5 --- /dev/null +++ b/streaming/src/util/streaming_logging.h @@ -0,0 +1,11 @@ +#ifndef RAY_STREAMING_LOGGING_H +#define RAY_STREAMING_LOGGING_H +#include "ray/util/logging.h" + +#define STREAMING_LOG RAY_LOG +#define STREAMING_CHECK RAY_CHECK +namespace ray { +namespace streaming {} // namespace streaming +} // namespace ray + +#endif // RAY_STREAMING_LOGGING_H diff --git a/streaming/src/util/streaming_util.cc b/streaming/src/util/streaming_util.cc new file mode 100644 index 0000000000000..4ed99a190f0ae --- /dev/null +++ b/streaming/src/util/streaming_util.cc @@ -0,0 +1,42 @@ +#include + +#include "streaming_util.h" +namespace ray { +namespace streaming { + +boost::any &Config::Get(ConfigEnum key) const { + auto item = config_map_.find(key); + STREAMING_CHECK(item != config_map_.end()); + return item->second; +} + +boost::any Config::Get(ConfigEnum key, boost::any default_value) const { + auto item = config_map_.find(key); + if (item == config_map_.end()) { + return default_value; + } + return item->second; +} + +std::string Util::Byte2hex(const uint8_t *data, uint32_t data_size) { + constexpr char hex[] = "0123456789abcdef"; + std::string result; + for (uint32_t i = 0; i < data_size; i++) { + unsigned short val = data[i]; + result.push_back(hex[val >> 4]); + result.push_back(hex[val & 0xf]); + } + return result; +} + +std::string Util::Hexqid2str(const std::string &q_id_hex) { + std::string result; + for (uint32_t i = 0; i < q_id_hex.size(); i += 2) { + std::string byte = q_id_hex.substr(i, 2); + char chr = static_cast(std::strtol(byte.c_str(), nullptr, 16)); + result.push_back(chr); + } + return result; +} +} // namespace streaming +} // namespace ray diff --git a/streaming/src/util/streaming_util.h b/streaming/src/util/streaming_util.h new file mode 100644 index 0000000000000..a6665f4fe4fbe --- /dev/null +++ b/streaming/src/util/streaming_util.h @@ -0,0 +1,99 @@ +#ifndef RAY_STREAMING_UTIL_H +#define RAY_STREAMING_UTIL_H +#include +#include +#include + +#include "util/streaming_logging.h" + +namespace ray { +namespace streaming { + +enum class ConfigEnum : uint32_t { + QUEUE_ID_VECTOR = 0, + RECONSTRUCT_RETRY_TIMES, + RECONSTRUCT_TIMEOUT_PER_MB, + CURRENT_DRIVER_ID, + /// For direct call + CORE_WORKER, + SYNC_FUNCTION, + ASYNC_FUNCTION, + TRANSFER_MIN = QUEUE_ID_VECTOR, + TRANSFER_MAX = ASYNC_FUNCTION +}; +} // namespace streaming +} // namespace ray + +namespace std { +template <> +struct hash<::ray::streaming::ConfigEnum> { + size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const { + return static_cast(config_enum_key); + } +}; + +template <> +struct hash { + size_t operator()(const ::ray::streaming::ConfigEnum &config_enum_key) const { + return static_cast(config_enum_key); + } +}; +} // namespace std + +namespace ray { +namespace streaming { + +class Config { + public: + template + inline void Set(ConfigEnum key, const ValueType &any) { + config_map_.emplace(key, any); + } + + template + inline void Set(ConfigEnum key, ValueType &&any) { + config_map_.emplace(key, any); + } + + template + inline boost::any &GetOrDefault(ConfigEnum key, ValueType &&any) { + auto item = config_map_.find(key); + if (item != config_map_.end()) { + return item->second; + } + Set(key, any); + return any; + } + + boost::any &Get(ConfigEnum key) const; + + boost::any Get(ConfigEnum key, boost::any default_value) const; + + inline uint32_t GetInt32(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline uint64_t GetInt64(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline double GetDouble(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline bool GetBool(ConfigEnum key) { return boost::any_cast(Get(key)); } + + inline std::string GetString(ConfigEnum key) { + return boost::any_cast(Get(key)); + } + + virtual ~Config() = default; + + protected: + mutable std::unordered_map config_map_; +}; + +class Util { + public: + static std::string Byte2hex(const uint8_t *data, uint32_t data_size); + + static std::string Hexqid2str(const std::string &q_id_hex); +}; +} // namespace streaming +} // namespace ray + +#endif // RAY_STREAMING_UTIL_H diff --git a/thirdparty/patches/grpc-cython-copts.patch b/thirdparty/patches/grpc-cython-copts.patch index cdb7cffea2932..5c5e4b1427fb6 100644 --- a/thirdparty/patches/grpc-cython-copts.patch +++ b/thirdparty/patches/grpc-cython-copts.patch @@ -1,30 +1,49 @@ diff --git bazel/cython_library.bzl bazel/cython_library.bzl -index 48b41d74e8..6084734f59 100644 +index 48b41d74e8..a9bc168e5d 100644 --- bazel/cython_library.bzl +++ bazel/cython_library.bzl -@@ -7,7 +7,7 @@ +@@ -7,18 +7,20 @@ # been written at cython/cython and tensorflow/tensorflow. We branch from # Tensorflow's version as it is more actively maintained and works for gRPC # Python's needs. -def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs): -+def pyx_library(name, deps=[], copts=[], py_deps=[], srcs=[], **kwargs): ++def pyx_library(name, deps=[], copts=[], cc_kwargs={}, py_deps=[], srcs=[], **kwargs): """Compiles a group of .pyx / .pxd / .py files. - + First runs Cython to create .cpp files for each input .pyx or .py + .pxd -@@ -19,6 +19,7 @@ def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs): +- pair. Then builds a shared object for each, passing "deps" to each cc_binary +- rule (includes Python headers by default). Finally, creates a py_library rule +- with the shared objects and any pure Python "srcs", with py_deps as its +- dependencies; the shared objects can be imported like normal Python files. ++ pair. Then builds a shared object for each, passing "deps" and `**cc_kwargs` ++ to each cc_binary rule (includes Python headers by default). Finally, creates ++ a py_library rule with the shared objects and any pure Python "srcs", with py_deps ++ as its dependencies; the shared objects can be imported like normal Python files. + Args: name: Name for the rule. deps: C/C++ dependencies of the Cython (e.g. Numpy headers). + copts: C/C++ compiler options for Cython ++ cc_kwargs: cc_binary extra arguments such as linkstatic, linkopts, features py_deps: Pure Python dependencies of the final library. srcs: .py, .pyx, or .pxd files to either compile or pass through. **kwargs: Extra keyword arguments passed to the py_library. -@@ -58,6 +59,7 @@ def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs): +@@ -57,9 +59,11 @@ def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs): + shared_object_name = stem + ".so" native.cc_binary( name=shared_object_name, - srcs=[stem + ".cpp"], +- srcs=[stem + ".cpp"], ++ srcs=[stem + ".cpp"] + cc_kwargs.pop("srcs", []), + copts=copts, deps=deps + ["@local_config_python//:python_headers"], linkshared=1, ++ **cc_kwargs ) --- + shared_objects.append(shared_object_name) + +@@ -72,3 +76,4 @@ def pyx_library(name, deps=[], py_deps=[], srcs=[], **kwargs): + data=shared_objects, + **kwargs) + ++ +--