Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Streaming Generator] Fix memory leak from the end of object stream object #38152

Merged
merged 7 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3238,6 +3238,9 @@ cdef class CoreWorker:
logger.warning("Local object store memory usage:\n{}\n".format(
message.decode("utf-8")))

def get_memory_store_size(self):
return CCoreWorkerProcess.GetCoreWorker().GetMemoryStoreSize()

cdef python_label_match_expressions_to_c(
self, python_expressions,
CLabelMatchExpressions *c_expressions):
Expand Down
1 change: 1 addition & 0 deletions python/ray/includes/libcoreworker.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
int64_t item_index,
uint64_t attempt_number)
c_string MemoryUsageString()
int GetMemoryStoreSize()

CWorkerContext &GetWorkerContext()
void YieldCurrentFiber(CFiberEvent &coroutine_done)
Expand Down
36 changes: 36 additions & 0 deletions python/ray/tests/test_streaming_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def assert_no_leak():
for rc in ref_counts.values():
assert rc["local"] == 0
assert rc["submitted"] == 0
assert core_worker.get_memory_store_size() == 0


class MockedWorker:
Expand Down Expand Up @@ -1132,6 +1133,41 @@ async def main():
assert 4.5 < time.time() - s < 6.5


def test_no_memory_store_obj_leak(shutdown_only):
"""Fixes https://github.com/ray-project/ray/issues/38089

Verify there's no leak from in-memory object store when
using a streaming generator.
"""
ray.init()

@ray.remote
def f():
for _ in range(10):
yield 1

for _ in range(10):
for ref in f.options(num_returns="streaming").remote():
del ref

time.sleep(0.2)

core_worker = ray._private.worker.global_worker.core_worker
assert core_worker.get_memory_store_size() == 0
assert_no_leak()

for _ in range(10):
for ref in f.options(num_returns="streaming").remote():
break

time.sleep(0.2)

del ref
core_worker = ray._private.worker.global_worker.core_worker
assert core_worker.get_memory_store_size() == 0
assert_no_leak()


if __name__ == "__main__":
import os

Expand Down
2 changes: 2 additions & 0 deletions src/ray/core_worker/core_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,8 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
}
}

int GetMemoryStoreSize() { return memory_store_->Size(); }

/// Returns a map of all ObjectIDs currently in scope with a pair of their
/// (local, submitted_task) reference counts. For debugging purposes.
std::unordered_map<ObjectID, std::pair<size_t, size_t>> GetAllReferenceCounts() const;
Expand Down
1 change: 1 addition & 0 deletions src/ray/core_worker/reference_count.cc
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ void ReferenceCounter::DeleteReferenceInternal(ReferenceTable::iterator it,
it->second.on_ref_removed(id);
it->second.on_ref_removed = nullptr;
}

PRINT_REF_COUNT(it);

