Skip to content

Commit

Permalink
Fix the bug of eagerly clearing up input blocks (#31459)
Browse files Browse the repository at this point in the history
This PR is fixing the issue found in #31286. Previously we always eagerly clears up non-lazy input blocks (plan._in_blocks) when executing the plan. This is not safe as the input blocks might be used by downstream operations later.

Signed-off-by: Cheng Su <scnju13@gmail.com>
Co-authored-by: Clark Zinzow <clark@anyscale.com>
  • Loading branch information
2 people authored and AmeerHajAli committed Jan 12, 2023
1 parent 8ca829a commit 9d663a4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
5 changes: 0 additions & 5 deletions python/ray/data/_internal/plan.py
Expand Up @@ -452,11 +452,6 @@ def _get_source_blocks_and_stages(
# beginning.
blocks = self._in_blocks
stats = self._in_stats
if not self.has_lazy_input():
# If not a lazy datasource, unlink the input blocks from the plan so we
# can eagerly reclaim the input block memory after the first stage is
# done executing.
self._in_blocks = None
return blocks, stats, stages

def has_lazy_input(self) -> bool:
Expand Down
16 changes: 16 additions & 0 deletions python/ray/data/tests/test_optimize.py
Expand Up @@ -254,6 +254,22 @@ def inc(x):
# Test that first map is executed twice.
assert ray.get(map_counter.get.remote()) == 2 * 10 + 10 + 10

ray.get(map_counter.reset.remote())
# The source data shouldn't be cleared since it's non-lazy.
ds = ray.data.from_items(list(range(10)))
# Add extra transformation before being lazy.
ds = ds.map(inc)
ds = ds.lazy()
ds1 = ds.map(inc)
ds2 = ds.map(inc)
# Test content.
assert ds1.fully_executed().take() == list(range(2, 12))
assert ds2.fully_executed().take() == list(range(2, 12))
# Test that first map is executed twice, because ds1.fully_executed()
# clears up the previous snapshot blocks, and ds2.fully_executed()
# has to re-execute ds.map(inc) again.
assert ray.get(map_counter.get.remote()) == 2 * 10 + 10 + 10


def test_spread_hint_inherit(ray_start_regular_shared):
ds = ray.data.range(10).lazy()
Expand Down

0 comments on commit 9d663a4

Please sign in to comment.