Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions ydb/core/base/kmeans_clusters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace {
return Ydb::Table::VectorIndexSettings::METRIC_UNSPECIFIED;
}
};

Ydb::Table::VectorIndexSettings_Metric ParseSimilarity(const TString& similarity_, TString& error) {
const TString similarity = to_lower(similarity_);
if (similarity == "cosine")
Expand All @@ -60,7 +60,7 @@ namespace {
return Ydb::Table::VectorIndexSettings::METRIC_UNSPECIFIED;
}
};

Ydb::Table::VectorIndexSettings_VectorType ParseVectorType(const TString& vectorType_, TString& error) {
const TString vectorType = to_lower(vectorType_);
if (vectorType == "float")
Expand Down Expand Up @@ -462,6 +462,58 @@ std::unique_ptr<IClusters> CreateClusters(const Ydb::Table::VectorIndexSettings&
}
}

std::unique_ptr<IClusters> CreateClustersAutoDetect(Ydb::Table::VectorIndexSettings settings, const TStringBuf& targetVector, ui32 maxRounds, TString& error) {
if (targetVector.empty()) {
error = "Target vector is empty";
return nullptr;
}

// Auto-detect vector type and dimension from target vector
const ui8 formatByte = static_cast<ui8>(targetVector.back());

switch (formatByte) {
case EFormat::FloatVector:
settings.set_vector_type(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_FLOAT);
if (targetVector.size() < HeaderLen + sizeof(float)) {
error = "Target vector too short for float type";
return nullptr;
}
settings.set_vector_dimension((targetVector.size() - HeaderLen) / sizeof(float));
break;
case EFormat::Uint8Vector:
settings.set_vector_type(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_UINT8);
if (targetVector.size() < HeaderLen + sizeof(ui8)) {
error = "Target vector too short for uint8 type";
return nullptr;
}
settings.set_vector_dimension((targetVector.size() - HeaderLen) / sizeof(ui8));
break;
case EFormat::Int8Vector:
settings.set_vector_type(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_INT8);
if (targetVector.size() < HeaderLen + sizeof(i8)) {
error = "Target vector too short for int8 type";
return nullptr;
}
settings.set_vector_dimension((targetVector.size() - HeaderLen) / sizeof(i8));
break;
case EFormat::BitVector:
settings.set_vector_type(Ydb::Table::VectorIndexSettings::VECTOR_TYPE_BIT);
if (targetVector.size() < 2 + HeaderLen) {
error = "Target vector too short for bit type";
return nullptr;
}
// For bit vectors: size = HeaderLen + ceil(dim/8) + 1 (padding info) + 1 (format byte)
// padding = targetVector[size - 2], actual bits = (targetVector.size() - HeaderLen - 1) * 8 - padding
settings.set_vector_dimension((targetVector.size() - HeaderLen - 1) * 8 - static_cast<ui8>(targetVector[targetVector.size() - 2]));
break;
default:
error = TStringBuilder() << "Unknown vector format byte: " << static_cast<int>(formatByte);
return nullptr;
}

return CreateClusters(settings, maxRounds, error);
}

bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& error) {
error = "";

Expand All @@ -474,16 +526,16 @@ bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& e
return false;
}

