Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support agg resize callback #8078

Merged
merged 6 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions dbms/src/Common/HashTable/TwoLevelHashTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class TwoLevelHashTable : private boost::noncopyable
/// NOTE Bad for hash tables with more than 2^32 cells.
static size_t getBucketFromHash(size_t hash_value) { return (hash_value >> (32 - BITS_FOR_BUCKET)) & MAX_BUCKET; }

void setResizeCallback(const ResizeCallback & resize_callback)
{
for (auto & impl : impls)
impl.setResizeCallback(resize_callback);
}

protected:
typename Impl::iterator beginOfNextNonEmptyBucket(size_t & bucket)
{
Expand Down
6 changes: 6 additions & 0 deletions dbms/src/Common/HashTable/TwoLevelStringHashTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class TwoLevelStringHashTable : private boost::noncopyable
});
}

void setResizeCallback(const ResizeCallback & resize_callback)
{
for (auto & impl : impls)
impl.setResizeCallback(resize_callback);
}

size_t operator()(const Key & x) const { return hash(x); }

/// NOTE Bad for hash tables with more than 2^32 cells.
Expand Down
2 changes: 2 additions & 0 deletions dbms/src/Core/CachedSpillHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ bool CachedSpillHandler::batchRead()
{
if unlikely (is_cancelled())
return false;
if unlikely (block.rows() == 0)
continue;
ret.push_back(std::move(block));
current_return_size += ret.back().estimateBytesForSpill();
if (bytes_threshold > 0 && current_return_size >= bytes_threshold)
Expand Down
105 changes: 90 additions & 15 deletions dbms/src/Interpreters/Aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,35 @@ size_t AggregatedDataVariants::getBucketNumberForTwoLevelHashTable(Type type)
}
}

void AggregatedDataVariants::setResizeCallbackIfNeeded(size_t thread_num) const
{
if (aggregator)
{
auto agg_spill_context = aggregator->agg_spill_context;
if (agg_spill_context->isSpillEnabled() && agg_spill_context->isInAutoSpillMode())
{
auto resize_callback = [agg_spill_context, thread_num]() {
return !(
agg_spill_context->supportFurtherSpill()
&& agg_spill_context->isThreadMarkedForAutoSpill(thread_num));
};
#define M(NAME) \
case AggregationMethodType(NAME): \
{ \
ToAggregationMethodPtr(NAME, aggregation_method_impl)->data.setResizeCallback(resize_callback); \
break; \
}
switch (type)
{
APPLY_FOR_VARIANTS_TWO_LEVEL(M)
default:
throw Exception("Unknown aggregated data variant.", ErrorCodes::UNKNOWN_AGGREGATED_DATA_VARIANT);
}
#undef M
}
}
}

