Skip to content

Commit

Permalink
Spill in pipeline model not driven by input block (#8115)
Browse files Browse the repository at this point in the history
ref #7738
  • Loading branch information
windtalker committed Sep 20, 2023
1 parent d8fa1d1 commit 45f9bb1
Show file tree
Hide file tree
Showing 9 changed files with 42 additions and 6 deletions.
17 changes: 14 additions & 3 deletions dbms/src/Interpreters/Join.cpp
Expand Up @@ -502,7 +502,18 @@ void Join::flushProbeSideMarkedSpillData(size_t stream_index)
data.clear();
}

void Join::checkAndMarkPartitionSpilledIfNeeded(
void Join::checkAndMarkPartitionSpilledIfNeeded(size_t stream_index)
{
/// todo need to check more partitions if partition_size is not equal to total stream size
size_t partition_index = stream_index;
const auto & join_partition = partitions[partition_index];
auto partition_lock = join_partition->tryLockPartition();
if (partition_lock)
checkAndMarkPartitionSpilledIfNeededInternal(*join_partition, partition_lock, partition_index, stream_index);
/// if someone already hold the lock, it will check the spill
}

void Join::checkAndMarkPartitionSpilledIfNeededInternal(
JoinPartition & join_partition,
std::unique_lock<std::mutex> & partition_lock,
size_t partition_index,
Expand Down Expand Up @@ -576,7 +587,7 @@ void Join::insertFromBlock(const Block & block, size_t stream_index)
auto partition_lock = join_partition->lockPartition();
join_partition->insertBlockForBuild(std::move(dispatch_blocks[i]));
/// to release memory before insert if already marked spill
checkAndMarkPartitionSpilledIfNeeded(*join_partition, partition_lock, i, stream_index);
checkAndMarkPartitionSpilledIfNeededInternal(*join_partition, partition_lock, i, stream_index);
if (!hash_join_spill_context->isPartitionSpilled(i))
{
bool meet_resize_exception = false;
Expand All @@ -591,7 +602,7 @@ void Join::insertFromBlock(const Block & block, size_t stream_index)
LOG_DEBUG(log, "Meet resize exception when insert into partition {}", i);
}
/// double check here to release memory
checkAndMarkPartitionSpilledIfNeeded(*join_partition, partition_lock, i, stream_index);
checkAndMarkPartitionSpilledIfNeededInternal(*join_partition, partition_lock, i, stream_index);
if (meet_resize_exception)
RUNTIME_CHECK_MSG(
hash_join_spill_context->isPartitionSpilled(i),
Expand Down
4 changes: 3 additions & 1 deletion dbms/src/Interpreters/Join.h
Expand Up @@ -234,7 +234,9 @@ class Join
/// The peak build bytes usage, if spill is not enabled, the same as getTotalByteCount
size_t getPeakBuildBytesUsage();

void checkAndMarkPartitionSpilledIfNeeded(
void checkAndMarkPartitionSpilledIfNeeded(size_t stream_index);

void checkAndMarkPartitionSpilledIfNeededInternal(
JoinPartition & join_partition,
std::unique_lock<std::mutex> & partition_lock,
size_t partition_index,
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Operators/AggregateBuildSinkOp.cpp
Expand Up @@ -25,7 +25,7 @@ OperatorStatus AggregateBuildSinkOp::prepareImpl()
if (agg_context->needSpill(index))
return OperatorStatus::IO_OUT;
}
return OperatorStatus::NEED_INPUT;
return agg_context->isTaskMarkedForSpill(index) ? OperatorStatus::IO_OUT : OperatorStatus::NEED_INPUT;
}

OperatorStatus AggregateBuildSinkOp::writeImpl(Block && block)
Expand Down
11 changes: 11 additions & 0 deletions dbms/src/Operators/AggregateContext.cpp
Expand Up @@ -54,6 +54,17 @@ void AggregateContext::buildOnLocalData(size_t task_index)
}
}

bool AggregateContext::isTaskMarkedForSpill(size_t task_index)
{
if (needSpill(task_index))
return true;
if (getAggSpillContext()->updatePerThreadRevocableMemory(many_data[task_index]->revocableBytes(), task_index))
{
return many_data[task_index]->tryMarkNeedSpill();
}
return false;
}

bool AggregateContext::hasLocalDataToBuild(size_t task_index)
{
return !threads_data[task_index]->agg_process_info.allBlockDataHandled();
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Operators/AggregateContext.h
Expand Up @@ -78,6 +78,8 @@ class AggregateContext

void buildOnLocalData(size_t task_index);

bool isTaskMarkedForSpill(size_t task_index);

size_t getTotalBuildRows(size_t task_index) { return threads_data[task_index]->src_rows; }

private:
Expand Down
6 changes: 6 additions & 0 deletions dbms/src/Operators/HashJoinBuildSink.cpp
Expand Up @@ -35,6 +35,12 @@ OperatorStatus HashJoinBuildSink::writeImpl(Block && block)
return join_ptr->hasBuildSideMarkedSpillData(op_index) ? OperatorStatus::IO_OUT : OperatorStatus::NEED_INPUT;
}

OperatorStatus HashJoinBuildSink::prepareImpl()
{
join_ptr->checkAndMarkPartitionSpilledIfNeeded(op_index);
return join_ptr->hasBuildSideMarkedSpillData(op_index) ? OperatorStatus::IO_OUT : OperatorStatus::NEED_INPUT;
}

OperatorStatus HashJoinBuildSink::executeIOImpl()
{
join_ptr->flushBuildSideMarkedSpillData(op_index);
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Operators/HashJoinBuildSink.h
Expand Up @@ -39,6 +39,8 @@ class HashJoinBuildSink : public SinkOp
protected:
OperatorStatus writeImpl(Block && block) override;

OperatorStatus prepareImpl() override;

OperatorStatus executeIOImpl() override;

private:
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/Operators/LocalAggregateTransform.cpp
Expand Up @@ -117,7 +117,7 @@ OperatorStatus LocalAggregateTransform::tryOutputImpl(Block & block)
if (tryFromBuildToSpill() == OperatorStatus::IO_OUT)
return OperatorStatus::IO_OUT;
}
return OperatorStatus::NEED_INPUT;
return agg_context.isTaskMarkedForSpill(task_index) ? tryFromBuildToSpill() : OperatorStatus::NEED_INPUT;
case LocalAggStatus::convergent:
block = agg_context.readForConvergent(task_index);
return OperatorStatus::HAS_OUTPUT;
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Operators/MergeSortTransformOp.cpp
Expand Up @@ -167,6 +167,8 @@ OperatorStatus MergeSortTransformOp::tryOutputImpl(Block & block)
switch (status)
{
case MergeSortStatus::PARTIAL:
if (sort_spill_context->updateRevocableMemory(sum_bytes_in_blocks))
return fromPartialToSpill();
return OperatorStatus::NEED_INPUT;
case MergeSortStatus::SPILL:
{
Expand Down

0 comments on commit 45f9bb1

Please sign in to comment.