From b5e1c29f829c08d50341c4d029203ce30ec7d1b8 Mon Sep 17 00:00:00 2001 From: azevaykin Date: Wed, 26 Nov 2025 20:06:37 +0300 Subject: [PATCH 1/6] Implement brute force vector search pushdown --- ydb/core/base/kmeans_clusters.cpp | 70 +++- ydb/core/base/kmeans_clusters.h | 3 + ydb/core/kqp/common/kqp_yql.cpp | 45 ++- ydb/core/kqp/common/kqp_yql.h | 10 + .../kqp/executer_actor/kqp_tasks_graph.cpp | 10 + ydb/core/kqp/opt/kqp_opt_build_txs.cpp | 24 +- ydb/core/kqp/opt/physical/kqp_opt_phy.cpp | 9 +- ydb/core/kqp/opt/physical/kqp_opt_phy_impl.h | 2 + .../kqp/opt/physical/kqp_opt_phy_limit.cpp | 365 ++++++++++++++++++ ydb/core/kqp/opt/physical/kqp_opt_phy_rules.h | 3 + .../kqp/query_compiler/kqp_query_compiler.cpp | 96 ++++- ydb/core/kqp/runtime/kqp_read_actor.cpp | 4 + ydb/core/kqp/ut/knn/kqp_knn_ut.cpp | 246 ++++++++++++ ydb/core/kqp/ut/knn/ya.make | 27 ++ ydb/core/kqp/ut/ya.make | 1 + ydb/core/protos/kqp_physical.proto | 4 + ydb/core/protos/tx_datashard.proto | 3 + .../tx/datashard/datashard__read_iterator.cpp | 7 +- 18 files changed, 914 insertions(+), 15 deletions(-) create mode 100644 ydb/core/kqp/ut/knn/kqp_knn_ut.cpp create mode 100644 ydb/core/kqp/ut/knn/ya.make diff --git a/ydb/core/base/kmeans_clusters.cpp b/ydb/core/base/kmeans_clusters.cpp index 535b0ab7b2c4..472b4a27f094 100644 --- a/ydb/core/base/kmeans_clusters.cpp +++ b/ydb/core/base/kmeans_clusters.cpp @@ -48,7 +48,7 @@ namespace { return Ydb::Table::VectorIndexSettings::METRIC_UNSPECIFIED; } }; - + Ydb::Table::VectorIndexSettings_Metric ParseSimilarity(const TString& similarity_, TString& error) { const TString similarity = to_lower(similarity_); if (similarity == "cosine") @@ -60,7 +60,7 @@ namespace { return Ydb::Table::VectorIndexSettings::METRIC_UNSPECIFIED; } }; - + Ydb::Table::VectorIndexSettings_VectorType ParseVectorType(const TString& vectorType_, TString& error) { const TString vectorType = to_lower(vectorType_); if (vectorType == "float") @@ -462,6 +462,58 @@ std::unique_ptr CreateClusters(const Ydb::Table::VectorIndexSettings& } } +std::unique_ptr CreateClustersAutoDetect(Ydb::Table::VectorIndexSettings settings, const TStringBuf& targetVector, ui32 maxRounds, TString& error) { + if (targetVector.empty()) { + error = "Target vector is empty"; + return nullptr; + } + + // Auto-detect vector type and dimension from target vector + const ui8 formatByte = static_cast(targetVector.back()); + + switch (formatByte) { + case EFormat::FloatVector: + settings.set_vector_type(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_FLOAT); + if (targetVector.size() < HeaderLen + sizeof(float)) { + error = "Target vector too short for float type"; + return nullptr; + } + settings.set_vector_dimension((targetVector.size() - HeaderLen) / sizeof(float)); + break; + case EFormat::Uint8Vector: + settings.set_vector_type(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_UINT8); + if (targetVector.size() < HeaderLen + sizeof(ui8)) { + error = "Target vector too short for uint8 type"; + return nullptr; + } + settings.set_vector_dimension((targetVector.size() - HeaderLen) / sizeof(ui8)); + break; + case EFormat::Int8Vector: + settings.set_vector_type(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_INT8); + if (targetVector.size() < HeaderLen + sizeof(i8)) { + error = "Target vector too short for int8 type"; + return nullptr; + } + settings.set_vector_dimension((targetVector.size() - HeaderLen) / sizeof(i8)); + break; + case EFormat::BitVector: + settings.set_vector_type(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_BIT); + if (targetVector.size() < 2 + HeaderLen) { + error = "Target vector too short for bit type"; + return nullptr; + } + // For bit vectors: size = ceil(dim/8) + 1 (padding info) + 1 (format byte) + // padding = targetVector[size - 2], actual bits = (size - 2) * 8 - padding + settings.set_vector_dimension((targetVector.size() - 2) * 8 - static_cast(targetVector[targetVector.size() - 2])); + break; + default: + error = TStringBuilder() << "Unknown vector format byte: " << static_cast(formatByte); + return nullptr; + } + + return CreateClusters(settings, maxRounds, error); +} + bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& error) { error = ""; @@ -474,16 +526,16 @@ bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& e return false; } - if (!ValidateSettingInRange("levels", - settings.has_levels() ? std::optional(settings.levels()) : std::nullopt, + if (!ValidateSettingInRange("levels", + settings.has_levels() ? std::optional(settings.levels()) : std::nullopt, MinLevels, MaxLevels, error)) { return false; } - if (!ValidateSettingInRange("clusters", - settings.has_clusters() ? std::optional(settings.clusters()) : std::nullopt, + if (!ValidateSettingInRange("clusters", + settings.has_clusters() ? std::optional(settings.clusters()) : std::nullopt, MinClusters, MaxClusters, error)) { @@ -500,7 +552,7 @@ bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& e } if (settings.settings().vector_dimension() * settings.clusters() > MaxVectorDimensionMultiplyClusters) { - error = TStringBuilder() << "Invalid vector_dimension*clusters: " << settings.settings().vector_dimension() << "*" << settings.clusters() + error = TStringBuilder() << "Invalid vector_dimension*clusters: " << settings.settings().vector_dimension() << "*" << settings.clusters() << " should be less than " << MaxVectorDimensionMultiplyClusters; return false; } @@ -528,8 +580,8 @@ bool ValidateSettings(const Ydb::Table::VectorIndexSettings& settings, TString& return false; } - if (!ValidateSettingInRange("vector_dimension", - settings.has_vector_dimension() ? std::optional(settings.vector_dimension()) : std::nullopt, + if (!ValidateSettingInRange("vector_dimension", + settings.has_vector_dimension() ? std::optional(settings.vector_dimension()) : std::nullopt, MinVectorDimension, MaxVectorDimension, error)) { diff --git a/ydb/core/base/kmeans_clusters.h b/ydb/core/base/kmeans_clusters.h index c3c9bcaa05ef..7e47890d2b16 100644 --- a/ydb/core/base/kmeans_clusters.h +++ b/ydb/core/base/kmeans_clusters.h @@ -47,6 +47,9 @@ class IClusters { std::unique_ptr CreateClusters(const Ydb::Table::VectorIndexSettings& settings, ui32 maxRounds, TString& error); +// Auto-detect vector type and dimension from target vector when settings have dimension=0 +std::unique_ptr CreateClustersAutoDetect(Ydb::Table::VectorIndexSettings settings, const TStringBuf& targetVector, ui32 maxRounds, TString& error); + bool ValidateSettings(const Ydb::Table::VectorIndexSettings& settings, TString& error); bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& error); bool FillSetting(Ydb::Table::KMeansTreeSettings& settings, const TString& name, const TString& value, TString& error); diff --git a/ydb/core/kqp/common/kqp_yql.cpp b/ydb/core/kqp/common/kqp_yql.cpp index 7d3cb9224dd6..ea0f13290210 100644 --- a/ydb/core/kqp/common/kqp_yql.cpp +++ b/ydb/core/kqp/common/kqp_yql.cpp @@ -182,7 +182,18 @@ TKqpReadTableSettings ParseInternal(const TCoNameValueTupleList& node) { for(const auto& kv: lv) { settings.IndexSelectionInfo.emplace(kv.Name().Value(), kv.Value().Cast().Value()); } - + } else if (name == TKqpReadTableSettings::VectorTopKColumnSettingName) { + YQL_ENSURE(tuple.Value().Maybe()); + settings.VectorTopKColumn = tuple.Value().Cast().Value(); + } else if (name == TKqpReadTableSettings::VectorTopKMetricSettingName) { + YQL_ENSURE(tuple.Value().Maybe()); + settings.VectorTopKMetric = tuple.Value().Cast().Value(); + } else if (name == TKqpReadTableSettings::VectorTopKTargetSettingName) { + YQL_ENSURE(tuple.Value().IsValid()); + settings.VectorTopKTarget = tuple.Value().Cast().Ptr(); + } else if (name == TKqpReadTableSettings::VectorTopKLimitSettingName) { + YQL_ENSURE(tuple.Value().IsValid()); + settings.VectorTopKLimit = tuple.Value().Cast().Ptr(); } else { YQL_ENSURE(false, "Unknown KqpReadTable setting name '" << name << "'"); } @@ -317,6 +328,38 @@ NNodes::TCoNameValueTupleList TKqpReadTableSettings::BuildNode(TExprContext& ctx .Done()); } + if (VectorTopKColumn) { + settings.emplace_back( + Build(ctx, pos) + .Name().Build(VectorTopKColumnSettingName) + .Value().Build(VectorTopKColumn) + .Done()); + } + + if (VectorTopKMetric) { + settings.emplace_back( + Build(ctx, pos) + .Name().Build(VectorTopKMetricSettingName) + .Value().Build(VectorTopKMetric) + .Done()); + } + + if (VectorTopKTarget) { + settings.emplace_back( + Build(ctx, pos) + .Name().Build(VectorTopKTargetSettingName) + .Value(VectorTopKTarget) + .Done()); + } + + if (VectorTopKLimit) { + settings.emplace_back( + Build(ctx, pos) + .Name().Build(VectorTopKLimitSettingName) + .Value(VectorTopKLimit) + .Done()); + } + return Build(ctx, pos) .Add(settings) .Done(); diff --git a/ydb/core/kqp/common/kqp_yql.h b/ydb/core/kqp/common/kqp_yql.h index ceb671181f56..73f2b70c743a 100644 --- a/ydb/core/kqp/common/kqp_yql.h +++ b/ydb/core/kqp/common/kqp_yql.h @@ -130,6 +130,10 @@ struct TKqpReadTableSettings: public TSortingOperator { static constexpr TStringBuf TabletIdName = "TabletId"; static constexpr TStringBuf PointPrefixLenSettingName = "PointPrefixLen"; static constexpr TStringBuf IndexSelectionDebugInfoSettingName = "IndexSelectionDebugInfo"; + static constexpr TStringBuf VectorTopKColumnSettingName = "VectorTopKColumn"; + static constexpr TStringBuf VectorTopKMetricSettingName = "VectorTopKMetric"; + static constexpr TStringBuf VectorTopKTargetSettingName = "VectorTopKTarget"; + static constexpr TStringBuf VectorTopKLimitSettingName = "VectorTopKLimit"; TVector SkipNullKeys; TExprNode::TPtr ItemsLimit; @@ -139,6 +143,12 @@ struct TKqpReadTableSettings: public TSortingOperator { ui64 PointPrefixLen = 0; THashMap IndexSelectionInfo; + // Vector top-K pushdown settings for brute force vector search + TString VectorTopKColumn; + TString VectorTopKMetric; + TExprNode::TPtr VectorTopKTarget; + TExprNode::TPtr VectorTopKLimit; + void AddSkipNullKey(const TString& key); void SetItemsLimit(const TExprNode::TPtr& expr) { ItemsLimit = expr; } diff --git a/ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp b/ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp index 14a9b61c38de..fdb0970052b9 100644 --- a/ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp +++ b/ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp @@ -2670,6 +2670,16 @@ TMaybe TKqpTasksGraph::BuildScanTasksFromSource(TStageInfo& stageInfo, b settings->SetItemsLimit(itemsLimit); } + if (source.HasVectorTopK()) { + const auto& in = source.GetVectorTopK(); + auto& out = *settings->MutableVectorTopK(); + out.SetColumn(in.GetColumn()); + *out.MutableSettings() = in.GetSettings(); + auto target = ExtractPhyValue(stageInfo, in.GetTargetVector(), TxAlloc->HolderFactory, TxAlloc->TypeEnv, NUdf::TUnboxedValuePod()); + out.SetTargetVector(TString(target.AsStringRef())); + out.SetLimit((ui32)ExtractPhyValue(stageInfo, in.GetLimit(), TxAlloc->HolderFactory, TxAlloc->TypeEnv, NUdf::TUnboxedValuePod()).Get()); + } + auto& lockTxId = GetMeta().LockTxId; if (lockTxId) { settings->SetLockTxId(*lockTxId); diff --git a/ydb/core/kqp/opt/kqp_opt_build_txs.cpp b/ydb/core/kqp/opt/kqp_opt_build_txs.cpp index 2e9503b43b93..f1855c38cb35 100644 --- a/ydb/core/kqp/opt/kqp_opt_build_txs.cpp +++ b/ydb/core/kqp/opt/kqp_opt_build_txs.cpp @@ -407,6 +407,16 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase { argsMap.emplace(inputArg.Raw(), makeParameterBinding(maybeBinding.Cast(), input.Pos()).Ptr()); } + // Also scan the program body for TKqpTxResultBinding (for VectorTopK precompute settings) + VisitExpr(stage.Program().Body().Ptr(), + [&](const TExprNode::TPtr& node) { + TExprBase expr(node); + if (auto binding = expr.Maybe()) { + sourceReplaceMap.emplace(node.Get(), makeParameterBinding(binding.Cast(), node->Pos()).Ptr()); + } + return true; + }); + auto inputs = Build(ctx, stage.Pos()) .Add(newInputs) .Done(); @@ -415,7 +425,7 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase { .Inputs(ctx.ReplaceNodes(ctx.ReplaceNodes(inputs.Ptr(), stagesMap), sourceReplaceMap)) .Program() .Args(newArgs) - .Body(ctx.ReplaceNodes(stage.Program().Body().Ptr(), argsMap)) + .Body(ctx.ReplaceNodes(ctx.ReplaceNodes(stage.Program().Body().Ptr(), argsMap), sourceReplaceMap)) .Build() .Settings(stage.Settings()) .Outputs(stage.Outputs()) @@ -491,6 +501,18 @@ TVector PrecomputeInputs(const TDqStage& stage) { }); } } + + // Also scan the program body for precomputes in read settings (for VectorTopK pushdown) + VisitExpr(stage.Program().Body().Ptr(), + [&] (const TExprNode::TPtr& ptr) { + TExprBase node(ptr); + if (auto maybePrecompute = node.Maybe()) { + result.push_back(maybePrecompute.Cast()); + return false; + } + return true; + }); + return result; } diff --git a/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp b/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp index b17c3c6f5a70..aa708f97c7a3 100644 --- a/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp +++ b/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp @@ -45,6 +45,7 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase { AddHandler(0, IsSort, HNDL(RemoveRedundantSortOverReadTable)); AddHandler(0, &TCoTake::Match, HNDL(ApplyLimitToReadTable)); AddHandler(0, &TCoTopSort::Match, HNDL(ApplyLimitToOlapReadTable)); + AddHandler(0, &TCoTopSort::Match, HNDL(ApplyVectorTopKToReadTable)); AddHandler(0, &TCoFlatMap::Match, HNDL(PushOlapFilter)); AddHandler(0, &TCoFlatMap::Match, HNDL(PushOlapProjections)); AddHandler(0, &TCoAggregateCombine::Match, HNDL(PushAggregateCombineToStage)); @@ -206,7 +207,7 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase { else { TExprBase output = KqpBuildStreamIdxLookupJoinStagesKeepSorted(node, ctx, TypesCtx, true); DumpAppliedRule("BuildStreamIdxLookupJoinStagesKeepSorted", node.Ptr(), output.Ptr(), ctx); - return output; + return output; } } @@ -260,6 +261,12 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase { return output; } + TMaybeNode ApplyVectorTopKToReadTable(TExprBase node, TExprContext& ctx) { + TExprBase output = KqpApplyVectorTopKToReadTable(node, ctx, KqpCtx); + DumpAppliedRule("ApplyVectorTopKToReadTable", node.Ptr(), output.Ptr(), ctx); + return output; + } + TMaybeNode PushOlapFilter(TExprBase node, TExprContext& ctx) { TExprBase output = KqpPushOlapFilter(node, ctx, KqpCtx, TypesCtx, *TypeAnnTransformer.Get()); DumpAppliedRule("PushOlapFilter", node.Ptr(), output.Ptr(), ctx); diff --git a/ydb/core/kqp/opt/physical/kqp_opt_phy_impl.h b/ydb/core/kqp/opt/physical/kqp_opt_phy_impl.h index 242f2d625e6d..c3410acfd88b 100644 --- a/ydb/core/kqp/opt/physical/kqp_opt_phy_impl.h +++ b/ydb/core/kqp/opt/physical/kqp_opt_phy_impl.h @@ -19,6 +19,8 @@ NYql::NNodes::TMaybeNode BuildLookupKeysPrecompu NYql::NNodes::TCoAtomList BuildColumnsList(const THashSet& columns, NYql::TPositionHandle pos, NYql::TExprContext& ctx); +NYql::NNodes::TExprBase KqpPrecomputeParameter(NYql::NNodes::TExprBase param, NYql::TExprContext& ctx); + NYql::NNodes::TCoAtomList BuildColumnsList(const TVector& columns, NYql::TPositionHandle pos, NYql::TExprContext& ctx); diff --git a/ydb/core/kqp/opt/physical/kqp_opt_phy_limit.cpp b/ydb/core/kqp/opt/physical/kqp_opt_phy_limit.cpp index 69ed31b6004a..5e1bc3b44903 100644 --- a/ydb/core/kqp/opt/physical/kqp_opt_phy_limit.cpp +++ b/ydb/core/kqp/opt/physical/kqp_opt_phy_limit.cpp @@ -3,6 +3,8 @@ #include +#include + namespace NKikimr::NKqp::NOpt { using namespace NYql; @@ -186,5 +188,368 @@ TExprBase KqpApplyLimitToOlapReadTable(TExprBase node, TExprContext& ctx, const .Done(); } +namespace { + +// Helper function to extract info from a Knn::*Distance Apply node +// argSubstitutions maps lambda arguments to their corresponding input expressions +// Returns: {column name, method name, target expression} or nullopt if not a valid Knn distance call +TMaybe> ExtractKnnDistanceFromApply( + const TCoApply& apply, + const TNodeMap& argSubstitutions = {}) { + auto udf = apply.Callable().Maybe(); + if (!udf) { + return {}; + } + + const auto methodName = TString(udf.Cast().MethodName().Value()); + if (!methodName.StartsWith("Knn.")) { + return {}; + } + + // Find the member (column) and target expression + // Knn distance functions have exactly 2 arguments: column and target vector + TMaybe columnName; + TExprNode::TPtr targetExpr; + size_t argCount = 0; + + for (const auto& arg : apply.Args()) { + // Skip the callable itself (first element in Args) + if (arg.Raw() == apply.Callable().Raw()) { + continue; + } + argCount++; + + // Try to resolve the argument through substitutions first + TExprNode::TPtr resolvedArg = arg.Ptr(); + auto it = argSubstitutions.find(arg.Raw()); + if (it != argSubstitutions.end()) { + resolvedArg = it->second; + } + + // Check if the (possibly resolved) argument is a Member - that's the column + TExprBase resolvedExpr(resolvedArg); + if (auto member = resolvedExpr.Maybe()) { + columnName = TString(member.Cast().Name().Value()); + } else { + // Not a column reference - it's the target expression + targetExpr = resolvedArg; + } + } + + // Knn distance functions should have exactly 2 arguments + if (!columnName || !targetExpr || argCount != 2) { + return {}; + } + + return std::make_tuple(*columnName, methodName, targetExpr); +} + +// Try to find a Knn distance expression in a struct literal by column name +// Returns the expression if found, nullopt otherwise +TMaybeNode FindColumnExprInStruct(const TExprBase& structExpr, const TStringBuf& columnName) { + if (auto asStruct = structExpr.Maybe()) { + for (const auto& item : asStruct.Cast()) { + if (auto tuple = item.Maybe()) { + if (tuple.Cast().Name().Value() == columnName) { + return tuple.Cast().Value(); + } + } + } + } + return {}; +} + +// Check if the lambda body is a Knn::*Distance function call and extract information +// Returns: {column name, metric name, target expression} or nullopt if not a Knn distance function +// Handles both direct Apply and FlatMap-wrapped cases (for nullable columns) +TMaybe> ExtractKnnDistanceInfo(const TExprBase& lambdaBody) { + // Case 1: Direct Apply - Knn::Distance(Member(input, 'emb'), expr) + if (auto apply = lambdaBody.Maybe()) { + return ExtractKnnDistanceFromApply(apply.Cast()); + } + + // Case 2: FlatMap-wrapped (for optional target expression like String::HexDecode) + // FlatMap(, lambda(arg): Knn::Distance(Member(lambda_arg, 'emb'), arg)) + // Where is the actual target expression and 'arg' is bound to it + if (auto flatMap = lambdaBody.Maybe()) { + auto input = flatMap.Cast().Input(); + auto lambda = flatMap.Cast().Lambda(); + auto innerBody = lambda.Body(); + + // Build substitution map: lambda arg -> FlatMap input + TNodeMap argSubstitutions; + if (lambda.Args().Size() == 1) { + argSubstitutions[lambda.Args().Arg(0).Raw()] = input.Ptr(); + } + + // Check if the inner body is a Just(Apply(...)) + if (auto just = innerBody.Maybe()) { + if (auto apply = just.Cast().Input().Maybe()) { + return ExtractKnnDistanceFromApply(apply.Cast(), argSubstitutions); + } + } + + // Check if the inner body is directly an Apply + if (auto apply = innerBody.Maybe()) { + return ExtractKnnDistanceFromApply(apply.Cast(), argSubstitutions); + } + + // Case 3: Nested FlatMap (for cases with multiple nullables) + // FlatMap(expr, lambda(arg1): FlatMap(input, lambda(arg2): Knn::Distance(...))) + if (auto innerFlatMap = innerBody.Maybe()) { + auto innerInput = innerFlatMap.Cast().Input(); + auto innerLambda = innerFlatMap.Cast().Lambda(); + auto innerInnerBody = innerLambda.Body(); + + // Add substitution for inner lambda arg + if (innerLambda.Args().Size() == 1) { + argSubstitutions[innerLambda.Args().Arg(0).Raw()] = innerInput.Ptr(); + } + + if (auto just = innerInnerBody.Maybe()) { + if (auto apply = just.Cast().Input().Maybe()) { + return ExtractKnnDistanceFromApply(apply.Cast(), argSubstitutions); + } + } + if (auto apply = innerInnerBody.Maybe()) { + return ExtractKnnDistanceFromApply(apply.Cast(), argSubstitutions); + } + } + } + + return {}; +} + +// Get the metric enum value from the method name +TString GetMetricFromMethodName(const TString& methodName, bool isAsc) { + if (methodName == "Knn.CosineDistance" && isAsc) { + return "CosineDistance"; + } + if (methodName == "Knn.CosineSimilarity" && !isAsc) { + return "CosineSimilarity"; + } + if (methodName == "Knn.InnerProductSimilarity" && !isAsc) { + return "InnerProductSimilarity"; + } + if (methodName == "Knn.ManhattanDistance" && isAsc) { + return "ManhattanDistance"; + } + if (methodName == "Knn.EuclideanDistance" && isAsc) { + return "EuclideanDistance"; + } + return {}; +} + +} // anonymous namespace + +TExprBase KqpApplyVectorTopKToReadTable(TExprBase node, TExprContext& ctx, const TKqpOptimizeContext& kqpCtx) { + if (!node.Maybe()) { + return node; + } + auto topSort = node.Cast(); + + auto input = topSort.Input(); + + // Check if input is directly TKqpReadTableRanges + bool isReadTableRanges = input.Maybe().IsValid(); + + // Also check if input is TDqCnUnionAll with a stage containing TKqpReadTableRanges + TMaybeNode maybeCnUnionAll; + TMaybeNode maybeStage; + TMaybeNode maybeReadTableRanges; + + // Track if there's a FlatMap between TopSort and read (for ORDER BY alias case) + TMaybeNode maybeFlatMap; + TExprBase flatMapInput = input; + + // Check for FlatMap (used when ORDER BY references an alias/computed column) + // The FlatMap could be directly under TopSort or inside a TDqCnUnionAll stage + if (input.Maybe()) { + maybeFlatMap = input.Cast(); + flatMapInput = maybeFlatMap.Cast().Input(); + } + + if (!isReadTableRanges && flatMapInput.Maybe()) { + maybeCnUnionAll = flatMapInput.Cast(); + maybeStage = maybeCnUnionAll.Cast().Output().Stage().Maybe(); + if (maybeStage) { + auto stage = maybeStage.Cast(); + auto stageBody = stage.Program().Body(); + + // Check if the stage program body is TKqpReadTableRanges directly + maybeReadTableRanges = stageBody.Maybe(); + if (maybeReadTableRanges) { + isReadTableRanges = true; + input = maybeReadTableRanges.Cast(); + } + + // Also check if the stage body is a FlatMap over TKqpReadTableRanges + // This happens when ORDER BY references an alias computed in the FlatMap + if (!isReadTableRanges && stageBody.Maybe()) { + auto stageFlatMap = stageBody.Cast(); + maybeReadTableRanges = stageFlatMap.Input().Maybe(); + if (maybeReadTableRanges) { + isReadTableRanges = true; + input = maybeReadTableRanges.Cast(); + // Use the FlatMap from the stage for alias resolution + if (!maybeFlatMap) { + maybeFlatMap = stageFlatMap; + } + } + } + } + } + + if (!isReadTableRanges) { + return node; + } + + auto& tableDesc = kqpCtx.Tables->ExistingTable(kqpCtx.Cluster, GetReadTablePath(input, true)); + + // Only for datashard tables + if (tableDesc.Metadata->Kind != EKikimrTableKind::Datashard) { + return node; + } + + auto settings = GetReadTableSettings(input, true); + if (settings.VectorTopKColumn) { + return node; // already set + } + + // Check sort direction - only handle single direction (either all forward or all reverse) + ESortDirection direction = GetSortDirection(topSort.SortDirections()); + if (direction != ESortDirection::Forward && direction != ESortDirection::Reverse) { + return node; + } + const bool isAsc = (direction == ESortDirection::Forward); + + // Extract Knn distance info from the key selector lambda + auto lambdaBody = topSort.KeySelectorLambda().Body(); + auto knnInfo = ExtractKnnDistanceInfo(lambdaBody); + + // If not found directly, check if it's a Member accessing a computed column from FlatMap + if (!knnInfo && maybeFlatMap) { + if (auto member = lambdaBody.Maybe()) { + auto columnName = member.Cast().Name().Value(); + // Look for this column in the FlatMap's lambda body + auto flatMapLambda = maybeFlatMap.Cast().Lambda(); + auto flatMapBody = flatMapLambda.Body(); + + // The FlatMap body might be Just(AsStruct(...)) or AsStruct(...) + TExprBase structExpr = flatMapBody; + if (auto just = flatMapBody.Maybe()) { + structExpr = just.Cast().Input(); + } + + auto maybeColumnExpr = FindColumnExprInStruct(structExpr, columnName); + if (maybeColumnExpr.IsValid()) { + knnInfo = ExtractKnnDistanceInfo(maybeColumnExpr.Cast()); + } + } + } + + if (!knnInfo) { + return node; + } + + auto [columnName, methodName, targetExpr] = *knnInfo; + + // Check if the metric is valid for the sort direction + TString metric = GetMetricFromMethodName(methodName, isAsc); + if (metric.empty()) { + return node; + } + + // Check if the column exists in the table + if (!tableDesc.Metadata->Columns.contains(columnName)) { + return node; + } + + YQL_CLOG(TRACE, ProviderKqp) << "-- applying vector top-K pushdown for column " << columnName << " with metric " << metric; + + // Set the vector top-K settings + settings.VectorTopKColumn = columnName; + settings.VectorTopKMetric = metric; + + // Target expression - wrap in precompute if not a simple type + TExprBase targetExprBase(targetExpr); + if (targetExprBase.Maybe() || targetExprBase.Maybe()) { + settings.VectorTopKTarget = targetExpr; + } else { + // Wrap non-simple expressions in precompute (TDqPhyPrecompute) + settings.VectorTopKTarget = KqpPrecomputeParameter(targetExprBase, ctx).Ptr(); + } + + // Limit expression - wrap in precompute if not a simple type + auto limitExpr = topSort.Count(); + if (limitExpr.Maybe() || limitExpr.Maybe()) { + settings.VectorTopKLimit = limitExpr.Ptr(); + } else { + settings.VectorTopKLimit = KqpPrecomputeParameter(limitExpr, ctx).Ptr(); + } + + auto newReadNode = BuildReadNode(node.Pos(), ctx, input, settings); + + // If we found the read inside a stage, we need to rebuild the stage and connection + if (maybeStage) { + auto stage = maybeStage.Cast(); + auto stageBody = stage.Program().Body(); + + // Determine the new stage body + TExprBase newStageBody = newReadNode; + + // If the stage body was a FlatMap over the read, preserve the FlatMap structure and type + if (auto orderedFlatMap = stageBody.Maybe()) { + if (orderedFlatMap.Cast().Input().Maybe()) { + // Rebuild as OrderedFlatMap to preserve ordering semantics + newStageBody = Build(ctx, orderedFlatMap.Cast().Pos()) + .Input(newReadNode) + .Lambda(orderedFlatMap.Cast().Lambda()) + .Done(); + } + } else if (auto flatMap = stageBody.Maybe()) { + if (flatMap.Cast().Input().Maybe()) { + // Rebuild as regular FlatMap + newStageBody = Build(ctx, flatMap.Cast().Pos()) + .Input(newReadNode) + .Lambda(flatMap.Cast().Lambda()) + .Done(); + } + } + + // Rebuild the stage with the new body + auto newStage = Build(ctx, stage.Pos()) + .Inputs(stage.Inputs()) + .Program() + .Args(stage.Program().Args()) + .Body(newStageBody) + .Build() + .Settings(stage.Settings()) + .Done(); + + // Rebuild the connection + auto newCnUnionAll = Build(ctx, maybeCnUnionAll.Cast().Pos()) + .Output() + .Stage(newStage) + .Index(maybeCnUnionAll.Cast().Output().Index()) + .Build() + .Done(); + + return Build(ctx, topSort.Pos()) + .Input(newCnUnionAll) + .Count(topSort.Count()) + .SortDirections(topSort.SortDirections()) + .KeySelectorLambda(topSort.KeySelectorLambda()) + .Done(); + } + + return Build(ctx, topSort.Pos()) + .Input(newReadNode) + .Count(topSort.Count()) + .SortDirections(topSort.SortDirections()) + .KeySelectorLambda(topSort.KeySelectorLambda()) + .Done(); +} + } // namespace NKikimr::NKqp::NOpt diff --git a/ydb/core/kqp/opt/physical/kqp_opt_phy_rules.h b/ydb/core/kqp/opt/physical/kqp_opt_phy_rules.h index 373cc251b13e..c1d6976a0d25 100644 --- a/ydb/core/kqp/opt/physical/kqp_opt_phy_rules.h +++ b/ydb/core/kqp/opt/physical/kqp_opt_phy_rules.h @@ -62,6 +62,9 @@ NYql::NNodes::TExprBase KqpApplyLimitToReadTable(NYql::NNodes::TExprBase node, N NYql::NNodes::TExprBase KqpApplyLimitToOlapReadTable(NYql::NNodes::TExprBase node, NYql::TExprContext& ctx, const TKqpOptimizeContext& kqpCtx); +NYql::NNodes::TExprBase KqpApplyVectorTopKToReadTable(NYql::NNodes::TExprBase node, NYql::TExprContext& ctx, + const TKqpOptimizeContext& kqpCtx); + NYql::NNodes::TExprBase KqpPushOlapFilter(NYql::NNodes::TExprBase node, NYql::TExprContext& ctx, const TKqpOptimizeContext& kqpCtx, NYql::TTypeAnnotationContext& typesCtx, NYql::IGraphTransformer &typeAnnTransformer); diff --git a/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp b/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp index bbd73426875f..00ddb13d01fc 100644 --- a/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp +++ b/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp @@ -37,6 +37,86 @@ using namespace NYql::NNodes; namespace { +// Helper function to set VectorTopK metric from string +void SetVectorTopKMetric(Ydb::Table::VectorIndexSettings* indexSettings, const TString& metric) { + if (metric == "CosineDistance") { + indexSettings->set_metric(Ydb::Table::VectorIndexSettings::DISTANCE_COSINE); + } else if (metric == "CosineSimilarity") { + indexSettings->set_metric(Ydb::Table::VectorIndexSettings::SIMILARITY_COSINE); + } else if (metric == "InnerProductSimilarity") { + indexSettings->set_metric(Ydb::Table::VectorIndexSettings::SIMILARITY_INNER_PRODUCT); + } else if (metric == "ManhattanDistance") { + indexSettings->set_metric(Ydb::Table::VectorIndexSettings::DISTANCE_MANHATTAN); + } else if (metric == "EuclideanDistance") { + indexSettings->set_metric(Ydb::Table::VectorIndexSettings::DISTANCE_EUCLIDEAN); + } +} + +// Helper function to set VectorTopK target vector expression +void SetVectorTopKTarget(NKqpProto::TKqpPhyValue* targetProto, const TExprNode::TPtr& targetExpr) { + TExprBase expr(targetExpr); + if (expr.Maybe()) { + auto* literal = targetProto->MutableLiteralValue(); + literal->MutableType()->SetKind(NKikimrMiniKQL::ETypeKind::Data); + literal->MutableType()->MutableData()->SetScheme(NScheme::NTypeIds::String); + literal->MutableValue()->SetText(TString(expr.Cast().Literal().Value())); + } else if (expr.Maybe()) { + targetProto->MutableParamValue()->SetParamName(expr.Cast().Name().StringValue()); + } else { + YQL_ENSURE(false, "Unexpected VectorTopKTarget callable " << expr.Ref().Content()); + } +} + +// Helper function to set VectorTopK limit expression +void SetVectorTopKLimit(NKqpProto::TKqpPhyValue* limitProto, const TExprNode::TPtr& limitExpr) { + TExprBase expr(limitExpr); + if (expr.Maybe()) { + auto* literal = limitProto->MutableLiteralValue(); + literal->MutableType()->SetKind(NKikimrMiniKQL::ETypeKind::Data); + literal->MutableType()->MutableData()->SetScheme(NScheme::NTypeIds::Uint64); + literal->MutableValue()->SetUint64(FromString(expr.Cast().Literal().Value())); + } else if (expr.Maybe()) { + limitProto->MutableParamValue()->SetParamName(expr.Cast().Name().StringValue()); + } else { + YQL_ENSURE(false, "Unexpected VectorTopKLimit callable " << expr.Ref().Content()); + } +} + +// Helper function to fill VectorTopK settings +template +void FillVectorTopKSettings( + NKqpProto::TKqpPhyVectorTopK& vectorTopK, + const TKqpReadTableSettings& settings, + const TColumnsRange& columns) +{ + // Find column index + ui32 columnIdx = 0; + bool columnFound = false; + for (const auto& col : columns) { + if (col.Value() == settings.VectorTopKColumn) { + columnFound = true; + break; + } + columnIdx++; + } + YQL_ENSURE(columnFound, "VectorTopK column " << settings.VectorTopKColumn << " not found in read columns"); + vectorTopK.SetColumn(columnIdx); + + // Set the metric settings + auto* indexSettings = vectorTopK.MutableSettings(); + SetVectorTopKMetric(indexSettings, settings.VectorTopKMetric); + + // Default vector settings - actual type will be determined from data + indexSettings->set_vector_type(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_FLOAT); + indexSettings->set_vector_dimension(0); + + // Set target vector + SetVectorTopKTarget(vectorTopK.MutableTargetVector(), settings.VectorTopKTarget); + + // Set limit + SetVectorTopKLimit(vectorTopK.MutableLimit(), settings.VectorTopKLimit); +} + NKqpProto::TKqpPhyTx::EType GetPhyTxType(const EPhysicalTxType& type) { switch (type) { case EPhysicalTxType::Compute: return NKqpProto::TKqpPhyTx::TYPE_COMPUTE; @@ -393,7 +473,7 @@ void FillReadRange(const TKqpWideReadTable& read, const TKikimrTableMetadata& ta } template -void FillReadRanges(const TReader& read, const TKikimrTableMetadata&, TProto& readProto) { +void FillReadRanges(const TReader& read, const TKikimrTableMetadata& /*tableMeta*/, TProto& readProto) { auto ranges = read.Ranges().template Maybe(); if (ranges.IsValid()) { @@ -431,6 +511,13 @@ void FillReadRanges(const TReader& read, const TKikimrTableMetadata&, TProto& re } } + // Handle VectorTopK settings for brute force vector search + if constexpr (std::is_same_v) { + if (settings.VectorTopKColumn) { + FillVectorTopKSettings(*readProto.MutableVectorTopK(), settings, read.Columns()); + } + } + readProto.SetReverse(settings.IsReverse()); } @@ -1166,6 +1253,11 @@ class TKqpQueryCompiler : public IKqpQueryCompiler { YQL_ENSURE(false, "Unexpected ItemsLimit callable " << expr.Ref().Content()); } } + + // Handle VectorTopK settings for brute force vector search + if (readSettings.VectorTopKColumn) { + FillVectorTopKSettings(*readProto.MutableVectorTopK(), readSettings, settings.Columns().Cast()); + } } else { YQL_ENSURE(false, "Unsupported source type"); } @@ -1375,7 +1467,7 @@ class TKqpQueryCompiler : public IKqpQueryCompiler { if (indexDescription.Type == TIndexDescription::EType::GlobalSync || indexDescription.Type == TIndexDescription::EType::GlobalSyncUnique) { const auto& implTable = tableMeta->ImplTables[index]; - + if (settingsProto.GetType() == NKikimrKqp::TKqpTableSinkSettings::MODE_UPDATE) { if (std::any_of(implTable->Columns.begin(), implTable->Columns.end(), [&](const auto& column) { return columnsSet.contains(column.first) && !mainKeyColumnsSet.contains(column.first); diff --git a/ydb/core/kqp/runtime/kqp_read_actor.cpp b/ydb/core/kqp/runtime/kqp_read_actor.cpp index 118fd31835fc..c3e2e92934ee 100644 --- a/ydb/core/kqp/runtime/kqp_read_actor.cpp +++ b/ydb/core/kqp/runtime/kqp_read_actor.cpp @@ -893,6 +893,10 @@ class TKqpReadActor : public TActorBootstrapped, public NYql::NDq record.SetLockNodeId(Settings->GetLockNodeId()); } + if (Settings->HasVectorTopK()) { + *record.MutableVectorTopK() = Settings->GetVectorTopK(); + } + CA_LOG_D(TStringBuilder() << "Send EvRead to shardId: " << state->TabletId << ", tablePath: " << Settings->GetTable().GetTablePath() << ", ranges: " << DebugPrintRanges(KeyColumnTypes, ev->Ranges, *AppData()->TypeRegistry) << ", limit: " << limit diff --git a/ydb/core/kqp/ut/knn/kqp_knn_ut.cpp b/ydb/core/kqp/ut/knn/kqp_knn_ut.cpp new file mode 100644 index 000000000000..297d73a2b5a4 --- /dev/null +++ b/ydb/core/kqp/ut/knn/kqp_knn_ut.cpp @@ -0,0 +1,246 @@ +#include + +#include + +#include + +#include + +namespace NKikimr { +namespace NKqp { + +using namespace NYdb; +using namespace NYdb::NTable; + +Y_UNIT_TEST_SUITE(KqpKnn) { + + TSession CreateTableForVectorSearch(TTableClient& db, bool nullable, const TString& dataCol = "data") { + auto session = db.CreateSession().GetValueSync().GetSession(); + + { + auto tableBuilder = db.GetTableBuilder(); + if (nullable) { + tableBuilder + .AddNullableColumn("pk", EPrimitiveType::Int64) + .AddNullableColumn("emb", EPrimitiveType::String) + .AddNullableColumn(dataCol, EPrimitiveType::String); + } else { + tableBuilder + .AddNonNullableColumn("pk", EPrimitiveType::Int64) + .AddNonNullableColumn("emb", EPrimitiveType::String) + .AddNonNullableColumn(dataCol, EPrimitiveType::String); + } + tableBuilder.SetPrimaryKeyColumns({"pk"}); + tableBuilder.BeginPartitioningSettings() + .SetMinPartitionsCount(3) + .EndPartitioningSettings(); + auto partitions = TExplicitPartitions{} + .AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(4).EndTuple().Build()) + .AppendSplitPoints(TValueBuilder{}.BeginTuple().AddElement().OptionalInt64(6).EndTuple().Build()); + tableBuilder.SetPartitionAtKeys(partitions); + auto result = session.CreateTable("/Root/TestTable", tableBuilder.Build()).ExtractValueSync(); + UNIT_ASSERT_VALUES_EQUAL(result.IsTransportError(), false); + UNIT_ASSERT_VALUES_EQUAL_C(result.GetStatus(), EStatus::SUCCESS, result.GetIssues().ToString()); + } + + { + const TString query1 = TStringBuilder() + << "UPSERT INTO `/Root/TestTable` (pk, emb, " << dataCol << ") VALUES " + << "(0, \"\x03\x30\x02\", \"0\")," + "(1, \"\x13\x31\x02\", \"1\")," + "(2, \"\x23\x32\x02\", \"2\")," + "(3, \"\x53\x33\x02\", \"3\")," + "(4, \"\x43\x34\x02\", \"4\")," + "(5, \"\x50\x60\x02\", \"5\")," + "(6, \"\x61\x11\x02\", \"6\")," + "(7, \"\x12\x62\x02\", \"7\")," + "(8, \"\x75\x76\x02\", \"8\")," + "(9, \"\x76\x76\x02\", \"9\");"; + + auto result = session.ExecuteDataQuery(Q_(query1), TTxControl::BeginTx(TTxSettings::SerializableRW()).CommitTx()) + .ExtractValueSync(); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + } + return session; + } + + Y_UNIT_TEST_TWIN(VectorSearchKnnPushdown, Nullable) { + auto setting = NKikimrKqp::TKqpSetting(); + auto serverSettings = TKikimrSettings() + .SetUseRealThreads(false) + .SetKqpSettings({setting}); + + TKikimrRunner kikimr(serverSettings); + auto runtime = kikimr.GetTestServer().GetRuntime(); + runtime->SetLogPriority(NKikimrServices::TX_DATASHARD, NActors::NLog::PRI_TRACE); + + auto db = kikimr.RunCall([&] { return kikimr.GetTableClient(); }); + auto session = kikimr.RunCall([&] { return CreateTableForVectorSearch(db, Nullable, "___data"); }); + + ui64 expectedLimit = 3; + auto captureEvents = [&](TTestActorRuntimeBase&, TAutoPtr& ev) { + if (ev->GetTypeRewrite() == TEvDataShard::TEvRead::EventType) { + auto& read = ev->Get()->Record; + UNIT_ASSERT(read.HasVectorTopK()); + auto& topK = read.GetVectorTopK(); + UNIT_ASSERT(topK.GetTargetVector() == "\x67\x71\x02"); + UNIT_ASSERT_VALUES_EQUAL(topK.GetLimit(), expectedLimit); + } + return false; + }; + runtime->SetEventFilter(captureEvents); + + auto runQuery = [&](const TString& query) { + auto result = kikimr.RunCall([&] { + return session.ExecuteDataQuery(query, + TTxControl::BeginTx(TTxSettings::SerializableRW()).CommitTx()).ExtractValueSync(); + }); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + }; + + auto runQueryWithParams = [&](const TString& query, TParams params) { + auto result = kikimr.RunCall([&] { + return session.ExecuteDataQuery(query, + TTxControl::BeginTx(TTxSettings::SerializableRW()).CommitTx(), params).ExtractValueSync(); + }); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + }; + + // Explicit columns + runQuery(Q_(R"( + $TargetEmbedding = String::HexDecode("677102"); + SELECT pk, emb, ___data FROM `/Root/TestTable` + ORDER BY Knn::CosineDistance(emb, $TargetEmbedding) + LIMIT 3 + )")); + + // Implicit columns + runQuery(Q_(R"( + $TargetEmbedding = String::HexDecode("677102"); + SELECT * FROM `/Root/TestTable` + ORDER BY Knn::CosineDistance(emb, $TargetEmbedding) + LIMIT 3 + )")); + + // Inner product similarity + runQuery(Q_(R"( + $TargetEmbedding = String::HexDecode("677102"); + SELECT * FROM `/Root/TestTable` + ORDER BY Knn::InnerProductSimilarity(emb, $TargetEmbedding) DESC + LIMIT 3 + )")); + + // Cosine similarity (DESC) + runQuery(Q_(R"( + $TargetEmbedding = String::HexDecode("677102"); + SELECT * FROM `/Root/TestTable` + ORDER BY Knn::CosineSimilarity(emb, $TargetEmbedding) DESC + LIMIT 3 + )")); + + // Manhattan distance (ASC) + runQuery(Q_(R"( + $TargetEmbedding = String::HexDecode("677102"); + SELECT * FROM `/Root/TestTable` + ORDER BY Knn::ManhattanDistance(emb, $TargetEmbedding) + LIMIT 3 + )")); + + // Euclidean distance (ASC) + runQuery(Q_(R"( + $TargetEmbedding = String::HexDecode("677102"); + SELECT * FROM `/Root/TestTable` + ORDER BY Knn::EuclideanDistance(emb, $TargetEmbedding) + LIMIT 3 + )")); + + // Parameters + runQueryWithParams(Q_(R"( + DECLARE $TargetEmbedding AS String; + SELECT * FROM `/Root/TestTable` + ORDER BY Knn::CosineDistance(emb, $TargetEmbedding) + LIMIT 3 + )"), db.GetParamsBuilder() + .AddParam("$TargetEmbedding") + .String(TString("\x67\x71\x02", 3)) + .Build() + .Build()); + + // LIMIT 1 (minimum) + expectedLimit = 1; + runQuery(Q_(R"( + $TargetEmbedding = String::HexDecode("677102"); + SELECT * FROM `/Root/TestTable` + ORDER BY Knn::CosineDistance(emb, $TargetEmbedding) + LIMIT 1 + )")); + + // Larger LIMIT + expectedLimit = 100; + runQuery(Q_(R"( + $TargetEmbedding = String::HexDecode("677102"); + SELECT * FROM `/Root/TestTable` + ORDER BY Knn::CosineDistance(emb, $TargetEmbedding) + LIMIT 100 + )")); + + // Verify actual results - check that top 3 PKs are correct + // Target vector is 0x67, 0x71 (103, 113) + // Cosine distances calculated: + // pk=8 (117, 118): 0.000882 - closest + // pk=5 (80, 96): 0.000985 + // pk=9 (118, 118): 0.001070 + expectedLimit = 3; + { + TString query(Q_(R"( + $TargetEmbedding = String::HexDecode("677102"); + SELECT pk, Knn::CosineDistance(emb, $TargetEmbedding) AS distance FROM `/Root/TestTable` + ORDER BY distance + LIMIT 3 + )")); + + auto result = kikimr.RunCall([&] { + return session.ExecuteDataQuery(query, + TTxControl::BeginTx(TTxSettings::SerializableRW()).CommitTx()).ExtractValueSync(); + }); + UNIT_ASSERT_C(result.IsSuccess(), result.GetIssues().ToString()); + + // Extract PKs and distances from result + std::vector> results; + auto parser = result.GetResultSetParser(0); + while (parser.TryNextRow()) { + i64 pk; + if constexpr (Nullable) { + auto pkOpt = parser.ColumnParser("pk").GetOptionalInt64(); + UNIT_ASSERT(pkOpt.has_value()); + pk = *pkOpt; + } else { + pk = parser.ColumnParser("pk").GetInt64(); + } + auto distanceOpt = parser.ColumnParser("distance").GetOptionalFloat(); + UNIT_ASSERT(distanceOpt.has_value()); + float distance = *distanceOpt; + results.push_back({pk, distance}); + } + + UNIT_ASSERT_VALUES_EQUAL(results.size(), 3u); + UNIT_ASSERT_VALUES_EQUAL(results[0].first, 8); + UNIT_ASSERT_VALUES_EQUAL(results[1].first, 5); + UNIT_ASSERT_VALUES_EQUAL(results[2].first, 9); + + // Check exact distance values (with small epsilon for float comparison) + auto checkDistance = [](float actual, float expected) { + UNIT_ASSERT_C(std::abs(actual - expected) < 0.0001f, + "Expected distance " << expected << ", got " << actual); + }; + checkDistance(results[0].second, 0.000882f); + checkDistance(results[1].second, 0.000985f); + checkDistance(results[2].second, 0.001070f); + } + } + +} + +} +} + diff --git a/ydb/core/kqp/ut/knn/ya.make b/ydb/core/kqp/ut/knn/ya.make new file mode 100644 index 000000000000..96d851ddfe78 --- /dev/null +++ b/ydb/core/kqp/ut/knn/ya.make @@ -0,0 +1,27 @@ +UNITTEST_FOR(ydb/core/kqp) + +FORK_SUBTESTS() +SPLIT_FACTOR(50) + +IF (SANITIZER_TYPE OR WITH_VALGRIND) + SIZE(LARGE) + TAG(ya:fat) +ELSE() + SIZE(MEDIUM) +ENDIF() + +SRCS( + kqp_knn_ut.cpp +) + +PEERDIR( + ydb/core/kqp + ydb/core/kqp/ut/common + ydb/library/yql/udfs/common/knn + yql/essentials/sql/pg_dummy +) + +YQL_LAST_ABI_VERSION() + +END() + diff --git a/ydb/core/kqp/ut/ya.make b/ydb/core/kqp/ut/ya.make index d8b85f8b7f89..cc9ce645d636 100644 --- a/ydb/core/kqp/ut/ya.make +++ b/ydb/core/kqp/ut/ya.make @@ -11,6 +11,7 @@ RECURSE_FOR_TESTS( indexes idx_test join + knn olap opt perf diff --git a/ydb/core/protos/kqp_physical.proto b/ydb/core/protos/kqp_physical.proto index 6a14980c6300..1c4e137f4f66 100644 --- a/ydb/core/protos/kqp_physical.proto +++ b/ydb/core/protos/kqp_physical.proto @@ -224,6 +224,8 @@ message TKqpPhyOpReadRanges { TKqpPhyValue ItemsLimit = 2; // Reverse sign, i.e. if user ask ORDER BY ... DESC we need to read table in reverse direction bool Reverse = 3; + // Vector top-K pushdown settings for brute force vector search + TKqpPhyVectorTopK VectorTopK = 4; } message TKqpPhyTableOperation { @@ -411,6 +413,8 @@ message TKqpReadRangesSource { repeated string SkipNullKeys = 8; uint64 SequentialInFlightShards = 9; bool IsTableImmutable = 10; + // Vector top-K pushdown for brute force vector search + TKqpPhyVectorTopK VectorTopK = 11; } message TKqpExternalSource { diff --git a/ydb/core/protos/tx_datashard.proto b/ydb/core/protos/tx_datashard.proto index f38a55288792..8a8f92646bac 100644 --- a/ydb/core/protos/tx_datashard.proto +++ b/ydb/core/protos/tx_datashard.proto @@ -276,6 +276,9 @@ message TKqpReadRangesSourceSettings { optional NKikimrKqp.EIsolationLevel IsolationLevel = 24 [default = ISOLATION_LEVEL_UNDEFINED]; optional string Database = 25; + + // Vector top-K pushdown for brute force vector search + optional NKikimrKqp.TReadVectorTopK VectorTopK = 26; } // Takes input rows with a vector column, resolves the leaf cluster ID using the given diff --git a/ydb/core/tx/datashard/datashard__read_iterator.cpp b/ydb/core/tx/datashard/datashard__read_iterator.cpp index 896279b72724..8ec4b1c2d20f 100644 --- a/ydb/core/tx/datashard/datashard__read_iterator.cpp +++ b/ydb/core/tx/datashard/datashard__read_iterator.cpp @@ -2213,7 +2213,12 @@ class TDataShard::TReadOperation : public TOperation, public IReadOperation { } else if (!topK.HasTargetVector()) { error = "Target vector is not specified"; } else { - topState->KMeans = NKMeans::CreateClusters(topK.GetSettings(), 0, error); + // Use auto-detect if vector_dimension is 0 (brute force search without index) + if (topK.GetSettings().vector_dimension() == 0) { + topState->KMeans = NKMeans::CreateClustersAutoDetect(topK.GetSettings(), topK.GetTargetVector(), 0, error); + } else { + topState->KMeans = NKMeans::CreateClusters(topK.GetSettings(), 0, error); + } if (!topState->KMeans && error == "") { error = "CreateClusters failed"; } From f1ab585cea9a63d9720d2ea2b6b048fa43822419 Mon Sep 17 00:00:00 2001 From: azevaykin <145343289+azevaykin@users.noreply.github.com> Date: Thu, 27 Nov 2025 09:08:47 +0300 Subject: [PATCH 2/6] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- ydb/core/base/kmeans_clusters.cpp | 6 +++--- ydb/core/kqp/query_compiler/kqp_query_compiler.cpp | 2 ++ ydb/core/tx/datashard/datashard__read_iterator.cpp | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ydb/core/base/kmeans_clusters.cpp b/ydb/core/base/kmeans_clusters.cpp index 472b4a27f094..3c304220d8c3 100644 --- a/ydb/core/base/kmeans_clusters.cpp +++ b/ydb/core/base/kmeans_clusters.cpp @@ -502,9 +502,9 @@ std::unique_ptr CreateClustersAutoDetect(Ydb::Table::VectorIndexSetti error = "Target vector too short for bit type"; return nullptr; } - // For bit vectors: size = ceil(dim/8) + 1 (padding info) + 1 (format byte) - // padding = targetVector[size - 2], actual bits = (size - 2) * 8 - padding - settings.set_vector_dimension((targetVector.size() - 2) * 8 - static_cast(targetVector[targetVector.size() - 2])); + // For bit vectors: size = HeaderLen + ceil(dim/8) + 1 (padding info) + 1 (format byte) + // padding = targetVector[size - 2], actual bits = (targetVector.size() - HeaderLen - 1) * 8 - padding + settings.set_vector_dimension((targetVector.size() - HeaderLen - 1) * 8 - static_cast(targetVector[targetVector.size() - 2])); break; default: error = TStringBuilder() << "Unknown vector format byte: " << static_cast(formatByte); diff --git a/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp b/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp index 00ddb13d01fc..39e2a4557e22 100644 --- a/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp +++ b/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp @@ -49,6 +49,8 @@ void SetVectorTopKMetric(Ydb::Table::VectorIndexSettings* indexSettings, const T indexSettings->set_metric(Ydb::Table::VectorIndexSettings::DISTANCE_MANHATTAN); } else if (metric == "EuclideanDistance") { indexSettings->set_metric(Ydb::Table::VectorIndexSettings::DISTANCE_EUCLIDEAN); + } else { + YQL_ENSURE(false, "Unrecognized VectorTopK metric: " << metric); } } diff --git a/ydb/core/tx/datashard/datashard__read_iterator.cpp b/ydb/core/tx/datashard/datashard__read_iterator.cpp index 8ec4b1c2d20f..9ef9d4a8d662 100644 --- a/ydb/core/tx/datashard/datashard__read_iterator.cpp +++ b/ydb/core/tx/datashard/datashard__read_iterator.cpp @@ -2219,7 +2219,7 @@ class TDataShard::TReadOperation : public TOperation, public IReadOperation { } else { topState->KMeans = NKMeans::CreateClusters(topK.GetSettings(), 0, error); } - if (!topState->KMeans && error == "") { + if (!topState->KMeans && error.empty()) { error = "CreateClusters failed"; } if (topState->KMeans && !topState->KMeans->IsExpectedFormat(topK.GetTargetVector())) { From 545dac78533009520d0c66cc8eaae46a1a98f643 Mon Sep 17 00:00:00 2001 From: azevaykin Date: Thu, 27 Nov 2025 10:26:32 +0300 Subject: [PATCH 3/6] Added defensive handling for TKqpTxResultBinding nodes in the VectorTopK helper functions --- .../kqp/query_compiler/kqp_query_compiler.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp b/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp index 39e2a4557e22..e80444bcd879 100644 --- a/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp +++ b/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -64,6 +65,13 @@ void SetVectorTopKTarget(NKqpProto::TKqpPhyValue* targetProto, const TExprNode:: literal->MutableValue()->SetText(TString(expr.Cast().Literal().Value())); } else if (expr.Maybe()) { targetProto->MutableParamValue()->SetParamName(expr.Cast().Name().StringValue()); + } else if (auto maybeBinding = expr.Maybe()) { + // TKqpTxResultBinding should have been replaced with TCoParameter by kqp_opt_build_txs, + // but handle it defensively by constructing the expected parameter name + auto binding = maybeBinding.Cast(); + TString paramName = TStringBuilder() << ParamNamePrefix + << "tx_result_binding_" << binding.TxIndex().Value() << "_" << binding.ResultIndex().Value(); + targetProto->MutableParamValue()->SetParamName(paramName); } else { YQL_ENSURE(false, "Unexpected VectorTopKTarget callable " << expr.Ref().Content()); } @@ -79,6 +87,13 @@ void SetVectorTopKLimit(NKqpProto::TKqpPhyValue* limitProto, const TExprNode::TP literal->MutableValue()->SetUint64(FromString(expr.Cast().Literal().Value())); } else if (expr.Maybe()) { limitProto->MutableParamValue()->SetParamName(expr.Cast().Name().StringValue()); + } else if (auto maybeBinding = expr.Maybe()) { + // TKqpTxResultBinding should have been replaced with TCoParameter by kqp_opt_build_txs, + // but handle it defensively by constructing the expected parameter name + auto binding = maybeBinding.Cast(); + TString paramName = TStringBuilder() << ParamNamePrefix + << "tx_result_binding_" << binding.TxIndex().Value() << "_" << binding.ResultIndex().Value(); + limitProto->MutableParamValue()->SetParamName(paramName); } else { YQL_ENSURE(false, "Unexpected VectorTopKLimit callable " << expr.Ref().Content()); } From f98406bcd64db7db394fb0d9c03d05ae5f806761 Mon Sep 17 00:00:00 2001 From: azevaykin Date: Thu, 27 Nov 2025 11:21:13 +0300 Subject: [PATCH 4/6] Only apply to full table scans --- .../kqp/opt/physical/kqp_opt_phy_limit.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/ydb/core/kqp/opt/physical/kqp_opt_phy_limit.cpp b/ydb/core/kqp/opt/physical/kqp_opt_phy_limit.cpp index 5e1bc3b44903..8946f8a599d4 100644 --- a/ydb/core/kqp/opt/physical/kqp_opt_phy_limit.cpp +++ b/ydb/core/kqp/opt/physical/kqp_opt_phy_limit.cpp @@ -404,6 +404,24 @@ TExprBase KqpApplyVectorTopKToReadTable(TExprBase node, TExprContext& ctx, const return node; } + // Only apply to full table scans (no WHERE clause filtering) + // When there's a WHERE clause, the Ranges() will not be TCoVoid + auto readTableRanges = input.Cast(); + if (!TCoVoid::Match(readTableRanges.Ranges().Raw())) { + return node; + } + + // If there's a FlatMap, check if it's filtering (WHERE on non-key columns) + // We allow FlatMaps that add computed columns (ORDER BY alias), but reject filtering FlatMaps + // Filtering FlatMaps use OptionalIf/ListIf, while projection FlatMaps use Just/SingleAsList + if (maybeFlatMap) { + auto flatMapBody = maybeFlatMap.Cast().Lambda().Body(); + // Check if the FlatMap body is conditional (filtering) + if (flatMapBody.Maybe() || flatMapBody.Maybe()) { + return node; + } + } + auto& tableDesc = kqpCtx.Tables->ExistingTable(kqpCtx.Cluster, GetReadTablePath(input, true)); // Only for datashard tables From 70896268face199aee3d3c45c5e247d4f6c9b067 Mon Sep 17 00:00:00 2001 From: azevaykin Date: Thu, 27 Nov 2025 15:27:10 +0300 Subject: [PATCH 5/6] Fix cost test --- ydb/core/kqp/ut/cost/kqp_cost_ut.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ydb/core/kqp/ut/cost/kqp_cost_ut.cpp b/ydb/core/kqp/ut/cost/kqp_cost_ut.cpp index cff110ef477d..7a91ec74b810 100644 --- a/ydb/core/kqp/ut/cost/kqp_cost_ut.cpp +++ b/ydb/core/kqp/ut/cost/kqp_cost_ut.cpp @@ -391,7 +391,7 @@ Y_UNIT_TEST_SUITE(KqpCost) { )", name.c_str())); auto result = session.ExecuteDataQuery(query, TTxControl::BeginTx().CommitTx(), GetDataQuerySettings()).ExtractValueSync(); - + UNIT_ASSERT_VALUES_EQUAL_C(result.GetStatus(), EStatus::SUCCESS, result.GetIssues().ToOneLineString()); Cerr << name << ":" << Endl; @@ -400,7 +400,7 @@ Y_UNIT_TEST_SUITE(KqpCost) { auto checkSelect = [&](auto query, TMap expectedReadsByTable) { auto result = session.ExecuteDataQuery(query, TTxControl::BeginTx().CommitTx(), GetDataQuerySettings()).ExtractValueSync(); - + UNIT_ASSERT_VALUES_EQUAL_C(result.GetStatus(), EStatus::SUCCESS, result.GetIssues().ToOneLineString()); auto stats = NYdb::TProtoAccessor::GetProto(*result.GetStats()); @@ -419,14 +419,14 @@ Y_UNIT_TEST_SUITE(KqpCost) { UNIT_ASSERT_VALUES_EQUAL_C(expectedReadsByTable, readsByTable, query); }; - { // 5x. SELECT VIEW PRIMARY KEY + { // 5x. SELECT VIEW PRIMARY KEY (with brute force vector search pushdown) // SELECT Key checkSelect(Q_(R"( SELECT Key FROM `/Root/Vectors` VIEW PRIMARY KEY ORDER BY Knn::CosineDistance(Embedding, "pQ\x02") LIMIT 10; )"), { - {"/Root/Vectors", 100} // full scan + {"/Root/Vectors", 10} // brute force vector search pushdown returns only LIMIT rows }); // SELECT Key, Value --- same stats @@ -435,7 +435,7 @@ Y_UNIT_TEST_SUITE(KqpCost) { ORDER BY Knn::CosineDistance(Embedding, "pQ\x02") LIMIT 10; )"), { - {"/Root/Vectors", 100} + {"/Root/Vectors", 10} }); } @@ -1120,7 +1120,7 @@ Y_UNIT_TEST_SUITE(KqpCost) { UNIT_ASSERT_VALUES_EQUAL(lhs.Reads, rhs.Reads); UNIT_ASSERT_VALUES_EQUAL(lhs.Deletes, rhs.Deletes); } - + Y_UNIT_TEST_QUAD(WriteRow, isSink, isOlap) { if (isOlap) { From 6979adb5dd8eb2f5eb45057f22c6755baa662fcf Mon Sep 17 00:00:00 2001 From: azevaykin Date: Fri, 28 Nov 2025 17:30:27 +0300 Subject: [PATCH 6/6] Refactor VectorTopK precompute handling to reuse common helpers for TDqPhyPrecompute and TKqpTxResultBinding collection --- ydb/core/kqp/opt/kqp_opt_build_txs.cpp | 88 +++++++++---------- .../kqp/query_compiler/kqp_query_compiler.cpp | 15 ---- 2 files changed, 40 insertions(+), 63 deletions(-) diff --git a/ydb/core/kqp/opt/kqp_opt_build_txs.cpp b/ydb/core/kqp/opt/kqp_opt_build_txs.cpp index f1855c38cb35..b269729366e6 100644 --- a/ydb/core/kqp/opt/kqp_opt_build_txs.cpp +++ b/ydb/core/kqp/opt/kqp_opt_build_txs.cpp @@ -378,20 +378,26 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase { return parameter; }; + // Helper to collect TKqpTxResultBinding nodes and replace them with parameters TNodeOnNodeOwnedMap sourceReplaceMap; + auto collectBindings = [&](const TExprNode::TPtr& root) { + VisitExpr(root, + [&](const TExprNode::TPtr& node) { + TExprBase expr(node); + if (auto binding = expr.Maybe()) { + sourceReplaceMap.emplace(node.Get(), makeParameterBinding(binding.Cast(), node->Pos()).Ptr()); + } + return true; + }); + }; + for (ui32 i = 0; i < stage.Inputs().Size(); ++i) { const auto& input = stage.Inputs().Item(i); const auto& inputArg = stage.Program().Args().Arg(i); + // Scan inputs that may contain TKqpTxResultBinding if (input.Maybe() || input.Maybe()) { - VisitExpr(input.Ptr(), - [&](const TExprNode::TPtr& node) { - TExprBase expr(node); - if (auto binding = expr.Maybe()) { - sourceReplaceMap.emplace(node.Get(), makeParameterBinding(binding.Cast(), node->Pos()).Ptr()); - } - return true; - }); + collectBindings(input.Ptr()); } auto maybeBinding = input.Maybe(); @@ -407,15 +413,8 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase { argsMap.emplace(inputArg.Raw(), makeParameterBinding(maybeBinding.Cast(), input.Pos()).Ptr()); } - // Also scan the program body for TKqpTxResultBinding (for VectorTopK precompute settings) - VisitExpr(stage.Program().Body().Ptr(), - [&](const TExprNode::TPtr& node) { - TExprBase expr(node); - if (auto binding = expr.Maybe()) { - sourceReplaceMap.emplace(node.Get(), makeParameterBinding(binding.Cast(), node->Pos()).Ptr()); - } - return true; - }); + // Scan program body for TKqpTxResultBinding (e.g. in TKqpReadTableRanges VectorTopK settings) + collectBindings(stage.Program().Body().Ptr()); auto inputs = Build(ctx, stage.Pos()) .Add(newInputs) @@ -473,45 +472,38 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase { TVector PrecomputeInputs(const TDqStage& stage) { TVector result; + + // Helper to collect precomputes from an expression tree + auto collectPrecomputes = [&result](const TExprNode::TPtr& root, bool checkConnections = false) { + VisitExpr(root, + [&](const TExprNode::TPtr& ptr) { + TExprBase node(ptr); + if (auto maybePrecompute = node.Maybe()) { + result.push_back(maybePrecompute.Cast()); + return false; + } + if (checkConnections) { + if (auto maybeConnection = node.Maybe()) { + YQL_ENSURE(false, "unexpected connection in source"); + } + } + return true; + }); + }; + + // Scan stage inputs for precomputes for (const auto& input : stage.Inputs()) { if (auto maybePrecompute = input.Maybe()) { result.push_back(maybePrecompute.Cast()); } else if (auto maybeSource = input.Maybe()) { - VisitExpr(maybeSource.Cast().Ptr(), - [&] (const TExprNode::TPtr& ptr) { - TExprBase node(ptr); - if (auto maybePrecompute = node.Maybe()) { - result.push_back(maybePrecompute.Cast()); - return false; - } - if (auto maybeConnection = node.Maybe()) { - YQL_ENSURE(false, "unexpected connection in source"); - } - return true; - }); + collectPrecomputes(maybeSource.Cast().Ptr(), /* checkConnections */ true); } else if (auto maybeStreamLookup = input.Maybe()) { - VisitExpr(maybeStreamLookup.Cast().Settings().Ptr(), - [&] (const TExprNode::TPtr& ptr) { - TExprBase node(ptr); - if (auto maybePrecompute = node.Maybe()) { - result.push_back(maybePrecompute.Cast()); - return false; - } - return true; - }); + collectPrecomputes(maybeStreamLookup.Cast().Settings().Ptr()); } } - // Also scan the program body for precomputes in read settings (for VectorTopK pushdown) - VisitExpr(stage.Program().Body().Ptr(), - [&] (const TExprNode::TPtr& ptr) { - TExprBase node(ptr); - if (auto maybePrecompute = node.Maybe()) { - result.push_back(maybePrecompute.Cast()); - return false; - } - return true; - }); + // Scan program body for precomputes (e.g. in TKqpReadTableRanges VectorTopK settings) + collectPrecomputes(stage.Program().Body().Ptr()); return result; } diff --git a/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp b/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp index e80444bcd879..39e2a4557e22 100644 --- a/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp +++ b/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include #include @@ -65,13 +64,6 @@ void SetVectorTopKTarget(NKqpProto::TKqpPhyValue* targetProto, const TExprNode:: literal->MutableValue()->SetText(TString(expr.Cast().Literal().Value())); } else if (expr.Maybe()) { targetProto->MutableParamValue()->SetParamName(expr.Cast().Name().StringValue()); - } else if (auto maybeBinding = expr.Maybe()) { - // TKqpTxResultBinding should have been replaced with TCoParameter by kqp_opt_build_txs, - // but handle it defensively by constructing the expected parameter name - auto binding = maybeBinding.Cast(); - TString paramName = TStringBuilder() << ParamNamePrefix - << "tx_result_binding_" << binding.TxIndex().Value() << "_" << binding.ResultIndex().Value(); - targetProto->MutableParamValue()->SetParamName(paramName); } else { YQL_ENSURE(false, "Unexpected VectorTopKTarget callable " << expr.Ref().Content()); } @@ -87,13 +79,6 @@ void SetVectorTopKLimit(NKqpProto::TKqpPhyValue* limitProto, const TExprNode::TP literal->MutableValue()->SetUint64(FromString(expr.Cast().Literal().Value())); } else if (expr.Maybe()) { limitProto->MutableParamValue()->SetParamName(expr.Cast().Name().StringValue()); - } else if (auto maybeBinding = expr.Maybe()) { - // TKqpTxResultBinding should have been replaced with TCoParameter by kqp_opt_build_txs, - // but handle it defensively by constructing the expected parameter name - auto binding = maybeBinding.Cast(); - TString paramName = TStringBuilder() << ParamNamePrefix - << "tx_result_binding_" << binding.TxIndex().Value() << "_" << binding.ResultIndex().Value(); - limitProto->MutableParamValue()->SetParamName(paramName); } else { YQL_ENSURE(false, "Unexpected VectorTopKLimit callable " << expr.Ref().Content()); }