// Whether it is safe to unpin the value.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set<ObjectID> &object_i
absl::flat_hash_set<ObjectID> *plasma_ids_to_delete) {
absl::MutexLock lock(&mu_);
for (const auto &object_id : object_ids) {
RAY_LOG(DEBUG) << "Delete an object from a memory store. ObjectId: " << object_id;
auto it = objects_.find(object_id);
if (it != objects_.end()) {
if (it->second->IsInPlasmaError()) {
Expand All @@ -492,6 +493,7 @@ void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set<ObjectID> &object_i
void CoreWorkerMemoryStore::Delete(const std::vector<ObjectID> &object_ids) {
absl::MutexLock lock(&mu_);
for (const auto &object_id : object_ids) {
RAY_LOG(DEBUG) << "Delete an object from a memory store. ObjectId: " << object_id;
auto it = objects_.find(object_id);
if (it != objects_.end()) {
OnDelete(it->second);
Expand Down
22 changes: 14 additions & 8 deletions src/ray/core_worker/task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,29 @@ const int64_t kTaskFailureThrottlingThreshold = 50;
// Throttle task failure logs to once this interval.
const int64_t kTaskFailureLoggingFrequencyMillis = 5000;

std::vector<ObjectID> ObjectRefStream::GetItemsUnconsumed() const {
std::vector<ObjectID> result;
absl::flat_hash_set<ObjectID> ObjectRefStream::GetItemsUnconsumed() const {
absl::flat_hash_set<ObjectID> result;
for (int64_t index = 0; index <= max_index_seen_; index++) {
const auto &object_id = GetObjectRefAtIndex(index);
if (refs_written_to_stream_.find(object_id) == refs_written_to_stream_.end()) {
continue;
}

if (index >= next_index_) {
result.push_back(object_id);
result.emplace(object_id);
}
}

if (end_of_stream_index_ != -1) {
// End of stream index is never consumed by a caller
// so we should add it here.
result.push_back(GetObjectRefAtIndex(end_of_stream_index_));
const auto &object_id = GetObjectRefAtIndex(end_of_stream_index_);
result.emplace(object_id);
}

// Temporarily owned refs are not consumed.
for (const auto &object_id : temporarily_owned_refs_) {
result.push_back(object_id);
result.emplace(object_id);
}
return result;
}
Expand Down Expand Up @@ -428,7 +429,7 @@ bool TaskManager::HandleTaskReturn(const ObjectID &object_id,

void TaskManager::DelObjectRefStream(const ObjectID &generator_id) {
RAY_LOG(DEBUG) << "Deleting an object ref stream of an id " << generator_id;
std::vector<ObjectID> object_ids_unconsumed;
absl::flat_hash_set<ObjectID> object_ids_unconsumed;

{
absl::MutexLock lock(&mu_);
Expand All @@ -441,12 +442,17 @@ void TaskManager::DelObjectRefStream(const ObjectID &generator_id) {
object_ids_unconsumed = stream.GetItemsUnconsumed();
object_ref_streams_.erase(generator_id);
}

// When calling RemoveLocalReference, we shouldn't hold a lock.
for (const auto &object_id : object_ids_unconsumed) {
std::vector<ObjectID> deleted;
RAY_LOG(INFO) << "Removing unconsume streaming ref " << object_id;
RAY_LOG(DEBUG) << "Removing unconsume streaming ref " << object_id;
reference_counter_->RemoveLocalReference(object_id, &deleted);
// TODO(sang): This is required because the reference counter
// cannot remove objects from the in memory store.
// Instead of doing this manually here, we should modify
// reference_count.h to automatically remove objects
// when the ref goes to 0.
in_memory_store_->Delete(deleted);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/task_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class ObjectRefStream {
/// Get all the ObjectIDs that are not read yet via TryReadNextItem.
///
/// \return A list of object IDs that are not read yet.
std::vector<ObjectID> GetItemsUnconsumed() const;
absl::flat_hash_set<ObjectID> GetItemsUnconsumed() const;

private:
ObjectID GetObjectRefAtIndex(int64_t generator_index) const;
Expand Down
15 changes: 9 additions & 6 deletions src/ray/core_worker/test/task_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ TEST_F(TaskManagerTest, TestObjectRefStreamEndtoEnd) {

TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) {
/**
* Verify DEL cleans all references and ignore all future WRITE.
* Verify DEL cleans all references/objects and ignore all future WRITE.
*
* CREATE WRITE WRITE DEL (make sure no refs are leaked)
*/
Expand Down Expand Up @@ -1602,6 +1602,8 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) {

// NumObjectIDsInScope == Generator + 2 WRITE
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3);
// 2 in memory objects.
ASSERT_EQ(store_->Size(), 2);
std::vector<std::shared_ptr<RayObject>> results;
WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
RAY_CHECK_OK(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results));
Expand All @@ -1614,11 +1616,8 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) {
// DELETE. This should clean all references except generator id.
manager_.DelObjectRefStream(generator_id);
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
// Unfortunately, when the obj ref goes out of scope,
// this is called from the language frontend. We mimic this behavior
// by manually calling these APIs.
store_->Delete({dynamic_return_id});
store_->Delete({dynamic_return_id2});
// All the in memory objects should be cleaned up.
ASSERT_EQ(store_->Size(), 0);
ASSERT_TRUE(store_->Get({dynamic_return_id}, 1, 1, ctx, false, &results).IsTimedOut());
results.clear();
ASSERT_TRUE(store_->Get({dynamic_return_id2}, 1, 1, ctx, false, &results).IsTimedOut());
Expand All @@ -1640,6 +1639,8 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelCleanReferences) {
ASSERT_FALSE(manager_.HandleReportGeneratorItemReturns(req));
// The write should have been no op. No refs and no obj values except the generator id.
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
// All the in memory objects should be cleaned up.
ASSERT_EQ(store_->Size(), 0);
ASSERT_TRUE(store_->Get({dynamic_return_id3}, 1, 1, ctx, false, &results).IsTimedOut());
results.clear();

Expand Down Expand Up @@ -1741,6 +1742,8 @@ TEST_F(TaskManagerTest, TestObjectRefStreamDelOutOfOrder) {

// There must be only a generator ID.
ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 1);
// All the objects should be cleaned up.
ASSERT_EQ(store_->Size(), 0);
CompletePendingStreamingTask(spec, caller_address, 0);
}

Expand Down
Loading