Skip to content

Commit

Permalink
support agg resize callback
Browse files Browse the repository at this point in the history
Signed-off-by: xufei <xufeixw@mail.ustc.edu.cn>
  • Loading branch information
windtalker committed Sep 13, 2023
1 parent e09ec61 commit b415837
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 29 deletions.
6 changes: 6 additions & 0 deletions dbms/src/Common/HashTable/TwoLevelHashTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class TwoLevelHashTable : private boost::noncopyable
/// NOTE Bad for hash tables with more than 2^32 cells.
static size_t getBucketFromHash(size_t hash_value) { return (hash_value >> (32 - BITS_FOR_BUCKET)) & MAX_BUCKET; }

void setResizeCallback(const ResizeCallback & resize_callback)
{
for (auto & impl : impls)
impl.setResizeCallback(resize_callback);
}

protected:
typename Impl::iterator beginOfNextNonEmptyBucket(size_t & bucket)
{
Expand Down
6 changes: 6 additions & 0 deletions dbms/src/Common/HashTable/TwoLevelStringHashTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class TwoLevelStringHashTable : private boost::noncopyable
});
}

void setResizeCallback(const ResizeCallback & resize_callback)
{
for (auto & impl : impls)
impl.setResizeCallback(resize_callback);
}

size_t operator()(const Key & x) const { return hash(x); }

/// NOTE Bad for hash tables with more than 2^32 cells.
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Core/CachedSpillHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ bool CachedSpillHandler::batchRead()
{
if unlikely (is_cancelled())
return false;
if unlikely (block.rows() == 0)
continue;
ret.push_back(std::move(block));
current_return_size += ret.back().estimateBytesForSpill();
if (bytes_threshold > 0 && current_return_size >= bytes_threshold)
Expand Down
124 changes: 95 additions & 29 deletions dbms/src/Interpreters/Aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,33 @@ size_t AggregatedDataVariants::getBucketNumberForTwoLevelHashTable(Type type)
}
}

void AggregatedDataVariants::setResizeCallbackIfNeeded(size_t thread_num)
{
if (aggregator)
{
auto agg_spill_context = aggregator->agg_spill_context;
if (agg_spill_context->isSpillEnabled() && agg_spill_context->isInAutoSpillMode())
{
auto resize_callback = [agg_spill_context, thread_num]() {
return !agg_spill_context->isThreadMarkedForAutoSpill(thread_num);
};
#define M(NAME) \
case AggregationMethodType(NAME): \
{ \
ToAggregationMethodPtr(NAME, aggregation_method_impl)->data.setResizeCallback(resize_callback); \
break; \
}
switch (type)
{
APPLY_FOR_VARIANTS_TWO_LEVEL(M)
default:
throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT);
}
#undef M
}
}
}

