From df3dd96c1e54df5428b6a6dc3f085103421cfc89 Mon Sep 17 00:00:00 2001 From: Hao Chen Date: Fri, 26 Jan 2024 15:03:26 -0800 Subject: [PATCH] [data][train] Fix deadlocks caused by streaming_split (#42601) Fix a deadlock issue for training jobs. The issue happens in the following situation: * The output blocks of `streaming_split` are assigned to multiple splits (`output_split_idx`). * When one split has finished reading all blocks, it won't stop the iteration until all the other splits have all finished, because of [this](https://github.com/ray-project/ray/blob/fae8d2ff814377eb027d63d73a23d5c5bf3b02bd/python/ray/data/_internal/execution/streaming_executor_state.py#L288). * This is usually fine. But when the unfinished splits are waiting for the finished splits (e.g., there is a gradient synchronization), there will be a dead lock due to circular dependencies. This PR makes the finished splits can finish iteration immediately without waiting for others. --------- Signed-off-by: Hao Chen --- .../execution/streaming_executor_state.py | 122 ++++++++++-------- .../data/tests/test_streaming_integration.py | 65 ++++++++++ 2 files changed, 136 insertions(+), 51 deletions(-) diff --git a/python/ray/data/_internal/execution/streaming_executor_state.py b/python/ray/data/_internal/execution/streaming_executor_state.py index 96e0cb66f89f4..88108ba86847e 100644 --- a/python/ray/data/_internal/execution/streaming_executor_state.py +++ b/python/ray/data/_internal/execution/streaming_executor_state.py @@ -106,62 +106,94 @@ class DownstreamMemoryInfo: object_store_memory: float -class RefBundleDeque(deque): - """Thread-safe wrapper around collections.deque that stores current stats.""" +class OpBufferQueue: + """A FIFO queue to buffer RefBundles between upstream and downstream operators. + This class is thread-safe. + """ def __init__(self): self._memory_usage = 0 self._num_blocks = 0 + self._queue = deque() + self._num_per_split = defaultdict(int) self._lock = threading.Lock() super().__init__() @property def memory_usage(self) -> int: + """The total memory usage of the queue in bytes.""" with self._lock: return self._memory_usage @property def num_blocks(self) -> int: + """The total number of blocks in the queue.""" with self._lock: return self._num_blocks - def append(self, ref: RefBundle): - with self._lock: - self._memory_usage += ref.size_bytes() - self._num_blocks += len(ref.blocks) - super().append(ref) + def __len__(self): + return len(self._queue) - def appendleft(self, ref: RefBundle): - with self._lock: - self._memory_usage += ref.size_bytes() - self._num_blocks += len(ref.blocks) - super().appendleft(ref) + def has_next(self, output_split_idx: Optional[int] = None) -> bool: + """Whether next RefBundle is available. - def pop(self) -> RefBundle: - ref = super().pop() - with self._lock: - self._memory_usage -= ref.size_bytes() - self._num_blocks -= len(ref.blocks) - return ref + Args: + output_split_idx: If specified, only check ref bundles with the + given output split. + """ + if output_split_idx is None: + return len(self._queue) > 0 + else: + with self._lock: + return self._num_per_split[output_split_idx] > 0 - def popleft(self) -> RefBundle: - ref = super().popleft() + def append(self, ref: RefBundle): + """Append a RefBundle to the queue.""" + self._queue.append(ref) with self._lock: - self._memory_usage -= ref.size_bytes() - self._num_blocks -= len(ref.blocks) - return ref - - def remove(self, ref: RefBundle): - super().remove(ref) + self._memory_usage += ref.size_bytes() + self._num_blocks += len(ref.blocks) + if ref.output_split_idx is not None: + self._num_per_split[ref.output_split_idx] += 1 + + def pop(self, output_split_idx: Optional[int] = None) -> Optional[RefBundle]: + """Pop a RefBundle from the queue. + Args: + output_split_idx: If specified, only pop a RefBundle + with the given output split. + Returns: + A RefBundle if available, otherwise None. + """ + ret = None + if output_split_idx is None: + try: + ret = self._queue.popleft() + except IndexError: + pass + else: + # TODO(hchen): Index the queue by output_split_idx to + # avoid linear scan. + for i in range(len(self._queue)): + ref = self._queue[i] + if ref.output_split_idx == output_split_idx: + ret = ref + del self._queue[i] + break + if ret is None: + return None with self._lock: - self._memory_usage -= ref.size_bytes() - self._num_blocks -= len(ref.blocks) + self._memory_usage -= ret.size_bytes() + self._num_blocks -= len(ret.blocks) + if ret.output_split_idx is not None: + self._num_per_split[ret.output_split_idx] -= 1 + return ret def clear(self): - super().clear() with self._lock: + self._queue.clear() self._memory_usage = 0 self._num_blocks = 0 + self._num_per_split.clear() class OpState: @@ -174,17 +206,17 @@ class OpState: operator queues to be shared across threads. """ - def __init__(self, op: PhysicalOperator, inqueues: List[RefBundleDeque]): + def __init__(self, op: PhysicalOperator, inqueues: List[OpBufferQueue]): # Each inqueue is connected to another operator's outqueue. assert len(inqueues) == len(op.input_dependencies), (op, inqueues) - self.inqueues: List[RefBundleDeque] = inqueues + self.inqueues: List[OpBufferQueue] = inqueues # The outqueue is connected to another operator's inqueue (they physically # share the same Python list reference). # # Note: this queue is also accessed concurrently from the consumer thread. # (in addition to the streaming executor thread). Hence, it must be a # thread-safe type such as `deque`. - self.outqueue: RefBundleDeque = RefBundleDeque() + self.outqueue: OpBufferQueue = OpBufferQueue() self.op = op self.progress_bar = None self.num_completed_tasks = 0 @@ -266,8 +298,9 @@ def summary_str(self) -> str: def dispatch_next_task(self) -> None: """Move a bundle from the operator inqueue to the operator itself.""" for i, inqueue in enumerate(self.inqueues): - if inqueue: - self.op.add_input(inqueue.popleft(), input_index=i) + ref = inqueue.pop() + if ref is not None: + self.op.add_input(ref, input_index=i) return assert False, "Nothing to dispatch" @@ -285,24 +318,11 @@ def get_output_blocking(self, output_split_idx: Optional[int]) -> RefBundle: # Check if StreamingExecutor has caught an exception or is done execution. if self._exception is not None: raise self._exception - elif self._finished and len(self.outqueue) == 0: + elif self._finished and not self.outqueue.has_next(output_split_idx): raise StopIteration() - try: - # Non-split output case. - if output_split_idx is None: - return self.outqueue.popleft() - - # Scan the queue and look for outputs tagged for the given index. - for i in range(len(self.outqueue)): - bundle = self.outqueue[i] - if bundle.output_split_idx == output_split_idx: - self.outqueue.remove(bundle) - return bundle - - # Didn't find any outputs matching this index, repeat the loop until - # we find one or hit a None. - except IndexError: - pass + ref = self.outqueue.pop(output_split_idx) + if ref is not None: + return ref time.sleep(0.01) def inqueue_memory_usage(self) -> int: diff --git a/python/ray/data/tests/test_streaming_integration.py b/python/ray/data/tests/test_streaming_integration.py index 0aeee09925263..f2317920ef755 100644 --- a/python/ray/data/tests/test_streaming_integration.py +++ b/python/ray/data/tests/test_streaming_integration.py @@ -276,6 +276,71 @@ def consume(x, times): ) +def test_streaming_split_independent_finish(ray_start_10_cpus_shared): + """Test that stream_split iterators can finish independently without + waiting for other iterators to finish. Otherwise, this would cause + deadlocks. + """ + num_blocks_per_split = 10 + num_splits = 2 + ds = ray.data.range( + num_splits * num_blocks_per_split, + parallelism=num_splits * num_blocks_per_split, + ) + ( + i1, + i2, + ) = ds.streaming_split(num_splits, equal=True) + + @ray.remote(max_concurrency=2) + class SignalActor: + def __init__(self): + self._event = threading.Event() + + def wait(self): + self._event.wait() + + def set(self): + self._event.set() + + @ray.remote + class Consumer: + def consume(self, it, signal_actor, split_index): + for i, _ in enumerate(it.iter_batches(batch_size=None, prefetch_batches=0)): + if i == num_blocks_per_split // 2 and split_index == 0: + # The first consumer waits for the second consumer to + # finish first in the middle of the iteration. + print("before wait") + ray.get(signal_actor.wait.remote()) + print("after wait") + if split_index == 1: + # The second consumer sends a signal to unblock the + # first consumer. It should finish the iteration independently. + # Otherwise, there will be a deadlock. + print("before set") + # Sleep some time to make sure the other + # consume calls wait first. + time.sleep(2) + ray.get(signal_actor.set.remote()) + print("after set") + pass + + signal_actor = SignalActor.remote() + consumer1 = Consumer.remote() + consumer2 = Consumer.remote() + + ready, _ = ray.wait( + [ + consumer1.consume.remote(i1, signal_actor, 0), + consumer2.consume.remote(i2, signal_actor, 1), + ], + num_returns=2, + timeout=20, + ) + + assert len(ready) == 2 + + @pytest.mark.skip( reason="Incomplete implementation of _validate_dag causes other errors, so we " "remove DAG validation for now; see https://github.com/ray-project/ray/pull/37829"