Skip to content

Commit

Permalink
[data][train] Fix deadlocks caused by streaming_split (#42601)
Browse files Browse the repository at this point in the history
Fix a deadlock issue for training jobs. The issue happens in the following situation:
* The output blocks of `streaming_split` are assigned to multiple splits (`output_split_idx`). 
* When one split has finished reading all blocks, it won't stop the iteration until all the other splits have all finished, because of [this](https://github.com/ray-project/ray/blob/fae8d2ff814377eb027d63d73a23d5c5bf3b02bd/python/ray/data/_internal/execution/streaming_executor_state.py#L288). 
* This is usually fine. But when the unfinished splits are waiting for the finished splits (e.g., there is a gradient synchronization), there will be a dead lock due to circular dependencies. 

This PR makes the finished splits can finish iteration immediately without waiting for others. 

---------

Signed-off-by: Hao Chen <chenh1024@gmail.com>
  • Loading branch information
raulchen committed Jan 26, 2024
1 parent f6da38f commit df3dd96
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 51 deletions.
122 changes: 71 additions & 51 deletions python/ray/data/_internal/execution/streaming_executor_state.py
Expand Up @@ -106,62 +106,94 @@ class DownstreamMemoryInfo:
object_store_memory: float


class RefBundleDeque(deque):
"""Thread-safe wrapper around collections.deque that stores current stats."""
class OpBufferQueue:
"""A FIFO queue to buffer RefBundles between upstream and downstream operators.
This class is thread-safe.
"""

def __init__(self):
self._memory_usage = 0
self._num_blocks = 0
self._queue = deque()
self._num_per_split = defaultdict(int)
self._lock = threading.Lock()
super().__init__()

@property
def memory_usage(self) -> int:
"""The total memory usage of the queue in bytes."""
with self._lock:
return self._memory_usage

@property
def num_blocks(self) -> int:
"""The total number of blocks in the queue."""
with self._lock:
return self._num_blocks

def append(self, ref: RefBundle):
with self._lock:
self._memory_usage += ref.size_bytes()
self._num_blocks += len(ref.blocks)
super().append(ref)
def __len__(self):
return len(self._queue)

def appendleft(self, ref: RefBundle):
with self._lock:
self._memory_usage += ref.size_bytes()
self._num_blocks += len(ref.blocks)
super().appendleft(ref)
def has_next(self, output_split_idx: Optional[int] = None) -> bool:
"""Whether next RefBundle is available.
def pop(self) -> RefBundle:
ref = super().pop()
with self._lock:
self._memory_usage -= ref.size_bytes()
self._num_blocks -= len(ref.blocks)
return ref
Args:
output_split_idx: If specified, only check ref bundles with the
given output split.
"""
if output_split_idx is None:
return len(self._queue) > 0
else:
with self._lock:
return self._num_per_split[output_split_idx] > 0

def popleft(self) -> RefBundle:
ref = super().popleft()
def append(self, ref: RefBundle):
"""Append a RefBundle to the queue."""
self._queue.append(ref)
with self._lock:
self._memory_usage -= ref.size_bytes()
self._num_blocks -= len(ref.blocks)
return ref

def remove(self, ref: RefBundle):
super().remove(ref)
self._memory_usage += ref.size_bytes()
self._num_blocks += len(ref.blocks)
if ref.output_split_idx is not None:
self._num_per_split[ref.output_split_idx] += 1

def pop(self, output_split_idx: Optional[int] = None) -> Optional[RefBundle]:
"""Pop a RefBundle from the queue.
Args:
output_split_idx: If specified, only pop a RefBundle
with the given output split.
Returns:
A RefBundle if available, otherwise None.
"""
ret = None
if output_split_idx is None:
try:
ret = self._queue.popleft()
except IndexError:
pass
else:
# TODO(hchen): Index the queue by output_split_idx to
# avoid linear scan.
for i in range(len(self._queue)):
ref = self._queue[i]
if ref.output_split_idx == output_split_idx:
ret = ref
del self._queue[i]
break
if ret is None:
return None
with self._lock:
self._memory_usage -= ref.size_bytes()
self._num_blocks -= len(ref.blocks)
self._memory_usage -= ret.size_bytes()
self._num_blocks -= len(ret.blocks)
if ret.output_split_idx is not None:
self._num_per_split[ret.output_split_idx] -= 1
return ret

def clear(self):
super().clear()
with self._lock:
self._queue.clear()
self._memory_usage = 0
self._num_blocks = 0
self._num_per_split.clear()


class OpState:
Expand All @@ -174,17 +206,17 @@ class OpState:
operator queues to be shared across threads.
"""

def __init__(self, op: PhysicalOperator, inqueues: List[RefBundleDeque]):
def __init__(self, op: PhysicalOperator, inqueues: List[OpBufferQueue]):
# Each inqueue is connected to another operator's outqueue.
assert len(inqueues) == len(op.input_dependencies), (op, inqueues)
self.inqueues: List[RefBundleDeque] = inqueues
self.inqueues: List[OpBufferQueue] = inqueues
# The outqueue is connected to another operator's inqueue (they physically
# share the same Python list reference).
#
# Note: this queue is also accessed concurrently from the consumer thread.
# (in addition to the streaming executor thread). Hence, it must be a
# thread-safe type such as `deque`.
self.outqueue: RefBundleDeque = RefBundleDeque()
self.outqueue: OpBufferQueue = OpBufferQueue()
self.op = op
self.progress_bar = None
self.num_completed_tasks = 0
Expand Down Expand Up @@ -266,8 +298,9 @@ def summary_str(self) -> str:
def dispatch_next_task(self) -> None:
"""Move a bundle from the operator inqueue to the operator itself."""
for i, inqueue in enumerate(self.inqueues):
if inqueue:
self.op.add_input(inqueue.popleft(), input_index=i)
ref = inqueue.pop()
if ref is not None:
self.op.add_input(ref, input_index=i)
return
assert False, "Nothing to dispatch"

Expand All @@ -285,24 +318,11 @@ def get_output_blocking(self, output_split_idx: Optional[int]) -> RefBundle:
# Check if StreamingExecutor has caught an exception or is done execution.
if self._exception is not None:
raise self._exception
elif self._finished and len(self.outqueue) == 0:
elif self._finished and not self.outqueue.has_next(output_split_idx):
raise StopIteration()
try:
# Non-split output case.
if output_split_idx is None:
return self.outqueue.popleft()

# Scan the queue and look for outputs tagged for the given index.
for i in range(len(self.outqueue)):
bundle = self.outqueue[i]
if bundle.output_split_idx == output_split_idx:
self.outqueue.remove(bundle)
return bundle

# Didn't find any outputs matching this index, repeat the loop until
# we find one or hit a None.
except IndexError:
pass
ref = self.outqueue.pop(output_split_idx)
if ref is not None:
return ref
time.sleep(0.01)

def inqueue_memory_usage(self) -> int:
Expand Down
65 changes: 65 additions & 0 deletions python/ray/data/tests/test_streaming_integration.py
Expand Up @@ -276,6 +276,71 @@ def consume(x, times):
)


def test_streaming_split_independent_finish(ray_start_10_cpus_shared):
"""Test that stream_split iterators can finish independently without
waiting for other iterators to finish. Otherwise, this would cause
deadlocks.
"""
num_blocks_per_split = 10
num_splits = 2
ds = ray.data.range(
num_splits * num_blocks_per_split,
parallelism=num_splits * num_blocks_per_split,
)
(
i1,
i2,
) = ds.streaming_split(num_splits, equal=True)

@ray.remote(max_concurrency=2)
class SignalActor:
def __init__(self):
self._event = threading.Event()

def wait(self):
self._event.wait()

def set(self):
self._event.set()

@ray.remote
class Consumer:
def consume(self, it, signal_actor, split_index):
for i, _ in enumerate(it.iter_batches(batch_size=None, prefetch_batches=0)):
if i == num_blocks_per_split // 2 and split_index == 0:
# The first consumer waits for the second consumer to
# finish first in the middle of the iteration.
print("before wait")
ray.get(signal_actor.wait.remote())
print("after wait")
if split_index == 1:
# The second consumer sends a signal to unblock the
# first consumer. It should finish the iteration independently.
# Otherwise, there will be a deadlock.
print("before set")
# Sleep some time to make sure the other
# consume calls wait first.
time.sleep(2)
ray.get(signal_actor.set.remote())
print("after set")
pass

signal_actor = SignalActor.remote()
consumer1 = Consumer.remote()
consumer2 = Consumer.remote()

ready, _ = ray.wait(
[
consumer1.consume.remote(i1, signal_actor, 0),
consumer2.consume.remote(i2, signal_actor, 1),
],
num_returns=2,
timeout=20,
)

assert len(ready) == 2


@pytest.mark.skip(
reason="Incomplete implementation of _validate_dag causes other errors, so we "
"remove DAG validation for now; see https://github.com/ray-project/ray/pull/37829"
Expand Down

0 comments on commit df3dd96

Please sign in to comment.