void AggregatedDataVariants::convertToTwoLevel()
{
switch (type)
Expand Down Expand Up @@ -645,9 +672,21 @@ ALWAYS_INLINE void Aggregator::executeImplBatch(
{
/// For all rows.
AggregateDataPtr place = aggregates_pool->alloc(0);
for (size_t i = agg_process_info.start_row; i < agg_process_info.start_row + agg_size; ++i)
state.emplaceKey(method.data, i, *aggregates_pool, sort_key_containers).setMapped(place);
agg_process_info.start_row += agg_size;
size_t processed_rows = std::numeric_limits<size_t>::max();
try
{
for (size_t i = agg_process_info.start_row; i < agg_process_info.start_row + agg_size; ++i)
{
state.emplaceKey(method.data, i, *aggregates_pool, sort_key_containers).setMapped(place);
processed_rows = i;
}
}
catch (ResizeException &)
{
LOG_INFO(log, "HashTable resize throw ResizeException since the data is already marked for spill");
}
if (processed_rows != std::numeric_limits<size_t>::max())
agg_process_info.start_row = processed_rows + 1;
return;
}

Expand All @@ -657,6 +696,7 @@ ALWAYS_INLINE void Aggregator::executeImplBatch(
for (AggregateFunctionInstruction * inst = agg_process_info.aggregate_functions_instructions.data(); inst->that;
++inst)
{
/// no resize will happen for this kind of hash table, so don't catch resize exception
inst->batch_that->addBatchLookupTable8(
agg_process_info.start_row,
agg_size,
Expand All @@ -678,43 +718,65 @@ ALWAYS_INLINE void Aggregator::executeImplBatch(
/// Generic case.

std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[agg_size]);
size_t processed_rows = std::numeric_limits<size_t>::max();
bool allow_exception = false;

for (size_t i = agg_process_info.start_row; i < agg_process_info.start_row + agg_size; ++i)
try
{
AggregateDataPtr aggregate_data = nullptr;
for (size_t i = agg_process_info.start_row; i < agg_process_info.start_row + agg_size; ++i)
{
AggregateDataPtr aggregate_data = nullptr;

auto emplace_result = state.emplaceKey(method.data, i, *aggregates_pool, sort_key_containers);
allow_exception = true;

/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (emplace_result.isInserted())
{
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
emplace_result.setMapped(nullptr);
auto emplace_result = state.emplaceKey(method.data, i, *aggregates_pool, sort_key_containers);

aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
createAggregateStates(aggregate_data);
allow_exception = false;

emplace_result.setMapped(aggregate_data);
}
else
aggregate_data = emplace_result.getMapped();
/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (emplace_result.isInserted())
{
/// exception-safety - if you can not allocate memory or create states, then destructors will not be called.
emplace_result.setMapped(nullptr);

aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states);
createAggregateStates(aggregate_data);

emplace_result.setMapped(aggregate_data);
}
else
aggregate_data = emplace_result.getMapped();

places[i - agg_process_info.start_row] = aggregate_data;
places[i - agg_process_info.start_row] = aggregate_data;
processed_rows = i;
}
}
catch (ResizeException &)
{
LOG_INFO(
log,
"HashTable resize throw ResizeException since the data is already marked for spill, allow_exception: {}",
allow_exception);
if unlikely (!allow_exception)
throw;
}

/// Add values to the aggregate functions.
for (AggregateFunctionInstruction * inst = agg_process_info.aggregate_functions_instructions.data(); inst->that;
++inst)
if (processed_rows != std::numeric_limits<size_t>::max())
{
inst->batch_that->addBatch(
agg_process_info.start_row,
agg_size,
places.get(),
inst->state_offset,
inst->batch_arguments,
aggregates_pool);
/// Add values to the aggregate functions.
for (AggregateFunctionInstruction * inst = agg_process_info.aggregate_functions_instructions.data(); inst->that;
++inst)
{
inst->batch_that->addBatch(
agg_process_info.start_row,
processed_rows - agg_process_info.start_row + 1,
places.get(),
inst->state_offset,
inst->batch_arguments,
aggregates_pool);
}
agg_process_info.start_row = processed_rows + 1;
}
agg_process_info.start_row += agg_size;
}

void NO_INLINE
Expand Down Expand Up @@ -896,7 +958,10 @@ bool Aggregator::executeOnBlock(AggProcessInfo & agg_process_info, AggregatedDat
* It allows you to make, in the subsequent, an effective merge - either economical from memory or parallel.
*/
if (result.isConvertibleToTwoLevel() && worth_convert_to_two_level)
{
result.convertToTwoLevel();
result.setResizeCallbackIfNeeded(thread_num);
}

/** Flush data to disk if too much RAM is consumed.
*/
Expand Down Expand Up @@ -953,6 +1018,7 @@ void Aggregator::spill(AggregatedDataVariants & data_variants, size_t thread_num

/// NOTE Instead of freeing up memory and creating new hash tables and arenas, you can re-use the old ones.
data_variants.init(data_variants.type);
data_variants.setResizeCallbackIfNeeded(thread_num);
data_variants.need_spill = false;
data_variants.aggregates_pools = Arenas(1, std::make_shared<Arena>());
data_variants.aggregates_pool = data_variants.aggregates_pools.back().get();
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Interpreters/Aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,8 @@ struct AggregatedDataVariants : private boost::noncopyable

void convertToTwoLevel();

void setResizeCallbackIfNeeded(size_t thread_num);

#define APPLY_FOR_VARIANTS_TWO_LEVEL(M) \
M(key32_two_level) \
M(key64_two_level) \
Expand Down

0 comments on commit b415837

Please sign in to comment.