diff --git a/Cargo.lock b/Cargo.lock index cbf73b2a20c..2e43deb764b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9618,6 +9618,7 @@ dependencies = [ "paste", "reqwest 0.13.3", "rstest", + "static_assertions", "tempfile", "tracing", "tracing-subscriber", diff --git a/vortex-duckdb/Cargo.toml b/vortex-duckdb/Cargo.toml index b6bb0109861..0f5e6b4797e 100644 --- a/vortex-duckdb/Cargo.toml +++ b/vortex-duckdb/Cargo.toml @@ -35,6 +35,7 @@ num-traits = { workspace = true } object_store = { workspace = true, features = ["aws"] } parking_lot = { workspace = true } paste = { workspace = true } +static_assertions = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } url = { workspace = true } diff --git a/vortex-duckdb/build.rs b/vortex-duckdb/build.rs index 868b417bf6c..4564327af85 100644 --- a/vortex-duckdb/build.rs +++ b/vortex-duckdb/build.rs @@ -355,6 +355,7 @@ fn cpp(duckdb_include_dir: &Path) { .flags(["-Wall", "-Wextra", "-Wpedantic"]) .cpp(true) .include(duckdb_include_dir) + .include("include") .include("cpp/include") .files(SOURCE_FILES) .compile("vortex-duckdb-extras"); diff --git a/vortex-duckdb/cbindgen.toml b/vortex-duckdb/cbindgen.toml index a9be94109db..ac21467f881 100644 --- a/vortex-duckdb/cbindgen.toml +++ b/vortex-duckdb/cbindgen.toml @@ -1,4 +1,5 @@ language = "C" +cpp_compat = true header = """ // SPDX-License-Identifier: Apache-2.0 @@ -23,3 +24,6 @@ trailer = """ // clang-format on """ + +[fn] +prefix = "extern" diff --git a/vortex-duckdb/cpp/CMakeLists.txt b/vortex-duckdb/cpp/CMakeLists.txt index 9671d93dd6d..db6dd8aeaa1 100644 --- a/vortex-duckdb/cpp/CMakeLists.txt +++ b/vortex-duckdb/cpp/CMakeLists.txt @@ -39,7 +39,7 @@ else() ) endif() -include_directories(include ${DUCKDB_INCLUDE}) +include_directories(include ${DUCKDB_INCLUDE} ../include) # Auto-discover C++ source files file(GLOB CPP_SOURCES "*.cpp") diff --git a/vortex-duckdb/cpp/include/duckdb_vx/table_function.h b/vortex-duckdb/cpp/include/duckdb_vx/table_function.h index 4b60207a036..4299d3a3b00 100644 --- a/vortex-duckdb/cpp/include/duckdb_vx/table_function.h +++ b/vortex-duckdb/cpp/include/duckdb_vx/table_function.h @@ -1,17 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -/** - * We redefine a C API for DuckDB Table Functions in order to expose the full functionality of the C++ API. - * - * Since this C API has no stability requirements (it's versioned lock-step with the Rust bindings), we can - * take a transparent vtable struct to populate the C++ Table Function vtable. - */ #pragma once - -#include "error.h" -#include "table_filter.h" #include "duckdb_vx/data.h" +#include "duckdb_vx/error.h" +#include "duckdb_vx/expr.h" +#include "table_filter.h" #include #ifdef __cplusplus @@ -73,9 +67,7 @@ typedef struct { // Result data returned from the cardinality callback. typedef struct { idx_t estimated_cardinality; - idx_t max_cardinality; bool has_estimated_cardinality; - bool has_max_cardinality; } duckdb_vx_node_statistics; typedef struct { @@ -98,46 +90,7 @@ typedef struct { size_t file_index; } duckdb_vx_partition_data; -// vtable mimicking subset of TableFunction. -// See duckdb/include/function/tfunc.hpp -typedef struct { - const char *name; - const duckdb_logical_type *parameters; - size_t parameter_count; - - duckdb_vx_data (*bind)(duckdb_client_context ctx, - duckdb_vx_tfunc_bind_input input, - duckdb_vx_tfunc_bind_result result, - duckdb_vx_error *error_out); - - duckdb_vx_data (*bind_data_clone)(const void *bind_data, duckdb_vx_error *error_out); - - duckdb_vx_data (*init_global)(const duckdb_vx_tfunc_init_input *input, duckdb_vx_error *error_out); - - duckdb_vx_data (*init_local)(void *init_global_data); - - void (*function)(void *init_global_data, - void *init_local_data, - duckdb_data_chunk data_chunk_out, - duckdb_vx_error *error_out); - - bool (*statistics)(const void *bind_data, size_t column_index, duckdb_column_statistics *stats_out); - - void (*cardinality)(void *bind_data, duckdb_vx_node_statistics *node_stats_out); - - bool (*pushdown_complex_filter)(void *bind_data, duckdb_vx_expr expr, duckdb_vx_error *error_out); - - void (*to_string)(void *bind_data, duckdb_vx_string_map map); - - double (*table_scan_progress)(void *global_state); - - void (*get_partition_data)(void *init_global_data, - void *init_local_data, - duckdb_vx_partition_data *partition_data_out); -} duckdb_vx_tfunc_vtab_t; - -// A single function for configuring the DuckDB table function vtable. -duckdb_state duckdb_vx_tfunc_register(duckdb_database ffi_db, const duckdb_vx_tfunc_vtab_t *vtab); +duckdb_state duckdb_vx_register_table_functions(duckdb_database ffi_db); #ifdef __cplusplus } diff --git a/vortex-duckdb/cpp/table_function.cpp b/vortex-duckdb/cpp/table_function.cpp index 14edeb2dd70..0adba325452 100644 --- a/vortex-duckdb/cpp/table_function.cpp +++ b/vortex-duckdb/cpp/table_function.cpp @@ -5,6 +5,7 @@ #include "duckdb_vx/duckdb_diagnostics.h" #include "duckdb_vx/error.hpp" #include "duckdb_vx/table_function.h" +#include "vortex.h" DUCKDB_INCLUDES_BEGIN #include "duckdb.h" @@ -15,58 +16,42 @@ DUCKDB_INCLUDES_BEGIN #include "duckdb/main/capi/capi_internal.hpp" #include "duckdb/main/connection.hpp" #include "duckdb/parser/parsed_data/create_table_function_info.hpp" -#include "duckdb/planner/expression/bound_operator_expression.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/planner/expression/bound_between_expression.hpp" -#include "duckdb/planner/expression/bound_conjunction_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" DUCKDB_INCLUDES_END +using namespace std::string_literals; using namespace duckdb; using vortex::CData; using vortex::IntoErrString; constexpr column_t COLUMN_IDENTIFIER_FILE_INDEX = MultiFileReader::COLUMN_IDENTIFIER_FILE_INDEX; constexpr column_t COLUMN_IDENTIFIER_FILE_ROW_NUMBER = MultiFileReader::COLUMN_IDENTIFIER_FILE_ROW_NUMBER; -struct CTableFunctionInfo final : TableFunctionInfo { - explicit CTableFunctionInfo(const duckdb_vx_tfunc_vtab_t &vtab) : vtab(vtab) { - } - - const duckdb_vx_tfunc_vtab_t vtab; -}; - struct CTableBindData final : FunctionData { - CTableBindData(const CTableFunctionInfo &info, - unique_ptr ffi_data_p, - const vector &types) - : info(info), ffi_data(std::move(ffi_data_p)), types(types) { + CTableBindData(unique_ptr ffi_data_p, const vector &types) + : ffi_data(std::move(ffi_data_p)), types(types) { } unique_ptr Copy() const override; bool Equals(const FunctionData &other_base) const override; - // Table function info lives for as long as TableFunction is alive as it's - // stored inside TableFunction, so it's safe to store a reference. - const CTableFunctionInfo &info; unique_ptr ffi_data; vector types; }; unique_ptr CTableBindData::Copy() const { duckdb_vx_error error_out = nullptr; - const auto copied_ffi_data = info.vtab.bind_data_clone(ffi_data->DataPtr(), &error_out); + const auto copied_ffi_data = duckdb_table_function_bind_data_clone(ffi_data->DataPtr(), &error_out); if (error_out) { throw BinderException(IntoErrString(error_out)); } auto ffi_data_p = unique_ptr(reinterpret_cast(copied_ffi_data)); - return make_uniq(info, std::move(ffi_data_p), types); + return make_uniq(std::move(ffi_data_p), types); } bool CTableBindData::Equals(const FunctionData &other_base) const { const CTableBindData &other = other_base.Cast(); // if "types" are different, "ffi_data" would also be different as it // contains types inside, so omit "types" from comparison. - return &info == &other.info && ffi_data.get() == other.ffi_data.get(); + return ffi_data.get() == other.ffi_data.get(); } struct CTableGlobalData final : GlobalTableFunctionState { @@ -87,12 +72,10 @@ struct CTableLocalData final : LocalTableFunctionState { unique_ptr ffi_data; }; -double table_scan_progress(ClientContext &, - const FunctionData *bind_data, - const GlobalTableFunctionState *global_state) { - auto &bind = bind_data->Cast(); +double +table_scan_progress(ClientContext &, const FunctionData *, const GlobalTableFunctionState *global_state) { void *const c_global_state = global_state->Cast().ffi_data->DataPtr(); - return bind.info.vtab.table_scan_progress(c_global_state); + return duckdb_table_function_scan_progress(c_global_state); } static Value &UnwrapValue(duckdb_value value) { @@ -152,7 +135,7 @@ unique_ptr statistics(ClientContext &, const FunctionData *bind_ void *const ffi_bind = bind.ffi_data->DataPtr(); duckdb_column_statistics statistics = {}; - if (!bind.info.vtab.statistics(ffi_bind, column_index, &statistics)) { + if (!duckdb_table_function_statistics(ffi_bind, column_index, &statistics)) { return {}; } @@ -204,21 +187,20 @@ unique_ptr c_bind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { - const auto &info = input.table_function.function_info->Cast(); CTableBindResult result = {return_types, names}; duckdb_vx_error error_out = nullptr; auto ctx = reinterpret_cast(&context); - auto ffi_bind_data = info.vtab.bind(ctx, - reinterpret_cast(&input), - reinterpret_cast(&result), - &error_out); + auto ffi_bind_data = duckdb_table_function_bind(ctx, + reinterpret_cast(&input), + reinterpret_cast(&result), + &error_out); if (error_out) { throw BinderException(IntoErrString(error_out)); } auto cdata = unique_ptr(reinterpret_cast(ffi_bind_data)); - return make_uniq(info, std::move(cdata), return_types); + return make_uniq(std::move(cdata), return_types); } unique_ptr c_init_global(ClientContext &context, TableFunctionInitInput &input) { @@ -235,7 +217,7 @@ unique_ptr c_init_global(ClientContext &context, Table }; duckdb_vx_error error_out = nullptr; - duckdb_vx_data ffi_global_data = bind.info.vtab.init_global(&ffi_input, &error_out); + duckdb_vx_data ffi_global_data = duckdb_table_function_init_global(&ffi_input, &error_out); if (error_out) { throw BinderException(IntoErrString(error_out)); } @@ -245,24 +227,21 @@ unique_ptr c_init_global(ClientContext &context, Table } unique_ptr -init_local(ExecutionContext &, TableFunctionInitInput &input, GlobalTableFunctionState *global_state) { - const auto &bind = input.bind_data->Cast(); +init_local(ExecutionContext &, TableFunctionInitInput &, GlobalTableFunctionState *global_state) { void *const ffi_global = global_state->Cast().ffi_data->DataPtr(); - duckdb_vx_data ffi_local_data = bind.info.vtab.init_local(ffi_global); + duckdb_vx_data ffi_local_data = duckdb_table_function_init_local(ffi_global); auto cdata = unique_ptr(reinterpret_cast(ffi_local_data)); return make_uniq(std::move(cdata)); } void function(ClientContext &, TableFunctionInput &input, DataChunk &output) { - const auto &bind = input.bind_data->Cast(); - void *const ffi_global = input.global_state->Cast().ffi_data->DataPtr(); void *const ffi_local = input.local_state->Cast().ffi_data->DataPtr(); duckdb_data_chunk chunk = reinterpret_cast(&output); duckdb_vx_error error_out = nullptr; - bind.info.vtab.function(ffi_global, ffi_local, chunk, &error_out); + duckdb_table_function_scan(ffi_global, ffi_local, chunk, &error_out); if (error_out) { throw InvalidInputException(IntoErrString(error_out)); } @@ -290,7 +269,7 @@ void pushdown_complex_filter(const FunctionData &bind_data, FilterVec &filters) for (auto iter = filters.begin(); iter != filters.end();) { duckdb_vx_expr ffi_expr = reinterpret_cast(iter->get()); - const bool pushed = bind.info.vtab.pushdown_complex_filter(ffi_bind, ffi_expr, &error_out); + const bool pushed = duckdb_table_function_pushdown_complex_filter(ffi_bind, ffi_expr, &error_out); if (error_out) { throw BinderException(IntoErrString(error_out)); } @@ -302,13 +281,12 @@ unique_ptr c_cardinality(ClientContext &, const FunctionData *bi auto &bind = bind_data->Cast(); duckdb_vx_node_statistics stats = {}; - bind.info.vtab.cardinality(bind.ffi_data->DataPtr(), &stats); + duckdb_table_function_cardinality(bind.ffi_data->DataPtr(), &stats); auto out = make_uniq(); out->has_estimated_cardinality = stats.has_estimated_cardinality; out->estimated_cardinality = stats.estimated_cardinality; - out->has_max_cardinality = stats.has_max_cardinality; - out->max_cardinality = stats.max_cardinality; + out->has_max_cardinality = false; return out; } @@ -357,11 +335,10 @@ TablePartitionInfo get_partition_info(ClientContext &, TableFunctionPartitionInp * each partition ~ exported array file_index is constant. */ OperatorPartitionData get_partition_data(ClientContext &, TableFunctionGetPartitionInput &input) { - auto &bind = input.bind_data->Cast(); void *const ffi_global = input.global_state->Cast().ffi_data->DataPtr(); void *const ffi_local = input.local_state->Cast().ffi_data->DataPtr(); duckdb_vx_partition_data partition_data; - bind.info.vtab.get_partition_data(ffi_global, ffi_local, &partition_data); + duckdb_table_function_get_partition_data(ffi_global, ffi_local, &partition_data); OperatorPartitionData out(partition_data.partition_index); @@ -390,21 +367,12 @@ InsertionOrderPreservingMap c_to_string(TableFunctionToStringInput &inpu InsertionOrderPreservingMap result; duckdb_vx_string_map ffi_map = reinterpret_cast(&result); void *const ffi_bind = input.bind_data->Cast().ffi_data->DataPtr(); - const auto &info = static_cast(*input.table_function.function_info); - info.vtab.to_string(ffi_bind, ffi_map); + duckdb_table_function_to_string(ffi_bind, ffi_map); return result; } -// pushdown_expression misses FunctionData so we can't place it in vtab -extern "C" bool duckdb_vx_pushdown_expression(duckdb_vx_expr expr); - -extern "C" duckdb_state duckdb_vx_tfunc_register(duckdb_database ffi_db, const duckdb_vx_tfunc_vtab_t *vtab) { - D_ASSERT(ffi_db); - D_ASSERT(vtab); - - const DatabaseWrapper &wrapper = *reinterpret_cast(ffi_db); - DatabaseInstance &db = *wrapper.database->instance; - TableFunction tf(vtab->name, {}, function, c_bind, c_init_global, init_local); +duckdb_state register_table_function(DatabaseInstance &db, LogicalType parameter, const std::string &name) { + TableFunction tf(name, {}, function, c_bind, c_init_global, init_local); tf.projection_pushdown = true; tf.filter_pushdown = true; @@ -412,7 +380,7 @@ extern "C" duckdb_state duckdb_vx_tfunc_register(duckdb_database ffi_db, const d tf.sampling_pushdown = false; tf.pushdown_expression = [](auto &, const auto &, Expression &expression) { - return duckdb_vx_pushdown_expression(reinterpret_cast(&expression)); + return duckdb_table_function_pushdown_expression(reinterpret_cast(&expression)); }; tf.pushdown_complex_filter = [](auto &, auto &, FunctionData *bind_data, FilterVec &filters) { pushdown_complex_filter(*bind_data, filters); @@ -442,12 +410,8 @@ extern "C" duckdb_state duckdb_vx_tfunc_register(duckdb_database ffi_db, const d }; }; - tf.arguments.resize(vtab->parameter_count); - for (size_t i = 0; i < vtab->parameter_count; i++) { - tf.arguments[i] = *reinterpret_cast(vtab->parameters[i]); - } - - tf.function_info = make_shared_ptr(*vtab); + tf.arguments.resize(1); + tf.arguments[0] = parameter; try { auto &system_catalog = Catalog::GetSystemCatalog(db); @@ -462,3 +426,18 @@ extern "C" duckdb_state duckdb_vx_tfunc_register(duckdb_database ffi_db, const d } return DuckDBSuccess; } + +extern "C" duckdb_state duckdb_vx_register_table_functions(duckdb_database ffi_db) { + D_ASSERT(ffi_db); + const DatabaseWrapper &wrapper = *reinterpret_cast(ffi_db); + DatabaseInstance &db = *wrapper.database->instance; + + for (LogicalType type : {LogicalType(LogicalType::VARCHAR), LogicalType::LIST(LogicalType::VARCHAR)}) { + for (const std::string &name : {"read_vortex"s, "vortex_scan"s}) { + if (register_table_function(db, type, name) == DuckDBError) { + return DuckDBError; + } + } + } + return DuckDBSuccess; +} diff --git a/vortex-duckdb/include/vortex.h b/vortex-duckdb/include/vortex.h index eb23b13fed4..5ac2f7afe77 100644 --- a/vortex-duckdb/include/vortex.h +++ b/vortex-duckdb/include/vortex.h @@ -14,6 +14,10 @@ extern "C" { #include "duckdb.h" +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + /** * Global symbol visibility in the Vortex extension: * - Rust functions use C ABI with "_rust" suffix (e.g., vortex_init_rust) @@ -25,20 +29,69 @@ extern "C" { * * The DuckDB extension ABI initialization function. */ -void vortex_init_rust(duckdb_database db); +extern void vortex_init_rust(duckdb_database db); /** * The DuckDB extension ABI version function. * This function returns the version of the DuckDB library the extension is built against. */ -const char *vortex_version_rust(void); +extern const char *vortex_version_rust(void); /** * An additional function we export to expose the version of the extension itself to C++ code. */ -const char *vortex_extension_version_rust(void); +extern const char *vortex_extension_version_rust(void); + +extern void duckdb_table_function_to_string(void *bind_data, duckdb_vx_string_map map); + +extern +bool duckdb_table_function_statistics(const void *bind_data, + uintptr_t column_index, + duckdb_column_statistics *stats_out); + +extern double duckdb_table_function_scan_progress(void *global_state); + +extern +void duckdb_table_function_get_partition_data(void *global_init_data, + void *local_init_data, + duckdb_vx_partition_data *partition_data_out); + +extern +bool duckdb_table_function_pushdown_complex_filter(void *bind_data, + duckdb_vx_expr expr, + duckdb_vx_error *error_out); -bool duckdb_vx_pushdown_expression(duckdb_vx_expr expr); +extern +void duckdb_table_function_scan(void *global_init_data, + void *local_init_data, + duckdb_data_chunk output, + duckdb_vx_error *error_out); + +extern bool duckdb_table_function_pushdown_expression(duckdb_vx_expr expr); + +extern +void duckdb_table_function_cardinality(void *bind_data, + duckdb_vx_node_statistics *node_stats_out); + +extern +duckdb_vx_data duckdb_table_function_init_global(const duckdb_vx_tfunc_init_input *init_input, + duckdb_vx_error *error_out); + +extern duckdb_vx_data duckdb_table_function_init_local(void *global_init_data); + +extern +duckdb_vx_data duckdb_table_function_bind(duckdb_client_context ctx, + duckdb_vx_tfunc_bind_input bind_input, + duckdb_vx_tfunc_bind_result bind_result, + duckdb_vx_error *error_out); + +extern +duckdb_vx_data duckdb_table_function_bind_data_clone(const void *bind_data, + duckdb_vx_error *error_out); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus #ifdef __cplusplus } diff --git a/vortex-duckdb/src/column_statistics.rs b/vortex-duckdb/src/column_statistics.rs new file mode 100644 index 00000000000..ccc71eeade1 --- /dev/null +++ b/vortex-duckdb/src/column_statistics.rs @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::stats::StatsSet; +use vortex::dtype::DType; +use vortex::error::VortexExpect as _; +use vortex::expr::stats::Precision; +use vortex::expr::stats::Stat; +use vortex::scalar::Scalar; +use vortex::scalar::ScalarValue; + +use crate::convert::ToDuckDBScalar as _; +use crate::duckdb::Value; + +#[derive(Debug, Default)] +pub struct ColumnStatistics { + pub min: Option, + pub max: Option, + pub max_string_length: u64, + pub has_null: bool, +} + +impl ColumnStatistics { + pub fn from(stats: &ColumnStatisticsAggregate, dtype: DType) -> Self { + let min = stats.min.as_ref().map(|value| { + let value = value.clone(); + Scalar::try_new(dtype.clone(), Some(value)) + .vortex_expect("scalar dtype and value are incompatible") + .try_to_duckdb_scalar() + .vortex_expect("can't convert Scalar to duckdb Value") + }); + let max = stats.max.as_ref().map(|value| { + Scalar::try_new(dtype.clone(), Some(value.clone())) + .vortex_expect("scalar dtype and value are incompatible") + .try_to_duckdb_scalar() + .vortex_expect("can't convert Scalar to duckdb Value") + }); + + let max_string_length = stats + .max_string_length + .map_or(0, |len| (1u64 << 63) | (len as u64)); + + // Useful estimate if we didn't get null count stats + let has_null = stats.has_null && dtype.is_nullable(); + + Self { + min, + max, + max_string_length, + has_null, + } + } +} + +#[derive(Default)] +pub struct ColumnStatisticsAggregate { + pub min: Option, + pub max: Option, + pub max_string_length: Option, + /// May be true if null count stat isn't present + pub has_null: bool, +} + +impl ColumnStatisticsAggregate { + pub fn new(stats: &StatsSet) -> Self { + let min = match stats.get(Stat::Min) { + Precision::Exact(min) => Some(min), + _ => None, + }; + let max = match stats.get(Stat::Max) { + Precision::Exact(max) => Some(max), + _ => None, + }; + + let max_string_length = + if let Precision::Exact(value) = stats.get(Stat::UncompressedSizeInBytes) { + // DuckDB's string length is u32 + #[allow(clippy::cast_possible_truncation)] + Some(value.as_primitive().as_u64().vortex_expect("not a u64") as u32) + } else { + None + }; + + let has_null = match stats.get(Stat::NullCount) { + Precision::Exact(cnt) => cnt.as_primitive().as_u64().vortex_expect("not a u64") > 0, + _ => true, + }; + + Self { + min, + max, + max_string_length, + has_null, + } + } +} diff --git a/vortex-duckdb/src/convert/expr.rs b/vortex-duckdb/src/convert/expr.rs index 1adb64db181..2fdd6d9c033 100644 --- a/vortex-duckdb/src/convert/expr.rs +++ b/vortex-duckdb/src/convert/expr.rs @@ -117,6 +117,20 @@ pub(super) fn try_from_bound_expression_with_col_sub( try_from_expression_inner(value, Some(col_sub)) } +/* + * Called before pushdown_complex_filter or a table filter expression call. + * As we support complex filter pushdown, Duckdb pushes expressions to Vortex. + * However, it doesn't know what type of expressions we can handle. Here we list + * all expressions that are quaranteed to be converted to Vortex expressions. + * + * If we return true here, and expression is in the list for + * pushdown_complex_filter, we must handle it, or query engine will break. + * + * Example: we don't support substr() expression so we tell Duckdb we can't + * push it. + * Example: optional filters may fail to parse on our side (we return + * Ok(None)), so we don't allow pushing these. + */ pub fn can_push_expression(value: &duckdb::ExpressionRef) -> bool { let Some(value) = value.as_class() else { return false; diff --git a/vortex-duckdb/src/datasource.rs b/vortex-duckdb/src/datasource.rs deleted file mode 100644 index 36618bc8913..00000000000 --- a/vortex-duckdb/src/datasource.rs +++ /dev/null @@ -1,906 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Reusable logic for driving a [`DataSourceRef`] scan through DuckDB's table function interface. -//! -//! Table functions that resolve to a [`DataSourceRef`] can implement [`DataSourceTableFunction`] -//! to get a blanket [`TableFunction`] implementation covering init, scan, progress, filter -//! pushdown, cardinality, and partitioning. - -use std::cmp::max; -use std::fmt::Debug; -use std::ops::Range; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use std::sync::atomic::AtomicU64; -use std::sync::atomic::Ordering; - -use custom_labels::CURRENT_LABELSET; -use futures::StreamExt; -use itertools::Itertools; -use num_traits::AsPrimitive; -use tracing::debug; -use vortex::array::ArrayRef; -use vortex::array::Canonical; -use vortex::array::VortexSessionExecute; -use vortex::array::arrays::ScalarFn; -use vortex::array::arrays::Struct; -use vortex::array::arrays::StructArray; -use vortex::array::arrays::scalar_fn::ScalarFnArrayExt; -use vortex::array::optimizer::ArrayOptimizer; -use vortex::array::stats::StatsSet; -use vortex::dtype::DType; -use vortex::dtype::FieldNames; -use vortex::error::VortexExpect; -use vortex::error::VortexResult; -use vortex::error::vortex_err; -use vortex::expr::Expression; -use vortex::expr::and_collect; -use vortex::expr::col; -use vortex::expr::merge; -use vortex::expr::pack; -use vortex::expr::root; -use vortex::expr::select; -use vortex::expr::stats::Precision; -use vortex::expr::stats::Stat; -use vortex::file::v2::FileStatsLayoutReader; -use vortex::io::kanal_ext::KanalExt; -use vortex::io::runtime::BlockingRuntime; -use vortex::io::runtime::current::ThreadSafeIterator; -use vortex::layout::layouts::row_idx::row_idx; -use vortex::layout::scan::multi::MultiLayoutChild; -use vortex::layout::scan::multi::MultiLayoutDataSource; -use vortex::metrics::tracing::get_global_labels; -use vortex::scalar::Scalar; -use vortex::scalar::ScalarValue; -use vortex::scalar_fn::fns::binary::Binary; -use vortex::scalar_fn::fns::operators::Operator; -use vortex::scalar_fn::fns::pack::Pack; -use vortex::scan::DataSource; -use vortex::scan::ScanRequest; -use vortex::scan::selection::Selection; -use vortex_utils::aliases::hash_set::HashSet; -use vortex_utils::parallelism::get_available_parallelism; - -use crate::RUNTIME; -use crate::SESSION; -use crate::convert::ToDuckDBScalar; -use crate::convert::try_from_bound_expression; -use crate::convert::try_from_table_filter; -use crate::convert::try_from_virtual_column_filter; -use crate::duckdb::BindInputRef; -use crate::duckdb::BindResultRef; -use crate::duckdb::Cardinality; -use crate::duckdb::ClientContextRef; -use crate::duckdb::ColumnStatistics; -use crate::duckdb::DataChunkRef; -use crate::duckdb::DuckdbStringMapRef; -use crate::duckdb::ExpressionRef; -use crate::duckdb::LogicalType; -use crate::duckdb::PartitionData; -use crate::duckdb::TableFilterClass; -use crate::duckdb::TableFilterSetRef; -use crate::duckdb::TableFunction; -use crate::duckdb::TableInitInput; -use crate::duckdb::Value; -use crate::exporter::ArrayExporter; -use crate::exporter::ConversionCache; - -// See MultiFileReader for constants - -/// "file_index" virtual column -static FILE_INDEX_COLUMN_IDX: u64 = 9223372036854775810; -/// "file_row_number" virtual column -static FILE_ROW_NUMBER_COLUMN_IDX: u64 = 9223372036854775809; - -/// See duckdb/src/common/constants.cpp -fn is_virtual_column(id: u64) -> bool { - id >= 9223372036854775808u64 -} - -/// A trait for table functions that resolve to a [`DataSourceRef`]. -/// -/// Implementors only need to define how parameters are declared and how binding produces a -/// data source. All other [`TableFunction`] methods (init, scan, progress, filter pushdown, -/// cardinality, partitioning) are provided by a blanket implementation. -pub(crate) trait DataSourceTableFunction: Sized + Debug { - /// Positional parameters - fn parameters() -> Vec; - - /// Bind the table function and return a [`DataSourceRef`]. - fn bind(ctx: &ClientContextRef, input: &BindInputRef) -> VortexResult; -} - -#[derive(Debug, Clone)] -struct DuckdbField { - name: String, - logical_type: LogicalType, - dtype: DType, -} - -/// Bind data produced by a [`DataSourceTableFunction`]. -pub struct DataSourceBindData { - data_source: Arc, - filter_exprs: Vec, - column_fields: Vec, - // There exists at least one non-optional table filter or at least one - // complex filter is pushed down. - has_non_optional_filter: AtomicBool, -} - -impl Clone for DataSourceBindData { - fn clone(&self) -> Self { - Self { - data_source: Arc::clone(&self.data_source), - // filter_exprs are consumed once in `init_global`. - filter_exprs: vec![], - column_fields: self.column_fields.clone(), - has_non_optional_filter: AtomicBool::new( - self.has_non_optional_filter.load(Ordering::Relaxed), - ), - } - } -} - -impl Debug for DataSourceBindData { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("DataSourceBindData") - .field("column_fields", &self.column_fields) - .field( - "filter_exprs", - &self - .filter_exprs - .iter() - .map(|e| e.to_string()) - .collect::>(), - ) - .finish() - } -} - -type DataSourceIterator = ThreadSafeIterator)>>; - -/// Global scan state for driving a `DataSource` scan through DuckDB. -pub struct DataSourceGlobal { - iterator: DataSourceIterator, - batch_id: AtomicU64, - bytes_total: Arc, - bytes_read: AtomicU64, - file_index_column_pos: Option, - file_row_number_column_pos: Option, -} - -/// Per-thread local scan state. -pub struct DataSourceLocal { - iterator: DataSourceIterator, - exporter: Option, - partition_index: u64, - file_index: usize, -} - -/// Returns scan progress as a percentage (0.0–100.0). -fn progress(bytes_read: &AtomicU64, bytes_total: &AtomicU64) -> f64 { - let read = bytes_read.load(Ordering::Relaxed); - let mut total = bytes_total.load(Ordering::Relaxed); - total += (total == 0) as u64; - read as f64 / total as f64 * 100. -} - -impl ColumnStatistics { - fn from(stats: &ColumnStatisticsAggregate, dtype: DType) -> Self { - let min = stats.min.as_ref().map(|value| { - let value = value.clone(); - Scalar::try_new(dtype.clone(), Some(value)) - .vortex_expect("scalar dtype and value are incompatible") - .try_to_duckdb_scalar() - .vortex_expect("can't convert Scalar to duckdb Value") - }); - let max = stats.max.as_ref().map(|value| { - Scalar::try_new(dtype.clone(), Some(value.clone())) - .vortex_expect("scalar dtype and value are incompatible") - .try_to_duckdb_scalar() - .vortex_expect("can't convert Scalar to duckdb Value") - }); - - let max_string_length = stats - .max_string_length - .map_or(0, |len| (1u64 << 63) | (len as u64)); - - // Useful estimate if we didn't get null count stats - let has_null = stats.has_null && dtype.is_nullable(); - - Self { - min, - max, - max_string_length, - has_null, - } - } -} - -#[derive(Default)] -pub struct ColumnStatisticsAggregate { - pub min: Option, - pub max: Option, - pub max_string_length: Option, - /// May be true if null count stat isn't present - pub has_null: bool, -} - -impl ColumnStatisticsAggregate { - pub fn new(stats: &StatsSet) -> Self { - let min = match stats.get(Stat::Min) { - Precision::Exact(min) => Some(min), - _ => None, - }; - let max = match stats.get(Stat::Max) { - Precision::Exact(max) => Some(max), - _ => None, - }; - - let max_string_length = - if let Precision::Exact(value) = stats.get(Stat::UncompressedSizeInBytes) { - // DuckDB's string length is u32 - #[allow(clippy::cast_possible_truncation)] - Some(value.as_primitive().as_u64().vortex_expect("not a u64") as u32) - } else { - None - }; - - let has_null = match stats.get(Stat::NullCount) { - Precision::Exact(cnt) => cnt.as_primitive().as_u64().vortex_expect("not a u64") > 0, - _ => true, - }; - - Self { - min, - max, - max_string_length, - has_null, - } - } -} - -// Duckdb requires post-filter cardinality estimates, otherwise join -// planner may flip join sides which is a huge regression for some -// queries i.e. 1000x for tpcds 85. -// -// See duckdb/src/optimizer/join_order/relation_statistics_helper.cpp -// As we don't report distinct values (same as Parquet), the only heuristic -// duckdb uses is a 0.2 filter if there is any non-optional filter. We mimic -// it here. -const DEFAULT_SELECTIVITY: f64 = 0.2; -fn postfilter_cardinality(initial_cardinality: u64, has_non_optional_filter: bool) -> u64 { - if has_non_optional_filter { - let post_cardinality = initial_cardinality as f64 * DEFAULT_SELECTIVITY; - // Clamp intentionally - let post_cardinality: u64 = post_cardinality.as_(); - max(1, post_cardinality) - } else { - initial_cardinality - } -} - -impl TableFunction for T { - type BindData = DataSourceBindData; - type GlobalState = DataSourceGlobal; - type LocalState = DataSourceLocal; - - fn parameters() -> Vec { - T::parameters() - } - - fn bind( - ctx: &ClientContextRef, - input: &BindInputRef, - result: &mut BindResultRef, - ) -> VortexResult { - let data_source = T::bind(ctx, input)?; - let column_fields = extract_schema_from_dtype(data_source.dtype())?; - for fields in &column_fields { - result.add_result_column(&fields.name, &fields.logical_type); - } - Ok(DataSourceBindData { - data_source: Arc::new(data_source), - filter_exprs: vec![], - column_fields, - has_non_optional_filter: AtomicBool::new(false), - }) - } - - fn init_global(init_input: &TableInitInput) -> VortexResult { - debug!(input=?init_input, "table function global input"); - - let bind_data = init_input.bind_data(); - let column_ids = init_input.column_ids(); - let projection_ids = init_input.projection_ids(); - - let ProjectionWithVirtualColumns { - projection, - file_index_column_pos, - file_row_number_column_pos, - } = extract_projection_expr(projection_ids, column_ids, &bind_data.column_fields); - - let FilterWithVirtualColumns { - filter, - row_selection, - row_range, - file_selection, - file_range, - has_non_optional_filter, - } = extract_table_filter_expr( - init_input.table_filter_set(), - column_ids, - &bind_data.column_fields, - &bind_data.filter_exprs, - bind_data.data_source.dtype(), - )?; - - if has_non_optional_filter { - init_input - .bind_data() - .has_non_optional_filter - .store(true, Ordering::Relaxed); - } - - debug!( - %projection, - filter = filter - .as_ref() - .map_or_else(|| "true".to_string(), |f| f.to_string()), - ?row_selection, - ?row_range, - ?file_selection, - ?file_range, - "table function scan input" - ); - - let request = ScanRequest { - projection, - filter, - ordered: file_row_number_column_pos.is_some(), - selection: row_selection, - row_range, - partition_selection: file_selection, - partition_range: file_range, - limit: None, - }; - - let scan = RUNTIME.block_on(bind_data.data_source.scan(request))?; - - let num_workers = get_available_parallelism().unwrap_or(1); - - // We create an async bounded channel so that all thread-local workers can pull the next - // available array chunk regardless of which partition it came from. - let (tx, rx) = kanal::bounded_async(num_workers * 2); - - // We drive one partition per worker thread. Each partition is driven as a spawned task - // that pushes array chunks into the shared channel as they are produced. This spawning - // allows all worker threads to drive the polling of all partitions, and then return the - // first available array chunk. - let stream = scan - .partitions() - .map(move |partition| { - let tx = tx.clone(); - RUNTIME.handle().spawn(async move { - let partition = match partition { - Ok(partition) => partition, - Err(e) => { - let _ = tx.send(Err(e)).await; - return; - } - }; - - let cache = Arc::new(ConversionCache { - file_index: partition.index(), - ..Default::default() - }); - - let mut stream = match partition.execute() { - Ok(s) => s, - Err(e) => { - let _ = tx.send(Err(e)).await; - return; - } - }; - while let Some(item) = stream.next().await { - if tx - .send(item.map(|a| (a, Arc::clone(&cache)))) - .await - .is_err() - { - // Exit early if the receiver has been dropped, which happens when the - // scan is complete or if an error has occurred in another partition. - return; - } - } - }) - }) - .buffer_unordered(num_workers); - - // Spawn a task to drive the partition stream and push array chunks into the channel. - RUNTIME.handle().spawn(stream.collect::<()>()).detach(); - - let iterator = RUNTIME.block_on_stream_thread_safe(|_handle| rx.into_stream()); - - Ok(DataSourceGlobal { - iterator, - batch_id: AtomicU64::new(0), - bytes_total: Arc::new(AtomicU64::new(0)), - bytes_read: AtomicU64::new(0), - file_index_column_pos, - file_row_number_column_pos, - }) - } - - fn init_local(global: &Self::GlobalState) -> Self::LocalState { - unsafe { - use custom_labels::sys; - - if sys::current().is_null() { - let ls = sys::new(0); - sys::replace(ls); - }; - } - - let global_labels = get_global_labels(); - - for (key, value) in global_labels { - CURRENT_LABELSET.set(key, value); - } - - DataSourceLocal { - iterator: global.iterator.clone(), - exporter: None, - partition_index: 0, - file_index: 0, - } - } - - fn scan( - local_state: &mut Self::LocalState, - global_state: &Self::GlobalState, - chunk: &mut DataChunkRef, - ) -> VortexResult<()> { - loop { - if local_state.exporter.is_none() { - let mut ctx = SESSION.create_execution_ctx(); - let Some(result) = local_state.iterator.next() else { - return Ok(()); - }; - let (array_result, conversion_cache) = result?; - let array_result = array_result.optimize_recursive(ctx.session())?; - local_state.file_index = conversion_cache.file_index; - - let array_result: StructArray = if let Some(array) = array_result.as_opt::() - { - array.into_owned() - } else if let Some(array) = array_result.as_opt::() - && let Some(pack_options) = array.scalar_fn().as_opt::() - { - StructArray::new( - pack_options.names.clone(), - array.children(), - array.len(), - pack_options.nullability.into(), - ) - } else { - array_result.execute::(&mut ctx)?.into_struct() - }; - - local_state.exporter = Some(ArrayExporter::try_new( - &array_result, - &conversion_cache, - ctx, - )?); - // Relaxed since there is no intra-instruction ordering required. - local_state.partition_index = global_state.batch_id.fetch_add(1, Ordering::Relaxed); - } - - let exporter = local_state - .exporter - .as_mut() - .vortex_expect("error: exporter missing"); - let has_more_data = exporter.export( - chunk, - global_state.file_index_column_pos, - global_state.file_row_number_column_pos, - )?; - - global_state - .bytes_read - .fetch_add(chunk.len(), Ordering::Relaxed); - - if !has_more_data { - // This exporter is fully consumed. - local_state.exporter = None; - local_state.partition_index = 0; - } else { - break; - } - } - - assert!(!chunk.is_empty()); - - if let Some(pos) = global_state.file_index_column_pos { - chunk - .get_vector_mut(pos) - .reference_value(&Value::from(local_state.file_index as u64)); - } - - Ok(()) - } - - fn table_scan_progress(global_state: &Self::GlobalState) -> f64 { - progress(&global_state.bytes_read, &global_state.bytes_total) - } - - fn pushdown_complex_filter( - bind_data: &mut Self::BindData, - expr: &ExpressionRef, - ) -> VortexResult { - debug!(%expr, "pushing down expression"); - - let Some(expr) = try_from_bound_expression(expr)? else { - debug!(%expr, "failed to push down expression"); - return Ok(false); - }; - - // Duckdb calls pushdown_complex_filter during planning phase. - // If all filters are pushed down, duckdb enables a LEFT_DELIM_JOIN -> - // COMPARISON_JOIN (HASH_JOIN) optimization: - // duckdb/src/optimizer/deliminator.cpp: Deliminator::HasSelection, - // Deliminator::Optimize. - // - // This leads to a massive regression on tpch sf=10 q17 and other - // benchmarks. - // - // This bug is reported to Duckdb - // https://github.com/duckdb/duckdb/issues/22669 - // - // As a hack, report equality filters as not pushed. - // We can also report only the first filter as not pushed, but this - // has a negative performance impact. - let report_pushed = !expr - .as_opt::() - .map(|op| *op == Operator::Eq) - .unwrap_or(false); - - // Only table filters may be optional, any complex filter is - // non-optional by definition. - bind_data - .has_non_optional_filter - .store(true, Ordering::Relaxed); - - debug!(%expr, report_pushed, "pushed down expression"); - bind_data.filter_exprs.push(expr); - Ok(report_pushed) - } - - /// Get column-wise statistics. Available only if we're reading a single - /// file. - fn statistics(bind_data: &Self::BindData, column_index: usize) -> Option { - let children = bind_data.data_source.children(); - // Otherwise we'd have to open all files eagerly which is a performance - // regression. Duckdb's Parquet reader only gets metadata for multiple - // files with a UNION BY NAME and we don't support it (yet) - // See duckdb/common/multi_file/multi_file_function.hpp#L691 - if children.len() != 1 { - return None; - } - let MultiLayoutChild::Opened(reader) = &children[0] else { - return None; - }; - let stats_sets = match reader.as_any().downcast_ref::() { - Some(inner) => inner.file_stats().stats_sets(), - None => return None, - }; - let stats_aggregate = ColumnStatisticsAggregate::new(&stats_sets[column_index]); - let dtype = bind_data.column_fields[column_index].dtype.clone(); - Some(ColumnStatistics::from(&stats_aggregate, dtype)) - } - - fn cardinality(bind_data: &Self::BindData) -> Cardinality { - match bind_data.data_source.row_count() { - Precision::Exact(v) | Precision::Inexact(v) => { - let has_non_optional_filter = - bind_data.has_non_optional_filter.load(Ordering::Relaxed); - // Post-filter estimate is always a heuristic. - Cardinality::Estimate(postfilter_cardinality(v, has_non_optional_filter)) - } - Precision::Absent => Cardinality::Unknown, - } - } - - fn partition_data( - global_init_data: &Self::GlobalState, - local_init_data: &mut Self::LocalState, - ) -> PartitionData { - PartitionData { - partition_index: local_init_data.partition_index, - file_index_column_pos: global_init_data.file_index_column_pos, - file_index: local_init_data.file_index, - } - } - - fn to_string(bind_data: &Self::BindData, map: &mut DuckdbStringMapRef) { - map.push("Function", "Vortex Scan"); - if !bind_data.filter_exprs.is_empty() { - let mut filters = bind_data.filter_exprs.iter().map(|f| format!("{f}")); - map.push("Filters", &filters.join("\n")); - } - } -} - -/// Extracts DuckDB column names and logical types from a Vortex struct DType. -fn extract_schema_from_dtype(dtype: &DType) -> VortexResult> { - let struct_dtype = dtype - .as_struct_fields_opt() - .ok_or_else(|| vortex_err!("Vortex file must contain a struct array at the top level"))?; - - let len = struct_dtype.names().len(); - let mut fields = Vec::with_capacity(len); - - for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) { - let logical_type = LogicalType::try_from(&field_dtype)?; - fields.push(DuckdbField { - name: field_name.to_string(), - logical_type, - dtype: field_dtype, - }); - } - Ok(fields) -} - -struct ProjectionWithVirtualColumns { - projection: Expression, - file_index_column_pos: Option, - file_row_number_column_pos: Option, -} - -fn extract_projection_expr( - projection_ids: Option<&[u64]>, - column_ids: &[u64], - column_fields: &[DuckdbField], -) -> ProjectionWithVirtualColumns { - // If projection ids are empty, use column_ids. - // See duckdb/src/planner/operator/logical_get.cpp#L168 - let (ids, has_projection_ids) = match projection_ids { - Some(ids) => (ids, true), - None => (column_ids, false), - }; - - let mut file_index_column_pos = None; - let mut file_row_number_column_pos = None; - let mut is_star = true; - let mut real_column_count = 0; - - // DuckDB uses u64 as column indices but Rust uses usize - for (column_pos, &column_id) in ids.iter().enumerate() { - let column_id = if has_projection_ids { - let column_id: usize = column_id.as_(); - column_ids[column_id] - } else { - column_id - }; - - if column_id == FILE_INDEX_COLUMN_IDX { - file_index_column_pos = Some(column_pos); - continue; - } - if column_id == FILE_ROW_NUMBER_COLUMN_IDX { - file_row_number_column_pos = Some(column_pos); - continue; - } - - // In SELECT * DuckDB requests all columns from 0 to column_fields in - // increasing order. After removing virtual columns, compare column_id - // with (0..column_fields.len()) range. - is_star &= column_id == real_column_count; - real_column_count += 1; - } - // Duckdb can request less columns than there are in table i.e. [0, 1] with - // 5 columns total. - is_star &= real_column_count == column_fields.len() as u64; - - let select = if is_star { - root() - } else { - let names = ids - .iter() - .map(|&column_id| { - if has_projection_ids { - let column_id: usize = column_id.as_(); - column_ids[column_id] - } else { - column_id - } - }) - .filter(|&col_id| !is_virtual_column(col_id)) - .map(|column_id| { - let column_id: usize = column_id.as_(); - Arc::from(column_fields[column_id].name.as_str()) - }) - .collect::(); - - select(names, root()) - }; - - // file_index column will be filled later when exporting the chunk. - let projection = if file_row_number_column_pos.is_some() { - // row_idx will be moved to correct position in scan(), prepend here - let row_idx_struct = pack([("file_row_number", row_idx())], false.into()); - merge([row_idx_struct, select]) - } else { - select - }; - - ProjectionWithVirtualColumns { - projection, - file_index_column_pos, - file_row_number_column_pos, - } -} - -struct FilterWithVirtualColumns { - filter: Option, - row_selection: Selection, - row_range: Option>, - file_selection: Selection, - file_range: Option>, - has_non_optional_filter: bool, -} - -/// Creates a table filter expression, row selection, and row range from the table filter set, -/// column metadata, additional filter expressions, and the top-level DType. -fn extract_table_filter_expr( - table_filter_set: Option<&TableFilterSetRef>, - column_ids: &[u64], - column_fields: &[DuckdbField], - additional_filters: &[Expression], - dtype: &DType, -) -> VortexResult { - let mut has_non_optional_filter = false; - - let mut table_filter_exprs: HashSet = if let Some(filter) = table_filter_set { - filter - .into_iter() - .filter(|(idx, _)| { - let idx_u: usize = idx.as_(); - !is_virtual_column(column_ids[idx_u]) - }) - .map(|(idx, ex)| { - has_non_optional_filter |= !matches!(ex.as_class(), TableFilterClass::Optional(_)); - - let idx_u: usize = idx.as_(); - let col_idx: usize = column_ids[idx_u].as_(); - let name = &column_fields.get(col_idx).vortex_expect("exists").name; - try_from_table_filter(ex, &col(name.as_str()), dtype) - }) - .collect::>>>()? - .unwrap_or_else(HashSet::new) - } else { - HashSet::new() - }; - - table_filter_exprs.extend(additional_filters.iter().cloned()); - - let mut file_selection = Selection::All; - let mut row_selection = Selection::All; - let mut row_range = None; - let mut file_range = None; - if let Some(filter) = table_filter_set { - for (idx, expression) in filter.into_iter() { - let idx: usize = idx.as_(); - if column_ids[idx] == FILE_ROW_NUMBER_COLUMN_IDX { - (row_selection, row_range) = try_from_virtual_column_filter(expression)?; - } - if column_ids[idx] == FILE_INDEX_COLUMN_IDX { - (file_selection, file_range) = try_from_virtual_column_filter(expression)?; - } - } - }; - - let out = FilterWithVirtualColumns { - filter: and_collect(table_filter_exprs), - row_selection, - row_range, - file_selection, - file_range, - has_non_optional_filter, - }; - Ok(out) -} - -#[cfg(test)] -mod tests { - use std::sync::atomic::AtomicU64; - use std::sync::atomic::Ordering::Relaxed; - - use vortex::dtype::DType; - use vortex::expr::merge; - use vortex::expr::pack; - use vortex::expr::root; - use vortex::layout::layouts::row_idx::row_idx; - - use super::progress; - use crate::datasource::DuckdbField; - use crate::datasource::FILE_INDEX_COLUMN_IDX; - use crate::datasource::FILE_ROW_NUMBER_COLUMN_IDX; - use crate::datasource::extract_projection_expr; - use crate::duckdb::LogicalType; - - #[test] - fn test_table_scan_progress() { - let bytes_total = AtomicU64::new(100); - let bytes_read = AtomicU64::new(0); - - assert_eq!(progress(&bytes_read, &bytes_total), 0.0); - - bytes_read.fetch_add(100, Relaxed); - assert_eq!(progress(&bytes_read, &bytes_total), 100.); - - bytes_total.fetch_add(100, Relaxed); - assert!((progress(&bytes_read, &bytes_total) - 50.).abs() < f64::EPSILON); - } - - #[test] - fn test_select_star() { - let ids = [0, 1, 2]; - let fields = [ - DuckdbField { - name: "".to_owned(), - logical_type: LogicalType::null(), - dtype: DType::Null, - }, - DuckdbField { - name: "".to_owned(), - logical_type: LogicalType::null(), - dtype: DType::Null, - }, - DuckdbField { - name: "".to_owned(), - logical_type: LogicalType::null(), - dtype: DType::Null, - }, - ]; - - assert_eq!( - extract_projection_expr(None, &ids, &fields).projection, - root() - ); - - let ids = [FILE_ROW_NUMBER_COLUMN_IDX, 0, 1, FILE_INDEX_COLUMN_IDX, 2]; - let exprs = extract_projection_expr(None, &ids, &fields); - let row_idx_struct = pack([("file_row_number", row_idx())], false.into()); - let root_with_virtual_cols = merge([row_idx_struct, root()]); - - assert_eq!(exprs.projection, root_with_virtual_cols); - assert_eq!(exprs.file_index_column_pos, Some(3)); - assert_eq!(exprs.file_row_number_column_pos, Some(0)); - - // projections can't be set in SELECT *. - assert_ne!( - extract_projection_expr(Some(&[0, 1]), &ids, &fields).projection, - root() - ); - - let ids = [0, 1]; - assert_ne!( - extract_projection_expr(None, &ids, &fields).projection, - root() - ); - - let ids = [0, 2, 2]; - assert_ne!( - extract_projection_expr(None, &ids, &fields).projection, - root() - ); - - let ids = [2, 1, 0]; - assert_ne!( - extract_projection_expr(None, &ids, &fields).projection, - root() - ); - } -} diff --git a/vortex-duckdb/src/duckdb/bind_input.rs b/vortex-duckdb/src/duckdb/bind_input.rs new file mode 100644 index 00000000000..1049a8065f6 --- /dev/null +++ b/vortex-duckdb/src/duckdb/bind_input.rs @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use crate::cpp; +use crate::duckdb::LogicalTypeRef; +use crate::duckdb::Value; +use crate::lifetime_wrapper; + +lifetime_wrapper!(BindInput, cpp::duckdb_vx_tfunc_bind_input, |_| {}); + +impl BindInputRef { + /// Returns the parameter at the given index. + pub fn get_parameter(&self, index: usize) -> Option { + let value_ptr = + unsafe { cpp::duckdb_vx_tfunc_bind_input_get_parameter(self.as_ptr(), index as _) }; + if value_ptr.is_null() { + None + } else { + Some(unsafe { Value::own(value_ptr) }) + } + } +} + +lifetime_wrapper!(BindResult, cpp::duckdb_vx_tfunc_bind_result, |_| {}); + +impl BindResultRef { + pub fn add_result_column(&self, name: &str, logical_type: &LogicalTypeRef) { + unsafe { + cpp::duckdb_vx_tfunc_bind_result_add_column( + self.as_ptr(), + name.as_ptr().cast(), + name.len() as _, + logical_type.as_ptr(), + ) + } + } +} diff --git a/vortex-duckdb/src/duckdb/database.rs b/vortex-duckdb/src/duckdb/database.rs index dc203b4aac5..effacfc75cd 100644 --- a/vortex-duckdb/src/duckdb/database.rs +++ b/vortex-duckdb/src/duckdb/database.rs @@ -131,6 +131,14 @@ impl Database { } impl DatabaseRef { + pub fn register_table_functions(&self) -> VortexResult<()> { + duckdb_try!( + unsafe { cpp::duckdb_vx_register_table_functions(self.as_ptr()) }, + "Failed to register table functions" + ); + Ok(()) + } + /// Connects to the DuckDB database. pub fn connect(&self) -> VortexResult { Connection::connect(self) diff --git a/vortex-duckdb/src/duckdb/mod.rs b/vortex-duckdb/src/duckdb/mod.rs index c42fbdaf1e4..5617d723e8d 100644 --- a/vortex-duckdb/src/duckdb/mod.rs +++ b/vortex-duckdb/src/duckdb/mod.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +mod bind_input; mod client_context; mod config; mod connection; @@ -17,8 +18,9 @@ mod query_result; mod reusable_dict; mod scalar_function; mod selection_vector; +mod string_map; mod table_filter; -mod table_function; +mod table_init_input; mod value; mod vector; mod vector_buffer; @@ -26,6 +28,7 @@ mod vector_buffer; use std::ffi::c_void; use std::ptr; +pub use bind_input::*; pub use client_context::*; pub use config::*; pub use connection::*; @@ -41,8 +44,9 @@ pub use query_result::*; pub use reusable_dict::*; pub use scalar_function::*; pub use selection_vector::*; +pub use string_map::*; pub use table_filter::*; -pub use table_function::*; +pub use table_init_input::*; pub use value::*; pub use vector::*; pub use vector_buffer::*; diff --git a/vortex-duckdb/src/duckdb/string_map.rs b/vortex-duckdb/src/duckdb/string_map.rs new file mode 100644 index 00000000000..68eb3714d14 --- /dev/null +++ b/vortex-duckdb/src/duckdb/string_map.rs @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ffi::CString; + +use crate::cpp; + +// String map lifetime is managed by C++ code +crate::lifetime_wrapper!(DuckdbStringMap, cpp::duckdb_vx_string_map, |_| {}); +impl DuckdbStringMapRef { + pub fn push(&mut self, key: &str, value: &str) { + let key = CString::new(key).unwrap_or_else(|_| CString::default()); + let value = CString::new(value).unwrap_or_else(|_| CString::default()); + unsafe { + cpp::duckdb_vx_string_map_insert(self.as_ptr(), key.as_ptr(), value.as_ptr()); + } + } +} diff --git a/vortex-duckdb/src/duckdb/table_function/bind.rs b/vortex-duckdb/src/duckdb/table_function/bind.rs deleted file mode 100644 index f56ca65f974..00000000000 --- a/vortex-duckdb/src/duckdb/table_function/bind.rs +++ /dev/null @@ -1,76 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex::error::vortex_err; - -use crate::cpp; -use crate::duckdb::ClientContext; -use crate::duckdb::Data; -use crate::duckdb::LogicalTypeRef; -use crate::duckdb::TableFunction; -use crate::duckdb::Value; -use crate::duckdb::try_or_null; -use crate::lifetime_wrapper; - -/// The native bind callback for a table function. -pub(crate) unsafe extern "C-unwind" fn bind_callback( - ctx: cpp::duckdb_client_context, - bind_input: cpp::duckdb_vx_tfunc_bind_input, - bind_result: cpp::duckdb_vx_tfunc_bind_result, - error_out: *mut cpp::duckdb_vx_error, -) -> cpp::duckdb_vx_data { - let client_context = unsafe { ClientContext::borrow(ctx) }; - let bind_input = unsafe { BindInput::own(bind_input) }; - let mut bind_result = unsafe { BindResult::own(bind_result) }; - - try_or_null(error_out, || { - let bind_data = T::bind(client_context, &bind_input, &mut bind_result)?; - Ok(Data::from(Box::new(bind_data)).as_ptr()) - }) -} - -/// The native copy callback for bind data. -pub(crate) unsafe extern "C-unwind" fn bind_data_clone_callback( - bind_data: *const std::ffi::c_void, - error_out: *mut cpp::duckdb_vx_error, -) -> cpp::duckdb_vx_data { - try_or_null(error_out, || { - let bind_data = unsafe { - (bind_data as *const T::BindData) - .as_ref() - .ok_or(vortex_err!("bind_data is nullptr"))? - }; - let copied_data = bind_data.clone(); - Ok(Data::from(Box::new(copied_data)).as_ptr()) - }) -} - -lifetime_wrapper!(BindInput, cpp::duckdb_vx_tfunc_bind_input, |_| {}); - -impl BindInputRef { - /// Returns the parameter at the given index. - pub fn get_parameter(&self, index: usize) -> Option { - let value_ptr = - unsafe { cpp::duckdb_vx_tfunc_bind_input_get_parameter(self.as_ptr(), index as _) }; - if value_ptr.is_null() { - None - } else { - Some(unsafe { Value::own(value_ptr) }) - } - } -} - -lifetime_wrapper!(BindResult, cpp::duckdb_vx_tfunc_bind_result, |_| {}); - -impl BindResultRef { - pub fn add_result_column(&self, name: &str, logical_type: &LogicalTypeRef) { - unsafe { - cpp::duckdb_vx_tfunc_bind_result_add_column( - self.as_ptr(), - name.as_ptr().cast(), - name.len() as _, - logical_type.as_ptr(), - ) - } - } -} diff --git a/vortex-duckdb/src/duckdb/table_function/cardinality.rs b/vortex-duckdb/src/duckdb/table_function/cardinality.rs deleted file mode 100644 index 0e359532ecb..00000000000 --- a/vortex-duckdb/src/duckdb/table_function/cardinality.rs +++ /dev/null @@ -1,35 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::ffi::c_void; - -use vortex::error::VortexExpect; - -use crate::cpp; -use crate::duckdb::Cardinality; -use crate::duckdb::TableFunction; - -/// Native callback for the cardinality estimate of a table function. -pub(crate) unsafe extern "C-unwind" fn cardinality_callback( - bind_data: *mut c_void, - node_stats_out: *mut cpp::duckdb_vx_node_statistics, -) { - let bind_data = - unsafe { bind_data.cast::().as_ref() }.vortex_expect("bind_data null pointer"); - let node_stats = - unsafe { node_stats_out.as_mut() }.vortex_expect("node_stats_out null pointer"); - - match T::cardinality(bind_data) { - Cardinality::Unknown => {} - Cardinality::Estimate(c) => { - node_stats.has_estimated_cardinality = true; - node_stats.estimated_cardinality = c as _; - } - Cardinality::Maximum(c) => { - node_stats.has_max_cardinality = true; - node_stats.max_cardinality = c as _; - node_stats.has_estimated_cardinality = true; - node_stats.estimated_cardinality = c as _; - } - } -} diff --git a/vortex-duckdb/src/duckdb/table_function/init.rs b/vortex-duckdb/src/duckdb/table_function/init.rs deleted file mode 100644 index f62238b7cab..00000000000 --- a/vortex-duckdb/src/duckdb/table_function/init.rs +++ /dev/null @@ -1,116 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::ffi::c_void; -use std::fmt::Debug; -use std::fmt::Formatter; -use std::ptr; - -use vortex::error::VortexExpect; -use vortex::error::VortexResult; -use vortex::error::vortex_bail; - -use crate::cpp; -use crate::duckdb::ClientContext; -use crate::duckdb::ClientContextRef; -use crate::duckdb::Data; -use crate::duckdb::TableFilterSet; -use crate::duckdb::TableFilterSetRef; -use crate::duckdb::TableFunction; - -/// Native callback for the global initialization of a table function. -pub(crate) unsafe extern "C-unwind" fn init_global_callback( - init_input: *const cpp::duckdb_vx_tfunc_init_input, - error_out: *mut cpp::duckdb_vx_error, -) -> cpp::duckdb_vx_data { - let init_input = TableInitInput::new( - unsafe { init_input.as_ref() }.vortex_expect("init_input null pointer"), - ); - - match T::init_global(&init_input) { - Ok(init_data) => Data::from(Box::new(init_data)).as_ptr(), - Err(e) => { - // Set the error in the error output. - let msg = e.to_string(); - unsafe { error_out.write(cpp::duckdb_vx_error_create(msg.as_ptr().cast(), msg.len())) }; - ptr::null_mut::().cast() - } - } -} - -/// Native callback for the local initialization of a table function. -pub(crate) unsafe extern "C-unwind" fn init_local_callback( - global_init_data: *mut c_void, -) -> cpp::duckdb_vx_data { - let global_init_data = unsafe { global_init_data.cast::().as_ref() } - .vortex_expect("global_init_data null pointer"); - - let init_data = T::init_local(global_init_data); - Data::from(Box::new(init_data)).as_ptr() -} - -/// A typed wrapper for the input to a table function's initialization. -pub struct TableInitInput<'a, T: TableFunction> { - input: &'a cpp::duckdb_vx_tfunc_init_input, - phantom: std::marker::PhantomData, -} - -impl Debug for TableInitInput<'_, T> { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.debug_struct("TableInitInput") - .field("table_function", &std::any::type_name::()) - .field("column_ids", &self.column_ids()) - .field("projection_ids", &self.projection_ids()) - .field("table_filter_set", &self.table_filter_set()) - .finish() - } -} - -impl<'a, T: TableFunction> TableInitInput<'a, T> { - fn new(input: &'a cpp::duckdb_vx_tfunc_init_input) -> Self { - Self { - input, - phantom: std::marker::PhantomData, - } - } - - pub fn bind_data(&self) -> &T::BindData { - unsafe { &*self.input.bind_data.cast::() } - } - - pub fn column_ids(&self) -> &[u64] { - unsafe { std::slice::from_raw_parts(self.input.column_ids, self.input.column_ids_count) } - } - - pub fn projection_ids(&self) -> Option<&[u64]> { - // Passed pointer is std::vector's .data(). However, C++ doesn't - // guarantee an empty vector's pointer is nullptr so we need to check - // both conditions - if self.input.projection_ids.is_null() || self.input.projection_ids_count == 0 { - return None; - } - Some(unsafe { - std::slice::from_raw_parts(self.input.projection_ids, self.input.projection_ids_count) - }) - } - - /// Returns the table filter set for the table function. - pub fn table_filter_set(&self) -> Option<&TableFilterSetRef> { - let ptr = self.input.filters; - if ptr.is_null() { - None - } else { - Some(unsafe { TableFilterSet::borrow(ptr) }) - } - } - - /// Returns the object cache from the client context for the table function. - pub fn client_context(&self) -> VortexResult<&ClientContextRef> { - unsafe { - if self.input.client_context.is_null() { - vortex_bail!("Client context is null"); - } - Ok(ClientContext::borrow(self.input.client_context)) - } - } -} diff --git a/vortex-duckdb/src/duckdb/table_function/mod.rs b/vortex-duckdb/src/duckdb/table_function/mod.rs deleted file mode 100644 index 0d70b170bf5..00000000000 --- a/vortex-duckdb/src/duckdb/table_function/mod.rs +++ /dev/null @@ -1,284 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::ffi::CStr; -use std::ffi::CString; -use std::ffi::c_void; -use std::fmt::Debug; -use std::ptr; - -use vortex::error::VortexExpect; -use vortex::error::VortexResult; -mod bind; -mod cardinality; -mod init; - -pub use bind::*; -pub use init::*; - -use crate::convert::can_push_expression; -use crate::cpp; -use crate::duckdb::DataChunk; -use crate::duckdb::DatabaseRef; -use crate::duckdb::Expression; -use crate::duckdb::LogicalType; -use crate::duckdb::Value; -use crate::duckdb::client_context::ClientContextRef; -use crate::duckdb::data_chunk::DataChunkRef; -use crate::duckdb::expr::ExpressionRef; -use crate::duckdb::table_function::cardinality::cardinality_callback; -use crate::duckdb::try_or; -use crate::duckdb_try; - -pub struct PartitionData { - pub partition_index: u64, - pub file_index_column_pos: Option, - pub file_index: usize, -} - -#[derive(Debug, Default)] -pub struct ColumnStatistics { - pub min: Option, - pub max: Option, - pub max_string_length: u64, - pub has_null: bool, -} - -// String map lifetime is managed by C++ code -crate::lifetime_wrapper!(DuckdbStringMap, cpp::duckdb_vx_string_map, |_| {}); -impl DuckdbStringMapRef { - pub fn push(&mut self, key: &str, value: &str) { - let key = CString::new(key).unwrap_or_else(|_| CString::default()); - let value = CString::new(value).unwrap_or_else(|_| CString::default()); - unsafe { - cpp::duckdb_vx_string_map_insert(self.as_ptr(), key.as_ptr(), value.as_ptr()); - } - } -} - -/// A trait that defines the supported operations for a table function in DuckDB. -/// -/// This trait does not yet cover the full C++ API, see table_function.hpp. -pub trait TableFunction: Sized + Debug { - type BindData: Send + Clone; - type GlobalState: Send + Sync; - type LocalState; - - /// Returns the parameters of the table function. - fn parameters() -> Vec { - // By default, we don't have any parameters. - vec![] - } - - /// This function is used for determining the schema of a table producing function and - /// returning bind data. - fn bind( - client_context: &ClientContextRef, - input: &BindInputRef, - result: &mut BindResultRef, - ) -> VortexResult; - - /// Report column statistics for a file or collections of files e.g. - /// registered as a VIEW. - fn statistics(bind_data: &Self::BindData, column_index: usize) -> Option; - - /// The function is called during query execution and is responsible for producing the output - fn scan( - init_local: &mut Self::LocalState, - init_global: &Self::GlobalState, - chunk: &mut DataChunkRef, - ) -> VortexResult<()>; - - /// Initialize the global operator state of the function. - /// - /// The global operator state is used to keep track of the progress in the table function and - /// is shared between all threads working on the table function. - fn init_global(input: &TableInitInput) -> VortexResult; - - /// Initialize the local operator state of the function. - /// - /// The local operator state is used to keep track of the progress in the table function and - /// is thread-local. - fn init_local(global: &Self::GlobalState) -> Self::LocalState; - - /// Return table scanning progress from 0. to 100. - fn table_scan_progress(global_state: &Self::GlobalState) -> f64; - - /// Pushes down a filter expression to the table function. - /// - /// Returns `true` if the filter was successfully pushed down (and stored on the bind data), - /// or `false` if the filter could not be pushed down. In which case, the filter will be - /// applied later in the query plan. - fn pushdown_complex_filter( - bind_data: &mut Self::BindData, - expr: &ExpressionRef, - ) -> VortexResult; - - /// Returns the cardinality estimate of the table function. - fn cardinality(bind_data: &Self::BindData) -> Cardinality; - - /// Returns the idx of the current partition being processed by a local threa. - /// This *must* be globally unique. - fn partition_data( - global_init_data: &Self::GlobalState, - local_init_data: &mut Self::LocalState, - ) -> PartitionData; - - /// Returns a vector of key-value pairs for EXPLAIN output - fn to_string(bind_data: &Self::BindData, map: &mut DuckdbStringMapRef); -} - -#[derive(Debug)] -pub enum Cardinality { - /// Completely unknown cardinality. - Unknown, - /// An estimate of the number of rows that will be returned by the table function. - Estimate(u64), - /// Will not return more than this number of rows. - Maximum(u64), -} - -impl DatabaseRef { - pub fn register_table_function(&self, name: &CStr) -> VortexResult<()> { - // Set up the parameters. - let parameters = T::parameters(); - let parameter_ptrs = parameters - .iter() - .map(|logical_type| logical_type.as_ptr()) - .collect::>(); - - let vtab = cpp::duckdb_vx_tfunc_vtab_t { - name: name.as_ptr(), - parameters: parameter_ptrs.as_ptr(), - parameter_count: parameters.len() as _, - bind: Some(bind_callback::), - bind_data_clone: Some(bind_data_clone_callback::), - init_global: Some(init_global_callback::), - init_local: Some(init_local_callback::), - function: Some(function::), - statistics: Some(statistics::), - cardinality: Some(cardinality_callback::), - pushdown_complex_filter: Some(pushdown_complex_filter_callback::), - to_string: Some(to_string_callback::), - table_scan_progress: Some(table_scan_progress_callback::), - get_partition_data: Some(get_partition_data_callback::), - }; - - duckdb_try!( - unsafe { cpp::duckdb_vx_tfunc_register(self.as_ptr(), &raw const vtab) }, - "Failed to register table function '{}'", - name.to_string_lossy() - ); - - Ok(()) - } -} - -/* - * Called before pushdown_complex_filter or a table filter expression call. - * As we support complex filter pushdown, Duckdb pushes expressions to Vortex. - * However, it doesn't know what type of expressions we can handle. Here we list - * all expressions that are quaranteed to be converted to Vortex expressions. - * - * If we return true here, and expression is in the list for - * pushdown_complex_filter, we must handle it, or query engine will break. - * - * Example: we don't support substr() expression so we tell Duckdb we can't - * push it. - * Example: optional filters may fail to parse on our side (we return - * Ok(None)), so we don't allow pushing these. - */ -#[unsafe(no_mangle)] -pub unsafe extern "C-unwind" fn duckdb_vx_pushdown_expression(expr: cpp::duckdb_vx_expr) -> bool { - let expr = unsafe { Expression::borrow(expr) }; - can_push_expression(expr) -} - -unsafe extern "C-unwind" fn to_string_callback( - bind_data: *mut c_void, - map: cpp::duckdb_vx_string_map, -) { - let bind_data = unsafe { &*(bind_data as *const T::BindData) }; - let map = unsafe { DuckdbStringMap::borrow_mut(map) }; - T::to_string(bind_data, map); -} - -unsafe extern "C-unwind" fn statistics( - bind_data: *const c_void, - column_index: usize, - stats_out: *mut cpp::duckdb_column_statistics, -) -> bool { - let stats_out = unsafe { &mut *stats_out }; - let bind_data = - unsafe { bind_data.cast::().as_ref() }.vortex_expect("bind_data null pointer"); - let Some(stats) = T::statistics(bind_data, column_index) else { - return false; - }; - stats_out.min = stats.min.map_or(ptr::null_mut(), |v| v.into_ptr()); - stats_out.max = stats.max.map_or(ptr::null_mut(), |v| v.into_ptr()); - stats_out.max_string_length = stats.max_string_length; - stats_out.has_null = stats.has_null; - true -} - -unsafe extern "C-unwind" fn table_scan_progress_callback( - global_state: *mut c_void, -) -> f64 { - let global_state = unsafe { global_state.cast::().as_ref() } - .vortex_expect("global_init_data null pointer"); - T::table_scan_progress(global_state) -} - -unsafe extern "C-unwind" fn get_partition_data_callback( - global_init_data: *mut c_void, - local_init_data: *mut c_void, - partition_data_out: *mut cpp::duckdb_vx_partition_data, -) { - let global_init_data = unsafe { global_init_data.cast::().as_ref() } - .vortex_expect("global_init_data null pointer"); - let local_init_data = unsafe { local_init_data.cast::().as_mut() } - .vortex_expect("local_init_data null pointer"); - let data = T::partition_data(global_init_data, local_init_data); - let out = unsafe { &mut *partition_data_out }; - - out.partition_index = data.partition_index; - out.file_index_column_pos = data.file_index_column_pos.unwrap_or(usize::MAX); - out.file_index = data.file_index; -} - -unsafe extern "C-unwind" fn pushdown_complex_filter_callback( - bind_data: *mut c_void, - expr: cpp::duckdb_vx_expr, - error_out: *mut cpp::duckdb_vx_error, -) -> bool { - let bind_data = - unsafe { bind_data.cast::().as_mut() }.vortex_expect("bind_data null pointer"); - let expr = unsafe { Expression::borrow(expr) }; - try_or(error_out, || T::pushdown_complex_filter(bind_data, expr)) -} - -unsafe extern "C-unwind" fn function( - global_init_data: *mut c_void, - local_init_data: *mut c_void, - output: cpp::duckdb_data_chunk, - error_out: *mut cpp::duckdb_vx_error, -) { - let global_init_data = unsafe { global_init_data.cast::().as_ref() } - .vortex_expect("global_init_data null pointer"); - let local_init_data = unsafe { local_init_data.cast::().as_mut() } - .vortex_expect("local_init_data null pointer"); - let data_chunk = unsafe { DataChunk::borrow_mut(output) }; - - match T::scan(local_init_data, global_init_data, data_chunk) { - Ok(()) => { - // The data chunk is already filled by the function. - // No need to do anything here. - } - Err(e) => unsafe { - error_out.write(cpp::duckdb_vx_error_create( - e.to_string().as_ptr().cast(), - e.to_string().len(), - )); - }, - } -} diff --git a/vortex-duckdb/src/duckdb/table_init_input.rs b/vortex-duckdb/src/duckdb/table_init_input.rs new file mode 100644 index 00000000000..f6ac05ae0b5 --- /dev/null +++ b/vortex-duckdb/src/duckdb/table_init_input.rs @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt::Debug; +use std::fmt::Formatter; +use std::fmt::Result; + +use crate::cpp; +use crate::duckdb::TableFilterSet; +use crate::duckdb::TableFilterSetRef; + +pub struct TableInitInput<'a> { + pub input: &'a cpp::duckdb_vx_tfunc_init_input, +} + +impl Debug for TableInitInput<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + f.debug_struct("TableInitInput") + .field("column_ids", &self.column_ids()) + .field("projection_ids", &self.projection_ids()) + .field("table_filter_set", &self.table_filter_set()) + .finish() + } +} + +impl<'a> TableInitInput<'a> { + pub fn new(input: &'a cpp::duckdb_vx_tfunc_init_input) -> Self { + Self { input } + } + + pub fn column_ids(&self) -> &[u64] { + unsafe { std::slice::from_raw_parts(self.input.column_ids, self.input.column_ids_count) } + } + + pub fn projection_ids(&self) -> Option<&[u64]> { + // Passed pointer is std::vector's .data(). However, C++ doesn't + // guarantee an empty vector's pointer is nullptr so we need to check + // both conditions + if self.input.projection_ids.is_null() || self.input.projection_ids_count == 0 { + return None; + } + Some(unsafe { + std::slice::from_raw_parts(self.input.projection_ids, self.input.projection_ids_count) + }) + } + + /// Returns the table filter set for the table function. + pub fn table_filter_set(&self) -> Option<&TableFilterSetRef> { + let ptr = self.input.filters; + if ptr.is_null() { + None + } else { + Some(unsafe { TableFilterSet::borrow(ptr) }) + } + } +} diff --git a/vortex-duckdb/src/ffi.rs b/vortex-duckdb/src/ffi.rs new file mode 100644 index 00000000000..e7a35498d2f --- /dev/null +++ b/vortex-duckdb/src/ffi.rs @@ -0,0 +1,215 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ffi::c_void; +use std::ptr; + +use vortex::error::VortexExpect; + +use crate::convert::can_push_expression; +use crate::cpp; +use crate::duckdb::BindInput; +use crate::duckdb::BindResult; +use crate::duckdb::ClientContext; +use crate::duckdb::Data; +use crate::duckdb::DataChunk; +use crate::duckdb::DuckdbStringMap; +use crate::duckdb::Expression; +use crate::duckdb::TableInitInput; +use crate::duckdb::try_or; +use crate::duckdb::try_or_null; +use crate::table_function::Cardinality; +use crate::table_function::TableFunctionBind; +use crate::table_function::TableFunctionGlobal; +use crate::table_function::TableFunctionLocal; +use crate::table_function::bind; +use crate::table_function::cardinality; +use crate::table_function::get_partition_data; +use crate::table_function::init_global; +use crate::table_function::init_local; +use crate::table_function::pushdown_complex_filter; +use crate::table_function::scan; +use crate::table_function::statistics; +use crate::table_function::table_scan_progress; +use crate::table_function::to_string; + +#[unsafe(no_mangle)] +unsafe extern "C-unwind" fn duckdb_table_function_to_string( + bind_data: *mut c_void, + map: cpp::duckdb_vx_string_map, +) { + let bind_data = unsafe { bind_data.cast::().as_ref() } + .vortex_expect("bind_data null pointer"); + let map = unsafe { DuckdbStringMap::borrow_mut(map) }; + to_string(bind_data, map); +} + +#[unsafe(no_mangle)] +unsafe extern "C-unwind" fn duckdb_table_function_statistics( + bind_data: *const c_void, + column_index: usize, + stats_out: *mut cpp::duckdb_column_statistics, +) -> bool { + let stats_out = unsafe { &mut *stats_out }; + let bind_data = unsafe { bind_data.cast::().as_ref() } + .vortex_expect("bind_data null pointer"); + let Some(stats) = statistics(bind_data, column_index) else { + return false; + }; + stats_out.min = stats.min.map_or(ptr::null_mut(), |v| v.into_ptr()); + stats_out.max = stats.max.map_or(ptr::null_mut(), |v| v.into_ptr()); + stats_out.max_string_length = stats.max_string_length; + stats_out.has_null = stats.has_null; + true +} + +#[unsafe(no_mangle)] +unsafe extern "C-unwind" fn duckdb_table_function_scan_progress(global_state: *mut c_void) -> f64 { + let global_state = unsafe { global_state.cast::().as_ref() } + .vortex_expect("global_init_data null pointer"); + table_scan_progress(global_state) +} + +#[unsafe(no_mangle)] +unsafe extern "C-unwind" fn duckdb_table_function_get_partition_data( + global_init_data: *mut c_void, + local_init_data: *mut c_void, + partition_data_out: *mut cpp::duckdb_vx_partition_data, +) { + let global_init_data = unsafe { global_init_data.cast::().as_ref() } + .vortex_expect("global_init_data null pointer"); + let local_init_data = unsafe { local_init_data.cast::().as_mut() } + .vortex_expect("local_init_data null pointer"); + let data = get_partition_data(global_init_data, local_init_data); + let out = unsafe { &mut *partition_data_out }; + + out.partition_index = data.partition_index; + out.file_index_column_pos = data.file_index_column_pos.unwrap_or(usize::MAX); + out.file_index = data.file_index; +} + +#[unsafe(no_mangle)] +unsafe extern "C-unwind" fn duckdb_table_function_pushdown_complex_filter( + bind_data: *mut c_void, + expr: cpp::duckdb_vx_expr, + error_out: *mut cpp::duckdb_vx_error, +) -> bool { + let bind_data = unsafe { bind_data.cast::().as_mut() } + .vortex_expect("bind_data null pointer"); + let expr = unsafe { Expression::borrow(expr) }; + try_or(error_out, || pushdown_complex_filter(bind_data, expr)) +} + +#[unsafe(no_mangle)] +unsafe extern "C-unwind" fn duckdb_table_function_scan( + global_init_data: *mut c_void, + local_init_data: *mut c_void, + output: cpp::duckdb_data_chunk, + error_out: *mut cpp::duckdb_vx_error, +) { + let global_init_data = unsafe { global_init_data.cast::().as_ref() } + .vortex_expect("global_init_data null pointer"); + let local_init_data = unsafe { local_init_data.cast::().as_mut() } + .vortex_expect("local_init_data null pointer"); + let data_chunk = unsafe { DataChunk::borrow_mut(output) }; + + match scan(local_init_data, global_init_data, data_chunk) { + Ok(()) => { + // The data chunk is already filled by the function. + // No need to do anything here. + } + Err(e) => unsafe { + error_out.write(cpp::duckdb_vx_error_create( + e.to_string().as_ptr().cast(), + e.to_string().len(), + )); + }, + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C-unwind" fn duckdb_table_function_pushdown_expression( + expr: cpp::duckdb_vx_expr, +) -> bool { + can_push_expression(unsafe { Expression::borrow(expr) }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C-unwind" fn duckdb_table_function_cardinality( + bind_data: *mut c_void, + node_stats_out: *mut cpp::duckdb_vx_node_statistics, +) { + let bind_data = unsafe { bind_data.cast::().as_ref() } + .vortex_expect("bind_data null pointer"); + let node_stats = + unsafe { node_stats_out.as_mut() }.vortex_expect("node_stats_out null pointer"); + + match cardinality(bind_data) { + Cardinality::Unknown => {} + Cardinality::Estimate(c) => { + node_stats.has_estimated_cardinality = true; + node_stats.estimated_cardinality = c as _; + } + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C-unwind" fn duckdb_table_function_init_global( + init_input: *const cpp::duckdb_vx_tfunc_init_input, + error_out: *mut cpp::duckdb_vx_error, +) -> cpp::duckdb_vx_data { + let init_input = TableInitInput::new( + unsafe { init_input.as_ref() }.vortex_expect("init_input null pointer"), + ); + + match init_global(&init_input) { + Ok(init_data) => Data::from(Box::new(init_data)).as_ptr(), + Err(e) => { + // Set the error in the error output. + let msg = e.to_string(); + unsafe { error_out.write(cpp::duckdb_vx_error_create(msg.as_ptr().cast(), msg.len())) }; + ptr::null_mut::().cast() + } + } +} + +#[unsafe(no_mangle)] +pub unsafe extern "C-unwind" fn duckdb_table_function_init_local( + global_init_data: *mut c_void, +) -> cpp::duckdb_vx_data { + let global_init_data = unsafe { global_init_data.cast::().as_ref() } + .vortex_expect("global_init_data null pointer"); + + let init_data = init_local(global_init_data); + Data::from(Box::new(init_data)).as_ptr() +} + +#[unsafe(no_mangle)] +pub unsafe extern "C-unwind" fn duckdb_table_function_bind( + ctx: cpp::duckdb_client_context, + bind_input: cpp::duckdb_vx_tfunc_bind_input, + bind_result: cpp::duckdb_vx_tfunc_bind_result, + error_out: *mut cpp::duckdb_vx_error, +) -> cpp::duckdb_vx_data { + let client_context = unsafe { ClientContext::borrow(ctx) }; + let bind_input = unsafe { BindInput::own(bind_input) }; + let mut bind_result = unsafe { BindResult::own(bind_result) }; + + try_or_null(error_out, || { + let bind_data = bind(client_context, &bind_input, &mut bind_result)?; + Ok(Data::from(Box::new(bind_data)).as_ptr()) + }) +} + +#[unsafe(no_mangle)] +pub unsafe extern "C-unwind" fn duckdb_table_function_bind_data_clone( + bind_data: *const c_void, + error_out: *mut cpp::duckdb_vx_error, +) -> cpp::duckdb_vx_data { + let bind_data = unsafe { bind_data.cast::().as_ref() } + .vortex_expect("bind_data null pointer"); + try_or_null(error_out, || { + let copied_data = bind_data.clone(); + Ok(Data::from(Box::new(copied_data)).as_ptr()) + }) +} diff --git a/vortex-duckdb/src/lib.rs b/vortex-duckdb/src/lib.rs index 410d241a766..6fdf0e523cd 100644 --- a/vortex-duckdb/src/lib.rs +++ b/vortex-duckdb/src/lib.rs @@ -21,15 +21,16 @@ use crate::duckdb::Database; use crate::duckdb::DatabaseRef; use crate::duckdb::LogicalType; use crate::duckdb::Value; -use crate::multi_file::VortexMultiFileScan; -use crate::multi_file::VortexMultiFileScanList; +mod column_statistics; mod convert; -mod datasource; pub mod duckdb; mod exporter; +mod ffi; mod filesystem; mod multi_file; +mod projection; +mod table_function; #[rustfmt::skip] #[path = "./cpp.rs"] @@ -70,11 +71,7 @@ pub fn initialize(db: &DatabaseRef) -> VortexResult<()> { LogicalType::varchar(), Value::from("vortex"), )?; - db.register_table_function::(c"vortex_scan")?; - db.register_table_function::(c"read_vortex")?; - // Register list overloads for multi-glob scanning (e.g., read_vortex(['a.vortex', 'b.vortex'])) - db.register_table_function::(c"vortex_scan")?; - db.register_table_function::(c"read_vortex")?; + db.register_table_functions()?; db.register_copy_function::(c"vortex", c"vortex") } diff --git a/vortex-duckdb/src/multi_file.rs b/vortex-duckdb/src/multi_file.rs index 3f99a854a22..165bcad6677 100644 --- a/vortex-duckdb/src/multi_file.rs +++ b/vortex-duckdb/src/multi_file.rs @@ -18,11 +18,9 @@ use vortex_utils::aliases::hash_map::HashMap; use crate::RUNTIME; use crate::SESSION; -use crate::datasource::DataSourceTableFunction; use crate::duckdb::BindInputRef; use crate::duckdb::ClientContextRef; use crate::duckdb::ExtractedValue; -use crate::duckdb::LogicalType; use crate::filesystem::resolve_filesystem; /// Parse a glob string into a [`Url`]. @@ -57,45 +55,8 @@ fn normalize_path(path: std::path::PathBuf) -> std::path::PathBuf { normalized } -/// Vortex multi-file scan table function (`vortex_scan` / `read_vortex`). -/// -/// Takes a file glob parameter and resolves it into a [`MultiFileDataSource`]. -/// All other table function logic is provided by the blanket [`DataSourceTableFunction`] -/// implementation. -#[derive(Debug)] -pub struct VortexMultiFileScan; - -/// Variant of [`VortexMultiFileScan`] that accepts a list of globs. -/// -/// This is registered as a separate overload to handle `read_vortex(['a.vortex', 'b.vortex'])`. -#[derive(Debug)] -pub struct VortexMultiFileScanList; - -impl DataSourceTableFunction for VortexMultiFileScan { - fn parameters() -> Vec { - vec![LogicalType::varchar()] - } - - fn bind(ctx: &ClientContextRef, input: &BindInputRef) -> VortexResult { - bind_multi_file_scan(ctx, input) - } -} - -impl DataSourceTableFunction for VortexMultiFileScanList { - fn parameters() -> Vec { - vec![ - LogicalType::list_type(LogicalType::varchar()) - .unwrap_or_else(|_| unreachable!("creating list type should not fail")), - ] - } - - fn bind(ctx: &ClientContextRef, input: &BindInputRef) -> VortexResult { - bind_multi_file_scan(ctx, input) - } -} - /// Shared bind logic for both single-glob and multi-glob variants. -fn bind_multi_file_scan( +pub fn bind_multi_file_scan( ctx: &ClientContextRef, input: &BindInputRef, ) -> VortexResult { diff --git a/vortex-duckdb/src/projection.rs b/vortex-duckdb/src/projection.rs new file mode 100644 index 00000000000..a27056e5f01 --- /dev/null +++ b/vortex-duckdb/src/projection.rs @@ -0,0 +1,289 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors +use std::ops::Range; +use std::sync::Arc; + +use num_traits::AsPrimitive as _; +use vortex::dtype::DType; +use vortex::dtype::FieldNames; +use vortex::error::VortexExpect; +use vortex::error::VortexResult; +use vortex::error::vortex_err; +use vortex::expr::Expression; +use vortex::expr::and_collect; +use vortex::expr::col; +use vortex::expr::merge; +use vortex::expr::pack; +use vortex::expr::root; +use vortex::expr::select; +use vortex::layout::layouts::row_idx::row_idx; +use vortex::scan::selection::Selection; +use vortex_utils::aliases::hash_set::HashSet; + +use crate::convert::try_from_table_filter; +use crate::convert::try_from_virtual_column_filter; +use crate::duckdb::LogicalType; +use crate::duckdb::TableFilterClass; +use crate::duckdb::TableFilterSetRef; + +// See MultiFileReader for constants + +/// "file_index" virtual column +static FILE_INDEX_COLUMN_IDX: u64 = 9223372036854775810; +/// "file_row_number" virtual column +static FILE_ROW_NUMBER_COLUMN_IDX: u64 = 9223372036854775809; + +/// See duckdb/src/common/constants.cpp +fn is_virtual_column(id: u64) -> bool { + id >= 9223372036854775808u64 +} + +#[derive(Debug, Clone)] +pub struct DuckdbField { + pub name: String, + pub logical_type: LogicalType, + pub dtype: DType, +} + +pub struct Projection { + pub projection: Expression, + pub file_index_column_pos: Option, + pub file_row_number_column_pos: Option, +} + +impl Projection { + pub fn new( + projection_ids: Option<&[u64]>, + column_ids: &[u64], + column_fields: &[DuckdbField], + ) -> Self { + // If projection ids are empty, use column_ids. + // See duckdb/src/planner/operator/logical_get.cpp#L168 + let (ids, has_projection_ids) = match projection_ids { + Some(ids) => (ids, true), + None => (column_ids, false), + }; + + let mut file_index_column_pos = None; + let mut file_row_number_column_pos = None; + let mut is_star = true; + let mut real_column_count = 0; + + // DuckDB uses u64 as column indices but Rust uses usize + for (column_pos, &column_id) in ids.iter().enumerate() { + let column_id = if has_projection_ids { + let column_id: usize = column_id.as_(); + column_ids[column_id] + } else { + column_id + }; + + if column_id == FILE_INDEX_COLUMN_IDX { + file_index_column_pos = Some(column_pos); + continue; + } + if column_id == FILE_ROW_NUMBER_COLUMN_IDX { + file_row_number_column_pos = Some(column_pos); + continue; + } + + // In SELECT * DuckDB requests all columns from 0 to column_fields in + // increasing order. After removing virtual columns, compare column_id + // with (0..column_fields.len()) range. + is_star &= column_id == real_column_count; + real_column_count += 1; + } + // Duckdb can request less columns than there are in table i.e. [0, 1] with + // 5 columns total. + is_star &= real_column_count == column_fields.len() as u64; + + let select = if is_star { + root() + } else { + let names = ids + .iter() + .map(|&column_id| { + if has_projection_ids { + let column_id: usize = column_id.as_(); + column_ids[column_id] + } else { + column_id + } + }) + .filter(|&col_id| !is_virtual_column(col_id)) + .map(|column_id| { + let column_id: usize = column_id.as_(); + Arc::from(column_fields[column_id].name.as_str()) + }) + .collect::(); + + select(names, root()) + }; + + // file_index column will be filled later when exporting the chunk. + let projection = if file_row_number_column_pos.is_some() { + // row_idx will be moved to correct position in scan(), prepend here + let row_idx_struct = pack([("file_row_number", row_idx())], false.into()); + merge([row_idx_struct, select]) + } else { + select + }; + + Self { + projection, + file_index_column_pos, + file_row_number_column_pos, + } + } +} + +pub struct Filter { + pub filter: Option, + pub row_selection: Selection, + pub row_range: Option>, + pub file_selection: Selection, + pub file_range: Option>, + pub has_non_optional_filter: bool, +} + +impl Filter { + /// Creates a table filter expression, row selection, and row range from the table filter set, + /// column metadata, additional filter expressions, and the top-level DType. + pub fn new( + table_filter_set: Option<&TableFilterSetRef>, + column_ids: &[u64], + column_fields: &[DuckdbField], + additional_filters: &[Expression], + dtype: &DType, + ) -> VortexResult { + let mut has_non_optional_filter = false; + + let mut table_filter_exprs: HashSet = if let Some(filter) = table_filter_set { + filter + .into_iter() + .filter(|(idx, _)| { + let idx_u: usize = idx.as_(); + !is_virtual_column(column_ids[idx_u]) + }) + .map(|(idx, ex)| { + has_non_optional_filter |= + !matches!(ex.as_class(), TableFilterClass::Optional(_)); + + let idx_u: usize = idx.as_(); + let col_idx: usize = column_ids[idx_u].as_(); + let name = &column_fields.get(col_idx).vortex_expect("exists").name; + try_from_table_filter(ex, &col(name.as_str()), dtype) + }) + .collect::>>>()? + .unwrap_or_else(HashSet::new) + } else { + HashSet::new() + }; + + table_filter_exprs.extend(additional_filters.iter().cloned()); + + let mut file_selection = Selection::All; + let mut row_selection = Selection::All; + let mut row_range = None; + let mut file_range = None; + if let Some(filter) = table_filter_set { + for (idx, expression) in filter.into_iter() { + let idx: usize = idx.as_(); + if column_ids[idx] == FILE_ROW_NUMBER_COLUMN_IDX { + (row_selection, row_range) = try_from_virtual_column_filter(expression)?; + } + if column_ids[idx] == FILE_INDEX_COLUMN_IDX { + (file_selection, file_range) = try_from_virtual_column_filter(expression)?; + } + } + }; + + let out = Self { + filter: and_collect(table_filter_exprs), + row_selection, + row_range, + file_selection, + file_range, + has_non_optional_filter, + }; + Ok(out) + } +} + +pub fn extract_schema_from_dtype(dtype: &DType) -> VortexResult> { + let struct_dtype = dtype + .as_struct_fields_opt() + .ok_or_else(|| vortex_err!("Vortex file must contain a struct array at the top level"))?; + + let len = struct_dtype.names().len(); + let mut fields = Vec::with_capacity(len); + + for (field_name, field_dtype) in struct_dtype.names().iter().zip(struct_dtype.fields()) { + let logical_type = LogicalType::try_from(&field_dtype)?; + fields.push(DuckdbField { + name: field_name.to_string(), + logical_type, + dtype: field_dtype, + }); + } + Ok(fields) +} + +#[cfg(test)] +mod tests { + use vortex::dtype::DType; + use vortex::expr::merge; + use vortex::expr::pack; + use vortex::expr::root; + use vortex::layout::layouts::row_idx::row_idx; + + use super::*; + + #[test] + fn test_select_star() { + let ids = [0, 1, 2]; + let fields = [ + DuckdbField { + name: "".to_owned(), + logical_type: LogicalType::null(), + dtype: DType::Null, + }, + DuckdbField { + name: "".to_owned(), + logical_type: LogicalType::null(), + dtype: DType::Null, + }, + DuckdbField { + name: "".to_owned(), + logical_type: LogicalType::null(), + dtype: DType::Null, + }, + ]; + + assert_eq!(Projection::new(None, &ids, &fields).projection, root()); + + let ids = [FILE_ROW_NUMBER_COLUMN_IDX, 0, 1, FILE_INDEX_COLUMN_IDX, 2]; + let exprs = Projection::new(None, &ids, &fields); + let row_idx_struct = pack([("file_row_number", row_idx())], false.into()); + let root_with_virtual_cols = merge([row_idx_struct, root()]); + + assert_eq!(exprs.projection, root_with_virtual_cols); + assert_eq!(exprs.file_index_column_pos, Some(3)); + assert_eq!(exprs.file_row_number_column_pos, Some(0)); + + // projections can't be set in SELECT *. + assert_ne!( + Projection::new(Some(&[0, 1]), &ids, &fields).projection, + root() + ); + + let ids = [0, 1]; + assert_ne!(Projection::new(None, &ids, &fields).projection, root()); + + let ids = [0, 2, 2]; + assert_ne!(Projection::new(None, &ids, &fields).projection, root()); + + let ids = [2, 1, 0]; + assert_ne!(Projection::new(None, &ids, &fields).projection, root()); + } +} diff --git a/vortex-duckdb/src/table_function.rs b/vortex-duckdb/src/table_function.rs new file mode 100644 index 00000000000..11c5851af27 --- /dev/null +++ b/vortex-duckdb/src/table_function.rs @@ -0,0 +1,529 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::cmp::max; +use std::fmt::Formatter; +use std::fmt::{self}; +use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::AtomicU64; +use std::sync::atomic::Ordering; + +use custom_labels::CURRENT_LABELSET; +use futures::StreamExt; +use itertools::Itertools; +use num_traits::AsPrimitive; +use static_assertions::assert_impl_all; +use tracing::debug; +use vortex::array::ArrayRef; +use vortex::array::Canonical; +use vortex::array::VortexSessionExecute as _; +use vortex::array::arrays::ScalarFn; +use vortex::array::arrays::Struct; +use vortex::array::arrays::StructArray; +use vortex::array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex::array::optimizer::ArrayOptimizer; +use vortex::error::VortexExpect; +use vortex::error::VortexResult; +use vortex::expr::Expression; +use vortex::expr::stats::Precision; +use vortex::file::v2::FileStatsLayoutReader; +use vortex::io::kanal_ext::KanalExt as _; +use vortex::io::runtime::BlockingRuntime as _; +use vortex::io::runtime::current::ThreadSafeIterator; +use vortex::layout::scan::multi::MultiLayoutChild; +use vortex::layout::scan::multi::MultiLayoutDataSource; +use vortex::metrics::tracing::get_global_labels; +use vortex::scalar_fn::fns::binary::Binary; +use vortex::scalar_fn::fns::operators::Operator; +use vortex::scalar_fn::fns::pack::Pack; +use vortex::scan::DataSource; +use vortex::scan::ScanRequest; +use vortex_utils::parallelism::get_available_parallelism; + +use crate::RUNTIME; +use crate::SESSION; +use crate::column_statistics::ColumnStatistics; +use crate::column_statistics::ColumnStatisticsAggregate; +use crate::convert::try_from_bound_expression; +use crate::duckdb::BindInputRef; +use crate::duckdb::BindResultRef; +use crate::duckdb::ClientContextRef; +use crate::duckdb::DataChunkRef; +use crate::duckdb::DuckdbStringMapRef; +use crate::duckdb::ExpressionRef; +use crate::duckdb::TableInitInput; +use crate::duckdb::Value; +use crate::exporter::ArrayExporter; +use crate::exporter::ConversionCache; +use crate::multi_file::bind_multi_file_scan; +use crate::projection::DuckdbField; +use crate::projection::Filter; +use crate::projection::Projection; +use crate::projection::extract_schema_from_dtype; + +pub struct TableFunctionBind { + data_source: Arc, + filter_exprs: Vec, + column_fields: Vec, + // There exists at least one non-optional table filter or at least one + // complex filter is pushed down. + has_non_optional_filter: AtomicBool, +} +assert_impl_all!(TableFunctionBind: Send, Clone); + +impl Clone for TableFunctionBind { + fn clone(&self) -> Self { + Self { + data_source: Arc::clone(&self.data_source), + // filter_exprs are consumed once in `init_global`. + filter_exprs: vec![], + column_fields: self.column_fields.clone(), + has_non_optional_filter: AtomicBool::new( + self.has_non_optional_filter.load(Ordering::Relaxed), + ), + } + } +} + +impl fmt::Debug for TableFunctionBind { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("DataSourceBindData") + .field("column_fields", &self.column_fields) + .field( + "filter_exprs", + &self + .filter_exprs + .iter() + .map(|e| e.to_string()) + .collect::>(), + ) + .finish() + } +} + +impl<'a> TableInitInput<'a> { + pub fn bind_data(&self) -> &TableFunctionBind { + unsafe { &*self.input.bind_data.cast::() } + } +} + +type DataSourceIterator = ThreadSafeIterator)>>; + +pub struct TableFunctionGlobal { + iterator: DataSourceIterator, + batch_id: AtomicU64, + bytes_total: Arc, + bytes_read: AtomicU64, + file_index_column_pos: Option, + file_row_number_column_pos: Option, +} +assert_impl_all!(TableFunctionGlobal: Send, Sync); + +/// Per-thread scan state +pub struct TableFunctionLocal { + iterator: DataSourceIterator, + exporter: Option, + partition_index: u64, + file_index: usize, +} + +pub struct PartitionData { + pub partition_index: u64, + pub file_index_column_pos: Option, + pub file_index: usize, +} + +#[derive(Debug)] +pub enum Cardinality { + /// Unknown number of rows + Unknown, + /// An estimate of the number of rows. + Estimate(u64), +} + +pub fn bind( + ctx: &ClientContextRef, + input: &BindInputRef, + result: &mut BindResultRef, +) -> VortexResult { + let data_source = bind_multi_file_scan(ctx, input)?; + let column_fields = extract_schema_from_dtype(data_source.dtype())?; + for fields in &column_fields { + result.add_result_column(&fields.name, &fields.logical_type); + } + Ok(TableFunctionBind { + data_source: Arc::new(data_source), + filter_exprs: vec![], + column_fields, + has_non_optional_filter: AtomicBool::new(false), + }) +} + +pub fn init_global(init_input: &TableInitInput) -> VortexResult { + debug!(input=?init_input, "table function global input"); + + let bind_data = init_input.bind_data(); + let column_ids = init_input.column_ids(); + let projection_ids = init_input.projection_ids(); + + let Projection { + projection, + file_index_column_pos, + file_row_number_column_pos, + } = Projection::new(projection_ids, column_ids, &bind_data.column_fields); + + let Filter { + filter, + row_selection, + row_range, + file_selection, + file_range, + has_non_optional_filter, + } = Filter::new( + init_input.table_filter_set(), + column_ids, + &bind_data.column_fields, + &bind_data.filter_exprs, + bind_data.data_source.dtype(), + )?; + + if has_non_optional_filter { + init_input + .bind_data() + .has_non_optional_filter + .store(true, Ordering::Relaxed); + } + + debug!( + %projection, + filter = filter + .as_ref() + .map_or_else(|| "true".to_string(), |f| f.to_string()), + ?row_selection, + ?row_range, + ?file_selection, + ?file_range, + "table function scan input" + ); + + let request = ScanRequest { + projection, + filter, + ordered: file_row_number_column_pos.is_some(), + selection: row_selection, + row_range, + partition_selection: file_selection, + partition_range: file_range, + limit: None, + }; + + let scan = RUNTIME.block_on(bind_data.data_source.scan(request))?; + + let num_workers = get_available_parallelism().unwrap_or(1); + + // We create an async bounded channel so that all thread-local workers can pull the next + // available array chunk regardless of which partition it came from. + let (tx, rx) = kanal::bounded_async(num_workers * 2); + + // We drive one partition per worker thread. Each partition is driven as a spawned task + // that pushes array chunks into the shared channel as they are produced. This spawning + // allows all worker threads to drive the polling of all partitions, and then return the + // first available array chunk. + let stream = scan + .partitions() + .map(move |partition| { + let tx = tx.clone(); + RUNTIME.handle().spawn(async move { + let partition = match partition { + Ok(partition) => partition, + Err(e) => { + let _ = tx.send(Err(e)).await; + return; + } + }; + + let cache = Arc::new(ConversionCache { + file_index: partition.index(), + ..Default::default() + }); + + let mut stream = match partition.execute() { + Ok(s) => s, + Err(e) => { + let _ = tx.send(Err(e)).await; + return; + } + }; + while let Some(item) = stream.next().await { + if tx + .send(item.map(|a| (a, Arc::clone(&cache)))) + .await + .is_err() + { + // Exit early if the receiver has been dropped, which happens when the + // scan is complete or if an error has occurred in another partition. + return; + } + } + }) + }) + .buffer_unordered(num_workers); + + // Spawn a task to drive the partition stream and push array chunks into the channel. + RUNTIME.handle().spawn(stream.collect::<()>()).detach(); + + let iterator = RUNTIME.block_on_stream_thread_safe(|_handle| rx.into_stream()); + + Ok(TableFunctionGlobal { + iterator, + batch_id: AtomicU64::new(0), + bytes_total: Arc::new(AtomicU64::new(0)), + bytes_read: AtomicU64::new(0), + file_index_column_pos, + file_row_number_column_pos, + }) +} + +pub fn init_local(global: &TableFunctionGlobal) -> TableFunctionLocal { + unsafe { + use custom_labels::sys; + + if sys::current().is_null() { + let ls = sys::new(0); + sys::replace(ls); + }; + } + + let global_labels = get_global_labels(); + + for (key, value) in global_labels { + CURRENT_LABELSET.set(key, value); + } + + TableFunctionLocal { + iterator: global.iterator.clone(), + exporter: None, + partition_index: 0, + file_index: 0, + } +} + +pub fn scan( + local_state: &mut TableFunctionLocal, + global_state: &TableFunctionGlobal, + chunk: &mut DataChunkRef, +) -> VortexResult<()> { + loop { + if local_state.exporter.is_none() { + let mut ctx = SESSION.create_execution_ctx(); + let Some(result) = local_state.iterator.next() else { + return Ok(()); + }; + let (array_result, conversion_cache) = result?; + let array_result = array_result.optimize_recursive(ctx.session())?; + local_state.file_index = conversion_cache.file_index; + + let array_result: StructArray = if let Some(array) = array_result.as_opt::() { + array.into_owned() + } else if let Some(array) = array_result.as_opt::() + && let Some(pack_options) = array.scalar_fn().as_opt::() + { + StructArray::new( + pack_options.names.clone(), + array.children(), + array.len(), + pack_options.nullability.into(), + ) + } else { + array_result.execute::(&mut ctx)?.into_struct() + }; + + local_state.exporter = Some(ArrayExporter::try_new( + &array_result, + &conversion_cache, + ctx, + )?); + // Relaxed since there is no intra-instruction ordering required. + local_state.partition_index = global_state.batch_id.fetch_add(1, Ordering::Relaxed); + } + + let exporter = local_state + .exporter + .as_mut() + .vortex_expect("error: exporter missing"); + let has_more_data = exporter.export( + chunk, + global_state.file_index_column_pos, + global_state.file_row_number_column_pos, + )?; + + global_state + .bytes_read + .fetch_add(chunk.len(), Ordering::Relaxed); + + if !has_more_data { + // This exporter is fully consumed. + local_state.exporter = None; + local_state.partition_index = 0; + } else { + break; + } + } + + assert!(!chunk.is_empty()); + + if let Some(pos) = global_state.file_index_column_pos { + chunk + .get_vector_mut(pos) + .reference_value(&Value::from(local_state.file_index as u64)); + } + + Ok(()) +} + +/// Scan progress as a percentage (0.0–100.0). +pub fn table_scan_progress(global_state: &TableFunctionGlobal) -> f64 { + progress(&global_state.bytes_read, &global_state.bytes_total) +} + +pub fn pushdown_complex_filter( + bind_data: &mut TableFunctionBind, + expr: &ExpressionRef, +) -> VortexResult { + debug!(%expr, "pushing down expression"); + + let Some(expr) = try_from_bound_expression(expr)? else { + debug!(%expr, "failed to push down expression"); + return Ok(false); + }; + + // Duckdb calls pushdown_complex_filter during planning phase. + // If all filters are pushed down, duckdb enables a LEFT_DELIM_JOIN -> + // COMPARISON_JOIN (HASH_JOIN) optimization: + // duckdb/src/optimizer/deliminator.cpp: Deliminator::HasSelection, + // Deliminator::Optimize. + // + // This leads to a massive regression on tpch sf=10 q17 and other + // benchmarks. + // + // This bug is reported to Duckdb + // https://github.com/duckdb/duckdb/issues/22669 + // + // As a hack, report equality filters as not pushed. + // We can also report only the first filter as not pushed, but this + // has a negative performance impact. + let report_pushed = !expr + .as_opt::() + .map(|op| *op == Operator::Eq) + .unwrap_or(false); + + // Only table filters may be optional, any complex filter is + // non-optional by definition. + bind_data + .has_non_optional_filter + .store(true, Ordering::Relaxed); + + debug!(%expr, report_pushed, "pushed down expression"); + bind_data.filter_exprs.push(expr); + Ok(report_pushed) +} + +/// Get column-wise statistics. Available only if we're reading a single file. +pub fn statistics(bind_data: &TableFunctionBind, column_index: usize) -> Option { + let children = bind_data.data_source.children(); + // Otherwise we'd have to open all files eagerly which is a performance + // regression. Duckdb's Parquet reader only gets metadata for multiple + // files with a UNION BY NAME and we don't support it (yet) + // See duckdb/common/multi_file/multi_file_function.hpp#L691 + if children.len() != 1 { + return None; + } + let MultiLayoutChild::Opened(reader) = &children[0] else { + return None; + }; + let stats_sets = match reader.as_any().downcast_ref::() { + Some(inner) => inner.file_stats().stats_sets(), + None => return None, + }; + let stats_aggregate = ColumnStatisticsAggregate::new(&stats_sets[column_index]); + let dtype = bind_data.column_fields[column_index].dtype.clone(); + Some(ColumnStatistics::from(&stats_aggregate, dtype)) +} + +/** + * Duckdb requires post-filter cardinality estimates, otherwise join planner + * may flip join sides which is a huge regression for some queries i.e. 1000x + * for tpcds 85. + * + * See duckdb/src/optimizer/join_order/relation_statistics_helper.cpp + * + * As we don't report distinct values (same as Parquet), the only heuristic + * duckdb uses is a 0.2 filter if there is any non-optional filter. We mimic it + * here. + */ +const DEFAULT_SELECTIVITY: f64 = 0.2; +pub fn cardinality(bind_data: &TableFunctionBind) -> Cardinality { + match bind_data.data_source.row_count() { + Precision::Exact(v) | Precision::Inexact(v) => { + if !bind_data.has_non_optional_filter.load(Ordering::Relaxed) { + // Although we may have an exact upper bound here, reporting + // it as exact has a negative performance impact on tpcds as + // it's not a real post-filter calculation. + return Cardinality::Estimate(v); + } + let post_cardinality = v as f64 * DEFAULT_SELECTIVITY; + let post_cardinality: u64 = post_cardinality.as_(); + Cardinality::Estimate(max(1, post_cardinality)) + } + Precision::Absent => Cardinality::Unknown, + } +} + +pub fn get_partition_data( + global_init_data: &TableFunctionGlobal, + local_init_data: &mut TableFunctionLocal, +) -> PartitionData { + PartitionData { + partition_index: local_init_data.partition_index, + file_index_column_pos: global_init_data.file_index_column_pos, + file_index: local_init_data.file_index, + } +} + +pub fn to_string(bind_data: &TableFunctionBind, map: &mut DuckdbStringMapRef) { + map.push("Function", "Vortex Scan"); + if !bind_data.filter_exprs.is_empty() { + let mut filters = bind_data.filter_exprs.iter().map(|f| format!("{f}")); + map.push("Filters", &filters.join("\n")); + } +} + +fn progress(bytes_read: &AtomicU64, bytes_total: &AtomicU64) -> f64 { + let read = bytes_read.load(Ordering::Relaxed); + let mut total = bytes_total.load(Ordering::Relaxed); + total += (total == 0) as u64; + read as f64 / total as f64 * 100. +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::AtomicU64; + use std::sync::atomic::Ordering::Relaxed; + + use crate::table_function::progress; + + #[test] + fn test_table_scan_progress() { + let bytes_total = AtomicU64::new(100); + let bytes_read = AtomicU64::new(0); + + assert_eq!(progress(&bytes_read, &bytes_total), 0.0); + + bytes_read.fetch_add(100, Relaxed); + assert_eq!(progress(&bytes_read, &bytes_total), 100.); + + bytes_total.fetch_add(100, Relaxed); + assert!((progress(&bytes_read, &bytes_total) - 50.).abs() < f64::EPSILON); + } +}