void AggregatedDataVariants::convertToTwoLevel()
{
switch (type)
Expand Down Expand Up @@ -625,6 +654,24 @@ void NO_INLINE Aggregator::executeImpl(
executeImplBatch(method, state, aggregates_pool, agg_process_info);
}

template <typename Method>
std::optional<typename Method::EmplaceResult> Aggregator::emplaceKey(
Method & method,
typename Method::State & state,
size_t index,
Arena & aggregates_pool,
std::vector<std::string> & sort_key_containers) const
{
try
{
return state.emplaceKey(method.data, index, aggregates_pool, sort_key_containers);
}
catch (ResizeException &)
{
return {};
}
}

template <typename Method>
ALWAYS_INLINE void Aggregator::executeImplBatch(
Method & method,
Expand All @@ -645,9 +692,21 @@ ALWAYS_INLINE void Aggregator::executeImplBatch(
{
/// For all rows.
AggregateDataPtr place = aggregates_pool->alloc(0);
for (size_t i = agg_process_info.start_row; i < agg_process_info.start_row + agg_size; ++i)
state.emplaceKey(method.data, i, *aggregates_pool, sort_key_containers).setMapped(place);
agg_process_info.start_row += agg_size;
for (size_t i = 0; i < agg_size; ++i)
{
auto emplace_result_hold
= emplaceKey(method, state, agg_process_info.start_row, *aggregates_pool, sort_key_containers);
if likely (emplace_result_hold.has_value())
{
emplace_result_hold.value().setMapped(place);
++agg_process_info.start_row;
}
else
{
LOG_INFO(log, "HashTable resize throw ResizeException since the data is already marked for spill");
break;
}
}
return;
}

Expand Down Expand Up @@ -678,12 +737,20 @@ ALWAYS_INLINE void Aggregator::executeImplBatch(
/// Generic case.

std::unique_ptr<AggregateDataPtr[]> places(new AggregateDataPtr[agg_size]);
std::optional<size_t> processed_rows;

for (size_t i = agg_process_info.start_row; i < agg_process_info.start_row + agg_size; ++i)
{
AggregateDataPtr aggregate_data = nullptr;

auto emplace_result = state.emplaceKey(method.data, i, *aggregates_pool, sort_key_containers);
auto emplace_result_holder = emplaceKey(method, state, i, *aggregates_pool, sort_key_containers);
if unlikely (!emplace_result_holder.has_value())
{
LOG_INFO(log, "HashTable resize throw ResizeException since the data is already marked for spill");
break;
}

auto & emplace_result = emplace_result_holder.value();

/// If a new key is inserted, initialize the states of the aggregate functions, and possibly something related to the key.
if (emplace_result.isInserted())
Expand All @@ -700,21 +767,25 @@ ALWAYS_INLINE void Aggregator::executeImplBatch(
aggregate_data = emplace_result.getMapped();

places[i - agg_process_info.start_row] = aggregate_data;
processed_rows = i;
}

/// Add values to the aggregate functions.
for (AggregateFunctionInstruction * inst = agg_process_info.aggregate_functions_instructions.data(); inst->that;
++inst)
if (processed_rows)
{
inst->batch_that->addBatch(
agg_process_info.start_row,
agg_size,
places.get(),
inst->state_offset,
inst->batch_arguments,
aggregates_pool);
/// Add values to the aggregate functions.
for (AggregateFunctionInstruction * inst = agg_process_info.aggregate_functions_instructions.data(); inst->that;
++inst)
{
inst->batch_that->addBatch(
agg_process_info.start_row,
*processed_rows - agg_process_info.start_row + 1,
places.get(),
inst->state_offset,
inst->batch_arguments,
aggregates_pool);
}
agg_process_info.start_row = *processed_rows + 1;
}
agg_process_info.start_row += agg_size;
}

void NO_INLINE
Expand Down Expand Up @@ -896,7 +967,10 @@ bool Aggregator::executeOnBlock(AggProcessInfo & agg_process_info, AggregatedDat
* It allows you to make, in the subsequent, an effective merge - either economical from memory or parallel.
*/
if (result.isConvertibleToTwoLevel() && worth_convert_to_two_level)
{
result.convertToTwoLevel();
result.setResizeCallbackIfNeeded(thread_num);
}

/** Flush data to disk if too much RAM is consumed.
*/
Expand Down Expand Up @@ -953,6 +1027,7 @@ void Aggregator::spill(AggregatedDataVariants & data_variants, size_t thread_num

/// NOTE Instead of freeing up memory and creating new hash tables and arenas, you can re-use the old ones.
data_variants.init(data_variants.type);
data_variants.setResizeCallbackIfNeeded(thread_num);
data_variants.need_spill = false;
data_variants.aggregates_pools = Arenas(1, std::make_shared<Arena>());
data_variants.aggregates_pool = data_variants.aggregates_pools.back().get();
Expand Down
20 changes: 20 additions & 0 deletions dbms/src/Interpreters/Aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ struct AggregationMethodOneNumber
/// To use one `Method` in different threads, use different `State`.
using State = ColumnsHashing::
HashMethodOneNumber<typename Data::value_type, Mapped, FieldType, consecutive_keys_optimization>;
using EmplaceResult = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl<Mapped>;

/// Shuffle key columns before `insertKeyIntoColumns` call if needed.
std::optional<Sizes> shuffleKeyColumns(std::vector<IColumn *> &, const Sizes &) { return {}; }
Expand Down Expand Up @@ -166,6 +167,7 @@ struct AggregationMethodString
{}

using State = ColumnsHashing::HashMethodString<typename Data::value_type, Mapped>;
using EmplaceResult = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl<Mapped>;

std::optional<Sizes> shuffleKeyColumns(std::vector<IColumn *> &, const Sizes &) { return {}; }

Expand Down Expand Up @@ -198,6 +200,7 @@ struct AggregationMethodStringNoCache

// Remove last zero byte.
using State = ColumnsHashing::HashMethodString<typename Data::value_type, Mapped, true, false>;
using EmplaceResult = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl<Mapped>;

std::optional<Sizes> shuffleKeyColumns(std::vector<IColumn *> &, const Sizes &) { return {}; }

Expand Down Expand Up @@ -229,6 +232,7 @@ struct AggregationMethodOneKeyStringNoCache
{}

using State = ColumnsHashing::HashMethodStringBin<typename Data::value_type, Mapped, bin_padding>;
using EmplaceResult = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl<Mapped>;

std::optional<Sizes> shuffleKeyColumns(std::vector<IColumn *> &, const Sizes &) { return {}; }

Expand Down Expand Up @@ -262,6 +266,7 @@ struct AggregationMethodMultiStringNoCache
{}

using State = ColumnsHashing::HashMethodMultiString<typename Data::value_type, Mapped>;
using EmplaceResult = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl<Mapped>;

std::optional<Sizes> shuffleKeyColumns(std::vector<IColumn *> &, const Sizes &) { return {}; }

Expand Down Expand Up @@ -292,6 +297,7 @@ struct AggregationMethodFastPathTwoKeysNoCache

using State
= ColumnsHashing::HashMethodFastPathTwoKeysSerialized<Key1Desc, Key2Desc, typename Data::value_type, Mapped>;
using EmplaceResult = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl<Mapped>;

std::optional<Sizes> shuffleKeyColumns(std::vector<IColumn *> &, const Sizes &) { return {}; }

Expand Down Expand Up @@ -386,6 +392,7 @@ struct AggregationMethodFixedString
{}

using State = ColumnsHashing::HashMethodFixedString<typename Data::value_type, Mapped>;
using EmplaceResult = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl<Mapped>;

std::optional<Sizes> shuffleKeyColumns(std::vector<IColumn *> &, const Sizes &) { return {}; }

Expand Down Expand Up @@ -417,6 +424,7 @@ struct AggregationMethodFixedStringNoCache
{}

using State = ColumnsHashing::HashMethodFixedString<typename Data::value_type, Mapped, true, false>;
using EmplaceResult = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl<Mapped>;

std::optional<Sizes> shuffleKeyColumns(std::vector<IColumn *> &, const Sizes &) { return {}; }

Expand Down Expand Up @@ -451,6 +459,7 @@ struct AggregationMethodKeysFixed

using State
= ColumnsHashing::HashMethodKeysFixed<typename Data::value_type, Key, Mapped, has_nullable_keys, use_cache>;
using EmplaceResult = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl<Mapped>;

std::optional<Sizes> shuffleKeyColumns(std::vector<IColumn *> & key_columns, const Sizes & key_sizes)
{
Expand Down Expand Up @@ -538,6 +547,7 @@ struct AggregationMethodSerialized
{}

using State = ColumnsHashing::HashMethodSerialized<typename Data::value_type, Mapped>;
using EmplaceResult = ColumnsHashing::columns_hashing_impl::EmplaceResultImpl<Mapped>;

std::optional<Sizes> shuffleKeyColumns(std::vector<IColumn *> &, const Sizes &) { return {}; }

Expand Down Expand Up @@ -938,6 +948,8 @@ struct AggregatedDataVariants : private boost::noncopyable

void convertToTwoLevel();

void setResizeCallbackIfNeeded(size_t thread_num) const;

#define APPLY_FOR_VARIANTS_TWO_LEVEL(M) \
M(key32_two_level) \
M(key64_two_level) \
Expand Down Expand Up @@ -1266,6 +1278,14 @@ class Aggregator
Arena * aggregates_pool,
AggProcessInfo & agg_process_info) const;

template <typename Method>
std::optional<typename Method::EmplaceResult> emplaceKey(
Method & method,
typename Method::State & state,
size_t index,
Arena & aggregates_pool,
std::vector<std::string> & sort_key_containers) const;

/// For case when there are no keys (all aggregate into one row).
static void executeWithoutKeyImpl(AggregatedDataWithoutKey & res, AggProcessInfo & agg_process_info, Arena * arena);

Expand Down
4 changes: 3 additions & 1 deletion dbms/src/Interpreters/JoinPartition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ void JoinPartition::setResizeCallbackIfNeeded()
if (hash_join_spill_context->isSpillEnabled() && hash_join_spill_context->isInAutoSpillMode())
{
auto resize_callback = [this]() {
return !hash_join_spill_context->isPartitionMarkedForAutoSpill(partition_index);
return !(
hash_join_spill_context->supportFurtherSpill()
&& hash_join_spill_context->isPartitionMarkedForAutoSpill(partition_index));
};
assert(pool != nullptr);
pool->setResizeCallback(resize_callback);
Expand Down