if (!ValidateSettingInRange("levels",
settings.has_levels() ? std::optional<ui64>(settings.levels()) : std::nullopt,
if (!ValidateSettingInRange("levels",
settings.has_levels() ? std::optional<ui64>(settings.levels()) : std::nullopt,
MinLevels, MaxLevels,
error))
{
return false;
}

if (!ValidateSettingInRange("clusters",
settings.has_clusters() ? std::optional<ui64>(settings.clusters()) : std::nullopt,
if (!ValidateSettingInRange("clusters",
settings.has_clusters() ? std::optional<ui64>(settings.clusters()) : std::nullopt,
MinClusters, MaxClusters,
error))
{
Expand All @@ -500,7 +552,7 @@ bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& e
}

if (settings.settings().vector_dimension() * settings.clusters() > MaxVectorDimensionMultiplyClusters) {
error = TStringBuilder() << "Invalid vector_dimension*clusters: " << settings.settings().vector_dimension() << "*" << settings.clusters()
error = TStringBuilder() << "Invalid vector_dimension*clusters: " << settings.settings().vector_dimension() << "*" << settings.clusters()
<< " should be less than " << MaxVectorDimensionMultiplyClusters;
return false;
}
Expand Down Expand Up @@ -528,8 +580,8 @@ bool ValidateSettings(const Ydb::Table::VectorIndexSettings& settings, TString&
return false;
}

if (!ValidateSettingInRange("vector_dimension",
settings.has_vector_dimension() ? std::optional<ui64>(settings.vector_dimension()) : std::nullopt,
if (!ValidateSettingInRange("vector_dimension",
settings.has_vector_dimension() ? std::optional<ui64>(settings.vector_dimension()) : std::nullopt,
MinVectorDimension, MaxVectorDimension,
error))
{
Expand Down
3 changes: 3 additions & 0 deletions ydb/core/base/kmeans_clusters.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class IClusters {

std::unique_ptr<IClusters> CreateClusters(const Ydb::Table::VectorIndexSettings& settings, ui32 maxRounds, TString& error);

// Auto-detect vector type and dimension from target vector when settings have dimension=0
std::unique_ptr<IClusters> CreateClustersAutoDetect(Ydb::Table::VectorIndexSettings settings, const TStringBuf& targetVector, ui32 maxRounds, TString& error);

bool ValidateSettings(const Ydb::Table::VectorIndexSettings& settings, TString& error);
bool ValidateSettings(const Ydb::Table::KMeansTreeSettings& settings, TString& error);
bool FillSetting(Ydb::Table::KMeansTreeSettings& settings, const TString& name, const TString& value, TString& error);
Expand Down
45 changes: 44 additions & 1 deletion ydb/core/kqp/common/kqp_yql.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,18 @@ TKqpReadTableSettings ParseInternal(const TCoNameValueTupleList& node) {
for(const auto& kv: lv) {
settings.IndexSelectionInfo.emplace(kv.Name().Value(), kv.Value().Cast<TCoAtom>().Value());
}

} else if (name == TKqpReadTableSettings::VectorTopKColumnSettingName) {
YQL_ENSURE(tuple.Value().Maybe<TCoAtom>());
settings.VectorTopKColumn = tuple.Value().Cast<TCoAtom>().Value();
} else if (name == TKqpReadTableSettings::VectorTopKMetricSettingName) {
YQL_ENSURE(tuple.Value().Maybe<TCoAtom>());
settings.VectorTopKMetric = tuple.Value().Cast<TCoAtom>().Value();
} else if (name == TKqpReadTableSettings::VectorTopKTargetSettingName) {
YQL_ENSURE(tuple.Value().IsValid());
settings.VectorTopKTarget = tuple.Value().Cast().Ptr();
} else if (name == TKqpReadTableSettings::VectorTopKLimitSettingName) {
YQL_ENSURE(tuple.Value().IsValid());
settings.VectorTopKLimit = tuple.Value().Cast().Ptr();
} else {
YQL_ENSURE(false, "Unknown KqpReadTable setting name '" << name << "'");
}
Expand Down Expand Up @@ -317,6 +328,38 @@ NNodes::TCoNameValueTupleList TKqpReadTableSettings::BuildNode(TExprContext& ctx
.Done());
}

if (VectorTopKColumn) {
settings.emplace_back(
Build<TCoNameValueTuple>(ctx, pos)
.Name().Build(VectorTopKColumnSettingName)
.Value<TCoAtom>().Build(VectorTopKColumn)
.Done());
}

if (VectorTopKMetric) {
settings.emplace_back(
Build<TCoNameValueTuple>(ctx, pos)
.Name().Build(VectorTopKMetricSettingName)
.Value<TCoAtom>().Build(VectorTopKMetric)
.Done());
}

if (VectorTopKTarget) {
settings.emplace_back(
Build<TCoNameValueTuple>(ctx, pos)
.Name().Build(VectorTopKTargetSettingName)
.Value(VectorTopKTarget)
.Done());
}

if (VectorTopKLimit) {
settings.emplace_back(
Build<TCoNameValueTuple>(ctx, pos)
.Name().Build(VectorTopKLimitSettingName)
.Value(VectorTopKLimit)
.Done());
}

return Build<TCoNameValueTupleList>(ctx, pos)
.Add(settings)
.Done();
Expand Down
10 changes: 10 additions & 0 deletions ydb/core/kqp/common/kqp_yql.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ struct TKqpReadTableSettings: public TSortingOperator<ERequestSorting::NONE> {
static constexpr TStringBuf TabletIdName = "TabletId";
static constexpr TStringBuf PointPrefixLenSettingName = "PointPrefixLen";
static constexpr TStringBuf IndexSelectionDebugInfoSettingName = "IndexSelectionDebugInfo";
static constexpr TStringBuf VectorTopKColumnSettingName = "VectorTopKColumn";
static constexpr TStringBuf VectorTopKMetricSettingName = "VectorTopKMetric";
static constexpr TStringBuf VectorTopKTargetSettingName = "VectorTopKTarget";
static constexpr TStringBuf VectorTopKLimitSettingName = "VectorTopKLimit";

TVector<TString> SkipNullKeys;
TExprNode::TPtr ItemsLimit;
Expand All @@ -139,6 +143,12 @@ struct TKqpReadTableSettings: public TSortingOperator<ERequestSorting::NONE> {
ui64 PointPrefixLen = 0;
THashMap<TString, TString> IndexSelectionInfo;

// Vector top-K pushdown settings for brute force vector search
TString VectorTopKColumn;
TString VectorTopKMetric;
TExprNode::TPtr VectorTopKTarget;
TExprNode::TPtr VectorTopKLimit;

void AddSkipNullKey(const TString& key);
void SetItemsLimit(const TExprNode::TPtr& expr) { ItemsLimit = expr; }

Expand Down
10 changes: 10 additions & 0 deletions ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2670,6 +2670,16 @@ TMaybe<size_t> TKqpTasksGraph::BuildScanTasksFromSource(TStageInfo& stageInfo, b
settings->SetItemsLimit(itemsLimit);
}

if (source.HasVectorTopK()) {
const auto& in = source.GetVectorTopK();
auto& out = *settings->MutableVectorTopK();
out.SetColumn(in.GetColumn());
*out.MutableSettings() = in.GetSettings();
auto target = ExtractPhyValue(stageInfo, in.GetTargetVector(), TxAlloc->HolderFactory, TxAlloc->TypeEnv, NUdf::TUnboxedValuePod());
out.SetTargetVector(TString(target.AsStringRef()));
out.SetLimit((ui32)ExtractPhyValue(stageInfo, in.GetLimit(), TxAlloc->HolderFactory, TxAlloc->TypeEnv, NUdf::TUnboxedValuePod()).Get<ui64>());
}

auto& lockTxId = GetMeta().LockTxId;
if (lockTxId) {
settings->SetLockTxId(*lockTxId);
Expand Down
74 changes: 44 additions & 30 deletions ydb/core/kqp/opt/kqp_opt_build_txs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,20 +378,26 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase {
return parameter;
};

// Helper to collect TKqpTxResultBinding nodes and replace them with parameters
TNodeOnNodeOwnedMap sourceReplaceMap;
auto collectBindings = [&](const TExprNode::TPtr& root) {
VisitExpr(root,
[&](const TExprNode::TPtr& node) {
TExprBase expr(node);
if (auto binding = expr.Maybe<TKqpTxResultBinding>()) {
sourceReplaceMap.emplace(node.Get(), makeParameterBinding(binding.Cast(), node->Pos()).Ptr());
}
return true;
});
};

for (ui32 i = 0; i < stage.Inputs().Size(); ++i) {
const auto& input = stage.Inputs().Item(i);
const auto& inputArg = stage.Program().Args().Arg(i);

// Scan inputs that may contain TKqpTxResultBinding
if (input.Maybe<TDqSource>() || input.Maybe<TKqpCnStreamLookup>()) {
VisitExpr(input.Ptr(),
[&](const TExprNode::TPtr& node) {
TExprBase expr(node);
if (auto binding = expr.Maybe<TKqpTxResultBinding>()) {
sourceReplaceMap.emplace(node.Get(), makeParameterBinding(binding.Cast(), node->Pos()).Ptr());
}
return true;
});
collectBindings(input.Ptr());
}

auto maybeBinding = input.Maybe<TKqpTxResultBinding>();
Expand All @@ -407,6 +413,9 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase {
argsMap.emplace(inputArg.Raw(), makeParameterBinding(maybeBinding.Cast(), input.Pos()).Ptr());
}

// Scan program body for TKqpTxResultBinding (e.g. in TKqpReadTableRanges VectorTopK settings)
collectBindings(stage.Program().Body().Ptr());

auto inputs = Build<TExprList>(ctx, stage.Pos())
.Add(newInputs)
.Done();
Expand All @@ -415,7 +424,7 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase {
.Inputs(ctx.ReplaceNodes(ctx.ReplaceNodes(inputs.Ptr(), stagesMap), sourceReplaceMap))
.Program()
.Args(newArgs)
.Body(ctx.ReplaceNodes(stage.Program().Body().Ptr(), argsMap))
.Body(ctx.ReplaceNodes(ctx.ReplaceNodes(stage.Program().Body().Ptr(), argsMap), sourceReplaceMap))
.Build()
.Settings(stage.Settings())
.Outputs(stage.Outputs())
Expand Down Expand Up @@ -463,34 +472,39 @@ class TKqpBuildTxTransformer : public TSyncTransformerBase {

TVector<TDqPhyPrecompute> PrecomputeInputs(const TDqStage& stage) {
TVector<TDqPhyPrecompute> result;

// Helper to collect precomputes from an expression tree
auto collectPrecomputes = [&result](const TExprNode::TPtr& root, bool checkConnections = false) {
VisitExpr(root,
[&](const TExprNode::TPtr& ptr) {
TExprBase node(ptr);
if (auto maybePrecompute = node.Maybe<TDqPhyPrecompute>()) {
result.push_back(maybePrecompute.Cast());
return false;
}
if (checkConnections) {
if (auto maybeConnection = node.Maybe<TDqConnection>()) {
YQL_ENSURE(false, "unexpected connection in source");
}
}
return true;
});
};

// Scan stage inputs for precomputes
for (const auto& input : stage.Inputs()) {
if (auto maybePrecompute = input.Maybe<TDqPhyPrecompute>()) {
result.push_back(maybePrecompute.Cast());
} else if (auto maybeSource = input.Maybe<TDqSource>()) {
VisitExpr(maybeSource.Cast().Ptr(),
[&] (const TExprNode::TPtr& ptr) {
TExprBase node(ptr);
if (auto maybePrecompute = node.Maybe<TDqPhyPrecompute>()) {
result.push_back(maybePrecompute.Cast());
return false;
}
if (auto maybeConnection = node.Maybe<TDqConnection>()) {
YQL_ENSURE(false, "unexpected connection in source");
}
return true;
});
collectPrecomputes(maybeSource.Cast().Ptr(), /* checkConnections */ true);
} else if (auto maybeStreamLookup = input.Maybe<TKqpCnStreamLookup>()) {
VisitExpr(maybeStreamLookup.Cast().Settings().Ptr(),
[&] (const TExprNode::TPtr& ptr) {
TExprBase node(ptr);
if (auto maybePrecompute = node.Maybe<TDqPhyPrecompute>()) {
result.push_back(maybePrecompute.Cast());
return false;
}
return true;
});
collectPrecomputes(maybeStreamLookup.Cast().Settings().Ptr());
}
}

// Scan program body for precomputes (e.g. in TKqpReadTableRanges VectorTopK settings)
collectPrecomputes(stage.Program().Body().Ptr());

return result;
}

Expand Down
9 changes: 8 additions & 1 deletion ydb/core/kqp/opt/physical/kqp_opt_phy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase {
AddHandler(0, IsSort, HNDL(RemoveRedundantSortOverReadTable));
AddHandler(0, &TCoTake::Match, HNDL(ApplyLimitToReadTable));
AddHandler(0, &TCoTopSort::Match, HNDL(ApplyLimitToOlapReadTable));
AddHandler(0, &TCoTopSort::Match, HNDL(ApplyVectorTopKToReadTable));
AddHandler(0, &TCoFlatMap::Match, HNDL(PushOlapFilter));
AddHandler(0, &TCoFlatMap::Match, HNDL(PushOlapProjections));
AddHandler(0, &TCoAggregateCombine::Match, HNDL(PushAggregateCombineToStage));
Expand Down Expand Up @@ -206,7 +207,7 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase {
else {
TExprBase output = KqpBuildStreamIdxLookupJoinStagesKeepSorted(node, ctx, TypesCtx, true);
DumpAppliedRule("BuildStreamIdxLookupJoinStagesKeepSorted", node.Ptr(), output.Ptr(), ctx);
return output;
return output;
}
}

Expand Down Expand Up @@ -260,6 +261,12 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase {
return output;
}

TMaybeNode<TExprBase> ApplyVectorTopKToReadTable(TExprBase node, TExprContext& ctx) {
TExprBase output = KqpApplyVectorTopKToReadTable(node, ctx, KqpCtx);
DumpAppliedRule("ApplyVectorTopKToReadTable", node.Ptr(), output.Ptr(), ctx);
return output;
}

TMaybeNode<TExprBase> PushOlapFilter(TExprBase node, TExprContext& ctx) {
TExprBase output = KqpPushOlapFilter(node, ctx, KqpCtx, TypesCtx, *TypeAnnTransformer.Get());
DumpAppliedRule("PushOlapFilter", node.Ptr(), output.Ptr(), ctx);
Expand Down
2 changes: 2 additions & 0 deletions ydb/core/kqp/opt/physical/kqp_opt_phy_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ NYql::NNodes::TMaybeNode<NYql::NNodes::TDqPhyPrecompute> BuildLookupKeysPrecompu
NYql::NNodes::TCoAtomList BuildColumnsList(const THashSet<TStringBuf>& columns, NYql::TPositionHandle pos,
NYql::TExprContext& ctx);

NYql::NNodes::TExprBase KqpPrecomputeParameter(NYql::NNodes::TExprBase param, NYql::TExprContext& ctx);

NYql::NNodes::TCoAtomList BuildColumnsList(const TVector<TStringBuf>& columns, NYql::TPositionHandle pos,
NYql::TExprContext& ctx);

Expand Down
Loading
Loading