diff --git a/dbms/src/Core/AutoSpillTrigger.h b/dbms/src/Core/AutoSpillTrigger.h new file mode 100644 index 00000000000..aba6d0f048f --- /dev/null +++ b/dbms/src/Core/AutoSpillTrigger.h @@ -0,0 +1,70 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace DB +{ +class AutoSpillTrigger +{ +public: + AutoSpillTrigger( + const MemoryTrackerPtr & memory_tracker_, + const std::shared_ptr & query_operator_spill_contexts_, + float auto_memory_revoke_trigger_threshold, + float auto_memory_revoke_target_threshold) + : memory_tracker(memory_tracker_) + , query_operator_spill_contexts(query_operator_spill_contexts_) + { + RUNTIME_CHECK_MSG(memory_tracker->getLimit() > 0, "Memory limit must be set for auto spill trigger"); + RUNTIME_CHECK_MSG( + auto_memory_revoke_target_threshold >= 0 && auto_memory_revoke_trigger_threshold > 0, + "Invalid auto trigger threshold {} or invalid auto target threshold {}", + auto_memory_revoke_trigger_threshold, + auto_memory_revoke_target_threshold); + if (unlikely(auto_memory_revoke_trigger_threshold < auto_memory_revoke_target_threshold)) + { + LOG_WARNING( + query_operator_spill_contexts->getLogger(), + "Auto trigger threshold {} less than auto target threshold {}, not valid, use default value instead", + auto_memory_revoke_trigger_threshold, + auto_memory_revoke_target_threshold); + /// invalid value, set the value to default value + auto_memory_revoke_trigger_threshold = 0.7; + auto_memory_revoke_target_threshold = 0.5; + } + trigger_threshold = static_cast(memory_tracker->getLimit() * auto_memory_revoke_trigger_threshold); + target_threshold = static_cast(memory_tracker->getLimit() * auto_memory_revoke_target_threshold); + } + + void triggerAutoSpill() + { + auto current_memory_usage = memory_tracker->get(); + if (current_memory_usage > trigger_threshold) + { + query_operator_spill_contexts->triggerAutoSpill(current_memory_usage - target_threshold); + } + } + +private: + MemoryTrackerPtr memory_tracker; + std::shared_ptr query_operator_spill_contexts; + Int64 trigger_threshold; + Int64 target_threshold; +}; +} // namespace DB diff --git a/dbms/src/Core/OperatorSpillContext.h b/dbms/src/Core/OperatorSpillContext.h index abd4a660c92..c0ced5b8d40 100644 --- a/dbms/src/Core/OperatorSpillContext.h +++ b/dbms/src/Core/OperatorSpillContext.h @@ -20,10 +20,12 @@ namespace DB { -enum class SpillStatus +enum class AutoSpillStatus { - NOT_SPILL, - SPILL, + /// auto spill is not needed or current auto spill already finished + NO_NEED_AUTO_SPILL, + /// auto spill is needed + NEED_AUTO_SPILL, }; class OperatorSpillContext @@ -31,7 +33,7 @@ class OperatorSpillContext protected: UInt64 operator_spill_threshold; std::atomic in_spillable_stage{true}; - std::atomic spill_status{SpillStatus::NOT_SPILL}; + std::atomic is_spilled{false}; bool enable_spill = true; String op_name; LoggerPtr log; @@ -39,15 +41,18 @@ class OperatorSpillContext virtual Int64 getTotalRevocableMemoryImpl() = 0; public: + /// minimum revocable operator memories that will trigger a spill + const static Int64 MIN_SPILL_THRESHOLD = 10ULL * 1024 * 1024; OperatorSpillContext(UInt64 operator_spill_threshold_, const String op_name_, const LoggerPtr & log_) : operator_spill_threshold(operator_spill_threshold_) , op_name(op_name_) , log(log_) {} virtual ~OperatorSpillContext() = default; - bool isSpillEnabled() const { return enable_spill && operator_spill_threshold > 0; } + bool isSpillEnabled() const { return enable_spill && (supportAutoTriggerSpill() || operator_spill_threshold > 0); } void disableSpill() { enable_spill = false; } void finishSpillableStage() { in_spillable_stage = false; } + bool spillableStageFinished() const { return !in_spillable_stage; } Int64 getTotalRevocableMemory() { if (in_spillable_stage) @@ -56,14 +61,21 @@ class OperatorSpillContext return 0; } UInt64 getOperatorSpillThreshold() const { return operator_spill_threshold; } - void markSpill() + void markSpilled() { - SpillStatus init_value = SpillStatus::NOT_SPILL; - if (spill_status.compare_exchange_strong(init_value, SpillStatus::SPILL, std::memory_order_relaxed)) + bool init_value = false; + if (is_spilled.compare_exchange_strong(init_value, true, std::memory_order_relaxed)) { LOG_INFO(log, "Begin spill in {}", op_name); } } - bool isSpilled() const { return spill_status != SpillStatus::NOT_SPILL; } + bool isSpilled() const { return is_spilled; } + /// auto trigger spill means the operator will auto spill under the constraint of query/global level memory threshold, + /// so user does not need set operator_spill_threshold explicitly + virtual bool supportAutoTriggerSpill() const { return false; } + virtual Int64 triggerSpill(Int64 expected_released_memories) = 0; }; + +using OperatorSpillContextPtr = std::shared_ptr; +using RegisterOperatorSpillContext = std::function; } // namespace DB diff --git a/dbms/src/Core/QueryOperatorSpillContexts.h b/dbms/src/Core/QueryOperatorSpillContexts.h new file mode 100644 index 00000000000..f899dd095ee --- /dev/null +++ b/dbms/src/Core/QueryOperatorSpillContexts.h @@ -0,0 +1,92 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace DB +{ +class QueryOperatorSpillContexts +{ +public: + explicit QueryOperatorSpillContexts(const MPPQueryId & query_id) + : log(Logger::get(query_id.toString())) + {} + Int64 triggerAutoSpill(Int64 expected_released_memories) + { + std::unique_lock lock(mutex, std::try_to_lock); + /// use mutex to avoid concurrent check, todo maybe need add minimum check interval(like 100ms) here? + if (lock.owns_lock()) + { + if unlikely (!first_check) + { + first_check = true; + LOG_INFO(log, "Query memory usage exceeded threshold, trigger auto spill check"); + } + /// vector of + std::vector> revocable_memories; + revocable_memories.reserve(task_operator_spill_contexts_list.size()); + for (auto it = task_operator_spill_contexts_list.begin(); it != task_operator_spill_contexts_list.end();) + { + if ((*it)->isFinished()) + { + it = task_operator_spill_contexts_list.erase(it); + } + else + { + revocable_memories.emplace_back((*it)->totalRevocableMemories(), (*it).get()); + ++it; + } + } + std::sort(revocable_memories.begin(), revocable_memories.end(), [](const auto & a, const auto & b) { + return a.first > b.first; + }); + for (auto & pair : revocable_memories) + { + if (pair.first < OperatorSpillContext::MIN_SPILL_THRESHOLD) + break; + expected_released_memories = pair.second->triggerAutoSpill(expected_released_memories); + if (expected_released_memories <= 0) + break; + } + return expected_released_memories; + } + return expected_released_memories; + } + + void registerTaskOperatorSpillContexts( + const std::shared_ptr & task_operator_spill_contexts) + { + std::unique_lock lock(mutex); + task_operator_spill_contexts_list.push_back(task_operator_spill_contexts); + } + /// used for test + size_t getTaskOperatorSpillContextsCount() const + { + std::unique_lock lock(mutex); + return task_operator_spill_contexts_list.size(); + } + + const LoggerPtr & getLogger() const { return log; } + +private: + std::list> task_operator_spill_contexts_list; + bool first_check = false; + LoggerPtr log; + mutable std::mutex mutex; +}; + +} // namespace DB diff --git a/dbms/src/Core/TaskOperatorSpillContexts.h b/dbms/src/Core/TaskOperatorSpillContexts.h new file mode 100644 index 00000000000..ae33e5176f9 --- /dev/null +++ b/dbms/src/Core/TaskOperatorSpillContexts.h @@ -0,0 +1,111 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB +{ +class TaskOperatorSpillContexts +{ +public: + Int64 triggerAutoSpill(Int64 expected_released_memories) + { + if (isFinished()) + return expected_released_memories; + appendAdditionalOperatorSpillContexts(); + bool has_finished_operator_spill_contexts = false; + for (auto & operator_spill_context : operator_spill_contexts) + { + assert(operator_spill_context->supportAutoTriggerSpill()); + if (operator_spill_context->spillableStageFinished()) + { + has_finished_operator_spill_contexts = true; + continue; + } + expected_released_memories = operator_spill_context->triggerSpill(expected_released_memories); + if (expected_released_memories <= 0) + break; + } + if (has_finished_operator_spill_contexts) + { + /// clean finished spill context + operator_spill_contexts.erase( + std::remove_if( + operator_spill_contexts.begin(), + operator_spill_contexts.end(), + [](const auto & context) { return context->spillableStageFinished(); }), + operator_spill_contexts.end()); + } + return expected_released_memories; + } + void registerOperatorSpillContext(const OperatorSpillContextPtr & operator_spill_context) + { + if (operator_spill_context->supportAutoTriggerSpill()) + { + std::unique_lock lock(mutex); + additional_operator_spill_contexts.push_back(operator_spill_context); + has_additional_operator_spill_contexts = true; + } + } + /// for tests + size_t operatorSpillContextCount() + { + appendAdditionalOperatorSpillContexts(); + return operator_spill_contexts.size(); + } + /// for tests + size_t additionalOperatorSpillContextCount() const + { + std::unique_lock lock(mutex); + return additional_operator_spill_contexts.size(); + } + + Int64 totalRevocableMemories() + { + if unlikely (isFinished()) + return 0; + appendAdditionalOperatorSpillContexts(); + Int64 ret = 0; + for (const auto & operator_spill_context : operator_spill_contexts) + ret += operator_spill_context->getTotalRevocableMemory(); + return ret; + } + + bool isFinished() const { return is_task_finished; } + + void finish() { is_task_finished = true; } + +private: + void appendAdditionalOperatorSpillContexts() + { + if (has_additional_operator_spill_contexts) + { + std::unique_lock lock(mutex); + operator_spill_contexts.splice(operator_spill_contexts.end(), additional_operator_spill_contexts); + has_additional_operator_spill_contexts = false; + additional_operator_spill_contexts.clear(); + } + } + /// access to operator_spill_contexts is thread safe + std::list operator_spill_contexts; + mutable std::mutex mutex; + /// access to additional_operator_spill_contexts need acquire lock first + std::list additional_operator_spill_contexts; + std::atomic has_additional_operator_spill_contexts{false}; + std::atomic is_task_finished{false}; +}; + +} // namespace DB diff --git a/dbms/src/DataStreams/IProfilingBlockInputStream.cpp b/dbms/src/DataStreams/IProfilingBlockInputStream.cpp index 82ea4aa41ce..e38ec02f5ba 100644 --- a/dbms/src/DataStreams/IProfilingBlockInputStream.cpp +++ b/dbms/src/DataStreams/IProfilingBlockInputStream.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include +#include #include #include #include @@ -88,7 +89,8 @@ Block IProfilingBlockInputStream::read(FilterPtr & res_filter, bool return_filte if (enabled_extremes) updateExtremes(res); - if (limits.mode == LIMITS_CURRENT && !limits.size_limits.check(info.rows, info.bytes, "result", ErrorCodes::TOO_MANY_ROWS_OR_BYTES)) + if (limits.mode == LIMITS_CURRENT + && !limits.size_limits.check(info.rows, info.bytes, "result", ErrorCodes::TOO_MANY_ROWS_OR_BYTES)) limit_exceeded_need_break = true; if (quota != nullptr) @@ -116,6 +118,8 @@ Block IProfilingBlockInputStream::read(FilterPtr & res_filter, bool return_filte } #endif + if (auto_spill_trigger != nullptr) + auto_spill_trigger->triggerAutoSpill(); info.updateExecutionTime(info.total_stopwatch.elapsed() - start_time); return res; } @@ -231,10 +235,11 @@ bool IProfilingBlockInputStream::checkTimeLimit() const { if (limits.max_execution_time != 0 && info.total_stopwatch.elapsed() > static_cast(limits.max_execution_time.totalMicroseconds()) * 1000) - return handleOverflowMode(limits.timeout_overflow_mode, - "Timeout exceeded: elapsed " + toString(info.total_stopwatch.elapsedSeconds()) - + " seconds, maximum: " + toString(limits.max_execution_time.totalMicroseconds() / 1000000.0), - ErrorCodes::TIMEOUT_EXCEEDED); + return handleOverflowMode( + limits.timeout_overflow_mode, + "Timeout exceeded: elapsed " + toString(info.total_stopwatch.elapsedSeconds()) + + " seconds, maximum: " + toString(limits.max_execution_time.totalMicroseconds() / 1000000.0), + ErrorCodes::TIMEOUT_EXCEEDED); return true; } @@ -294,13 +299,15 @@ void IProfilingBlockInputStream::progressImpl(const Progress & value) case OverflowMode::THROW: { if (limits.size_limits.max_rows && total_rows_estimate > limits.size_limits.max_rows) - throw Exception("Limit for rows to read exceeded: " + toString(total_rows_estimate) - + " rows read (or to read), maximum: " + toString(limits.size_limits.max_rows), - ErrorCodes::TOO_MANY_ROWS); + throw Exception( + "Limit for rows to read exceeded: " + toString(total_rows_estimate) + + " rows read (or to read), maximum: " + toString(limits.size_limits.max_rows), + ErrorCodes::TOO_MANY_ROWS); else - throw Exception("Limit for (uncompressed) bytes to read exceeded: " + toString(progress.bytes) - + " bytes read, maximum: " + toString(limits.size_limits.max_bytes), - ErrorCodes::TOO_MANY_BYTES); + throw Exception( + "Limit for (uncompressed) bytes to read exceeded: " + toString(progress.bytes) + + " bytes read, maximum: " + toString(limits.size_limits.max_bytes), + ErrorCodes::TOO_MANY_BYTES); break; } @@ -330,22 +337,26 @@ void IProfilingBlockInputStream::progressImpl(const Progress & value) if (total_elapsed > limits.timeout_before_checking_execution_speed.totalMicroseconds() / 1000000.0) { if (limits.min_execution_speed && progress.rows / total_elapsed < limits.min_execution_speed) - throw Exception("Query is executing too slow: " + toString(progress.rows / total_elapsed) - + " rows/sec., minimum: " + toString(limits.min_execution_speed), - ErrorCodes::TOO_SLOW); + throw Exception( + "Query is executing too slow: " + toString(progress.rows / total_elapsed) + + " rows/sec., minimum: " + toString(limits.min_execution_speed), + ErrorCodes::TOO_SLOW); size_t total_rows = progress.total_rows; /// If the predicted execution time is longer than `max_execution_time`. if (limits.max_execution_time != 0 && total_rows) { - double estimated_execution_time_seconds = total_elapsed * (static_cast(total_rows) / progress.rows); + double estimated_execution_time_seconds + = total_elapsed * (static_cast(total_rows) / progress.rows); if (estimated_execution_time_seconds > limits.max_execution_time.totalSeconds()) - throw Exception("Estimated query execution time (" + toString(estimated_execution_time_seconds) + " seconds)" - + " is too long. Maximum: " + toString(limits.max_execution_time.totalSeconds()) - + ". Estimated rows to process: " + toString(total_rows), - ErrorCodes::TOO_SLOW); + throw Exception( + "Estimated query execution time (" + toString(estimated_execution_time_seconds) + + " seconds)" + + " is too long. Maximum: " + toString(limits.max_execution_time.totalSeconds()) + + ". Estimated rows to process: " + toString(total_rows), + ErrorCodes::TOO_SLOW); } } } @@ -391,10 +402,33 @@ bool IProfilingBlockInputStream::isCancelledOrThrowIfKilled() const void IProfilingBlockInputStream::setProgressCallback(const ProgressCallback & callback) { + std::unordered_set visited_nodes; + setProgressCallbackImpl(callback, visited_nodes); +} + +void IProfilingBlockInputStream::setProgressCallbackImpl( + const ProgressCallback & callback, + std::unordered_set & visited_nodes) +{ + if (visited_nodes.find(this) != visited_nodes.end()) + return; + visited_nodes.insert(this); progress_callback = callback; forEachProfilingChild([&](IProfilingBlockInputStream & child) { - child.setProgressCallback(callback); + child.setProgressCallbackImpl(callback, visited_nodes); + return false; + }); +} + +void IProfilingBlockInputStream::setAutoSpillTrigger(AutoSpillTrigger * auto_spill_trigger_) +{ + if (auto_spill_trigger == auto_spill_trigger_) + return; + auto_spill_trigger = auto_spill_trigger_; + + forEachProfilingChild([&](IProfilingBlockInputStream & child) { + child.setAutoSpillTrigger(auto_spill_trigger_); return false; }); } @@ -402,6 +436,8 @@ void IProfilingBlockInputStream::setProgressCallback(const ProgressCallback & ca void IProfilingBlockInputStream::setProcessListElement(ProcessListElement * elem) { + if (process_list_elem == elem) + return; process_list_elem = elem; forEachProfilingChild([&](IProfilingBlockInputStream & child) { diff --git a/dbms/src/DataStreams/IProfilingBlockInputStream.h b/dbms/src/DataStreams/IProfilingBlockInputStream.h index 2ed5c1e43d1..b57cb4f293a 100644 --- a/dbms/src/DataStreams/IProfilingBlockInputStream.h +++ b/dbms/src/DataStreams/IProfilingBlockInputStream.h @@ -36,6 +36,7 @@ extern const int QUERY_WAS_CANCELLED; class QuotaForIntervals; class ProcessListElement; class IProfilingBlockInputStream; +class AutoSpillTrigger; using ProfilingBlockInputStreamPtr = std::shared_ptr; @@ -86,6 +87,12 @@ class IProfilingBlockInputStream : public IBlockInputStream */ void setProgressCallback(const ProgressCallback & callback); + /** Set auto spill trigger, the auto spill trigger will trigger auto spill based on + * query memory threshold or global memory threshold + * @param callback + */ + void setAutoSpillTrigger(AutoSpillTrigger * auto_spill_trigger_); + /** In this method: * - the progress callback is called; @@ -158,23 +165,14 @@ class IProfilingBlockInputStream : public IBlockInputStream }; /** Set limitations that checked on each block. */ - void setLimits(const LocalLimits & limits_) - { - limits = limits_; - } + void setLimits(const LocalLimits & limits_) { limits = limits_; } - const LocalLimits & getLimits() const - { - return limits; - } + const LocalLimits & getLimits() const { return limits; } /** Set the quota. If you set a quota on the amount of raw data, * then you should also set mode = LIMITS_TOTAL to LocalLimits with setLimits. */ - void setQuota(QuotaForIntervals & quota_) - { - quota = "a_; - } + void setQuota(QuotaForIntervals & quota_) { quota = "a_; } /// Enable calculation of minimums and maximums by the result columns. void enableExtremes() { enabled_extremes = true; } @@ -185,6 +183,7 @@ class IProfilingBlockInputStream : public IBlockInputStream std::atomic is_killed{false}; ProgressCallback progress_callback; ProcessListElement * process_list_elem = nullptr; + AutoSpillTrigger * auto_spill_trigger = nullptr; /// Additional information that can be generated during the work process. @@ -247,6 +246,7 @@ class IProfilingBlockInputStream : public IBlockInputStream if (f(*p_child)) return; } + void setProgressCallbackImpl(const ProgressCallback & callback, std::unordered_set & visited_nodes); }; } // namespace DB diff --git a/dbms/src/DataStreams/MergeSortingBlockInputStream.cpp b/dbms/src/DataStreams/MergeSortingBlockInputStream.cpp index 9809af0e580..a97a3e76c59 100644 --- a/dbms/src/DataStreams/MergeSortingBlockInputStream.cpp +++ b/dbms/src/DataStreams/MergeSortingBlockInputStream.cpp @@ -32,7 +32,8 @@ MergeSortingBlockInputStream::MergeSortingBlockInputStream( size_t limit_, size_t max_bytes_before_external_sort, const SpillConfig & spill_config, - const String & req_id) + const String & req_id, + const RegisterOperatorSpillContext & register_operator_spill_context) : description(description_) , max_merged_block_size(max_merged_block_size_) , limit(limit_) @@ -46,8 +47,27 @@ MergeSortingBlockInputStream::MergeSortingBlockInputStream( sort_spill_context = std::make_shared(spill_config, max_bytes_before_external_sort, log); if (sort_spill_context->isSpillEnabled()) sort_spill_context->buildSpiller(header_without_constants); + if (register_operator_spill_context != nullptr) + register_operator_spill_context(sort_spill_context); } +void MergeSortingBlockInputStream::spillCurrentBlocks() +{ + sort_spill_context->markSpilled(); + auto block_in = std::make_shared( + blocks, + description, + log->identifier(), + std::max(1, max_merged_block_size / 10), + limit); + auto is_cancelled_pred = [this]() { + return this->isCancelled(); + }; + sort_spill_context->getSpiller()->spillBlocksUsingBlockInputStream(block_in, 0, is_cancelled_pred); + sort_spill_context->finishOneSpill(); + blocks.clear(); + sum_bytes_in_blocks = 0; +} Block MergeSortingBlockInputStream::readImpl() { @@ -79,26 +99,29 @@ Block MergeSortingBlockInputStream::readImpl() */ if (sort_spill_context->updateRevocableMemory(sum_bytes_in_blocks)) { - sort_spill_context->markSpill(); - auto block_in = std::make_shared(blocks, description, log->identifier(), max_merged_block_size, limit); - auto is_cancelled_pred = [this]() { - return this->isCancelled(); - }; - sort_spill_context->getSpiller()->spillBlocksUsingBlockInputStream(block_in, 0, is_cancelled_pred); - blocks.clear(); + spillCurrentBlocks(); if (is_cancelled) break; - sum_bytes_in_blocks = 0; } } + sort_spill_context->finishSpillableStage(); + if (!blocks.empty() && sort_spill_context->needFinalSpill()) + { + spillCurrentBlocks(); + } + if (isCancelledOrThrowIfKilled() || (blocks.empty() && !hasSpilledData())) return Block(); - sort_spill_context->finishSpillableStage(); if (!hasSpilledData()) { - impl = std::make_unique(blocks, description, log->identifier(), max_merged_block_size, limit); + impl = std::make_unique( + blocks, + description, + log->identifier(), + max_merged_block_size, + limit); } else { @@ -120,7 +143,11 @@ Block MergeSortingBlockInputStream::readImpl() limit)); /// Will merge that sorted streams. - impl = std::make_unique(inputs_to_merge, description, max_merged_block_size, limit); + impl = std::make_unique( + inputs_to_merge, + description, + max_merged_block_size, + limit); } } diff --git a/dbms/src/DataStreams/MergeSortingBlockInputStream.h b/dbms/src/DataStreams/MergeSortingBlockInputStream.h index c057984d93f..99abbf3e6ac 100644 --- a/dbms/src/DataStreams/MergeSortingBlockInputStream.h +++ b/dbms/src/DataStreams/MergeSortingBlockInputStream.h @@ -41,7 +41,8 @@ class MergeSortingBlockInputStream : public IProfilingBlockInputStream size_t limit_, size_t max_bytes_before_external_sort_, const SpillConfig & spill_config_, - const String & req_id); + const String & req_id, + const RegisterOperatorSpillContext & register_operator_spill_context); String getName() const override { return NAME; } @@ -56,10 +57,9 @@ class MergeSortingBlockInputStream : public IProfilingBlockInputStream void appendInfo(FmtBuffer & buffer) const override; private: - bool hasSpilledData() const - { - return sort_spill_context->hasSpilledData(); - } + bool hasSpilledData() const { return sort_spill_context->hasSpilledData(); } + + void spillCurrentBlocks(); SortDescription description; size_t max_merged_block_size; size_t limit; diff --git a/dbms/src/Flash/Coprocessor/DAGContext.cpp b/dbms/src/Flash/Coprocessor/DAGContext.cpp index 51802b2b52b..f5c0a880f1b 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.cpp +++ b/dbms/src/Flash/Coprocessor/DAGContext.cpp @@ -42,17 +42,25 @@ bool strictSqlMode(UInt64 sql_mode) } // for non-mpp(cop/batchCop) -DAGContext::DAGContext(tipb::DAGRequest & dag_request_, TablesRegionsInfo && tables_regions_info_, KeyspaceID keyspace_id_, const String & tidb_host_, bool is_batch_cop_, LoggerPtr log_) +DAGContext::DAGContext( + tipb::DAGRequest & dag_request_, + TablesRegionsInfo && tables_regions_info_, + KeyspaceID keyspace_id_, + const String & tidb_host_, + bool is_batch_cop_, + LoggerPtr log_) : dag_request(&dag_request_) , dummy_query_string(dag_request->DebugString()) , dummy_ast(makeDummyQuery()) , tidb_host(tidb_host_) - , collect_execution_summaries(dag_request->has_collect_execution_summaries() && dag_request->collect_execution_summaries()) + , collect_execution_summaries( + dag_request->has_collect_execution_summaries() && dag_request->collect_execution_summaries()) , is_mpp_task(false) , is_root_mpp_task(false) , is_batch_cop(is_batch_cop_) , tables_regions_info(std::move(tables_regions_info_)) , log(std::move(log_)) + , operator_spill_contexts(std::make_shared()) , flags(dag_request->flags()) , sql_mode(dag_request->sql_mode()) , max_recorded_error_count(getMaxErrorCount(*dag_request)) @@ -68,9 +76,11 @@ DAGContext::DAGContext(tipb::DAGRequest & dag_request_, const mpp::TaskMeta & me : dag_request(&dag_request_) , dummy_query_string(dag_request->DebugString()) , dummy_ast(makeDummyQuery()) - , collect_execution_summaries(dag_request->has_collect_execution_summaries() && dag_request->collect_execution_summaries()) + , collect_execution_summaries( + dag_request->has_collect_execution_summaries() && dag_request->collect_execution_summaries()) , is_mpp_task(true) , is_root_mpp_task(is_root_mpp_task_) + , operator_spill_contexts(std::make_shared()) , flags(dag_request->flags()) , sql_mode(dag_request->sql_mode()) , mpp_task_meta(meta_) @@ -86,18 +96,25 @@ DAGContext::DAGContext(tipb::DAGRequest & dag_request_, const mpp::TaskMeta & me } // for disaggregated task on write node -DAGContext::DAGContext(tipb::DAGRequest & dag_request_, const disaggregated::DisaggTaskMeta & task_meta_, TablesRegionsInfo && tables_regions_info_, const String & compute_node_host_, LoggerPtr log_) +DAGContext::DAGContext( + tipb::DAGRequest & dag_request_, + const disaggregated::DisaggTaskMeta & task_meta_, + TablesRegionsInfo && tables_regions_info_, + const String & compute_node_host_, + LoggerPtr log_) : dag_request(&dag_request_) , dummy_query_string(dag_request->DebugString()) , dummy_ast(makeDummyQuery()) , tidb_host(compute_node_host_) - , collect_execution_summaries(dag_request->has_collect_execution_summaries() && dag_request->collect_execution_summaries()) + , collect_execution_summaries( + dag_request->has_collect_execution_summaries() && dag_request->collect_execution_summaries()) , is_mpp_task(false) , is_root_mpp_task(false) , is_batch_cop(false) , is_disaggregated_task(true) , tables_regions_info(std::move(tables_regions_info_)) , log(std::move(log_)) + , operator_spill_contexts(std::make_shared()) , flags(dag_request->flags()) , sql_mode(dag_request->sql_mode()) , disaggregated_id(std::make_unique(task_meta_)) @@ -116,6 +133,7 @@ DAGContext::DAGContext(UInt64 max_error_count_) , collect_execution_summaries(false) , is_mpp_task(false) , is_root_mpp_task(false) + , operator_spill_contexts(std::make_shared()) , flags(0) , sql_mode(0) , max_recorded_error_count(max_error_count_) @@ -129,10 +147,12 @@ DAGContext::DAGContext(tipb::DAGRequest & dag_request_, String log_identifier, s , dummy_query_string(dag_request->DebugString()) , dummy_ast(makeDummyQuery()) , initialize_concurrency(concurrency) - , collect_execution_summaries(dag_request->has_collect_execution_summaries() && dag_request->collect_execution_summaries()) + , collect_execution_summaries( + dag_request->has_collect_execution_summaries() && dag_request->collect_execution_summaries()) , is_mpp_task(false) , is_root_mpp_task(false) , log(Logger::get(log_identifier)) + , operator_spill_contexts(std::make_shared()) , flags(dag_request->flags()) , sql_mode(dag_request->sql_mode()) , max_recorded_error_count(getMaxErrorCount(*dag_request)) @@ -142,6 +162,11 @@ DAGContext::DAGContext(tipb::DAGRequest & dag_request_, String log_identifier, s initOutputInfo(); } +DAGContext::~DAGContext() +{ + operator_spill_contexts->finish(); +} + void DAGContext::initOutputInfo() { output_field_types = collectOutputFieldTypes(*dag_request); @@ -152,12 +177,17 @@ void DAGContext::initOutputInfo() output_offsets.push_back(i); if (unlikely(i >= output_field_types.size())) throw TiFlashException( - fmt::format("{}: Invalid output offset(schema has {} columns, access index {}", __PRETTY_FUNCTION__, output_field_types.size(), i), + fmt::format( + "{}: Invalid output offset(schema has {} columns, access index {}", + __PRETTY_FUNCTION__, + output_field_types.size(), + i), Errors::Coprocessor::BadRequest); result_field_types.push_back(output_field_types[i]); } encode_type = analyzeDAGEncodeType(*this); - keep_session_timezone_info = encode_type == tipb::EncodeType::TypeChunk || encode_type == tipb::EncodeType::TypeCHBlock; + keep_session_timezone_info + = encode_type == tipb::EncodeType::TypeChunk || encode_type == tipb::EncodeType::TypeCHBlock; } bool DAGContext::allowZeroInDate() const @@ -187,7 +217,10 @@ std::unordered_map & DAGContext::getOperatorProfil return operator_profile_infos_map; } -void DAGContext::addOperatorProfileInfos(const String & executor_id, OperatorProfileInfos && profile_infos, bool is_append) +void DAGContext::addOperatorProfileInfos( + const String & executor_id, + OperatorProfileInfos && profile_infos, + bool is_append) { if (profile_infos.empty()) return; @@ -206,7 +239,10 @@ void DAGContext::addOperatorProfileInfos(const String & executor_id, OperatorPro } } -void DAGContext::addInboundIOProfileInfos(const String & executor_id, IOProfileInfos && io_profile_infos, bool is_append) +void DAGContext::addInboundIOProfileInfos( + const String & executor_id, + IOProfileInfos && io_profile_infos, + bool is_append) { if (io_profile_infos.empty()) return; @@ -312,7 +348,8 @@ void DAGContext::handleInvalidTime(const String & msg, const TiFlashError & erro throw TiFlashException(msg, error); } handleTruncateError(msg); - if (strictSqlMode(sql_mode) && (flags & TiDBSQLFlags::IN_INSERT_STMT || flags & TiDBSQLFlags::IN_UPDATE_OR_DELETE_STMT)) + if (strictSqlMode(sql_mode) + && (flags & TiDBSQLFlags::IN_INSERT_STMT || flags & TiDBSQLFlags::IN_UPDATE_OR_DELETE_STMT)) { throw TiFlashException(msg, error); } diff --git a/dbms/src/Flash/Coprocessor/DAGContext.h b/dbms/src/Flash/Coprocessor/DAGContext.h index 3f1b09c22c4..07e9c32562e 100644 --- a/dbms/src/Flash/Coprocessor/DAGContext.h +++ b/dbms/src/Flash/Coprocessor/DAGContext.h @@ -24,6 +24,8 @@ #include #include +#include +#include #include #include #include @@ -52,6 +54,8 @@ using MPPReceiverSetPtr = std::shared_ptr; class CoprocessorReader; using CoprocessorReaderPtr = std::shared_ptr; +class AutoSpillTrigger; + struct JoinProfileInfo; using JoinProfileInfoPtr = std::shared_ptr; struct JoinExecuteInfo @@ -138,13 +142,24 @@ class DAGContext { public: // for non-mpp(cop/batchCop) - DAGContext(tipb::DAGRequest & dag_request_, TablesRegionsInfo && tables_regions_info_, KeyspaceID keyspace_id_, const String & tidb_host_, bool is_batch_cop_, LoggerPtr log_); + DAGContext( + tipb::DAGRequest & dag_request_, + TablesRegionsInfo && tables_regions_info_, + KeyspaceID keyspace_id_, + const String & tidb_host_, + bool is_batch_cop_, + LoggerPtr log_); // for mpp DAGContext(tipb::DAGRequest & dag_request_, const mpp::TaskMeta & meta_, bool is_root_mpp_task_); // for disaggregated task on write node - DAGContext(tipb::DAGRequest & dag_request_, const disaggregated::DisaggTaskMeta & task_meta_, TablesRegionsInfo && tables_regions_info_, const String & compute_node_host_, LoggerPtr log_); + DAGContext( + tipb::DAGRequest & dag_request_, + const disaggregated::DisaggTaskMeta & task_meta_, + TablesRegionsInfo && tables_regions_info_, + const String & compute_node_host_, + LoggerPtr log_); // for test explicit DAGContext(UInt64 max_error_count_); @@ -152,11 +167,16 @@ class DAGContext // for tests need to run query tasks. DAGContext(tipb::DAGRequest & dag_request_, String log_identifier, size_t concurrency); + ~DAGContext(); + std::unordered_map & getProfileStreamsMap(); std::unordered_map & getOperatorProfileInfosMap(); - void addOperatorProfileInfos(const String & executor_id, OperatorProfileInfos && profile_infos, bool is_append = false); + void addOperatorProfileInfos( + const String & executor_id, + OperatorProfileInfos && profile_infos, + bool is_append = false); std::unordered_map> & getExecutorIdToJoinIdMap(); @@ -166,7 +186,10 @@ class DAGContext std::unordered_map & getInboundIOProfileInfosMap(); - void addInboundIOProfileInfos(const String & executor_id, IOProfileInfos && io_profile_infos, bool is_append = false); + void addInboundIOProfileInfos( + const String & executor_id, + IOProfileInfos && io_profile_infos, + bool is_append = false); void handleTruncateError(const String & msg); void handleOverflowError(const String & msg, const TiFlashError & error); @@ -208,14 +231,8 @@ class DAGContext bool isMPPTask() const { return is_mpp_task; } /// root mpp task means mpp task that send data back to TiDB bool isRootMPPTask() const { return is_root_mpp_task; } - const MPPTaskId & getMPPTaskId() const - { - return mpp_task_id; - } - const std::unique_ptr & getDisaggTaskId() const - { - return disaggregated_id; - } + const MPPTaskId & getMPPTaskId() const { return mpp_task_id; } + const std::unique_ptr & getDisaggTaskId() const { return disaggregated_id; } std::pair getTableScanThroughput(); @@ -223,55 +240,22 @@ class DAGContext bool containsRegionsInfoForTable(Int64 table_id) const; - UInt64 getFlags() const - { - return flags; - } - void setFlags(UInt64 f) - { - flags = f; - } - void addFlag(UInt64 f) - { - flags |= f; - } - void delFlag(UInt64 f) - { - flags &= (~f); - } - bool hasFlag(UInt64 f) const - { - return (flags & f); - } + UInt64 getFlags() const { return flags; } + void setFlags(UInt64 f) { flags = f; } + void addFlag(UInt64 f) { flags |= f; } + void delFlag(UInt64 f) { flags &= (~f); } + bool hasFlag(UInt64 f) const { return (flags & f); } - UInt64 getSQLMode() const - { - return sql_mode; - } - void setSQLMode(UInt64 f) - { - sql_mode = f; - } - void addSQLMode(UInt64 f) - { - sql_mode |= f; - } - void delSQLMode(UInt64 f) - { - sql_mode &= (~f); - } - bool hasSQLMode(UInt64 f) const - { - return sql_mode & f; - } + UInt64 getSQLMode() const { return sql_mode; } + void setSQLMode(UInt64 f) { sql_mode = f; } + void addSQLMode(UInt64 f) { sql_mode |= f; } + void delSQLMode(UInt64 f) { sql_mode &= (~f); } + bool hasSQLMode(UInt64 f) const { return sql_mode & f; } void updateFinalConcurrency(size_t cur_streams_size, size_t streams_upper_limit); ExchangeReceiverPtr getMPPExchangeReceiver(const String & executor_id) const; - void setMPPReceiverSet(const MPPReceiverSetPtr & receiver_set) - { - mpp_receiver_set = receiver_set; - } + void setMPPReceiverSet(const MPPReceiverSetPtr & receiver_set) { mpp_receiver_set = receiver_set; } void addCoprocessorReader(const CoprocessorReaderPtr & coprocessor_reader); std::vector & getCoprocessorReaders(); void setDisaggregatedComputeExchangeReceiver(const String & executor_id, const ExchangeReceiverPtr & receiver) @@ -287,8 +271,21 @@ class DAGContext void addSubquery(const String & subquery_id, SubqueryForSet && subquery); bool hasSubquery() const { return !subqueries.empty(); } std::vector && moveSubqueries() { return std::move(subqueries); } - void setProcessListEntry(std::shared_ptr entry) { process_list_entry = entry; } + void setProcessListEntry(const std::shared_ptr & entry) { process_list_entry = entry; } std::shared_ptr getProcessListEntry() const { return process_list_entry; } + void setQueryOperatorSpillContexts( + const std::shared_ptr & query_operator_spill_contexts_) + { + query_operator_spill_contexts = query_operator_spill_contexts_; + } + std::shared_ptr & getQueryOperatorSpillContexts() + { + return query_operator_spill_contexts; + } + void setAutoSpillTrigger(const std::shared_ptr & auto_spill_trigger_) + { + auto_spill_trigger = auto_spill_trigger_; + } void addTableLock(const TableLockHolder & lock) { table_locks.push_back(lock); } @@ -308,6 +305,16 @@ class DAGContext } ExecutionMode getExecutionMode() const { return execution_mode; } + void registerOperatorSpillContext(const OperatorSpillContextPtr & operator_spill_context) + { + operator_spill_contexts->registerOperatorSpillContext(operator_spill_context); + } + + void registerTaskOperatorSpillContexts() + { + query_operator_spill_contexts->registerTaskOperatorSpillContexts(operator_spill_contexts); + } + public: DAGRequest dag_request; /// Some existing code inherited from Clickhouse assume that each query must have a valid query string and query ast, @@ -360,6 +367,9 @@ class DAGContext private: std::shared_ptr process_list_entry; + std::shared_ptr operator_spill_contexts; + std::shared_ptr query_operator_spill_contexts; + std::shared_ptr auto_spill_trigger; /// Holding the table lock to make sure that the table wouldn't be dropped during the lifetime of this query, even if there are no local regions. /// TableLockHolders need to be released after the BlockInputStream is destroyed to prevent data read exceptions. TableLockHolders table_locks; diff --git a/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp b/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp index 038beae9200..d4553739963 100644 --- a/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp +++ b/dbms/src/Flash/Coprocessor/InterpreterUtils.cpp @@ -43,16 +43,15 @@ using UnionWithBlock = UnionBlockInputStream<>; using UnionWithoutBlock = UnionBlockInputStream; } // namespace -void restoreConcurrency( - DAGPipeline & pipeline, - size_t concurrency, - Int64 max_buffered_bytes, - const LoggerPtr & log) +void restoreConcurrency(DAGPipeline & pipeline, size_t concurrency, Int64 max_buffered_bytes, const LoggerPtr & log) { if (concurrency > 1 && pipeline.streams.size() == 1) { - BlockInputStreamPtr shared_query_block_input_stream - = std::make_shared(concurrency * 5, max_buffered_bytes, pipeline.firstStream(), log->identifier()); + BlockInputStreamPtr shared_query_block_input_stream = std::make_shared( + concurrency * 5, + max_buffered_bytes, + pipeline.firstStream(), + log->identifier()); shared_query_block_input_stream->setExtraInfo("restore concurrency"); pipeline.streams.assign(concurrency, shared_query_block_input_stream); } @@ -70,9 +69,19 @@ void executeUnion( { BlockInputStreamPtr stream; if (ignore_block) - stream = std::make_shared(pipeline.streams, BlockInputStreams{}, max_streams, max_buffered_bytes, log->identifier()); + stream = std::make_shared( + pipeline.streams, + BlockInputStreams{}, + max_streams, + max_buffered_bytes, + log->identifier()); else - stream = std::make_shared(pipeline.streams, BlockInputStreams{}, max_streams, max_buffered_bytes, log->identifier()); + stream = std::make_shared( + pipeline.streams, + BlockInputStreams{}, + max_streams, + max_buffered_bytes, + log->identifier()); stream->setExtraInfo(extra_info); pipeline.streams.resize(1); @@ -97,7 +106,8 @@ void restoreConcurrency( auto cur_header = group_builder.getCurrentHeader(); group_builder.addGroup(); for (size_t i = 0; i < concurrency; ++i) - group_builder.addConcurrency(std::make_unique(exec_context, log->identifier(), cur_header, shared_queue)); + group_builder.addConcurrency( + std::make_unique(exec_context, log->identifier(), cur_header, shared_queue)); } } @@ -115,7 +125,8 @@ void executeUnion( }); auto cur_header = group_builder.getCurrentHeader(); group_builder.addGroup(); - group_builder.addConcurrency(std::make_unique(exec_context, log->identifier(), cur_header, shared_queue)); + group_builder.addConcurrency( + std::make_unique(exec_context, log->identifier(), cur_header, shared_queue)); } } @@ -155,7 +166,8 @@ void executeExpression( if (expr_actions && !expr_actions->getActions().empty()) { group_builder.transform([&](auto & builder) { - builder.appendTransformOp(std::make_unique(exec_context, log->identifier(), expr_actions)); + builder.appendTransformOp( + std::make_unique(exec_context, log->identifier(), expr_actions)); }); } } @@ -188,8 +200,20 @@ void orderStreams( settings.max_block_size, limit, getAverageThreshold(settings.max_bytes_before_external_sort, pipeline.streams.size()), - SpillConfig(context.getTemporaryPath(), fmt::format("{}_sort", log->identifier()), settings.max_cached_data_bytes_in_spiller, settings.max_spilled_rows_per_file, settings.max_spilled_bytes_per_file, context.getFileProvider()), - log->identifier()); + SpillConfig( + context.getTemporaryPath(), + fmt::format("{}_sort", log->identifier()), + settings.max_cached_data_bytes_in_spiller, + settings.max_spilled_rows_per_file, + settings.max_spilled_bytes_per_file, + context.getFileProvider()), + log->identifier(), + [&](const OperatorSpillContextPtr & operator_spill_context) { + if (context.getDAGContext() != nullptr) + { + context.getDAGContext()->registerOperatorSpillContext(operator_spill_context); + } + }); stream->setExtraInfo(String(enableFineGrainedShuffleExtraInfo)); }); } @@ -206,8 +230,20 @@ void orderStreams( limit, settings.max_bytes_before_external_sort, // todo use identifier_executor_id as the spill id - SpillConfig(context.getTemporaryPath(), fmt::format("{}_sort", log->identifier()), settings.max_cached_data_bytes_in_spiller, settings.max_spilled_rows_per_file, settings.max_spilled_bytes_per_file, context.getFileProvider()), - log->identifier()); + SpillConfig( + context.getTemporaryPath(), + fmt::format("{}_sort", log->identifier()), + settings.max_cached_data_bytes_in_spiller, + settings.max_spilled_rows_per_file, + settings.max_spilled_bytes_per_file, + context.getFileProvider()), + log->identifier(), + [&](const OperatorSpillContextPtr & operator_spill_context) { + if (context.getDAGContext() != nullptr) + { + context.getDAGContext()->registerOperatorSpillContext(operator_spill_context); + } + }); } } @@ -227,7 +263,8 @@ void executeLocalSort( { group_builder.transform([&](auto & builder) { auto local_limit = std::make_shared(input_header, *limit); - builder.appendTransformOp(std::make_unique>(exec_context, log->identifier(), local_limit)); + builder.appendTransformOp( + std::make_unique>(exec_context, log->identifier(), local_limit)); }); } // For order by const and doesn't has limit, do nothing here. @@ -242,7 +279,8 @@ void executeLocalSort( limit.value_or(0))); // 0 means that no limit in PartialSortTransformOp. }); const Settings & settings = context.getSettingsRef(); - size_t max_bytes_before_external_sort = getAverageThreshold(settings.max_bytes_before_external_sort, group_builder.concurrency()); + size_t max_bytes_before_external_sort + = getAverageThreshold(settings.max_bytes_before_external_sort, group_builder.concurrency()); SpillConfig spill_config{ context.getTemporaryPath(), fmt::format("{}_sort", log->identifier()), @@ -279,7 +317,8 @@ void executeFinalSort( { auto global_limit = std::make_shared(input_header, *limit); group_builder.transform([&](auto & builder) { - builder.appendTransformOp(std::make_unique>(exec_context, log->identifier(), global_limit)); + builder.appendTransformOp( + std::make_unique>(exec_context, log->identifier(), global_limit)); }); } // For order by const and doesn't has limit, do nothing here. @@ -318,23 +357,43 @@ void executeFinalSort( } } -void executeCreatingSets( - DAGPipeline & pipeline, - const Context & context, - size_t max_streams, - const LoggerPtr & log) +void executeCreatingSets(DAGPipeline & pipeline, const Context & context, size_t max_streams, const LoggerPtr & log) { DAGContext & dag_context = *context.getDAGContext(); /// add union to run in parallel if needed if (unlikely(context.isExecutorTest() || context.isInterpreterTest())) - executeUnion(pipeline, max_streams, context.getSettingsRef().max_buffered_bytes_in_executor, log, /*ignore_block=*/false, "for test"); + executeUnion( + pipeline, + max_streams, + context.getSettingsRef().max_buffered_bytes_in_executor, + log, + /*ignore_block=*/false, + "for test"); else if (context.isMPPTest()) - executeUnion(pipeline, max_streams, context.getSettingsRef().max_buffered_bytes_in_executor, log, /*ignore_block=*/true, "for mpp test"); + executeUnion( + pipeline, + max_streams, + context.getSettingsRef().max_buffered_bytes_in_executor, + log, + /*ignore_block=*/true, + "for mpp test"); else if (dag_context.isMPPTask()) /// MPPTask do not need the returned blocks. - executeUnion(pipeline, max_streams, context.getSettingsRef().max_buffered_bytes_in_executor, log, /*ignore_block=*/true, "for mpp"); + executeUnion( + pipeline, + max_streams, + context.getSettingsRef().max_buffered_bytes_in_executor, + log, + /*ignore_block=*/true, + "for mpp"); else - executeUnion(pipeline, max_streams, context.getSettingsRef().max_buffered_bytes_in_executor, log, /*ignore_block=*/false, "for non mpp"); + executeUnion( + pipeline, + max_streams, + context.getSettingsRef().max_buffered_bytes_in_executor, + log, + /*ignore_block=*/false, + "for non mpp"); if (dag_context.hasSubquery()) { const Settings & settings = context.getSettingsRef(); @@ -377,7 +436,8 @@ void executePushedDownFilter( LoggerPtr log, DAGPipeline & pipeline) { - auto [before_where, filter_column_name, project_after_where] = ::DB::buildPushDownFilter(filter_conditions.conditions, analyzer); + auto [before_where, filter_column_name, project_after_where] + = ::DB::buildPushDownFilter(filter_conditions.conditions, analyzer); for (auto & stream : pipeline.streams) { @@ -397,15 +457,22 @@ void executePushedDownFilter( DAGExpressionAnalyzer & analyzer, LoggerPtr log) { - auto [before_where, filter_column_name, project_after_where] = ::DB::buildPushDownFilter(filter_conditions.conditions, analyzer); + auto [before_where, filter_column_name, project_after_where] + = ::DB::buildPushDownFilter(filter_conditions.conditions, analyzer); auto input_header = group_builder.getCurrentHeader(); for (size_t i = 0; i < group_builder.concurrency(); ++i) { auto & builder = group_builder.getCurBuilder(i); - builder.appendTransformOp(std::make_unique(exec_context, log->identifier(), input_header, before_where, filter_column_name)); + builder.appendTransformOp(std::make_unique( + exec_context, + log->identifier(), + input_header, + before_where, + filter_column_name)); // after filter, do project action to keep the schema of local transforms and remote transforms the same. - builder.appendTransformOp(std::make_unique(exec_context, log->identifier(), project_after_where)); + builder.appendTransformOp( + std::make_unique(exec_context, log->identifier(), project_after_where)); } } @@ -417,7 +484,10 @@ void executeGeneratedColumnPlaceholder( if (generated_column_infos.empty()) return; pipeline.transform([&](auto & stream) { - stream = std::make_shared(stream, generated_column_infos, log->identifier()); + stream = std::make_shared( + stream, + generated_column_infos, + log->identifier()); stream->setExtraInfo("generated column placeholder above table scan"); }); } @@ -433,7 +503,11 @@ void executeGeneratedColumnPlaceholder( auto input_header = group_builder.getCurrentHeader(); group_builder.transform([&](auto & builder) { - builder.appendTransformOp(std::make_unique(exec_context, log->identifier(), input_header, generated_column_infos)); + builder.appendTransformOp(std::make_unique( + exec_context, + log->identifier(), + input_header, + generated_column_infos)); }); } diff --git a/dbms/src/Flash/Executor/PipelineExecutor.cpp b/dbms/src/Flash/Executor/PipelineExecutor.cpp index 4dfcfa1d339..7a35c0434b7 100644 --- a/dbms/src/Flash/Executor/PipelineExecutor.cpp +++ b/dbms/src/Flash/Executor/PipelineExecutor.cpp @@ -23,6 +23,8 @@ namespace DB { PipelineExecutor::PipelineExecutor( const MemoryTrackerPtr & memory_tracker_, + AutoSpillTrigger * auto_spill_trigger, + const RegisterOperatorSpillContext & register_operator_spill_context, Context & context_, const String & req_id) : QueryExecutor(memory_tracker_, context_, req_id) @@ -31,7 +33,9 @@ PipelineExecutor::PipelineExecutor( // But for cop/batchCop, there is no such unique identifier, so an empty value is given here, indicating that the query id of PipelineExecutor is invalid. /*query_id=*/context.getDAGContext()->is_mpp_task ? context.getDAGContext()->getMPPTaskId().toString() : "", req_id, - memory_tracker_) + memory_tracker_, + auto_spill_trigger, + register_operator_spill_context) { PhysicalPlan physical_plan{context, log->identifier()}; physical_plan.build(context.getDAGContext()->dag_request()); diff --git a/dbms/src/Flash/Executor/PipelineExecutor.h b/dbms/src/Flash/Executor/PipelineExecutor.h index 6cb0d11f7af..3665a96e1ac 100644 --- a/dbms/src/Flash/Executor/PipelineExecutor.h +++ b/dbms/src/Flash/Executor/PipelineExecutor.h @@ -26,6 +26,8 @@ class Pipeline; using PipelinePtr = std::shared_ptr; using Pipelines = std::vector; +class AutoSpillTrigger; + /** * PipelineExecutor is the implementation of the pipeline-based execution model. * @@ -54,6 +56,8 @@ class PipelineExecutor : public QueryExecutor public: PipelineExecutor( const MemoryTrackerPtr & memory_tracker_, + AutoSpillTrigger * auto_spill_trigger, + const RegisterOperatorSpillContext & register_operator_spill_context, Context & context_, const String & req_id); diff --git a/dbms/src/Flash/Executor/PipelineExecutorContext.h b/dbms/src/Flash/Executor/PipelineExecutorContext.h index 26dec35af63..ac9e0c412ac 100644 --- a/dbms/src/Flash/Executor/PipelineExecutorContext.h +++ b/dbms/src/Flash/Executor/PipelineExecutorContext.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -27,6 +28,8 @@ namespace DB { +class OperatorSpillContext; +using RegisterOperatorSpillContext = std::function & ptr)>; class PipelineExecutorContext : private boost::noncopyable { public: @@ -38,10 +41,17 @@ class PipelineExecutorContext : private boost::noncopyable , mem_tracker(nullptr) {} - PipelineExecutorContext(const String & query_id_, const String & req_id, const MemoryTrackerPtr & mem_tracker_) + PipelineExecutorContext( + const String & query_id_, + const String & req_id, + const MemoryTrackerPtr & mem_tracker_, + AutoSpillTrigger * auto_spill_trigger_ = nullptr, + const RegisterOperatorSpillContext & register_operator_spill_context_ = nullptr) : query_id(query_id_) , log(Logger::get(req_id)) , mem_tracker(mem_tracker_) + , auto_spill_trigger(auto_spill_trigger_) + , register_operator_spill_context(register_operator_spill_context_) {} ExecutionResult toExecutionResult(); @@ -128,31 +138,28 @@ class PipelineExecutorContext : private boost::noncopyable void cancel(); - ALWAYS_INLINE bool isCancelled() - { - return is_cancelled.load(std::memory_order_acquire); - } + ALWAYS_INLINE bool isCancelled() { return is_cancelled.load(std::memory_order_acquire); } ResultQueuePtr toConsumeMode(size_t queue_size); - void update(const TaskProfileInfo & task_profile_info) - { - query_profile_info.merge(task_profile_info); - } + void update(const TaskProfileInfo & task_profile_info) { query_profile_info.merge(task_profile_info); } - const QueryProfileInfo & getQueryProfileInfo() const - { - return query_profile_info; - } + const QueryProfileInfo & getQueryProfileInfo() const { return query_profile_info; } + + const String & getQueryId() const { return query_id; } + + const MemoryTrackerPtr & getMemoryTracker() const { return mem_tracker; } - const String & getQueryId() const + void triggerAutoSpill() const { - return query_id; + if (auto_spill_trigger != nullptr) + auto_spill_trigger->triggerAutoSpill(); } - const MemoryTrackerPtr & getMemoryTracker() const + void registerOperatorSpillContext(const std::shared_ptr & operator_spill_context) { - return mem_tracker; + if (register_operator_spill_context != nullptr) + register_operator_spill_context(operator_spill_context); } private: @@ -183,5 +190,9 @@ class PipelineExecutorContext : private boost::noncopyable std::optional result_queue; QueryProfileInfo query_profile_info; + + AutoSpillTrigger * auto_spill_trigger; + + RegisterOperatorSpillContext register_operator_spill_context; }; } // namespace DB diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index ea00dd1b8ea..a479f93cf89 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -350,6 +350,13 @@ void MPPTask::unregisterTask() LOG_WARNING(log, "task failed to unregister, reason: {}", reason); } +void MPPTask::initQueryOperatorSpillContexts( + const std::shared_ptr & mpp_query_operator_spill_contexts) +{ + assert(mpp_query_operator_spill_contexts != nullptr); + dag_context->setQueryOperatorSpillContexts(mpp_query_operator_spill_contexts); +} + void MPPTask::initProcessListEntry(const std::shared_ptr & query_process_list_entry) { /// all the mpp tasks of the same mpp query shares the same process list entry @@ -524,6 +531,7 @@ void MPPTask::runImpl() throw Exception("task not in running state, may be cancelled"); } mpp_task_statistics.start(); + dag_context->registerTaskOperatorSpillContexts(); #ifndef NDEBUG if (isRootMPPTask()) diff --git a/dbms/src/Flash/Mpp/MPPTask.h b/dbms/src/Flash/Mpp/MPPTask.h index 15cbdcbfb39..b2860538a01 100644 --- a/dbms/src/Flash/Mpp/MPPTask.h +++ b/dbms/src/Flash/Mpp/MPPTask.h @@ -40,6 +40,7 @@ class MPPTaskManager; using MPPTaskManagerPtr = std::shared_ptr; class DAGContext; class ProcessListEntry; +class QueryOperatorSpillContexts; enum class AbortType { @@ -127,6 +128,9 @@ class MPPTask void initProcessListEntry(const std::shared_ptr & query_process_list_entry); + void initQueryOperatorSpillContexts( + const std::shared_ptr & mpp_query_operator_spill_contexts); + void initExchangeReceivers(); String getErrString() const; diff --git a/dbms/src/Flash/Mpp/MPPTaskManager.cpp b/dbms/src/Flash/Mpp/MPPTaskManager.cpp index 448231ccfb2..2dd83656b5e 100644 --- a/dbms/src/Flash/Mpp/MPPTaskManager.cpp +++ b/dbms/src/Flash/Mpp/MPPTaskManager.cpp @@ -78,7 +78,7 @@ MPPTaskManager::~MPPTaskManager() MPPQueryPtr MPPTaskManager::addMPPQuery(const MPPQueryId & query_id, bool has_meaningful_gather_id) { - auto ptr = std::make_shared(has_meaningful_gather_id); + auto ptr = std::make_shared(query_id, has_meaningful_gather_id); mpp_query_map.insert({query_id, ptr}); GET_METRIC(tiflash_mpp_task_manager, type_mpp_query_count).Set(mpp_query_map.size()); return ptr; @@ -334,6 +334,7 @@ std::pair MPPTaskManager::registerTask(MPPTask * task) } gather_task_set->registerTask(task->id); task->initProcessListEntry(query->process_list_entry); + task->initQueryOperatorSpillContexts(query->mpp_query_operator_spill_contexts); return {true, ""}; } diff --git a/dbms/src/Flash/Mpp/MPPTaskManager.h b/dbms/src/Flash/Mpp/MPPTaskManager.h index 38d6acc3b63..19b2c0a75f4 100644 --- a/dbms/src/Flash/Mpp/MPPTaskManager.h +++ b/dbms/src/Flash/Mpp/MPPTaskManager.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -73,14 +74,16 @@ using MPPGatherTaskSetPtr = std::shared_ptr; struct MPPQuery { - explicit MPPQuery(bool has_meaningful_gather_id_) - : has_meaningful_gather_id(has_meaningful_gather_id_) + MPPQuery(const MPPQueryId & mpp_query_id, bool has_meaningful_gather_id_) + : mpp_query_operator_spill_contexts(std::make_shared(mpp_query_id)) + , has_meaningful_gather_id(has_meaningful_gather_id_) {} MPPGatherTaskSetPtr addMPPGatherTaskSet(const MPPGatherId & gather_id); ~MPPQuery(); std::shared_ptr process_list_entry; std::unordered_map mpp_gathers; + std::shared_ptr mpp_query_operator_spill_contexts; bool has_meaningful_gather_id; }; using MPPQueryPtr = std::shared_ptr; diff --git a/dbms/src/Flash/executeQuery.cpp b/dbms/src/Flash/executeQuery.cpp index 1556445492b..615352ecb10 100644 --- a/dbms/src/Flash/executeQuery.cpp +++ b/dbms/src/Flash/executeQuery.cpp @@ -15,7 +15,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -96,6 +98,7 @@ QueryExecutorPtr doExecuteAsBlockIO(IQuerySource & dag, Context & context, bool FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_interpreter_failpoint); auto interpreter = dag.interpreter(context, QueryProcessingStage::Complete); BlockIO res = interpreter->execute(); + /// query level memory tracker MemoryTrackerPtr memory_tracker; if (likely(process_list_entry)) { @@ -106,6 +109,24 @@ QueryExecutorPtr doExecuteAsBlockIO(IQuerySource & dag, Context & context, bool /// Hold element of process list till end of query execution. res.process_list_entry = process_list_entry; + auto auto_spill_trigger_threshold = context.getSettingsRef().auto_memory_revoke_trigger_threshold.get(); + auto auto_spill_target_threshold = context.getSettingsRef().auto_memory_revoke_target_threshold.get(); + /// if query level memory tracker has a limit, then setup auto spill trigger + if likely (memory_tracker != nullptr) + { + if (memory_tracker->getLimit() != 0 && auto_spill_trigger_threshold > 0) + { + auto auto_spill_trigger = std::make_shared( + memory_tracker, + dag_context.getQueryOperatorSpillContexts(), + auto_spill_trigger_threshold, + auto_spill_target_threshold); + dag_context.setAutoSpillTrigger(auto_spill_trigger); + auto * stream = dynamic_cast(res.in.get()); + RUNTIME_ASSERT(stream != nullptr); + stream->setAutoSpillTrigger(auto_spill_trigger.get()); + } + } if (likely(!internal)) logQueryPipeline(logger, res.in); @@ -151,7 +172,33 @@ std::optional executeAsPipeline(Context & context, bool intern memory_tracker = (*process_list_entry)->getMemoryTrackerPtr(); FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_interpreter_failpoint); - auto executor = std::make_unique(memory_tracker, context, logger->identifier()); + std::unique_ptr executor; + /// if query level memory tracker has a limit, then setup auto spill trigger + if (memory_tracker != nullptr && memory_tracker->getLimit() != 0 + && context.getSettingsRef().auto_memory_revoke_trigger_threshold.get() > 0) + { + auto register_operator_spill_context = [&context](const OperatorSpillContextPtr & operator_spill_context) { + context.getDAGContext()->registerOperatorSpillContext(operator_spill_context); + }; + auto auto_spill_trigger_threshold = context.getSettingsRef().auto_memory_revoke_trigger_threshold.get(); + auto auto_spill_target_threshold = context.getSettingsRef().auto_memory_revoke_target_threshold.get(); + auto auto_spill_trigger = std::make_shared( + memory_tracker, + dag_context.getQueryOperatorSpillContexts(), + auto_spill_trigger_threshold, + auto_spill_target_threshold); + dag_context.setAutoSpillTrigger(auto_spill_trigger); + executor = std::make_unique( + memory_tracker, + auto_spill_trigger.get(), + register_operator_spill_context, + context, + logger->identifier()); + } + else + { + executor = std::make_unique(memory_tracker, nullptr, nullptr, context, logger->identifier()); + } if (likely(!internal)) LOG_INFO(logger, fmt::format("Query pipeline:\n{}", executor->toString())); dag_context.switchToPipelineMode(); diff --git a/dbms/src/Interpreters/AggSpillContext.cpp b/dbms/src/Interpreters/AggSpillContext.cpp index 6615ce7102e..26594cdf842 100644 --- a/dbms/src/Interpreters/AggSpillContext.cpp +++ b/dbms/src/Interpreters/AggSpillContext.cpp @@ -17,7 +17,11 @@ namespace DB { -AggSpillContext::AggSpillContext(size_t concurrency, const SpillConfig & spill_config_, UInt64 operator_spill_threshold_, const LoggerPtr & log) +AggSpillContext::AggSpillContext( + size_t concurrency, + const SpillConfig & spill_config_, + UInt64 operator_spill_threshold_, + const LoggerPtr & log) : OperatorSpillContext(operator_spill_threshold_, "aggregator", log) , per_thread_revocable_memories(concurrency) , spill_config(spill_config_) @@ -54,4 +58,9 @@ Int64 AggSpillContext::getTotalRevocableMemoryImpl() ret += x; return ret; } + +Int64 AggSpillContext::triggerSpill(Int64) +{ + throw Exception("Not supported yet"); +} } // namespace DB diff --git a/dbms/src/Interpreters/AggSpillContext.h b/dbms/src/Interpreters/AggSpillContext.h index f4db01844c4..cc7313db69d 100644 --- a/dbms/src/Interpreters/AggSpillContext.h +++ b/dbms/src/Interpreters/AggSpillContext.h @@ -29,12 +29,17 @@ class AggSpillContext final : public OperatorSpillContext UInt64 per_thread_spill_threshold; public: - AggSpillContext(size_t concurrency, const SpillConfig & spill_config_, UInt64 operator_spill_threshold_, const LoggerPtr & log); + AggSpillContext( + size_t concurrency, + const SpillConfig & spill_config_, + UInt64 operator_spill_threshold_, + const LoggerPtr & log); void buildSpiller(const Block & input_schema); SpillerPtr & getSpiller() { return spiller; } - bool hasSpilledData() const { return spill_status != SpillStatus::NOT_SPILL && spiller->hasSpilledData(); } + bool hasSpilledData() const { return isSpilled() && spiller->hasSpilledData(); } bool updatePerThreadRevocableMemory(Int64 new_value, size_t thread_num); Int64 getTotalRevocableMemoryImpl() override; + Int64 triggerSpill(Int64 expected_released_memories) override; }; using AggSpillContextPtr = std::shared_ptr; diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index 98e2d40236a..56b30776507 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -78,7 +78,7 @@ bool AggregatedDataVariants::tryMarkNeedSpill() convertToTwoLevel(); } need_spill = true; - aggregator->agg_spill_context->markSpill(); + aggregator->agg_spill_context->markSpilled(); return true; } @@ -158,10 +158,11 @@ void AggregatedDataVariants::convertToTwoLevel() case AggregationMethodType(NAME): \ { \ if (aggregator) \ - LOG_TRACE(aggregator->log, \ - "Converting aggregation data type `{}` to `{}`.", \ - getMethodName(AggregationMethodType(NAME)), \ - getMethodName(AggregationMethodTypeTwoLevel(NAME))); \ + LOG_TRACE( \ + aggregator->log, \ + "Converting aggregation data type `{}` to `{}`.", \ + getMethodName(AggregationMethodType(NAME)), \ + getMethodName(AggregationMethodTypeTwoLevel(NAME))); \ auto ori_ptr = ToAggregationMethodPtr(NAME, aggregation_method_impl); \ auto two_level = std::make_unique(*ori_ptr); \ delete ori_ptr; \ @@ -226,7 +227,10 @@ Block Aggregator::Params::getHeader( if (final) type = aggregate.function->getReturnType(); else - type = std::make_shared(aggregate.function, argument_types, aggregate.parameters); + type = std::make_shared( + aggregate.function, + argument_types, + aggregate.parameters); res.insert({type, aggregate.column_name}); } @@ -272,7 +276,8 @@ Aggregator::Aggregator(const Params & params_, const String & req_id, size_t con /// Extend total_size to next alignment requirement /// Add padding by rounding up 'total_size_of_aggregate_states' to be a multiplier of alignment_of_next_state. - total_size_of_aggregate_states = (total_size_of_aggregate_states + alignment_of_next_state - 1) / alignment_of_next_state * alignment_of_next_state; + total_size_of_aggregate_states = (total_size_of_aggregate_states + alignment_of_next_state - 1) + / alignment_of_next_state * alignment_of_next_state; } if (!params.aggregates[i].function->hasTrivialDestructor()) @@ -281,7 +286,11 @@ Aggregator::Aggregator(const Params & params_, const String & req_id, size_t con method_chosen = chooseAggregationMethod(); RUNTIME_CHECK_MSG(method_chosen != AggregatedDataVariants::Type::EMPTY, "Invalid aggregation method"); - agg_spill_context = std::make_shared(concurrency, params.spill_config, params.getMaxBytesBeforeExternalGroupBy(), log); + agg_spill_context = std::make_shared( + concurrency, + params.spill_config, + params.getMaxBytesBeforeExternalGroupBy(), + log); if (agg_spill_context->isSpillEnabled()) { /// init spiller if needed @@ -297,7 +306,9 @@ Aggregator::Aggregator(const Params & params_, const String & req_id, size_t con { params.setMaxBytesBeforeExternalGroupBy(0); agg_spill_context->disableSpill(); - LOG_WARNING(log, "Aggregation does not support spill because aggregator hash table does not support two level"); + LOG_WARNING( + log, + "Aggregation does not support spill because aggregator hash table does not support two level"); } } } @@ -366,7 +377,10 @@ AggregatedDataVariants::Type ChooseAggregationMethodTwoKeys(const AggFastPathTyp } // return AggregatedDataVariants::Type::serialized if can NOT determine fast path. -AggregatedDataVariants::Type ChooseAggregationMethodFastPath(size_t keys_size, const DataTypes & types_not_null, const TiDB::TiDBCollators & collators) +AggregatedDataVariants::Type ChooseAggregationMethodFastPath( + size_t keys_size, + const DataTypes & types_not_null, + const TiDB::TiDBCollators & collators) { std::array fast_path_types{}; @@ -434,7 +448,8 @@ AggregatedDataVariants::Type Aggregator::chooseAggregationMethod() for (const auto & pos : params.keys) { - const auto & type = (params.src_header ? params.src_header : params.intermediate_header).safeGetByPosition(pos).type; + const auto & type + = (params.src_header ? params.src_header : params.intermediate_header).safeGetByPosition(pos).type; if (type->isNullable()) { @@ -457,7 +472,8 @@ AggregatedDataVariants::Type Aggregator::chooseAggregationMethod() { if (types_removed_nullable[j]->isValueUnambiguouslyRepresentedInContiguousMemoryRegion()) { - if (types_removed_nullable[j]->isValueUnambiguouslyRepresentedInFixedSizeContiguousMemoryRegion() && (params.collators.empty() || params.collators[j] == nullptr)) + if (types_removed_nullable[j]->isValueUnambiguouslyRepresentedInFixedSizeContiguousMemoryRegion() + && (params.collators.empty() || params.collators[j] == nullptr)) { ++num_fixed_contiguous_keys; key_sizes[j] = types_removed_nullable[j]->getSizeOfValueInMemory(); @@ -504,7 +520,9 @@ AggregatedDataVariants::Type Aggregator::chooseAggregationMethod() return AggregatedDataVariants::Type::keys256; if (size_of_field == sizeof(Decimal256)) return AggregatedDataVariants::Type::key_int256; - throw Exception("Logical error: numeric column has sizeOfField not in 1, 2, 4, 8, 16, 32.", ErrorCodes::LOGICAL_ERROR); + throw Exception( + "Logical error: numeric column has sizeOfField not in 1, 2, 4, 8, 16, 32.", + ErrorCodes::LOGICAL_ERROR); } /// If all keys fits in N bits, will use hash table with all keys packed (placed contiguously) to single N-bit key. @@ -634,7 +652,8 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( reinterpret_cast(method.data.data()), inst->state_offset, [&](AggregateDataPtr & aggregate_data) { - aggregate_data = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states); + aggregate_data + = aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states); createAggregateStates(aggregate_data); }, state.getKeyData(), @@ -675,7 +694,13 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst) { if (inst->offsets) - inst->batch_that->addBatchArray(rows, places.get(), inst->state_offset, inst->batch_arguments, inst->offsets, aggregates_pool); + inst->batch_that->addBatchArray( + rows, + places.get(), + inst->state_offset, + inst->batch_arguments, + inst->offsets, + aggregates_pool); else inst->batch_that->addBatch(rows, places.get(), inst->state_offset, inst->batch_arguments, aggregates_pool); } @@ -702,7 +727,11 @@ void NO_INLINE Aggregator::executeWithoutKeyImpl( } -void Aggregator::prepareAggregateInstructions(Columns columns, AggregateColumns & aggregate_columns, Columns & materialized_columns, AggregateFunctionInstructions & aggregate_functions_instructions) +void Aggregator::prepareAggregateInstructions( + Columns columns, + AggregateColumns & aggregate_columns, + Columns & materialized_columns, + AggregateFunctionInstructions & aggregate_functions_instructions) { for (size_t i = 0; i < params.aggregates_size; ++i) aggregate_columns[i].resize(params.aggregates[i].arguments.size()); @@ -796,7 +825,8 @@ bool Aggregator::executeOnBlock( if (result.type == AggregatedDataVariants::Type::without_key && !result.without_key) { - AggregateDataPtr place = result.aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states); + AggregateDataPtr place + = result.aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states); createAggregateStates(place); result.without_key = place; } @@ -806,15 +836,25 @@ bool Aggregator::executeOnBlock( /// For the case when there are no keys (all aggregate into one row). if (result.type == AggregatedDataVariants::Type::without_key) { - executeWithoutKeyImpl(result.without_key, num_rows, aggregate_functions_instructions.data(), result.aggregates_pool); + executeWithoutKeyImpl( + result.without_key, + num_rows, + aggregate_functions_instructions.data(), + result.aggregates_pool); } else { -#define M(NAME, IS_TWO_LEVEL) \ - case AggregationMethodType(NAME): \ - { \ - executeImpl(*ToAggregationMethodPtr(NAME, result.aggregation_method_impl), result.aggregates_pool, num_rows, key_columns, params.collators, aggregate_functions_instructions.data()); \ - break; \ +#define M(NAME, IS_TWO_LEVEL) \ + case AggregationMethodType(NAME): \ + { \ + executeImpl( \ + *ToAggregationMethodPtr(NAME, result.aggregation_method_impl), \ + result.aggregates_pool, \ + num_rows, \ + key_columns, \ + params.collators, \ + aggregate_functions_instructions.data()); \ + break; \ } switch (result.type) @@ -833,8 +873,8 @@ bool Aggregator::executeOnBlock( /// worth_convert_to_two_level is set to true if /// 1. some other threads already convert to two level /// 2. the result size exceeds threshold - bool worth_convert_to_two_level - = use_two_level_hash_table || (group_by_two_level_threshold && result_size >= group_by_two_level_threshold) + bool worth_convert_to_two_level = use_two_level_hash_table + || (group_by_two_level_threshold && result_size >= group_by_two_level_threshold) || (group_by_two_level_threshold_bytes && result_size_bytes >= group_by_two_level_threshold_bytes); /** Converting to a two-level data structure. @@ -869,19 +909,19 @@ BlockInputStreams Aggregator::restoreSpilledData() void Aggregator::initThresholdByAggregatedDataVariantsSize(size_t aggregated_data_variants_size) { group_by_two_level_threshold = params.getGroupByTwoLevelThreshold(); - group_by_two_level_threshold_bytes = getAverageThreshold(params.getGroupByTwoLevelThresholdBytes(), aggregated_data_variants_size); + group_by_two_level_threshold_bytes + = getAverageThreshold(params.getGroupByTwoLevelThresholdBytes(), aggregated_data_variants_size); } void Aggregator::spill(AggregatedDataVariants & data_variants) { assert(data_variants.need_spill); /// Flush only two-level data and possibly overflow data. -#define M(NAME) \ - case AggregationMethodType(NAME): \ - { \ - spillImpl(data_variants, \ - *ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl)); \ - break; \ +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + spillImpl(data_variants, *ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl)); \ + break; \ } switch (data_variants.type) @@ -909,9 +949,24 @@ Block Aggregator::convertOneBucketToBlock( bool final, size_t bucket) const { - Block block = prepareBlockAndFill(data_variants, final, method.data.impls[bucket].size(), [bucket, &method, arena, this](MutableColumns & key_columns, AggregateColumnsData & aggregate_columns, MutableColumns & final_aggregate_columns, bool final_) { - convertToBlockImpl(method, method.data.impls[bucket], key_columns, aggregate_columns, final_aggregate_columns, arena, final_); - }); + Block block = prepareBlockAndFill( + data_variants, + final, + method.data.impls[bucket].size(), + [bucket, &method, arena, this]( + MutableColumns & key_columns, + AggregateColumnsData & aggregate_columns, + MutableColumns & final_aggregate_columns, + bool final_) { + convertToBlockImpl( + method, + method.data.impls[bucket], + key_columns, + aggregate_columns, + final_aggregate_columns, + arena, + final_); + }); block.info.bucket_num = bucket; return block; @@ -925,9 +980,24 @@ BlocksList Aggregator::convertOneBucketToBlocks( bool final, size_t bucket) const { - BlocksList blocks = prepareBlocksAndFill(data_variants, final, method.data.impls[bucket].size(), [bucket, &method, arena, this](std::vector & key_columns_vec, std::vector & aggregate_columns_vec, std::vector & final_aggregate_columns_vec, bool final_) { - convertToBlocksImpl(method, method.data.impls[bucket], key_columns_vec, aggregate_columns_vec, final_aggregate_columns_vec, arena, final_); - }); + BlocksList blocks = prepareBlocksAndFill( + data_variants, + final, + method.data.impls[bucket].size(), + [bucket, &method, arena, this]( + std::vector & key_columns_vec, + std::vector & aggregate_columns_vec, + std::vector & final_aggregate_columns_vec, + bool final_) { + convertToBlocksImpl( + method, + method.data.impls[bucket], + key_columns_vec, + aggregate_columns_vec, + final_aggregate_columns_vec, + arena, + final_); + }); for (auto & block : blocks) { @@ -939,11 +1009,11 @@ BlocksList Aggregator::convertOneBucketToBlocks( template -void Aggregator::spillImpl( - AggregatedDataVariants & data_variants, - Method & method) +void Aggregator::spillImpl(AggregatedDataVariants & data_variants, Method & method) { - RUNTIME_ASSERT(agg_spill_context->getSpiller() != nullptr, "spiller must not be nullptr in Aggregator when spilling"); + RUNTIME_ASSERT( + agg_spill_context->getSpiller() != nullptr, + "spiller must not be nullptr in Aggregator when spilling"); size_t max_temporary_block_size_rows = 0; size_t max_temporary_block_size_bytes = 0; @@ -972,7 +1042,11 @@ void Aggregator::spillImpl( /// `data_variants` will not destroy them in the destructor, they are now owned by ColumnAggregateFunction objects. data_variants.aggregator = nullptr; - LOG_TRACE(log, "Max size of temporary bucket blocks: {} rows, {:.3f} MiB.", max_temporary_block_size_rows, (max_temporary_block_size_bytes / 1048576.0)); + LOG_TRACE( + log, + "Max size of temporary bucket blocks: {} rows, {:.3f} MiB.", + max_temporary_block_size_rows, + (max_temporary_block_size_bytes / 1048576.0)); } @@ -1153,8 +1227,7 @@ inline void Aggregator::insertAggregatesIntoColumns( { /// If ownership was not transferred to ColumnAggregateFunction. if (!(destroy_i < insert_i && aggregate_functions[destroy_i]->isState())) - aggregate_functions[destroy_i]->destroy( - mapped + offsets_of_aggregate_states[destroy_i]); + aggregate_functions[destroy_i]->destroy(mapped + offsets_of_aggregate_states[destroy_i]); } /// Mark the cell as destroyed so it will not be destroyed in destructor. @@ -1173,7 +1246,11 @@ struct AggregatorMethodInitKeyColumnHelper {} ALWAYS_INLINE inline void initAggKeys(size_t, std::vector &) {} template - ALWAYS_INLINE inline void insertKeyIntoColumns(const Key & key, std::vector & key_columns, const Sizes & sizes, const TiDB::TiDBCollators & collators) + ALWAYS_INLINE inline void insertKeyIntoColumns( + const Key & key, + std::vector & key_columns, + const Sizes & sizes, + const TiDB::TiDBCollators & collators) { method.insertKeyIntoColumns(key, key_columns, sizes, collators); } @@ -1196,7 +1273,11 @@ struct AggregatorMethodInitKeyColumnHelper(rows, key_columns[1]); index = 0; } - ALWAYS_INLINE inline void insertKeyIntoColumns(const StringRef & key, std::vector & key_columns, const Sizes &, const TiDB::TiDBCollators &) + ALWAYS_INLINE inline void insertKeyIntoColumns( + const StringRef & key, + std::vector & key_columns, + const Sizes &, + const TiDB::TiDBCollators &) { method.insertKeyIntoColumns(key, key_columns, index); ++index; @@ -1219,7 +1300,11 @@ struct AggregatorMethodInitKeyColumnHelper & key_columns, const Sizes &, const TiDB::TiDBCollators &) + ALWAYS_INLINE inline void insertKeyIntoColumns( + const StringRef & key, + std::vector & key_columns, + const Sizes &, + const TiDB::TiDBCollators &) { method.insertKeyIntoColumns(key, key_columns, index); ++index; @@ -1249,7 +1334,10 @@ void NO_INLINE Aggregator::convertToBlockImplFinal( namespace { template -std::optional shuffleKeyColumnsForKeyColumnsVec(Method & method, std::vector> & key_columns_vec, const Sizes & key_sizes) +std::optional shuffleKeyColumnsForKeyColumnsVec( + Method & method, + std::vector> & key_columns_vec, + const Sizes & key_sizes) { auto shuffled_key_sizes = method.shuffleKeyColumns(key_columns_vec[0], key_sizes); for (size_t i = 1; i < key_columns_vec.size(); ++i) @@ -1260,7 +1348,11 @@ std::optional shuffleKeyColumnsForKeyColumnsVec(Method & method, std::vec return shuffled_key_sizes; } template -std::vector>> initAggKeysForKeyColumnsVec(Method & method, std::vector> & key_columns_vec, size_t max_block_size, size_t total_row_count) +std::vector>> initAggKeysForKeyColumnsVec( + Method & method, + std::vector> & key_columns_vec, + size_t max_block_size, + size_t total_row_count) { std::vector>> agg_keys_helpers; size_t block_row_count = max_block_size; @@ -1293,7 +1385,8 @@ void NO_INLINE Aggregator::convertToBlocksImplFinal( size_t data_index = 0; data.forEachValue([&](const auto & key, auto & mapped) { size_t key_columns_vec_index = data_index / params.max_block_size; - agg_keys_helpers[key_columns_vec_index]->insertKeyIntoColumns(key, key_columns_vec[key_columns_vec_index], key_sizes_ref, params.collators); + agg_keys_helpers[key_columns_vec_index] + ->insertKeyIntoColumns(key, key_columns_vec[key_columns_vec_index], key_sizes_ref, params.collators); insertAggregatesIntoColumns(mapped, final_aggregate_columns_vec[key_columns_vec_index], arena); ++data_index; }); @@ -1338,7 +1431,8 @@ void NO_INLINE Aggregator::convertToBlocksImplNotFinal( size_t data_index = 0; data.forEachValue([&](const auto & key, auto & mapped) { size_t key_columns_vec_index = data_index / params.max_block_size; - agg_keys_helpers[key_columns_vec_index]->insertKeyIntoColumns(key, key_columns_vec[key_columns_vec_index], key_sizes_ref, params.collators); + agg_keys_helpers[key_columns_vec_index] + ->insertKeyIntoColumns(key, key_columns_vec[key_columns_vec_index], key_sizes_ref, params.collators); /// reserved, so push_back does not throw exceptions for (size_t i = 0; i < params.aggregates_size; ++i) @@ -1350,11 +1444,8 @@ void NO_INLINE Aggregator::convertToBlocksImplNotFinal( } template -Block Aggregator::prepareBlockAndFill( - AggregatedDataVariants & data_variants, - bool final, - size_t rows, - Filler && filler) const +Block Aggregator::prepareBlockAndFill(AggregatedDataVariants & data_variants, bool final, size_t rows, Filler && filler) + const { MutableColumns key_columns(params.keys_size); MutableColumns aggregate_columns(params.aggregates_size); @@ -1393,7 +1484,8 @@ Block Aggregator::prepareBlockAndFill( if (aggregate_functions[i]->isState()) { /// The ColumnAggregateFunction column captures the shared ownership of the arena with aggregate function states. - if (auto * column_aggregate_func = typeid_cast(final_aggregate_columns[i].get())) + if (auto * column_aggregate_func + = typeid_cast(final_aggregate_columns[i].get())) for (auto & pool : data_variants.aggregates_pools) column_aggregate_func->addArena(pool); } @@ -1489,7 +1581,8 @@ BlocksList Aggregator::prepareBlocksAndFill( if (aggregate_functions[i]->isState()) { /// The ColumnAggregateFunction column captures the shared ownership of the arena with aggregate function states. - if (auto * column_aggregate_func = typeid_cast(final_aggregate_columns[i].get())) + if (auto * column_aggregate_func + = typeid_cast(final_aggregate_columns[i].get())) for (auto & pool : data_variants.aggregates_pools) column_aggregate_func->addArena(pool); } @@ -1613,36 +1706,33 @@ BlocksList Aggregator::prepareBlocksAndFillSingleLevel(AggregatedDataVariants & template -void NO_INLINE Aggregator::mergeDataImpl( - Table & table_dst, - Table & table_src, - Arena * arena) const +void NO_INLINE Aggregator::mergeDataImpl(Table & table_dst, Table & table_src, Arena * arena) const { - table_src.mergeToViaEmplace(table_dst, - [&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted) { - if (!inserted) - { - for (size_t i = 0; i < params.aggregates_size; ++i) - aggregate_functions[i]->merge( - dst + offsets_of_aggregate_states[i], - src + offsets_of_aggregate_states[i], - arena); - - for (size_t i = 0; i < params.aggregates_size; ++i) - aggregate_functions[i]->destroy(src + offsets_of_aggregate_states[i]); - } - else - { - dst = src; - } - - src = nullptr; - }); + table_src.mergeToViaEmplace( + table_dst, + [&](AggregateDataPtr & __restrict dst, AggregateDataPtr & __restrict src, bool inserted) { + if (!inserted) + { + for (size_t i = 0; i < params.aggregates_size; ++i) + aggregate_functions[i]->merge( + dst + offsets_of_aggregate_states[i], + src + offsets_of_aggregate_states[i], + arena); + + for (size_t i = 0; i < params.aggregates_size; ++i) + aggregate_functions[i]->destroy(src + offsets_of_aggregate_states[i]); + } + else + { + dst = src; + } + + src = nullptr; + }); table_src.clearAndShrink(); } -void NO_INLINE Aggregator::mergeWithoutKeyDataImpl( - ManyAggregatedDataVariants & non_empty_data) const +void NO_INLINE Aggregator::mergeWithoutKeyDataImpl(ManyAggregatedDataVariants & non_empty_data) const { AggregatedDataVariantsPtr & res = non_empty_data[0]; @@ -1653,7 +1743,10 @@ void NO_INLINE Aggregator::mergeWithoutKeyDataImpl( AggregatedDataWithoutKey & current_data = non_empty_data[result_num]->without_key; for (size_t i = 0; i < params.aggregates_size; ++i) - aggregate_functions[i]->merge(res_data + offsets_of_aggregate_states[i], current_data + offsets_of_aggregate_states[i], res->aggregates_pool); + aggregate_functions[i]->merge( + res_data + offsets_of_aggregate_states[i], + current_data + offsets_of_aggregate_states[i], + res->aggregates_pool); for (size_t i = 0; i < params.aggregates_size; ++i) aggregate_functions[i]->destroy(current_data + offsets_of_aggregate_states[i]); @@ -1664,8 +1757,7 @@ void NO_INLINE Aggregator::mergeWithoutKeyDataImpl( template -void NO_INLINE Aggregator::mergeSingleLevelDataImpl( - ManyAggregatedDataVariants & non_empty_data) const +void NO_INLINE Aggregator::mergeSingleLevelDataImpl(ManyAggregatedDataVariants & non_empty_data) const { AggregatedDataVariantsPtr & res = non_empty_data[0]; @@ -1691,10 +1783,7 @@ APPLY_FOR_VARIANTS_SINGLE_LEVEL(M) #undef M template -void NO_INLINE Aggregator::mergeBucketImpl( - ManyAggregatedDataVariants & data, - Int32 bucket, - Arena * arena) const +void NO_INLINE Aggregator::mergeBucketImpl(ManyAggregatedDataVariants & data, Int32 bucket, Arena * arena) const { /// We merge all aggregation results to the first. AggregatedDataVariantsPtr & res = data[0]; @@ -1735,9 +1824,12 @@ MergingBucketsPtr Aggregator::mergeAndConvertToBlocks( if (non_empty_data.size() > 1) { /// Sort the states in descending order so that the merge is more efficient (since all states are merged into the first). - std::sort(non_empty_data.begin(), non_empty_data.end(), [](const AggregatedDataVariantsPtr & lhs, const AggregatedDataVariantsPtr & rhs) { - return lhs->size() > rhs->size(); - }); + std::sort( + non_empty_data.begin(), + non_empty_data.end(), + [](const AggregatedDataVariantsPtr & lhs, const AggregatedDataVariantsPtr & rhs) { + return lhs->size() > rhs->size(); + }); } /// If at least one of the options is two-level, then convert all the options into two-level ones, if there are not such. @@ -1763,14 +1855,17 @@ MergingBucketsPtr Aggregator::mergeAndConvertToBlocks( for (size_t i = 1, size = non_empty_data.size(); i < size; ++i) { if (unlikely(first->type != non_empty_data[i]->type)) - throw Exception("Cannot merge different aggregated data variants.", ErrorCodes::CANNOT_MERGE_DIFFERENT_AGGREGATED_DATA_VARIANTS); + throw Exception( + "Cannot merge different aggregated data variants.", + ErrorCodes::CANNOT_MERGE_DIFFERENT_AGGREGATED_DATA_VARIANTS); /** Elements from the remaining sets can be moved to the first data set. * Therefore, it must own all the arenas of all other sets. */ - first->aggregates_pools.insert(first->aggregates_pools.end(), - non_empty_data[i]->aggregates_pools.begin(), - non_empty_data[i]->aggregates_pools.end()); + first->aggregates_pools.insert( + first->aggregates_pools.end(), + non_empty_data[i]->aggregates_pools.begin(), + non_empty_data[i]->aggregates_pools.end()); } // for single level merge, concurrency must be 1. @@ -1795,7 +1890,8 @@ void NO_INLINE Aggregator::mergeStreamsImplCase( for (size_t i = 0; i < params.aggregates_size; ++i) { const auto & aggregate_column_name = params.aggregates[i].column_name; - aggregate_columns[i] = &typeid_cast(*block.getByName(aggregate_column_name).column).getData(); + aggregate_columns[i] + = &typeid_cast(*block.getByName(aggregate_column_name).column).getData(); } std::vector sort_key_containers; @@ -1845,19 +1941,13 @@ void NO_INLINE Aggregator::mergeStreamsImplCase( } template -void NO_INLINE Aggregator::mergeStreamsImpl( - Block & block, - Arena * aggregates_pool, - Method & method, - Table & data) const +void NO_INLINE Aggregator::mergeStreamsImpl(Block & block, Arena * aggregates_pool, Method & method, Table & data) const { mergeStreamsImplCase(block, aggregates_pool, method, data); } -void NO_INLINE Aggregator::mergeWithoutKeyStreamsImpl( - Block & block, - AggregatedDataVariants & result) const +void NO_INLINE Aggregator::mergeWithoutKeyStreamsImpl(Block & block, AggregatedDataVariants & result) const { AggregateColumnsConstData aggregate_columns(params.aggregates_size); @@ -1865,13 +1955,15 @@ void NO_INLINE Aggregator::mergeWithoutKeyStreamsImpl( for (size_t i = 0; i < params.aggregates_size; ++i) { const auto & aggregate_column_name = params.aggregates[i].column_name; - aggregate_columns[i] = &typeid_cast(*block.getByName(aggregate_column_name).column).getData(); + aggregate_columns[i] + = &typeid_cast(*block.getByName(aggregate_column_name).column).getData(); } AggregatedDataWithoutKey & res = result.without_key; if (!res) { - AggregateDataPtr place = result.aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states); + AggregateDataPtr place + = result.aggregates_pool->alignedAlloc(total_size_of_aggregate_states, align_aggregate_states); createAggregateStates(place); res = place; } @@ -1880,7 +1972,10 @@ void NO_INLINE Aggregator::mergeWithoutKeyStreamsImpl( { /// Adding Values for (size_t i = 0; i < params.aggregates_size; ++i) - aggregate_functions[i]->merge(res + offsets_of_aggregate_states[i], (*aggregate_columns[i])[0], result.aggregates_pool); + aggregate_functions[i]->merge( + res + offsets_of_aggregate_states[i], + (*aggregate_columns[i])[0], + result.aggregates_pool); } /// Early release memory. @@ -1893,7 +1988,11 @@ BlocksList Aggregator::vstackBlocks(BlocksList & blocks, bool final) auto bucket_num = blocks.front().info.bucket_num; - LOG_TRACE(log, "Merging partially aggregated blocks (bucket = {}). Original method `{}`.", bucket_num, AggregatedDataVariants::getMethodName(method_chosen)); + LOG_TRACE( + log, + "Merging partially aggregated blocks (bucket = {}). Original method `{}`.", + bucket_num, + AggregatedDataVariants::getMethodName(method_chosen)); Stopwatch watch; /** If possible, change 'method' to some_hash64. Otherwise, leave as is. @@ -1945,14 +2044,15 @@ BlocksList Aggregator::vstackBlocks(BlocksList & blocks, bool final) if (result.type == AggregatedDataVariants::Type::without_key) mergeWithoutKeyStreamsImpl(block, result); -#define M(NAME, IS_TWO_LEVEL) \ - case AggregationMethodType(NAME): \ - { \ - mergeStreamsImpl(block, \ - result.aggregates_pool, \ - *ToAggregationMethodPtr(NAME, result.aggregation_method_impl), \ - ToAggregationMethodPtr(NAME, result.aggregation_method_impl)->data); \ - break; \ +#define M(NAME, IS_TWO_LEVEL) \ + case AggregationMethodType(NAME): \ + { \ + mergeStreamsImpl( \ + block, \ + result.aggregates_pool, \ + *ToAggregationMethodPtr(NAME, result.aggregation_method_impl), \ + ToAggregationMethodPtr(NAME, result.aggregation_method_impl)->data); \ + break; \ } switch (result.type) { @@ -1989,7 +2089,8 @@ BlocksList Aggregator::vstackBlocks(BlocksList & blocks, bool final) double elapsed_seconds = watch.elapsedSeconds(); LOG_TRACE( log, - "Merged partially aggregated blocks. Return {} rows in {} blocks, {:.3f} MiB. in {:.3f} sec. ({:.3f} rows/sec., {:.3f} MiB/sec.)", + "Merged partially aggregated blocks. Return {} rows in {} blocks, {:.3f} MiB. in {:.3f} sec. ({:.3f} " + "rows/sec., {:.3f} MiB/sec.)", rows, return_blocks.size(), bytes / 1048576.0, @@ -2090,12 +2191,11 @@ Blocks Aggregator::convertBlockToTwoLevel(const Block & block) size_t num_buckets = 0; -#define M(NAME) \ - case AggregationMethodType(NAME): \ - { \ - num_buckets \ - = ToAggregationMethodPtr(NAME, data.aggregation_method_impl)->data.NUM_BUCKETS; \ - break; \ +#define M(NAME) \ + case AggregationMethodType(NAME): \ + { \ + num_buckets = ToAggregationMethodPtr(NAME, data.aggregation_method_impl)->data.NUM_BUCKETS; \ + break; \ } switch (data.type) @@ -2204,7 +2304,11 @@ void Aggregator::setCancellationHook(CancellationHook cancellation_hook) is_cancelled = cancellation_hook; } -MergingBuckets::MergingBuckets(const Aggregator & aggregator_, const ManyAggregatedDataVariants & data_, bool final_, size_t concurrency_) +MergingBuckets::MergingBuckets( + const Aggregator & aggregator_, + const ManyAggregatedDataVariants & data_, + bool final_, + size_t concurrency_) : log(Logger::get(aggregator_.log ? aggregator_.log->identifier() : "")) , aggregator(aggregator_) , data(data_) @@ -2270,9 +2374,7 @@ Block MergingBuckets::getDataForSingleLevel() if (first->type == AggregatedDataVariants::Type::without_key) { aggregator.mergeWithoutKeyDataImpl(data); - single_level_blocks = aggregator.prepareBlocksAndFillWithoutKey( - *first, - final); + single_level_blocks = aggregator.prepareBlocksAndFillWithoutKey(*first, final); } else { diff --git a/dbms/src/Interpreters/HashJoinSpillContext.cpp b/dbms/src/Interpreters/HashJoinSpillContext.cpp index 20fb72c90c1..00ffa0537c1 100644 --- a/dbms/src/Interpreters/HashJoinSpillContext.cpp +++ b/dbms/src/Interpreters/HashJoinSpillContext.cpp @@ -16,21 +16,27 @@ namespace DB { -HashJoinSpillContext::HashJoinSpillContext(const SpillConfig & build_spill_config_, const SpillConfig & probe_spill_config_, UInt64 operator_spill_threshold, const LoggerPtr & log) +HashJoinSpillContext::HashJoinSpillContext( + const SpillConfig & build_spill_config_, + const SpillConfig & probe_spill_config_, + UInt64 operator_spill_threshold, + const LoggerPtr & log) : OperatorSpillContext(operator_spill_threshold, "join", log) , build_spill_config(build_spill_config_) , probe_spill_config(probe_spill_config_) - , max_cached_bytes(std::max(build_spill_config.max_cached_data_bytes_in_spiller, probe_spill_config.max_cached_data_bytes_in_spiller)) + , max_cached_bytes(std::max( + build_spill_config.max_cached_data_bytes_in_spiller, + probe_spill_config.max_cached_data_bytes_in_spiller)) {} void HashJoinSpillContext::init(size_t partition_num) { partition_revocable_memories = std::make_unique>>(partition_num); - partition_spill_status = std::make_unique>>(partition_num); + partition_is_spilled = std::make_unique>>(partition_num); for (auto & memory : *partition_revocable_memories) memory = 0; - for (auto & status : *partition_spill_status) - status = SpillStatus::NOT_SPILL; + for (auto & status : *partition_is_spilled) + status = false; } Int64 HashJoinSpillContext::getTotalRevocableMemoryImpl() @@ -43,18 +49,28 @@ Int64 HashJoinSpillContext::getTotalRevocableMemoryImpl() void HashJoinSpillContext::buildBuildSpiller(const Block & input_schema) { - build_spiller = std::make_unique(build_spill_config, false, (*partition_revocable_memories).size(), input_schema, log); + build_spiller = std::make_unique( + build_spill_config, + false, + (*partition_revocable_memories).size(), + input_schema, + log); } void HashJoinSpillContext::buildProbeSpiller(const Block & input_schema) { - probe_spiller = std::make_unique(probe_spill_config, false, (*partition_revocable_memories).size(), input_schema, log); + probe_spiller = std::make_unique( + probe_spill_config, + false, + (*partition_revocable_memories).size(), + input_schema, + log); } -void HashJoinSpillContext::markPartitionSpill(size_t partition_index) +void HashJoinSpillContext::markPartitionSpilled(size_t partition_index) { - markSpill(); - (*partition_spill_status)[partition_index] = SpillStatus::SPILL; + markSpilled(); + (*partition_is_spilled)[partition_index] = true; } bool HashJoinSpillContext::updatePartitionRevocableMemory(size_t partition_id, Int64 new_value) @@ -62,9 +78,10 @@ bool HashJoinSpillContext::updatePartitionRevocableMemory(size_t partition_id, I (*partition_revocable_memories)[partition_id] = new_value; /// this function only trigger spill if current partition is already chosen to spill /// the new partition to spill is chosen in getPartitionsToSpill - if ((*partition_spill_status)[partition_id] == SpillStatus::NOT_SPILL) + if (!(*partition_is_spilled)[partition_id]) return false; - auto force_spill = operator_spill_threshold > 0 && getTotalRevocableMemoryImpl() > static_cast(operator_spill_threshold); + auto force_spill + = operator_spill_threshold > 0 && getTotalRevocableMemoryImpl() > static_cast(operator_spill_threshold); if (force_spill || (max_cached_bytes > 0 && (*partition_revocable_memories)[partition_id] > max_cached_bytes)) { (*partition_revocable_memories)[partition_id] = 0; @@ -75,11 +92,23 @@ bool HashJoinSpillContext::updatePartitionRevocableMemory(size_t partition_id, I SpillConfig HashJoinSpillContext::createBuildSpillConfig(const String & spill_id) const { - return SpillConfig(build_spill_config.spill_dir, spill_id, build_spill_config.max_cached_data_bytes_in_spiller, build_spill_config.max_spilled_rows_per_file, build_spill_config.max_spilled_bytes_per_file, build_spill_config.file_provider); + return SpillConfig( + build_spill_config.spill_dir, + spill_id, + build_spill_config.max_cached_data_bytes_in_spiller, + build_spill_config.max_spilled_rows_per_file, + build_spill_config.max_spilled_bytes_per_file, + build_spill_config.file_provider); } SpillConfig HashJoinSpillContext::createProbeSpillConfig(const String & spill_id) const { - return SpillConfig(probe_spill_config.spill_dir, spill_id, build_spill_config.max_cached_data_bytes_in_spiller, build_spill_config.max_spilled_rows_per_file, build_spill_config.max_spilled_bytes_per_file, build_spill_config.file_provider); + return SpillConfig( + probe_spill_config.spill_dir, + spill_id, + build_spill_config.max_cached_data_bytes_in_spiller, + build_spill_config.max_spilled_rows_per_file, + build_spill_config.max_spilled_bytes_per_file, + build_spill_config.file_provider); } std::vector HashJoinSpillContext::getPartitionsToSpill() @@ -111,4 +140,8 @@ std::vector HashJoinSpillContext::getPartitionsToSpill() return ret; } +Int64 HashJoinSpillContext::triggerSpill(Int64) +{ + throw Exception("Not supported yet"); +} } // namespace DB diff --git a/dbms/src/Interpreters/HashJoinSpillContext.h b/dbms/src/Interpreters/HashJoinSpillContext.h index 66b669c2d3e..06ee373a053 100644 --- a/dbms/src/Interpreters/HashJoinSpillContext.h +++ b/dbms/src/Interpreters/HashJoinSpillContext.h @@ -23,7 +23,7 @@ namespace DB class HashJoinSpillContext final : public OperatorSpillContext { private: - std::unique_ptr>> partition_spill_status; + std::unique_ptr>> partition_is_spilled; std::unique_ptr>> partition_revocable_memories; SpillConfig build_spill_config; SpillerPtr build_spiller; @@ -32,19 +32,24 @@ class HashJoinSpillContext final : public OperatorSpillContext Int64 max_cached_bytes; public: - HashJoinSpillContext(const SpillConfig & build_spill_config_, const SpillConfig & probe_spill_config_, UInt64 operator_spill_threshold_, const LoggerPtr & log); + HashJoinSpillContext( + const SpillConfig & build_spill_config_, + const SpillConfig & probe_spill_config_, + UInt64 operator_spill_threshold_, + const LoggerPtr & log); void init(size_t partition_num); void buildBuildSpiller(const Block & input_schema); void buildProbeSpiller(const Block & input_schema); SpillerPtr & getBuildSpiller() { return build_spiller; } SpillerPtr & getProbeSpiller() { return probe_spiller; } - bool isPartitionSpilled(size_t partition_index) const { return (*partition_spill_status)[partition_index] != SpillStatus::NOT_SPILL; } - void markPartitionSpill(size_t partition_index); + bool isPartitionSpilled(size_t partition_index) const { return (*partition_is_spilled)[partition_index]; } + void markPartitionSpilled(size_t partition_index); bool updatePartitionRevocableMemory(size_t partition_id, Int64 new_value); Int64 getTotalRevocableMemoryImpl() override; SpillConfig createBuildSpillConfig(const String & spill_id) const; SpillConfig createProbeSpillConfig(const String & spill_id) const; std::vector getPartitionsToSpill(); + Int64 triggerSpill(Int64 expected_released_memories) override; }; using HashJoinSpillContextPtr = std::shared_ptr; diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.cpp b/dbms/src/Interpreters/InterpreterSelectQuery.cpp index 3b7361979c1..a1b27b34f05 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectQuery.cpp @@ -165,7 +165,8 @@ void InterpreterSelectQuery::init(const Names & required_result_column_names) else if (table_expression && typeid_cast(table_expression.get())) { /// Read from subquery. - source_columns = InterpreterSelectWithUnionQuery::getSampleBlock(table_expression, context).getNamesAndTypesList(); + source_columns + = InterpreterSelectWithUnionQuery::getSampleBlock(table_expression, context).getNamesAndTypesList(); } else if (table_expression && typeid_cast(table_expression.get())) { @@ -199,10 +200,15 @@ void InterpreterSelectQuery::init(const Names & required_result_column_names) throw Exception("Illegal SAMPLE: table doesn't support sampling", ErrorCodes::SAMPLING_NOT_SUPPORTED); if (query.final() && (input || !storage || !storage->supportsFinal())) - throw Exception((!input && storage) ? "Storage " + storage->getName() + " doesn't support FINAL" : "Illegal FINAL", ErrorCodes::ILLEGAL_FINAL); + throw Exception( + (!input && storage) ? "Storage " + storage->getName() + " doesn't support FINAL" : "Illegal FINAL", + ErrorCodes::ILLEGAL_FINAL); if (query.prewhere_expression && (input || !storage || !storage->supportsPrewhere())) - throw Exception((!input && storage) ? "Storage " + storage->getName() + " doesn't support PREWHERE" : "Illegal PREWHERE", ErrorCodes::ILLEGAL_PREWHERE); + throw Exception( + (!input && storage) ? "Storage " + storage->getName() + " doesn't support PREWHERE" + : "Illegal PREWHERE", + ErrorCodes::ILLEGAL_PREWHERE); /// Save the new temporary tables in the query context for (const auto & it : query_analyzer->getExternalTables()) @@ -226,7 +232,10 @@ void InterpreterSelectQuery::getAndLockStorageWithSchemaVersion(const String & d context.getTMTContext().getSchemaSyncerManager()->syncSchemas(context, NullspaceID); auto storage_tmp = context.getTable(database_name, table_name); auto managed_storage = std::dynamic_pointer_cast(storage_tmp); - if (!managed_storage || !(managed_storage->engineType() == ::TiDB::StorageEngine::DT || managed_storage->engineType() == ::TiDB::StorageEngine::TMT)) + if (!managed_storage + || !( + managed_storage->engineType() == ::TiDB::StorageEngine::DT + || managed_storage->engineType() == ::TiDB::StorageEngine::TMT)) { LOG_DEBUG(log, "{}.{} is not ManageableStorage", database_name, table_name); storage = storage_tmp; @@ -234,8 +243,12 @@ void InterpreterSelectQuery::getAndLockStorageWithSchemaVersion(const String & d return; } - context.getTMTContext().getSchemaSyncerManager()->syncTableSchema(context, NullspaceID, managed_storage->getTableInfo().id); - auto schema_sync_cost = std::chrono::duration_cast(Clock::now() - start_time).count(); + context.getTMTContext().getSchemaSyncerManager()->syncTableSchema( + context, + NullspaceID, + managed_storage->getTableInfo().id); + auto schema_sync_cost + = std::chrono::duration_cast(Clock::now() - start_time).count(); LOG_DEBUG(log, "Table {} schema sync cost {}ms.", qualified_name, schema_sync_cost); table_lock = storage_tmp->lockForShare(context.getCurrentQueryId()); @@ -311,11 +324,11 @@ InterpreterSelectQuery::AnalysisResult InterpreterSelectQuery::analyzeExpression AnalysisResult res; /// Do I need to perform the first part of the pipeline - running on remote servers during distributed processing. - res.first_stage = from_stage < QueryProcessingStage::WithMergeableState - && to_stage >= QueryProcessingStage::WithMergeableState; + res.first_stage + = from_stage < QueryProcessingStage::WithMergeableState && to_stage >= QueryProcessingStage::WithMergeableState; /// Do I need to execute the second part of the pipeline - running on the initiating server during distributed processing. - res.second_stage = from_stage <= QueryProcessingStage::WithMergeableState - && to_stage > QueryProcessingStage::WithMergeableState; + res.second_stage + = from_stage <= QueryProcessingStage::WithMergeableState && to_stage > QueryProcessingStage::WithMergeableState; /** First we compose a chain of actions and remember the necessary steps from it. * Regardless of from_stage and to_stage, we will compose a complete sequence of actions to perform optimization and @@ -361,7 +374,8 @@ InterpreterSelectQuery::AnalysisResult InterpreterSelectQuery::analyzeExpression /// If there is aggregation, we execute expressions in SELECT and ORDER BY on the initiating server, otherwise on the source servers. query_analyzer->appendSelect(chain, res.need_aggregate ? !res.second_stage : !res.first_stage); res.selected_columns = chain.getLastStep().required_output; - res.has_order_by = query_analyzer->appendOrderBy(chain, res.need_aggregate ? !res.second_stage : !res.first_stage); + res.has_order_by + = query_analyzer->appendOrderBy(chain, res.need_aggregate ? !res.second_stage : !res.first_stage); res.before_order_and_select = chain.getLastActions(); chain.addStep(); @@ -413,7 +427,11 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt throw Exception("Distributed on Distributed is not supported", ErrorCodes::NOT_IMPLEMENTED); if (!dry_run) - LOG_TRACE(log, "{} -> {}", QueryProcessingStage::toString(from_stage), QueryProcessingStage::toString(to_stage)); + LOG_TRACE( + log, + "{} -> {}", + QueryProcessingStage::toString(from_stage), + QueryProcessingStage::toString(to_stage)); AnalysisResult expressions = analyzeExpressions(from_stage); @@ -431,7 +449,8 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt if (expressions.has_join) { for (auto & stream : pipeline.streams) - stream = std::make_shared(stream, expressions.before_join, /*req_id=*/""); + stream + = std::make_shared(stream, expressions.before_join, /*req_id=*/""); } if (expressions.has_where) @@ -502,14 +521,13 @@ void InterpreterSelectQuery::executeImpl(Pipeline & pipeline, const BlockInputSt /** Optimization - if there are several sources and there is LIMIT, then first apply the preliminary LIMIT, * limiting the number of rows in each up to `offset + limit`. */ - if (query.limit_length && pipeline.hasMoreThanOneStream() && !query.distinct && !expressions.has_limit_by && !settings.extremes) + if (query.limit_length && pipeline.hasMoreThanOneStream() && !query.distinct && !expressions.has_limit_by + && !settings.extremes) { executePreLimit(pipeline); } - if (need_second_distinct_pass - || query.limit_length - || query.limit_by_expression_list) + if (need_second_distinct_pass || query.limit_length || query.limit_by_expression_list) { need_merge_streams = true; } @@ -584,7 +602,8 @@ QueryProcessingStage::Enum InterpreterSelectQuery::executeFetchColumns(Pipeline { const auto default_it = column_defaults.find(column); if (default_it != std::end(column_defaults) && default_it->second.kind == ColumnDefaultKind::Alias) - required_columns_expr_list->children.emplace_back(setAlias(default_it->second.expression->clone(), column)); + required_columns_expr_list->children.emplace_back( + setAlias(default_it->second.expression->clone(), column)); else required_columns_expr_list->children.emplace_back(std::make_shared(column)); } @@ -643,16 +662,9 @@ QueryProcessingStage::Enum InterpreterSelectQuery::executeFetchColumns(Pipeline * then as the block size we will use limit + offset (not to read more from the table than requested), * and also set the number of threads to 1. */ - if (!query.distinct - && !query.prewhere_expression - && !query.where_expression - && !query.group_expression_list - && !query.having_expression - && !query.order_expression_list - && !query.limit_by_expression_list - && query.limit_length - && !query_analyzer->hasAggregation() - && limit_length + limit_offset < max_block_size) + if (!query.distinct && !query.prewhere_expression && !query.where_expression && !query.group_expression_list + && !query.having_expression && !query.order_expression_list && !query.limit_by_expression_list + && query.limit_length && !query_analyzer->hasAggregation() && limit_length + limit_offset < max_block_size) { max_block_size = limit_length + limit_offset; max_streams = 1; @@ -672,7 +684,8 @@ QueryProcessingStage::Enum InterpreterSelectQuery::executeFetchColumns(Pipeline if (!dry_run) pipeline.streams = interpreter_subquery->executeWithMultipleStreams(); else - pipeline.streams.emplace_back(std::make_shared(interpreter_subquery->getSampleBlock())); + pipeline.streams.emplace_back( + std::make_shared(interpreter_subquery->getSampleBlock())); } else if (storage) { @@ -687,7 +700,8 @@ QueryProcessingStage::Enum InterpreterSelectQuery::executeFetchColumns(Pipeline query_info.query = query_ptr; query_info.sets = query_analyzer->getPreparedSets(); auto scan_context = std::make_shared(); - query_info.mvcc_query_info = std::make_unique(settings.resolve_locks, settings.read_tso, scan_context); + query_info.mvcc_query_info + = std::make_unique(settings.resolve_locks, settings.read_tso, scan_context); const String & request_str = settings.regions; @@ -728,7 +742,9 @@ QueryProcessingStage::Enum InterpreterSelectQuery::executeFetchColumns(Pipeline } if (query_info.mvcc_query_info->regions_query_info.empty()) - throw Exception("[InterpreterSelectQuery::executeFetchColumns] no region query", ErrorCodes::LOGICAL_ERROR); + throw Exception( + "[InterpreterSelectQuery::executeFetchColumns] no region query", + ErrorCodes::LOGICAL_ERROR); } /// PARTITION SELECT only supports MergeTree family now. @@ -759,12 +775,14 @@ QueryProcessingStage::Enum InterpreterSelectQuery::executeFetchColumns(Pipeline if (likely(!select_query->no_kvstore)) { auto table_info = managed_storage->getTableInfo(); - learner_read_snapshot = doLearnerRead(table_info.id, *query_info.mvcc_query_info, false, context, log); + learner_read_snapshot + = doLearnerRead(table_info.id, *query_info.mvcc_query_info, false, context, log); } } } - pipeline.streams = storage->read(required_columns, query_info, context, from_stage, max_block_size, max_streams); + pipeline.streams + = storage->read(required_columns, query_info, context, from_stage, max_block_size, max_streams); if (!learner_read_snapshot.empty()) { @@ -773,11 +791,10 @@ QueryProcessingStage::Enum InterpreterSelectQuery::executeFetchColumns(Pipeline } if (pipeline.streams.empty()) - pipeline.streams.emplace_back(std::make_shared(storage->getSampleBlockForColumns(required_columns))); + pipeline.streams.emplace_back( + std::make_shared(storage->getSampleBlockForColumns(required_columns))); - pipeline.transform([&](auto & stream) { - stream->addTableLock(table_lock); - }); + pipeline.transform([&](auto & stream) { stream->addTableLock(table_lock); }); } else throw Exception("Logical error in InterpreterSelectQuery: nowhere to read", ErrorCodes::LOGICAL_ERROR); @@ -797,12 +814,19 @@ QueryProcessingStage::Enum InterpreterSelectQuery::executeFetchColumns(Pipeline void InterpreterSelectQuery::executeWhere(Pipeline & pipeline, const ExpressionActionsPtr & expression) { pipeline.transform([&](auto & stream) { - stream = std::make_shared(stream, expression, query.where_expression->getColumnName(), /*req_id=*/""); + stream = std::make_shared( + stream, + expression, + query.where_expression->getColumnName(), + /*req_id=*/""); }); } -void InterpreterSelectQuery::executeAggregation(Pipeline & pipeline, const ExpressionActionsPtr & expression, bool final) +void InterpreterSelectQuery::executeAggregation( + Pipeline & pipeline, + const ExpressionActionsPtr & expression, + bool final) { pipeline.transform([&](auto & stream) { stream = std::make_shared(stream, expression, /*req_id=*/""); @@ -827,7 +851,8 @@ void InterpreterSelectQuery::executeAggregation(Pipeline & pipeline, const Expre * 1. Parallel aggregation is done, and the results should be merged in parallel. * 2. An aggregation is done with store of temporary data on the disk, and they need to be merged in a memory efficient way. */ - bool allow_to_use_two_level_group_by = pipeline.streams.size() > 1 || settings.max_bytes_before_external_group_by != 0; + bool allow_to_use_two_level_group_by + = pipeline.streams.size() > 1 || settings.max_bytes_before_external_group_by != 0; SpillConfig spill_config( context.getTemporaryPath(), @@ -836,7 +861,16 @@ void InterpreterSelectQuery::executeAggregation(Pipeline & pipeline, const Expre settings.max_spilled_rows_per_file, settings.max_spilled_bytes_per_file, context.getFileProvider()); - Aggregator::Params params(header, keys, aggregates, allow_to_use_two_level_group_by ? settings.group_by_two_level_threshold : SettingUInt64(0), allow_to_use_two_level_group_by ? settings.group_by_two_level_threshold_bytes : SettingUInt64(0), settings.max_bytes_before_external_group_by, false, spill_config, settings.max_block_size); + Aggregator::Params params( + header, + keys, + aggregates, + allow_to_use_two_level_group_by ? settings.group_by_two_level_threshold : SettingUInt64(0), + allow_to_use_two_level_group_by ? settings.group_by_two_level_threshold_bytes : SettingUInt64(0), + settings.max_bytes_before_external_group_by, + false, + spill_config, + settings.max_block_size); /// If there are several sources, then we perform parallel aggregation if (pipeline.streams.size() > 1) @@ -902,14 +936,27 @@ void InterpreterSelectQuery::executeMergeAggregated(Pipeline & pipeline, bool fi const Settings & settings = context.getSettingsRef(); - Aggregator::Params params(header, keys, aggregates, SpillConfig(context.getTemporaryPath(), "aggregation", settings.max_cached_data_bytes_in_spiller, settings.max_spilled_rows_per_file, settings.max_spilled_bytes_per_file, context.getFileProvider()), settings.max_block_size); + Aggregator::Params params( + header, + keys, + aggregates, + SpillConfig( + context.getTemporaryPath(), + "aggregation", + settings.max_cached_data_bytes_in_spiller, + settings.max_spilled_rows_per_file, + settings.max_spilled_bytes_per_file, + context.getFileProvider()), + settings.max_block_size); pipeline.firstStream() = std::make_shared( pipeline.streams, params, final, max_streams, - settings.aggregation_memory_efficient_merge_threads ? static_cast(settings.aggregation_memory_efficient_merge_threads) : static_cast(settings.max_threads), + settings.aggregation_memory_efficient_merge_threads + ? static_cast(settings.aggregation_memory_efficient_merge_threads) + : static_cast(settings.max_threads), /*req_id=*/""); pipeline.streams.resize(1); @@ -919,7 +966,11 @@ void InterpreterSelectQuery::executeMergeAggregated(Pipeline & pipeline, bool fi void InterpreterSelectQuery::executeHaving(Pipeline & pipeline, const ExpressionActionsPtr & expression) { pipeline.transform([&](auto & stream) { - stream = std::make_shared(stream, expression, query.having_expression->getColumnName(), /*req_id=*/""); + stream = std::make_shared( + stream, + expression, + query.having_expression->getColumnName(), + /*req_id=*/""); }); } @@ -942,7 +993,8 @@ static SortDescription getSortDescription(ASTSelectQuery & query) std::shared_ptr collator; if (order_by_elem.collation) - collator = std::make_shared(typeid_cast(*order_by_elem.collation).value.get()); + collator = std::make_shared( + typeid_cast(*order_by_elem.collation).value.get()); order_descr.emplace_back(name, order_by_elem.direction, order_by_elem.nulls_direction, collator); } @@ -987,8 +1039,15 @@ void InterpreterSelectQuery::executeOrder(Pipeline & pipeline) settings.max_block_size, limit, settings.max_bytes_before_external_sort, - SpillConfig(context.getTemporaryPath(), "sort", settings.max_cached_data_bytes_in_spiller, settings.max_spilled_rows_per_file, settings.max_spilled_bytes_per_file, context.getFileProvider()), - /*req_id=*/""); + SpillConfig( + context.getTemporaryPath(), + "sort", + settings.max_cached_data_bytes_in_spiller, + settings.max_spilled_rows_per_file, + settings.max_spilled_bytes_per_file, + context.getFileProvider()), + /*req_id=*/"", + [](const OperatorSpillContextPtr &) {}); } @@ -1005,12 +1064,14 @@ void InterpreterSelectQuery::executeMergeSorted(Pipeline & pipeline) /** MergingSortedBlockInputStream reads the sources sequentially. * To make the data on the remote servers prepared in parallel, we wrap it in AsynchronousBlockInputStream. */ - pipeline.transform([&](auto & stream) { - stream = std::make_shared(stream); - }); + pipeline.transform([&](auto & stream) { stream = std::make_shared(stream); }); /// Merge the sorted sources into one sorted source. - pipeline.firstStream() = std::make_shared(pipeline.streams, order_descr, settings.max_block_size, limit); + pipeline.firstStream() = std::make_shared( + pipeline.streams, + order_descr, + settings.max_block_size, + limit); pipeline.streams.resize(1); } } @@ -1041,7 +1102,10 @@ void InterpreterSelectQuery::executeDistinct(Pipeline & pipeline, bool before_or limit_for_distinct = limit_length + limit_offset; pipeline.transform([&](auto & stream) { - SizeLimits limits(settings.max_rows_in_distinct, settings.max_bytes_in_distinct, settings.distinct_overflow_mode); + SizeLimits limits( + settings.max_rows_in_distinct, + settings.max_bytes_in_distinct, + settings.distinct_overflow_mode); if (stream->isGroupedOutput()) stream = std::make_shared(stream, limits, limit_for_distinct, columns); @@ -1088,7 +1152,11 @@ void InterpreterSelectQuery::executePreLimit(Pipeline & pipeline) if (limit_length) { pipeline.transform([&](auto & stream) { - stream = std::make_shared(stream, limit_length + limit_offset, /* offset */ 0, /*req_id=*/""); + stream = std::make_shared( + stream, + limit_length + limit_offset, + /* offset */ 0, + /*req_id=*/""); }); } } @@ -1122,7 +1190,9 @@ void InterpreterSelectQuery::executeExtremes(Pipeline & pipeline) } -void InterpreterSelectQuery::executeSubqueriesInSetsAndJoins(Pipeline & pipeline, SubqueriesForSets & subqueries_for_sets) +void InterpreterSelectQuery::executeSubqueriesInSetsAndJoins( + Pipeline & pipeline, + SubqueriesForSets & subqueries_for_sets) { const Settings & settings = context.getSettingsRef(); diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index 58bc6b1edfd..2187a580708 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -65,7 +65,11 @@ ColumnRawPtrs getKeyColumns(const Names & key_names, const Block & block) return key_columns; } -size_t getRestoreJoinBuildConcurrency(size_t total_partitions, size_t spilled_partitions, Int64 join_restore_concurrency, size_t total_concurrency) +size_t getRestoreJoinBuildConcurrency( + size_t total_partitions, + size_t spilled_partitions, + Int64 join_restore_concurrency, + size_t total_concurrency) { if (join_restore_concurrency < 0) { @@ -83,12 +87,14 @@ size_t getRestoreJoinBuildConcurrency(size_t total_partitions, size_t spilled_pa size_t unspilled_partitions = total_partitions - spilled_partitions; /// try to restore at most (unspilled_partitions - 1) partitions at a time size_t max_concurrent_restore_partition = unspilled_partitions <= 1 ? 1 : unspilled_partitions - 1; - size_t restore_times = (spilled_partitions + max_concurrent_restore_partition - 1) / max_concurrent_restore_partition; + size_t restore_times + = (spilled_partitions + max_concurrent_restore_partition - 1) / max_concurrent_restore_partition; size_t restore_build_concurrency = (restore_times * total_concurrency) / spilled_partitions; return std::max(2, restore_build_concurrency); } } -std::pair getDataAndNullMapVectorFromFilterColumn(ColumnPtr & filter_column) +std::pair getDataAndNullMapVectorFromFilterColumn( + ColumnPtr & filter_column) { if (filter_column->isColumnConst()) filter_column = filter_column->convertToFullColumnIfConst(); @@ -155,7 +161,9 @@ Join::Join( , max_block_size(max_block_size_) , runtime_filter_list(runtime_filter_list_) , join_restore_concurrency(join_restore_concurrency_) - , shallow_copy_cross_probe_threshold(shallow_copy_cross_probe_threshold_ > 0 ? shallow_copy_cross_probe_threshold_ : std::max(1, max_block_size / 10)) + , shallow_copy_cross_probe_threshold( + shallow_copy_cross_probe_threshold_ > 0 ? shallow_copy_cross_probe_threshold_ + : std::max(1, max_block_size / 10)) , tidb_output_column_names(tidb_output_column_names_) , is_test(is_test_) , log(Logger::get(req_id)) @@ -183,7 +191,11 @@ Join::Join( if (unlikely(!err.empty())) throw Exception("Validate join conditions error: {}" + err); - hash_join_spill_context = std::make_shared(build_spill_config, probe_spill_config, max_bytes_before_external_join, log); + hash_join_spill_context = std::make_shared( + build_spill_config, + probe_spill_config, + max_bytes_before_external_join, + log); size_t max_restore_round = 4; #ifdef DBMS_PUBLIC_GTEST max_restore_round = MAX_RESTORE_ROUND_IN_GTEST; @@ -195,7 +207,11 @@ Join::Join( hash_join_spill_context->disableSpill(); } - LOG_DEBUG(log, "FineGrainedShuffle flag {}, stream count {}", enable_fine_grained_shuffle, fine_grained_shuffle_count); + LOG_DEBUG( + log, + "FineGrainedShuffle flag {}, stream count {}", + enable_fine_grained_shuffle, + fine_grained_shuffle_count); } void Join::meetError(const String & error_message_) @@ -274,7 +290,9 @@ size_t Join::getPeakBuildBytesUsage() void Join::setBuildConcurrencyAndInitJoinPartition(size_t build_concurrency_) { if (unlikely(build_concurrency > 0)) - throw Exception("Logical error: `setBuildConcurrencyAndInitJoinPartition` shouldn't be called more than once", ErrorCodes::LOGICAL_ERROR); + throw Exception( + "Logical error: `setBuildConcurrencyAndInitJoinPartition` shouldn't be called more than once", + ErrorCodes::LOGICAL_ERROR); /// do not set active_build_threads because in compile stage, `joinBlock` will be called to get generate header, if active_build_threads /// is set here, `joinBlock` will hang when used to get header build_concurrency = std::max(1, build_concurrency_); @@ -282,7 +300,15 @@ void Join::setBuildConcurrencyAndInitJoinPartition(size_t build_concurrency_) partitions.reserve(build_concurrency); for (size_t i = 0; i < getBuildConcurrency(); ++i) { - partitions.push_back(std::make_unique(join_map_method, kind, strictness, i, max_block_size, hash_join_spill_context, log, has_other_condition)); + partitions.push_back(std::make_unique( + join_map_method, + kind, + strictness, + i, + max_block_size, + hash_join_spill_context, + log, + has_other_condition)); } } @@ -333,8 +359,10 @@ std::shared_ptr Join::createRestoreJoin(size_t max_bytes_before_external_j false, 0, max_bytes_before_external_join_, - hash_join_spill_context->createBuildSpillConfig(fmt::format("{}_hash_join_{}_build", log->identifier(), restore_round + 1)), - hash_join_spill_context->createProbeSpillConfig(fmt::format("{}_hash_join_{}_probe", log->identifier(), restore_round + 1)), + hash_join_spill_context->createBuildSpillConfig( + fmt::format("{}_hash_join_{}_build", log->identifier(), restore_round + 1)), + hash_join_spill_context->createProbeSpillConfig( + fmt::format("{}_hash_join_{}_probe", log->identifier(), restore_round + 1)), join_restore_concurrency, tidb_output_column_names, collators, @@ -519,7 +547,9 @@ void Join::insertFromBlock(const Block & block, size_t stream_index) auto ret = hash_join_spill_context->updatePartitionRevocableMemory(i, join_partition->revocableBytes()); if (ret) { - RUNTIME_CHECK_MSG(hash_join_spill_context->isPartitionSpilled(i), "Join spill should not triggered here"); + RUNTIME_CHECK_MSG( + hash_join_spill_context->isPartitionSpilled(i), + "Join spill should not triggered here"); blocks_to_spill = join_partition->trySpillBuildPartition(partition_lock); } else @@ -640,7 +670,18 @@ void Join::insertFromBlockInternal(Block * stored_block, size_t stream_index) if (enable_join_spill) assert(partitions[stream_index]->getPartitionPool() != nullptr); /// Fill the hash table. - JoinPartition::insertBlockIntoMaps(partitions, rows, key_columns, key_sizes, collators, stored_block, null_map, stream_index, getBuildConcurrency(), enable_fine_grained_shuffle, enable_join_spill); + JoinPartition::insertBlockIntoMaps( + partitions, + rows, + key_columns, + key_sizes, + collators, + stored_block, + null_map, + stream_index, + getBuildConcurrency(), + enable_fine_grained_shuffle, + enable_join_spill); } // generator in runtime filter @@ -674,7 +715,11 @@ void Join::cancelRuntimeFilter(const String & reason) } } -void mergeNullAndFilterResult(Block & block, ColumnVector::Container & filter_column, const String & filter_column_name, bool null_as_true) +void mergeNullAndFilterResult( + Block & block, + ColumnVector::Container & filter_column, + const String & filter_column_name, + bool null_as_true) { if (filter_column_name.empty()) return; @@ -712,7 +757,11 @@ void mergeNullAndFilterResult(Block & block, ColumnVector::Container & fi * @param left_table_columns * @param right_table_columns */ -void Join::handleOtherConditions(Block & block, std::unique_ptr & anti_filter, std::unique_ptr & offsets_to_replicate, const std::vector & right_table_columns) const +void Join::handleOtherConditions( + Block & block, + std::unique_ptr & anti_filter, + std::unique_ptr & offsets_to_replicate, + const std::vector & right_table_columns) const { non_equal_conditions.other_cond_expr->execute(block); @@ -727,18 +776,24 @@ void Join::handleOtherConditions(Block & block, std::unique_ptr { const auto helper_pos = block.getPositionByName(match_helper_name); - const auto * old_match_nullable = checkAndGetColumn(block.safeGetByPosition(helper_pos).column.get()); - const auto & old_match_vec = static_cast *>(old_match_nullable->getNestedColumnPtr().get())->getData(); + const auto * old_match_nullable + = checkAndGetColumn(block.safeGetByPosition(helper_pos).column.get()); + const auto & old_match_vec + = static_cast *>(old_match_nullable->getNestedColumnPtr().get())->getData(); { /// we assume there is no null value in the `match-helper` column after adder<>(). - if (!mem_utils::memoryIsZero(old_match_nullable->getNullMapData().data(), old_match_nullable->getNullMapData().size())) + if (!mem_utils::memoryIsZero( + old_match_nullable->getNullMapData().data(), + old_match_nullable->getNullMapData().size())) throw Exception("T here shouldn't be null before merging other conditions.", ErrorCodes::LOGICAL_ERROR); } const auto rows = offsets_to_replicate->size(); if (old_match_vec.size() != rows) - throw Exception("Size of column match-helper must be equal to column size of left block.", ErrorCodes::LOGICAL_ERROR); + throw Exception( + "Size of column match-helper must be equal to column size of left block.", + ErrorCodes::LOGICAL_ERROR); auto match_col = ColumnInt8::create(rows, 0); auto & match_vec = match_col->getData(); @@ -795,7 +850,8 @@ void Join::handleOtherConditions(Block & block, std::unique_ptr for (size_t i = 0; i < block.columns(); ++i) if (i != helper_pos) block.getByPosition(i).column = block.getByPosition(i).column->filter(row_filter, -1); - block.safeGetByPosition(helper_pos).column = ColumnNullable::create(std::move(match_col), std::move(match_nullmap)); + block.safeGetByPosition(helper_pos).column + = ColumnNullable::create(std::move(match_col), std::move(match_nullmap)); return; } @@ -805,7 +861,8 @@ void Join::handleOtherConditions(Block & block, std::unique_ptr /// be returned, if other_eq_filter_from_in_column return true or null this row should not be returned. mergeNullAndFilterResult(block, filter, non_equal_conditions.other_eq_cond_from_in_name, isAntiJoin(kind)); - if ((isInnerJoin(kind) && original_strictness == ASTTableJoin::Strictness::All) || isNecessaryKindToUseRowFlaggedHashMap(kind)) + if ((isInnerJoin(kind) && original_strictness == ASTTableJoin::Strictness::All) + || isNecessaryKindToUseRowFlaggedHashMap(kind)) { /// inner | rightSemi | rightAnti | rightOuter join, just use other_filter_column to filter result for (size_t i = 0; i < block.columns(); ++i) @@ -858,7 +915,8 @@ void Join::handleOtherConditions(Block & block, std::unique_ptr for (size_t right_table_column : right_table_columns) { auto & column = block.getByPosition(right_table_column); - auto full_column = column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column; + auto full_column + = column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column; if (!full_column->isColumnNullable()) { throw Exception("Should not reach here, the right table column for left join must be nullable"); @@ -930,7 +988,8 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & if (matched_row_count_in_current_block > 0) { for (size_t i = 0; i < block.columns(); ++i) - block.safeGetByPosition(i).column = block.safeGetByPosition(i).column->filter(filter, matched_row_count_in_current_block); + block.safeGetByPosition(i).column + = block.safeGetByPosition(i).column->filter(filter, matched_row_count_in_current_block); } else { @@ -945,7 +1004,8 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & if (matched_row_count_in_current_block > 0) { for (size_t i = 0; i < block.columns(); ++i) - block.safeGetByPosition(i).column = block.safeGetByPosition(i).column->filter(filter, matched_row_count_in_current_block); + block.safeGetByPosition(i).column + = block.safeGetByPosition(i).column->filter(filter, matched_row_count_in_current_block); } else if (probe_process_info.isCurrentProbeRowFinished() && !probe_process_info.has_row_matched) { @@ -956,7 +1016,8 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & for (size_t right_table_column : probe_process_info.right_column_index) { auto & column = block.getByPosition(right_table_column); - auto full_column = column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column; + auto full_column + = column.column->isColumnConst() ? column.column->convertToFullColumnIfConst() : column.column; if (!full_column->isColumnNullable()) { throw Exception("Should not reach here, the right table column for left join must be nullable"); @@ -1022,7 +1083,8 @@ void Join::handleOtherConditionsForOneProbeRow(Block & block, ProbeProcessInfo & match_vec[0] = 1; else if (probe_process_info.has_row_null) match_nullmap_vec[0] = 1; - block.getByName(match_helper_name).column = ColumnNullable::create(std::move(match_col), std::move(match_nullmap)); + block.getByName(match_helper_name).column + = ColumnNullable::create(std::move(match_col), std::move(match_nullmap)); probe_process_info.finishCurrentProbeRow(); } else @@ -1072,7 +1134,10 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info) const for (size_t i = 0; i < num_columns_to_add; ++i) { const ColumnWithTypeAndName & src_column = sample_block_with_columns_to_add.getByPosition(i); - RUNTIME_CHECK_MSG(!block.has(src_column.name), "block from probe side has a column with the same name: {} as a column in sample_block_with_columns_to_add", src_column.name); + RUNTIME_CHECK_MSG( + !block.has(src_column.name), + "block from probe side has a column with the same name: {} as a column in sample_block_with_columns_to_add", + src_column.name); added_columns.push_back(src_column.column->cloneEmpty()); if (src_column.type && src_column.type->haveMaximumSizeOfValue()) @@ -1098,8 +1163,28 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info) const auto & offsets_to_replicate = probe_process_info.offsets_to_replicate; bool enable_spill_join = isEnableSpill(); - JoinBuildInfo join_build_info{enable_fine_grained_shuffle, fine_grained_shuffle_count, enable_spill_join, hash_join_spill_context->isSpilled(), build_concurrency, restore_round}; - JoinPartition::probeBlock(partitions, rows, probe_process_info.key_columns, key_sizes, added_columns, probe_process_info.null_map, filter, current_offset, offsets_to_replicate, right_indexes, collators, join_build_info, probe_process_info, flag_mapped_entry_helper_column); + JoinBuildInfo join_build_info{ + enable_fine_grained_shuffle, + fine_grained_shuffle_count, + enable_spill_join, + hash_join_spill_context->isSpilled(), + build_concurrency, + restore_round}; + JoinPartition::probeBlock( + partitions, + rows, + probe_process_info.key_columns, + key_sizes, + added_columns, + probe_process_info.null_map, + filter, + current_offset, + offsets_to_replicate, + right_indexes, + collators, + join_build_info, + probe_process_info, + flag_mapped_entry_helper_column); FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_join_prob_failpoint); /// For RIGHT_SEMI/RIGHT_ANTI join without other conditions, hash table has been marked already, just return empty build table header if (isRightSemiFamily(kind) && !flag_mapped_entry_helper_column) @@ -1113,7 +1198,10 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info) const block.insert(ColumnWithTypeAndName(std::move(added_columns[i]), sample_col.type, sample_col.name)); } if (flag_mapped_entry_helper_column) - block.insert(ColumnWithTypeAndName(std::move(flag_mapped_entry_helper_column), flag_mapped_entry_helper_type, flag_mapped_entry_helper_name)); + block.insert(ColumnWithTypeAndName( + std::move(flag_mapped_entry_helper_column), + flag_mapped_entry_helper_type, + flag_mapped_entry_helper_name)); size_t process_rows = probe_process_info.end_row - probe_process_info.start_row; @@ -1134,7 +1222,10 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info) const { for (size_t i = 0; i < existing_columns; ++i) { - block.safeGetByPosition(i).column = block.safeGetByPosition(i).column->replicateRange(probe_process_info.start_row, probe_process_info.end_row, *offsets_to_replicate); + block.safeGetByPosition(i).column = block.safeGetByPosition(i).column->replicateRange( + probe_process_info.start_row, + probe_process_info.end_row, + *offsets_to_replicate); } if (rows != process_rows) @@ -1144,7 +1235,9 @@ Block Join::doJoinBlockHash(ProbeProcessInfo & probe_process_info) const auto helper_col = block.getByName(match_helper_name).column; helper_col = helper_col->cut(probe_process_info.start_row, probe_process_info.end_row); } - offsets_to_replicate->assign(offsets_to_replicate->begin() + probe_process_info.start_row, offsets_to_replicate->begin() + probe_process_info.end_row); + offsets_to_replicate->assign( + offsets_to_replicate->begin() + probe_process_info.start_row, + offsets_to_replicate->begin() + probe_process_info.end_row); } } } @@ -1217,7 +1310,8 @@ Block Join::joinBlockHash(ProbeProcessInfo & probe_process_info) const /// exit the while loop if /// 1. probe_process_info.all_rows_joined_finish is true, which means all the rows in current block is processed /// 2. the block may be expanded after join and result_rows exceeds the min_result_block_size - if (probe_process_info.all_rows_joined_finish || (may_probe_side_expanded_after_join && result_rows >= probe_process_info.min_result_block_size)) + if (probe_process_info.all_rows_joined_finish + || (may_probe_side_expanded_after_join && result_rows >= probe_process_info.min_result_block_size)) break; } assert(!result_blocks.empty()); @@ -1239,14 +1333,19 @@ Block Join::doJoinBlockCross(ProbeProcessInfo & probe_process_info) const { probe_process_info.cutFilterAndOffsetVector(probe_process_info.start_row, probe_process_info.end_row); } - handleOtherConditions(block, probe_process_info.filter, probe_process_info.offsets_to_replicate, probe_process_info.right_column_index); + handleOtherConditions( + block, + probe_process_info.filter, + probe_process_info.offsets_to_replicate, + probe_process_info.right_column_index); } return block; } else if (cross_probe_mode == CrossProbeMode::SHALLOW_COPY_RIGHT_BLOCK) { probe_process_info.updateStartRow(); - auto [block, is_matched_rows] = crossProbeBlockShallowCopyRightBlock(kind, strictness, probe_process_info, original_blocks); + auto [block, is_matched_rows] + = crossProbeBlockShallowCopyRightBlock(kind, strictness, probe_process_info, original_blocks); if (is_matched_rows) { if (non_equal_conditions.other_cond_expr != nullptr) @@ -1263,15 +1362,21 @@ Block Join::doJoinBlockCross(ProbeProcessInfo & probe_process_info) const } if (isLeftOuterSemiFamily(kind)) { - auto helper_index = probe_process_info.block.columns() + probe_process_info.right_column_index.size() - 1; + auto helper_index + = probe_process_info.block.columns() + probe_process_info.right_column_index.size() - 1; if (block.getByPosition(helper_index).column->isColumnConst()) - block.getByPosition(helper_index).column = block.getByPosition(helper_index).column->convertToFullColumnIfConst(); + block.getByPosition(helper_index).column + = block.getByPosition(helper_index).column->convertToFullColumnIfConst(); } } else if (non_equal_conditions.other_cond_expr != nullptr) { probe_process_info.cutFilterAndOffsetVector(0, block.rows()); - handleOtherConditions(block, probe_process_info.filter, probe_process_info.offsets_to_replicate, probe_process_info.right_column_index); + handleOtherConditions( + block, + probe_process_info.filter, + probe_process_info.offsets_to_replicate, + probe_process_info.right_column_index); } return block; } @@ -1302,7 +1407,8 @@ Block Join::joinBlockCross(ProbeProcessInfo & probe_process_info) const block = removeUselessColumn(block); result_rows += block.rows(); result_blocks.push_back(std::move(block)); - if (probe_process_info.all_rows_joined_finish || (may_probe_side_expanded_after_join && result_rows >= probe_process_info.min_result_block_size)) + if (probe_process_info.all_rows_joined_finish + || (may_probe_side_expanded_after_join && result_rows >= probe_process_info.min_result_block_size)) break; } @@ -1346,24 +1452,63 @@ Block Join::joinBlockNullAware(ProbeProcessInfo & probe_process_info) const for (size_t i = 0; i < num_columns_to_add; ++i) { const ColumnWithTypeAndName & src_column = sample_block_with_columns_to_add.getByPosition(i); - RUNTIME_CHECK_MSG(!block.has(src_column.name), "block from probe side has a column with the same name: {} as a column in sample_block_with_columns_to_add", src_column.name); + RUNTIME_CHECK_MSG( + !block.has(src_column.name), + "block from probe side has a column with the same name: {} as a column in sample_block_with_columns_to_add", + src_column.name); block.insert(src_column); } using enum ASTTableJoin::Strictness; using enum ASTTableJoin::Kind; if (kind == NullAware_Anti && strictness == All) - joinBlockNullAwareImpl(block, existing_columns, key_columns, null_map, filter_map, all_key_null_map); + joinBlockNullAwareImpl( + block, + existing_columns, + key_columns, + null_map, + filter_map, + all_key_null_map); else if (kind == NullAware_Anti && strictness == Any) - joinBlockNullAwareImpl(block, existing_columns, key_columns, null_map, filter_map, all_key_null_map); + joinBlockNullAwareImpl( + block, + existing_columns, + key_columns, + null_map, + filter_map, + all_key_null_map); else if (kind == NullAware_LeftOuterSemi && strictness == All) - joinBlockNullAwareImpl(block, existing_columns, key_columns, null_map, filter_map, all_key_null_map); + joinBlockNullAwareImpl( + block, + existing_columns, + key_columns, + null_map, + filter_map, + all_key_null_map); else if (kind == NullAware_LeftOuterSemi && strictness == Any) - joinBlockNullAwareImpl(block, existing_columns, key_columns, null_map, filter_map, all_key_null_map); + joinBlockNullAwareImpl( + block, + existing_columns, + key_columns, + null_map, + filter_map, + all_key_null_map); else if (kind == NullAware_LeftOuterAnti && strictness == All) - joinBlockNullAwareImpl(block, existing_columns, key_columns, null_map, filter_map, all_key_null_map); + joinBlockNullAwareImpl( + block, + existing_columns, + key_columns, + null_map, + filter_map, + all_key_null_map); else if (kind == NullAware_LeftOuterAnti && strictness == Any) - joinBlockNullAwareImpl(block, existing_columns, key_columns, null_map, filter_map, all_key_null_map); + joinBlockNullAwareImpl( + block, + existing_columns, + key_columns, + null_map, + filter_map, + all_key_null_map); else throw Exception("Logical error: unknown combination of JOIN", ErrorCodes::LOGICAL_ERROR); @@ -1390,7 +1535,11 @@ void Join::joinBlockNullAwareImpl( null_rows[i] = partitions[i]->getRowsNotInsertedToMap(); NALeftSideInfo left_side_info(null_map, filter_map, all_key_null_map); - NARightSideInfo right_side_info(right_has_all_key_null_row.load(std::memory_order_relaxed), right_table_is_empty.load(std::memory_order_relaxed), null_key_check_all_blocks_directly, null_rows); + NARightSideInfo right_side_info( + right_has_all_key_null_row.load(std::memory_order_relaxed), + right_table_is_empty.load(std::memory_order_relaxed), + null_key_check_all_blocks_directly, + null_rows); auto [res, res_list] = JoinPartition::probeBlockNullAware( partitions, block, @@ -1406,14 +1555,8 @@ void Join::joinBlockNullAwareImpl( if (!res_list.empty()) { - NASemiJoinHelper helper( - block, - left_columns, - right_columns, - blocks, - null_rows, - max_block_size, - non_equal_conditions); + NASemiJoinHelper + helper(block, left_columns, right_columns, blocks, null_rows, max_block_size, non_equal_conditions); helper.joinResult(res_list); @@ -1433,7 +1576,8 @@ void Join::joinBlockNullAwareImpl( PaddedPODArray * left_semi_column_data = nullptr; PaddedPODArray * left_semi_null_map = nullptr; - if constexpr (KIND == ASTTableJoin::Kind::NullAware_LeftOuterSemi || KIND == ASTTableJoin::Kind::NullAware_LeftOuterAnti) + if constexpr ( + KIND == ASTTableJoin::Kind::NullAware_LeftOuterSemi || KIND == ASTTableJoin::Kind::NullAware_LeftOuterAnti) { auto * left_semi_column = typeid_cast(added_columns[right_columns - 1].get()); left_semi_column_data = &typeid_cast &>(left_semi_column->getNestedColumn()).getData(); @@ -1559,7 +1703,8 @@ void Join::workAfterBuildFinish(size_t stream_index) if (isSpilled()) { // TODO support runtime filter with spill. - cancelRuntimeFilter("Currently runtime filter is not compatible with join spill, so cancel runtime filter here."); + cancelRuntimeFilter( + "Currently runtime filter is not compatible with join spill, so cancel runtime filter here."); } else { @@ -1573,10 +1718,9 @@ void Join::workAfterBuildFinish(size_t stream_index) spilled_partition_index, partitions[spilled_partition_index]->trySpillBuildPartition(), stream_index); - has_build_data_in_memory = std::any_of( - partitions.cbegin(), - partitions.cend(), - [](const auto & p) { return !p->isSpill() && p->hasBuildData(); }); + has_build_data_in_memory = std::any_of(partitions.cbegin(), partitions.cend(), [](const auto & p) { + return !p->isSpill() && p->hasBuildData(); + }); } else { @@ -1612,7 +1756,8 @@ void Join::finalizeNullAwareSemiFamilyBuild() if (unlikely(is_test)) null_key_check_all_blocks_directly = false; else - null_key_check_all_blocks_directly = static_cast(null_rows_size) > static_cast(total_input_build_rows) / 3.0; + null_key_check_all_blocks_directly + = static_cast(null_rows_size) > static_cast(total_input_build_rows) / 3.0; } void Join::finalizeCrossJoinBuild() @@ -1626,7 +1771,8 @@ void Join::finalizeCrossJoinBuild() if (strictness == ASTTableJoin::Strictness::Any) { /// for cross any join, at most 1 row is added - right_rows_to_be_added_when_matched_for_cross_join = std::min(right_rows_to_be_added_when_matched_for_cross_join, 1); + right_rows_to_be_added_when_matched_for_cross_join + = std::min(right_rows_to_be_added_when_matched_for_cross_join, 1); } else if (blocks.size() > 1 && right_rows_to_be_added_when_matched_for_cross_join <= max_block_size) { @@ -1674,9 +1820,7 @@ void Join::workAfterProbeFinish(size_t stream_index) void Join::waitUntilAllBuildFinished() const { std::unique_lock lock(build_probe_mutex); - build_cv.wait(lock, [&]() { - return build_finished || meet_error || skip_wait; - }); + build_cv.wait(lock, [&]() { return build_finished || meet_error || skip_wait; }); if (meet_error) throw Exception(error_message); } @@ -1710,9 +1854,7 @@ void Join::finalizeProbe() void Join::waitUntilAllProbeFinished() const { std::unique_lock lock(build_probe_mutex); - probe_cv.wait(lock, [&]() { - return probe_finished || meet_error || skip_wait; - }); + probe_cv.wait(lock, [&]() { return probe_finished || meet_error || skip_wait; }); if (meet_error) throw Exception(error_message); } @@ -1777,8 +1919,10 @@ Block Join::joinBlock(ProbeProcessInfo & probe_process_info, bool dry_run) const /// for (cartesian)antiLeftSemi join, the meaning of "match-helper" is `non-matched` instead of `matched`. if (kind == LeftOuterAnti || kind == Cross_LeftOuterAnti) { - const auto * nullable_column = checkAndGetColumn(block.getByName(match_helper_name).column.get()); - const auto & vec_matched = static_cast *>(nullable_column->getNestedColumnPtr().get())->getData(); + const auto * nullable_column + = checkAndGetColumn(block.getByName(match_helper_name).column.get()); + const auto & vec_matched + = static_cast *>(nullable_column->getNestedColumnPtr().get())->getData(); auto col_non_matched = ColumnInt8::create(vec_matched.size()); auto & vec_non_matched = col_non_matched->getData(); @@ -1786,15 +1930,25 @@ Block Join::joinBlock(ProbeProcessInfo & probe_process_info, bool dry_run) const for (size_t i = 0; i < vec_matched.size(); ++i) vec_non_matched[i] = !vec_matched[i]; - block.getByName(match_helper_name).column = ColumnNullable::create(std::move(col_non_matched), std::move(nullable_column->getNullMapColumnPtr())); + block.getByName(match_helper_name).column + = ColumnNullable::create(std::move(col_non_matched), std::move(nullable_column->getNullMapColumnPtr())); } return block; } -BlockInputStreamPtr Join::createScanHashMapAfterProbeStream(const Block & left_sample_block, size_t index, size_t step, size_t max_block_size_) const +BlockInputStreamPtr Join::createScanHashMapAfterProbeStream( + const Block & left_sample_block, + size_t index, + size_t step, + size_t max_block_size_) const { - return std::make_shared(*this, left_sample_block, index, step, max_block_size_); + return std::make_shared( + *this, + left_sample_block, + index, + step, + max_block_size_); } Blocks Join::dispatchBlock(const Strings & key_columns_names, const Block & from_block) @@ -1916,17 +2070,27 @@ void Join::spillMostMemoryUsedPartitionIfNeed(size_t stream_index) #ifdef DBMS_PUBLIC_GTEST // for join spill to disk gtest - if (restore_round == std::max(2, MAX_RESTORE_ROUND_IN_GTEST) - 1 && spilled_partition_indexes.size() >= partitions.size() / 2) + if (restore_round == std::max(2, MAX_RESTORE_ROUND_IN_GTEST) - 1 + && spilled_partition_indexes.size() >= partitions.size() / 2) return; #endif for (const auto & partition_to_be_spilled : hash_join_spill_context->getPartitionsToSpill()) { - RUNTIME_CHECK_MSG(build_concurrency > 1, "spilling is not is not supported when stream size = 1, please increase max_threads or set max_bytes_before_external_join = 0."); - LOG_INFO(log, fmt::format("Join with restore round: {}, used {} bytes, will spill partition: {}.", restore_round, getTotalByteCount(), partition_to_be_spilled)); + RUNTIME_CHECK_MSG( + build_concurrency > 1, + "spilling is not is not supported when stream size = 1, please increase max_threads or set " + "max_bytes_before_external_join = 0."); + LOG_INFO( + log, + fmt::format( + "Join with restore round: {}, used {} bytes, will spill partition: {}.", + restore_round, + getTotalByteCount(), + partition_to_be_spilled)); std::unique_lock partition_lock = partitions[partition_to_be_spilled]->lockPartition(); - hash_join_spill_context->markPartitionSpill(partition_to_be_spilled); + hash_join_spill_context->markPartitionSpilled(partition_to_be_spilled); partitions[partition_to_be_spilled]->releasePartitionPoolAndHashMap(partition_lock); auto blocks_to_spill = partitions[partition_to_be_spilled]->trySpillBuildPartition(partition_lock); spilled_partition_indexes.push_back(partition_to_be_spilled); @@ -1980,23 +2144,45 @@ std::optional Join::getOneRestoreStream(size_t max_block_size_) // build new restore infos. auto spilled_partition_index = spilled_partition_indexes.front(); - RUNTIME_CHECK_MSG(hash_join_spill_context->isPartitionSpilled(spilled_partition_index), "should not restore unspilled partition."); + RUNTIME_CHECK_MSG( + hash_join_spill_context->isPartitionSpilled(spilled_partition_index), + "should not restore unspilled partition."); if (restore_join_build_concurrency <= 0) - restore_join_build_concurrency = getRestoreJoinBuildConcurrency(partitions.size(), spilled_partition_indexes.size(), join_restore_concurrency, probe_concurrency); + restore_join_build_concurrency = getRestoreJoinBuildConcurrency( + partitions.size(), + spilled_partition_indexes.size(), + join_restore_concurrency, + probe_concurrency); /// for restore join we make sure that the restore_join_build_concurrency is at least 2, so it can be spill again. /// And restore_join_build_concurrency should not be greater than probe_concurrency, Otherwise some restore_stream will never be executed. RUNTIME_CHECK_MSG( - 2 <= restore_join_build_concurrency && restore_join_build_concurrency <= static_cast(probe_concurrency), + 2 <= restore_join_build_concurrency + && restore_join_build_concurrency <= static_cast(probe_concurrency), "restore_join_build_concurrency must in [2, {}], but the current value is {}", probe_concurrency, restore_join_build_concurrency); - LOG_INFO(log, "Begin restore data from disk for hash join, partition {}, restore round {}, build concurrency {}.", spilled_partition_index, restore_round, restore_join_build_concurrency); + LOG_INFO( + log, + "Begin restore data from disk for hash join, partition {}, restore round {}, build concurrency {}.", + spilled_partition_index, + restore_round, + restore_join_build_concurrency); - auto restore_build_streams = hash_join_spill_context->getBuildSpiller()->restoreBlocks(spilled_partition_index, restore_join_build_concurrency, true); - RUNTIME_CHECK_MSG(restore_build_streams.size() == static_cast(restore_join_build_concurrency), "restore streams size must equal to restore_join_build_concurrency"); - auto restore_probe_streams = hash_join_spill_context->getProbeSpiller()->restoreBlocks(spilled_partition_index, restore_join_build_concurrency, true); - auto new_max_bytes_before_external_join = static_cast(hash_join_spill_context->getOperatorSpillThreshold() * (static_cast(restore_join_build_concurrency) / build_concurrency)); + auto restore_build_streams = hash_join_spill_context->getBuildSpiller()->restoreBlocks( + spilled_partition_index, + restore_join_build_concurrency, + true); + RUNTIME_CHECK_MSG( + restore_build_streams.size() == static_cast(restore_join_build_concurrency), + "restore streams size must equal to restore_join_build_concurrency"); + auto restore_probe_streams = hash_join_spill_context->getProbeSpiller()->restoreBlocks( + spilled_partition_index, + restore_join_build_concurrency, + true); + auto new_max_bytes_before_external_join = static_cast( + hash_join_spill_context->getOperatorSpillThreshold() + * (static_cast(restore_join_build_concurrency) / build_concurrency)); restore_join = createRestoreJoin(std::max(1, new_max_bytes_before_external_join)); restore_join->initBuild(build_sample_block, restore_join_build_concurrency); restore_join->setInitActiveBuildThreads(); @@ -2007,11 +2193,20 @@ std::optional Join::getOneRestoreStream(size_t max_block_size_) { auto header = restore_probe_streams.back()->getHeader(); for (Int64 i = 0; i < restore_join_build_concurrency; ++i) - restore_scan_hash_map_streams[i] = restore_join->createScanHashMapAfterProbeStream(header, i, restore_join_build_concurrency, max_block_size_); + restore_scan_hash_map_streams[i] = restore_join->createScanHashMapAfterProbeStream( + header, + i, + restore_join_build_concurrency, + max_block_size_); } for (Int64 i = 0; i < restore_join_build_concurrency; ++i) { - restore_infos.emplace_back(restore_join, i, std::move(restore_scan_hash_map_streams[i]), std::move(restore_build_streams[i]), std::move(restore_probe_streams[i])); + restore_infos.emplace_back( + restore_join, + i, + std::move(restore_scan_hash_map_streams[i]), + std::move(restore_build_streams[i]), + std::move(restore_probe_streams[i])); } } } diff --git a/dbms/src/Interpreters/Settings.h b/dbms/src/Interpreters/Settings.h index 725de74b5a6..0e403430b29 100644 --- a/dbms/src/Interpreters/Settings.h +++ b/dbms/src/Interpreters/Settings.h @@ -294,7 +294,7 @@ struct Settings M(SettingUInt64, manual_compact_more_until_ms, 60000, "Continuously compact more segments until reaching specified elapsed time. If 0 is specified, only one segment will be compacted each round.") \ M(SettingUInt64, max_bytes_before_external_join, 0, "max bytes used by join before spill, 0 as the default value, 0 means no limit") \ M(SettingInt64, join_restore_concurrency, 0, "join restore concurrency, negative value means restore join serially, 0 means TiFlash choose restore concurrency automatically, 0 as the default value") \ - M(SettingUInt64, max_cached_data_bytes_in_spiller, 1024ULL * 1024 * 100, "Max cached data bytes in spiller before spilling, 100MB as the default value, 0 means no limit") \ + M(SettingUInt64, max_cached_data_bytes_in_spiller, 1024ULL * 1024 * 10, "Max cached data bytes in spiller before spilling, 10 MB as the default value, 0 means no limit") \ M(SettingUInt64, max_spilled_rows_per_file, 200000, "Max spilled data rows per spill file, 200000 as the default value, 0 means no limit.") \ M(SettingUInt64, max_spilled_bytes_per_file, 0, "Max spilled data bytes per spill file, 0 as the default value, 0 means no limit.") \ M(SettingBool, enable_planner, true, "Enable planner") \ @@ -311,7 +311,9 @@ struct Settings M(SettingUInt64, shallow_copy_cross_probe_threshold, 0, "minimum right rows to use shallow copy probe mode for cross join, default is max(1, max_block_size/10)") \ M(SettingInt64, max_buffered_bytes_in_executor, 200LL * 1024 * 1024, "The max buffered size in each executor, 0 mean unlimited, use 200MB as the default value") \ M(SettingUInt64, ddl_sync_interval_seconds, 60, "The interval of background DDL sync schema in seconds") \ - M(SettingUInt64, ddl_restart_wait_seconds, 180, "The wait time for sync schema in seconds when restart") + M(SettingUInt64, ddl_restart_wait_seconds, 180, "The wait time for sync schema in seconds when restart") \ + M(SettingFloat, auto_memory_revoke_trigger_threshold, 0.7, "Trigger auto memory revocation when the memory usage is above this percentage.") \ + M(SettingFloat, auto_memory_revoke_target_threshold, 0.5, "When auto revoking memory, try to revoke enough that the memory usage is filled below the target percentage at the end.") // clang-format on @@ -355,4 +357,4 @@ struct Settings }; -} // namespace DB \ No newline at end of file +} // namespace DB diff --git a/dbms/src/Interpreters/SortSpillContext.cpp b/dbms/src/Interpreters/SortSpillContext.cpp index 3bd64e9f5f0..d20a03c882b 100644 --- a/dbms/src/Interpreters/SortSpillContext.cpp +++ b/dbms/src/Interpreters/SortSpillContext.cpp @@ -16,7 +16,10 @@ namespace DB { -SortSpillContext::SortSpillContext(const SpillConfig & spill_config_, UInt64 operator_spill_threshold_, const LoggerPtr & log) +SortSpillContext::SortSpillContext( + const SpillConfig & spill_config_, + UInt64 operator_spill_threshold_, + const LoggerPtr & log) : OperatorSpillContext(operator_spill_threshold_, "sort", log) , spill_config(spill_config_) {} @@ -31,11 +34,35 @@ bool SortSpillContext::updateRevocableMemory(Int64 new_value) if (!in_spillable_stage) return false; revocable_memory = new_value; - if (enable_spill && operator_spill_threshold > 0 && revocable_memory > static_cast(operator_spill_threshold)) + if (auto_spill_status == AutoSpillStatus::NEED_AUTO_SPILL + || (enable_spill && operator_spill_threshold > 0 + && revocable_memory > static_cast(operator_spill_threshold))) { revocable_memory = 0; return true; } return false; } + +Int64 SortSpillContext::triggerSpill(Int64 expected_released_memories) +{ + if (!in_spillable_stage || !isSpillEnabled()) + return expected_released_memories; + auto total_revocable_memory = getTotalRevocableMemory(); + if (total_revocable_memory >= MIN_SPILL_THRESHOLD) + { + AutoSpillStatus old_value = AutoSpillStatus::NO_NEED_AUTO_SPILL; + if (auto_spill_status.compare_exchange_strong(old_value, AutoSpillStatus::NEED_AUTO_SPILL)) + { + expected_released_memories = std::max(expected_released_memories - total_revocable_memory, 0); + revocable_memory = 0; + } + } + return expected_released_memories; +} + +void SortSpillContext::finishOneSpill() +{ + auto_spill_status = AutoSpillStatus::NO_NEED_AUTO_SPILL; +} } // namespace DB diff --git a/dbms/src/Interpreters/SortSpillContext.h b/dbms/src/Interpreters/SortSpillContext.h index dbfac78b36a..4b7bcba618b 100644 --- a/dbms/src/Interpreters/SortSpillContext.h +++ b/dbms/src/Interpreters/SortSpillContext.h @@ -24,6 +24,7 @@ class SortSpillContext final : public OperatorSpillContext { private: std::atomic revocable_memory; + std::atomic auto_spill_status{AutoSpillStatus::NO_NEED_AUTO_SPILL}; SpillConfig spill_config; SpillerPtr spiller; @@ -31,9 +32,13 @@ class SortSpillContext final : public OperatorSpillContext SortSpillContext(const SpillConfig & spill_config_, UInt64 operator_spill_threshold_, const LoggerPtr & log); void buildSpiller(const Block & input_schema); SpillerPtr & getSpiller() { return spiller; } - bool hasSpilledData() const { return spill_status != SpillStatus::NOT_SPILL && spiller->hasSpilledData(); } + void finishOneSpill(); + bool hasSpilledData() const { return isSpilled() && spiller->hasSpilledData(); } bool updateRevocableMemory(Int64 new_value); Int64 getTotalRevocableMemoryImpl() override { return revocable_memory; }; + Int64 triggerSpill(Int64 expected_released_memories) override; + bool needFinalSpill() const { return auto_spill_status == AutoSpillStatus::NEED_AUTO_SPILL; } + bool supportAutoTriggerSpill() const override { return true; } }; using SortSpillContextPtr = std::shared_ptr; diff --git a/dbms/src/Interpreters/tests/gtest_operator_spill_context.cpp b/dbms/src/Interpreters/tests/gtest_operator_spill_context.cpp index 23e02c0d91d..07619dce0a4 100644 --- a/dbms/src/Interpreters/tests/gtest_operator_spill_context.cpp +++ b/dbms/src/Interpreters/tests/gtest_operator_spill_context.cpp @@ -57,7 +57,7 @@ try { auto spill_context = std::make_shared(1, *spill_config_ptr, 1000, logger); ASSERT_TRUE(spill_context->isSpilled() == false); - spill_context->markSpill(); + spill_context->markSpilled(); ASSERT_TRUE(spill_context->isSpilled() == true); } CATCH @@ -96,7 +96,7 @@ try { auto spill_context = std::make_shared(*spill_config_ptr, 1000, logger); ASSERT_TRUE(spill_context->isSpilled() == false); - spill_context->markSpill(); + spill_context->markSpilled(); ASSERT_TRUE(spill_context->isSpilled() == true); } CATCH @@ -129,7 +129,7 @@ try auto spill_context = std::make_shared(*spill_config_ptr, *spill_config_ptr, 1000, logger); spill_context->init(10); ASSERT_FALSE(spill_context->isSpilled()); - spill_context->markPartitionSpill(0); + spill_context->markPartitionSpilled(0); ASSERT_TRUE(spill_context->isSpilled()); } CATCH diff --git a/dbms/src/Interpreters/tests/gtest_query_operator_spill_contexts.cpp b/dbms/src/Interpreters/tests/gtest_query_operator_spill_contexts.cpp new file mode 100644 index 00000000000..d28bc1eab1a --- /dev/null +++ b/dbms/src/Interpreters/tests/gtest_query_operator_spill_contexts.cpp @@ -0,0 +1,130 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace tests +{ +class TestQueryOperatorSpillContexts : public ::testing::Test +{ +protected: + void SetUp() override + { + logger = Logger::get("operator_spill_context_test"); + Poco::File spiller_dir(spill_dir); + auto key_manager = std::make_shared(false); + auto file_provider = std::make_shared(key_manager, false); + spill_config_ptr = std::make_shared(spill_dir, "test", 1024ULL * 1024 * 1024, 0, 0, file_provider); + } + void TearDown() override + { + Poco::File spiller_dir(spill_dir); + /// remove spiller dir if exists + if (spiller_dir.exists()) + spiller_dir.remove(true); + } + static String spill_dir; + std::shared_ptr spill_config_ptr; + LoggerPtr logger; +}; + +String TestQueryOperatorSpillContexts::spill_dir + = DB::tests::TiFlashTestEnv::getTemporaryPath("operator_spill_context_test"); + +TEST_F(TestQueryOperatorSpillContexts, TestRegisterTaskOperatorSpillContext) +try +{ + /// currently only sort_spill_context support auto spill + auto sort_spill_context = std::make_shared(*spill_config_ptr, 1000, logger); + std::shared_ptr task_operator_spill_contexts + = std::make_shared(); + task_operator_spill_contexts->registerOperatorSpillContext(sort_spill_context); + QueryOperatorSpillContexts query_operator_spill_contexts(MPPQueryId(0, 0, 0, 0)); + ASSERT_TRUE(query_operator_spill_contexts.getTaskOperatorSpillContextsCount() == 0); + query_operator_spill_contexts.registerTaskOperatorSpillContexts(task_operator_spill_contexts); + ASSERT_TRUE(query_operator_spill_contexts.getTaskOperatorSpillContextsCount() == 1); + query_operator_spill_contexts.registerTaskOperatorSpillContexts(task_operator_spill_contexts); + ASSERT_TRUE(query_operator_spill_contexts.getTaskOperatorSpillContextsCount() == 2); +} +CATCH + +TEST_F(TestQueryOperatorSpillContexts, TestTriggerSpill) +try +{ + auto sort_spill_context_1 = std::make_shared(*spill_config_ptr, 0, logger); + auto sort_spill_context_2 = std::make_shared(*spill_config_ptr, 0, logger); + auto sort_spill_context_3 = std::make_shared(*spill_config_ptr, 0, logger); + std::shared_ptr task_operator_spill_contexts_1 + = std::make_shared(); + std::shared_ptr task_operator_spill_contexts_2 + = std::make_shared(); + task_operator_spill_contexts_1->registerOperatorSpillContext(sort_spill_context_1); + task_operator_spill_contexts_2->registerOperatorSpillContext(sort_spill_context_2); + task_operator_spill_contexts_2->registerOperatorSpillContext(sort_spill_context_3); + + QueryOperatorSpillContexts query_operator_spill_contexts(MPPQueryId(0, 0, 0, 0)); + query_operator_spill_contexts.registerTaskOperatorSpillContexts(task_operator_spill_contexts_1); + query_operator_spill_contexts.registerTaskOperatorSpillContexts(task_operator_spill_contexts_2); + + /// trigger spill for all task_operator_spill_contexts + sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + sort_spill_context_3->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + ASSERT_TRUE(query_operator_spill_contexts.triggerAutoSpill(OperatorSpillContext::MIN_SPILL_THRESHOLD * 3) == 0); + ASSERT_TRUE(sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + ASSERT_TRUE(sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + ASSERT_TRUE(sort_spill_context_3->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + sort_spill_context_1->finishOneSpill(); + sort_spill_context_2->finishOneSpill(); + sort_spill_context_3->finishOneSpill(); + + /// trigger spill only for task_operator_spill_contexts that uses more revocable memory usage + sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD * 3); + sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD * 2); + sort_spill_context_3->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD * 2); + ASSERT_TRUE(query_operator_spill_contexts.triggerAutoSpill(OperatorSpillContext::MIN_SPILL_THRESHOLD * 4) == 0); + ASSERT_FALSE(sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + ASSERT_TRUE(sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + ASSERT_TRUE(sort_spill_context_3->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + sort_spill_context_1->finishOneSpill(); + sort_spill_context_2->finishOneSpill(); + sort_spill_context_3->finishOneSpill(); + + /// auto clean finished task_operator_spill_contexts + sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD * 3); + sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD * 2); + sort_spill_context_3->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD * 2); + task_operator_spill_contexts_2->finish(); + ASSERT_TRUE( + query_operator_spill_contexts.triggerAutoSpill(OperatorSpillContext::MIN_SPILL_THRESHOLD * 4) + == OperatorSpillContext::MIN_SPILL_THRESHOLD); + ASSERT_TRUE(sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + ASSERT_FALSE(sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + ASSERT_FALSE(sort_spill_context_3->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + ASSERT_TRUE(query_operator_spill_contexts.getTaskOperatorSpillContextsCount() == 1); +} +CATCH +} // namespace tests +} // namespace DB \ No newline at end of file diff --git a/dbms/src/Interpreters/tests/gtest_task_operator_spill_contexts.cpp b/dbms/src/Interpreters/tests/gtest_task_operator_spill_contexts.cpp new file mode 100644 index 00000000000..8373a5ed8d4 --- /dev/null +++ b/dbms/src/Interpreters/tests/gtest_task_operator_spill_contexts.cpp @@ -0,0 +1,171 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace tests +{ +class TestTaskOperatorSpillContexts : public ::testing::Test +{ +protected: + void SetUp() override + { + logger = Logger::get("operator_spill_context_test"); + Poco::File spiller_dir(spill_dir); + auto key_manager = std::make_shared(false); + auto file_provider = std::make_shared(key_manager, false); + spill_config_ptr = std::make_shared(spill_dir, "test", 1024ULL * 1024 * 1024, 0, 0, file_provider); + } + void TearDown() override + { + Poco::File spiller_dir(spill_dir); + /// remove spiller dir if exists + if (spiller_dir.exists()) + spiller_dir.remove(true); + } + static String spill_dir; + std::shared_ptr spill_config_ptr; + LoggerPtr logger; +}; + +String TestTaskOperatorSpillContexts::spill_dir + = DB::tests::TiFlashTestEnv::getTemporaryPath("operator_spill_context_test"); + +TEST_F(TestTaskOperatorSpillContexts, TestRegisterOperatorSpillContext) +try +{ + /// currently only sort_spill_context support auto spill + auto agg_spill_context = std::make_shared(1, *spill_config_ptr, 1000, logger); + auto sort_spill_context = std::make_shared(*spill_config_ptr, 1000, logger); + auto join_spill_context + = std::make_shared(*spill_config_ptr, *spill_config_ptr, 1000, logger); + join_spill_context->init(10); + TaskOperatorSpillContexts task_operator_spill_contexts; + ASSERT_TRUE(task_operator_spill_contexts.operatorSpillContextCount() == 0); + task_operator_spill_contexts.registerOperatorSpillContext(agg_spill_context); + ASSERT_TRUE(task_operator_spill_contexts.operatorSpillContextCount() == 0); + task_operator_spill_contexts.registerOperatorSpillContext(sort_spill_context); + /// register will first add spill context to additional_operator_spill_contexts + ASSERT_TRUE(task_operator_spill_contexts.additionalOperatorSpillContextCount() == 1); + ASSERT_TRUE(task_operator_spill_contexts.operatorSpillContextCount() == 1); + /// additional_operator_spill_contexts has been merged to operator_spill_contexts + ASSERT_TRUE(task_operator_spill_contexts.additionalOperatorSpillContextCount() == 0); + task_operator_spill_contexts.registerOperatorSpillContext(join_spill_context); + ASSERT_TRUE(task_operator_spill_contexts.operatorSpillContextCount() == 1); + task_operator_spill_contexts.registerOperatorSpillContext(sort_spill_context); + ASSERT_TRUE(task_operator_spill_contexts.operatorSpillContextCount() == 2); +} +CATCH + +TEST_F(TestTaskOperatorSpillContexts, TestSpillAutoTrigger) +try +{ + auto sort_spill_context_1 = std::make_shared(*spill_config_ptr, 0, logger); + auto sort_spill_context_2 = std::make_shared(*spill_config_ptr, 0, logger); + TaskOperatorSpillContexts task_operator_spill_contexts; + task_operator_spill_contexts.registerOperatorSpillContext(sort_spill_context_1); + task_operator_spill_contexts.registerOperatorSpillContext(sort_spill_context_2); + + /// memory usage under OperatorSpillContext::MIN_SPILL_THRESHOLD will not trigger spill + sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD - 1); + sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD - 1); + ASSERT_TRUE( + task_operator_spill_contexts.triggerAutoSpill(OperatorSpillContext::MIN_SPILL_THRESHOLD / 2) + == OperatorSpillContext::MIN_SPILL_THRESHOLD / 2); + auto spill_1 = sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1); + auto spill_2 = sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1); + ASSERT_TRUE(!spill_1 && !spill_2); + + /// only one spill_context will trigger spill + sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + ASSERT_TRUE(task_operator_spill_contexts.triggerAutoSpill(OperatorSpillContext::MIN_SPILL_THRESHOLD / 2) <= 0); + spill_1 = sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1); + spill_2 = sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1); + ASSERT_TRUE(spill_1 ^ spill_2); + if (spill_1) + sort_spill_context_1->finishOneSpill(); + if (spill_2) + sort_spill_context_2->finishOneSpill(); + + /// two spill_context will be triggered spill + sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + ASSERT_TRUE(task_operator_spill_contexts.triggerAutoSpill(OperatorSpillContext::MIN_SPILL_THRESHOLD * 1.5) <= 0); + spill_1 = sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1); + spill_2 = sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1); + ASSERT_TRUE(spill_1 && spill_2); + sort_spill_context_1->finishOneSpill(); + sort_spill_context_2->finishOneSpill(); + + /// revocable memories less than expected released memory + sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + ASSERT_TRUE( + task_operator_spill_contexts.triggerAutoSpill(OperatorSpillContext::MIN_SPILL_THRESHOLD * 2.5) + == OperatorSpillContext::MIN_SPILL_THRESHOLD * 0.5); + spill_1 = sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1); + spill_2 = sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1); + ASSERT_TRUE(spill_1 && spill_2); + sort_spill_context_1->finishOneSpill(); + sort_spill_context_2->finishOneSpill(); + + /// one spill_context not in spilled stage + sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + sort_spill_context_1->finishSpillableStage(); + ASSERT_TRUE( + task_operator_spill_contexts.triggerAutoSpill(OperatorSpillContext::MIN_SPILL_THRESHOLD * 2.5) + == OperatorSpillContext::MIN_SPILL_THRESHOLD * 1.5); + ASSERT_FALSE(sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + ASSERT_TRUE(sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + sort_spill_context_1->finishOneSpill(); + + /// all spill_context not in spilled stage + sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + sort_spill_context_1->finishSpillableStage(); + sort_spill_context_2->finishSpillableStage(); + ASSERT_TRUE( + task_operator_spill_contexts.triggerAutoSpill(OperatorSpillContext::MIN_SPILL_THRESHOLD * 2.5) + == OperatorSpillContext::MIN_SPILL_THRESHOLD * 2.5); + ASSERT_FALSE(sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + ASSERT_FALSE(sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + + /// add new spill_context at runtime + auto sort_spill_context_3 = std::make_shared(*spill_config_ptr, 0, logger); + task_operator_spill_contexts.registerOperatorSpillContext(sort_spill_context_3); + sort_spill_context_3->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD); + ASSERT_TRUE( + task_operator_spill_contexts.triggerAutoSpill(OperatorSpillContext::MIN_SPILL_THRESHOLD * 2.5) + == OperatorSpillContext::MIN_SPILL_THRESHOLD * 1.5); + ASSERT_FALSE(sort_spill_context_1->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + ASSERT_FALSE(sort_spill_context_2->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); + ASSERT_TRUE(sort_spill_context_3->updateRevocableMemory(OperatorSpillContext::MIN_SPILL_THRESHOLD + 1)); +} +CATCH +} // namespace tests +} // namespace DB \ No newline at end of file diff --git a/dbms/src/Operators/MergeSortTransformOp.cpp b/dbms/src/Operators/MergeSortTransformOp.cpp index 2f7da1e4b67..a73685c177a 100644 --- a/dbms/src/Operators/MergeSortTransformOp.cpp +++ b/dbms/src/Operators/MergeSortTransformOp.cpp @@ -105,9 +105,14 @@ OperatorStatus MergeSortTransformOp::fromPartialToSpill() // convert to restore phase. status = MergeSortStatus::SPILL; assert(!cached_handler); - sort_spill_context->markSpill(); + sort_spill_context->markSpilled(); cached_handler = sort_spill_context->getSpiller()->createCachedSpillHandler( - std::make_shared(sorted_blocks, order_desc, log->identifier(), max_block_size, limit), + std::make_shared( + sorted_blocks, + order_desc, + log->identifier(), + std::max(1, max_block_size / 10), + limit), /*partition_id=*/0, [&]() { return exec_context.isCancelled(); }); // fallback to partial phase. @@ -124,6 +129,7 @@ OperatorStatus MergeSortTransformOp::fromSpillToPartial() sum_bytes_in_blocks = 0; sorted_blocks.clear(); status = MergeSortStatus::PARTIAL; + sort_spill_context->finishOneSpill(); return OperatorStatus::NEED_INPUT; } @@ -136,9 +142,9 @@ OperatorStatus MergeSortTransformOp::transformImpl(Block & block) if unlikely (!block) { sort_spill_context->finishSpillableStage(); - return hasSpilledData() - ? fromPartialToRestore() - : fromPartialToMerge(block); + if (!sorted_blocks.empty() && sort_spill_context->needFinalSpill()) + return fromPartialToSpill(); + return hasSpilledData() ? fromPartialToRestore() : fromPartialToMerge(block); } // store the sorted block in `sorted_blocks`. @@ -165,9 +171,7 @@ OperatorStatus MergeSortTransformOp::tryOutputImpl(Block & block) case MergeSortStatus::SPILL: { assert(cached_handler); - return cached_handler->batchRead() - ? OperatorStatus::IO_OUT - : fromSpillToPartial(); + return cached_handler->batchRead() ? OperatorStatus::IO_OUT : fromSpillToPartial(); } case MergeSortStatus::MERGE: { @@ -209,9 +213,7 @@ OperatorStatus MergeSortTransformOp::executeIOImpl() } } -void MergeSortTransformOp::transformHeaderImpl(Block &) -{ -} +void MergeSortTransformOp::transformHeaderImpl(Block &) {} bool MergeSortTransformOp::RestoredResult::hasData() const { diff --git a/dbms/src/Operators/MergeSortTransformOp.h b/dbms/src/Operators/MergeSortTransformOp.h index 2bab3b1cc7f..4291f296a47 100644 --- a/dbms/src/Operators/MergeSortTransformOp.h +++ b/dbms/src/Operators/MergeSortTransformOp.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -39,12 +40,10 @@ class MergeSortTransformOp : public TransformOp , max_block_size(max_block_size_) { sort_spill_context = std::make_shared(spill_config, max_bytes_before_external_sort, log); + exec_context.registerOperatorSpillContext(sort_spill_context); } - String getName() const override - { - return "MergeSortTransformOp"; - } + String getName() const override { return "MergeSortTransformOp"; } protected: void operatePrefixImpl() override; diff --git a/dbms/src/Operators/Operator.cpp b/dbms/src/Operators/Operator.cpp index 7419f86b3af..c7742867246 100644 --- a/dbms/src/Operators/Operator.cpp +++ b/dbms/src/Operators/Operator.cpp @@ -67,7 +67,10 @@ OperatorStatus Operator::await() // [non-waiting, waiting, waiting, waiting, .., waiting, non-waiting] if (op_status != OperatorStatus::WAITING) + { + exec_context.triggerAutoSpill(); profile_info.update(); + } return op_status; } @@ -79,6 +82,7 @@ OperatorStatus Operator::executeIO() #ifndef NDEBUG assertOperatorStatus(op_status, {OperatorStatus::FINISHED, OperatorStatus::NEED_INPUT, OperatorStatus::HAS_OUTPUT}); #endif + exec_context.triggerAutoSpill(); profile_info.update(); FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_pipeline_model_operator_run_failpoint); return op_status; @@ -98,6 +102,7 @@ OperatorStatus SourceOp::read(Block & block) } assertOperatorStatus(op_status, {OperatorStatus::HAS_OUTPUT}); #endif + exec_context.triggerAutoSpill(); if (op_status == OperatorStatus::HAS_OUTPUT) profile_info.update(block); else @@ -119,6 +124,7 @@ OperatorStatus TransformOp::transform(Block & block) } assertOperatorStatus(op_status, {OperatorStatus::NEED_INPUT, OperatorStatus::HAS_OUTPUT}); #endif + exec_context.triggerAutoSpill(); if (op_status == OperatorStatus::HAS_OUTPUT) profile_info.update(block); else @@ -141,6 +147,7 @@ OperatorStatus TransformOp::tryOutput(Block & block) } assertOperatorStatus(op_status, {OperatorStatus::NEED_INPUT, OperatorStatus::HAS_OUTPUT}); #endif + exec_context.triggerAutoSpill(); if (op_status == OperatorStatus::HAS_OUTPUT) profile_info.update(block); else @@ -177,6 +184,7 @@ OperatorStatus SinkOp::write(Block && block) #ifndef NDEBUG assertOperatorStatus(op_status, {OperatorStatus::FINISHED, OperatorStatus::NEED_INPUT}); #endif + exec_context.triggerAutoSpill(); profile_info.update(); FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::random_pipeline_model_operator_run_failpoint); return op_status;