diff --git a/ydb/core/kqp/common/kqp_tx.cpp b/ydb/core/kqp/common/kqp_tx.cpp index 5fb07bf759eb..be0671f3f332 100644 --- a/ydb/core/kqp/common/kqp_tx.cpp +++ b/ydb/core/kqp/common/kqp_tx.cpp @@ -392,6 +392,7 @@ bool HasUncommittedChangesRead(THashSet& modifiedTables, cons case NKqpProto::TKqpPhyConnection::kResult: case NKqpProto::TKqpPhyConnection::kValue: case NKqpProto::TKqpPhyConnection::kMerge: + case NKqpProto::TKqpPhyConnection::kDqSourceStreamLookup: case NKqpProto::TKqpPhyConnection::TYPE_NOT_SET: break; } diff --git a/ydb/core/kqp/compute_actor/kqp_compute_actor.cpp b/ydb/core/kqp/compute_actor/kqp_compute_actor.cpp index 361f8094d979..08611b7b31fb 100644 --- a/ydb/core/kqp/compute_actor/kqp_compute_actor.cpp +++ b/ydb/core/kqp/compute_actor/kqp_compute_actor.cpp @@ -6,19 +6,20 @@ #include #include #include -#include #include #include #include #include -#include +#include #include +#include +#include +#include #include -#include +#include #include #include -#include -#include +#include namespace NKikimr { namespace NMiniKQL { @@ -90,6 +91,7 @@ NYql::NDq::IDqAsyncIoFactory::TPtr CreateKqpAsyncIoFactory( RegisterKqpWriteActor(*factory, counters); RegisterSequencerActorFactory(*factory, counters); RegisterKqpVectorResolveActor(*factory, counters); + NYql::NDq::RegisterDqInputTransformLookupActorFactory(*factory); if (federatedQuerySetup) { auto s3HttpRetryPolicy = NYql::GetHTTPDefaultRetryPolicy(NYql::THttpRetryPolicyOptions{.RetriedCurlCodes = NYql::FqRetriedCurlCodes()}); diff --git a/ydb/core/kqp/compute_actor/ya.make b/ydb/core/kqp/compute_actor/ya.make index 95544a251b24..bf56cc428a19 100644 --- a/ydb/core/kqp/compute_actor/ya.make +++ b/ydb/core/kqp/compute_actor/ya.make @@ -25,12 +25,13 @@ PEERDIR( ydb/library/formats/arrow/protos ydb/library/formats/arrow/common ydb/library/yql/dq/actors/compute + ydb/library/yql/dq/actors/input_transforms + ydb/library/yql/dq/comp_nodes ydb/library/yql/providers/generic/actors ydb/library/yql/providers/pq/async_io ydb/library/yql/providers/s3/actors_factory ydb/library/yql/providers/solomon/actors yql/essentials/public/issue - ydb/library/yql/dq/comp_nodes ) GENERATE_ENUM_SERIALIZATION(kqp_compute_state.h) diff --git a/ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp b/ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp index 8d69bec1df6c..6e71ec85987d 100644 --- a/ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp +++ b/ydb/core/kqp/executer_actor/kqp_tasks_graph.cpp @@ -518,6 +518,47 @@ void TKqpTasksGraph::BuildVectorResolveChannels(const TStageInfo& stageInfo, ui3 inputStageInfo, outputIndex, enableSpilling, logFunc); } +void TKqpTasksGraph::BuildDqSourceStreamLookupChannels(const TStageInfo& stageInfo, ui32 inputIndex, const TStageInfo& inputStageInfo, + ui32 outputIndex, const NKqpProto::TKqpPhyCnDqSourceStreamLookup& dqSourceStreamLookup, const TChannelLogFunc& logFunc) { + YQL_ENSURE(stageInfo.Tasks.size() == 1); + + auto* settings = GetMeta().Allocate(); + settings->SetLeftLabel(dqSourceStreamLookup.GetLeftLabel()); + settings->SetRightLabel(dqSourceStreamLookup.GetRightLabel()); + settings->SetJoinType(dqSourceStreamLookup.GetJoinType()); + settings->SetNarrowInputRowType(dqSourceStreamLookup.GetConnectionInputRowType()); + settings->SetNarrowOutputRowType(dqSourceStreamLookup.GetConnectionOutputRowType()); + settings->SetCacheLimit(dqSourceStreamLookup.GetCacheLimit()); + settings->SetCacheTtlSeconds(dqSourceStreamLookup.GetCacheTtlSeconds()); + settings->SetMaxDelayedRows(dqSourceStreamLookup.GetMaxDelayedRows()); + settings->SetIsMultiget(dqSourceStreamLookup.GetIsMultiGet()); + + const auto& leftJointKeys = dqSourceStreamLookup.GetLeftJoinKeyNames(); + settings->MutableLeftJoinKeyNames()->Assign(leftJointKeys.begin(), leftJointKeys.end()); + + const auto& rightJointKeys = dqSourceStreamLookup.GetRightJoinKeyNames(); + settings->MutableRightJoinKeyNames()->Assign(rightJointKeys.begin(), rightJointKeys.end()); + + auto& streamLookupSource = *settings->MutableRightSource(); + streamLookupSource.SetSerializedRowType(dqSourceStreamLookup.GetLookupRowType()); + const auto& compiledSource = dqSourceStreamLookup.GetLookupSource(); + streamLookupSource.SetProviderName(compiledSource.GetType()); + *streamLookupSource.MutableLookupSource() = compiledSource.GetSettings(); + + TTransform dqSourceStreamLookupTransform = { + .Type = "StreamLookupInputTransform", + .InputType = dqSourceStreamLookup.GetInputStageRowType(), + .OutputType = dqSourceStreamLookup.GetOutputStageRowType(), + }; + YQL_ENSURE(dqSourceStreamLookupTransform.Settings.PackFrom(*settings)); + + for (const auto taskId : stageInfo.Tasks) { + GetTask(taskId).Inputs[inputIndex].Transform = dqSourceStreamLookupTransform; + } + + BuildUnionAllChannels(*this, stageInfo, inputIndex, inputStageInfo, outputIndex, /* enableSpilling */ false, logFunc); +} + void TKqpTasksGraph::BuildKqpStageChannels(TStageInfo& stageInfo, ui64 txId, bool enableSpilling, bool enableShuffleElimination) { auto& stage = stageInfo.Meta.GetStage(stageInfo.Id); @@ -710,6 +751,12 @@ void TKqpTasksGraph::BuildKqpStageChannels(TStageInfo& stageInfo, ui64 txId, boo break; } + case NKqpProto::TKqpPhyConnection::kDqSourceStreamLookup: { + BuildDqSourceStreamLookupChannels(stageInfo, inputIdx, inputStageInfo, outputIdx, + input.GetDqSourceStreamLookup(), log); + break; + } + default: YQL_ENSURE(false, "Unexpected stage input type: " << (ui32)input.GetTypeCase()); } @@ -1370,6 +1417,8 @@ void TKqpTasksGraph::FillInputDesc(NYql::NDqProto::TTaskInput& inputDesc, const } transformProto->MutableSettings()->PackFrom(*input.Meta.VectorResolveSettings); + } else { + *transformProto->MutableSettings() = input.Transform->Settings; } } } @@ -1725,6 +1774,7 @@ bool TKqpTasksGraph::BuildComputeTasks(TStageInfo& stageInfo, const ui32 nodesCo case NKqpProto::TKqpPhyConnection::kMap: case NKqpProto::TKqpPhyConnection::kParallelUnionAll: case NKqpProto::TKqpPhyConnection::kVectorResolve: + case NKqpProto::TKqpPhyConnection::kDqSourceStreamLookup: break; default: YQL_ENSURE(false, "Unexpected connection type: " << (ui32)input.GetTypeCase() << Endl); diff --git a/ydb/core/kqp/executer_actor/kqp_tasks_graph.h b/ydb/core/kqp/executer_actor/kqp_tasks_graph.h index b559526ec559..7ae57f5df01f 100644 --- a/ydb/core/kqp/executer_actor/kqp_tasks_graph.h +++ b/ydb/core/kqp/executer_actor/kqp_tasks_graph.h @@ -422,6 +422,8 @@ class TKqpTasksGraph : public NYql::NDq::TDqTasksGraphPqGateway) { InitPqProvider(); } + TypesCtx->StreamLookupJoin = true; } InitPgProvider(); diff --git a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp index cb44be02dfab..22fb922a0dcd 100644 --- a/ydb/core/kqp/opt/logical/kqp_opt_log.cpp +++ b/ydb/core/kqp/opt/logical/kqp_opt_log.cpp @@ -41,6 +41,7 @@ class TKqpLogicalOptTransformer : public TOptimizeTransformerBase { AddHandler(0, &TCoTake::Match, HNDL(RewriteTakeSortToTopSort)); AddHandler(0, &TCoFlatMap::Match, HNDL(RewriteSqlInToEquiJoin)); AddHandler(0, &TCoFlatMap::Match, HNDL(RewriteSqlInCompactToJoin)); + AddHandler(0, &TCoEquiJoin::Match, HNDL(RewriteStreamEquiJoinWithLookup)); AddHandler(0, &TCoEquiJoin::Match, HNDL(OptimizeEquiJoinWithCosts)); AddHandler(0, &TCoEquiJoin::Match, HNDL(RewriteEquiJoin)); AddHandler(0, &TDqJoin::Match, HNDL(JoinToIndexLookup)); @@ -167,6 +168,12 @@ class TKqpLogicalOptTransformer : public TOptimizeTransformerBase { return output; } + TMaybeNode RewriteStreamEquiJoinWithLookup(TExprBase node, TExprContext& ctx) { + TExprBase output = DqRewriteStreamEquiJoinWithLookup(node, ctx, TypesCtx); + DumpAppliedRule("KqpRewriteStreamEquiJoinWithLookup", node.Ptr(), output.Ptr(), ctx); + return output; + } + TMaybeNode OptimizeEquiJoinWithCosts(TExprBase node, TExprContext& ctx) { auto maxDPhypDPTableSize = Config->MaxDPHypDPTableSize.Get().GetOrElse(TDqSettings::TDefault::MaxDPHypDPTableSize); auto optLevel = Config->CostBasedOptimizationLevel.Get().GetOrElse(Config->DefaultCostBasedOptimizationLevel); diff --git a/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp b/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp index a130f5eb9ef6..32b99a6f1a36 100644 --- a/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp +++ b/ydb/core/kqp/opt/physical/kqp_opt_phy.cpp @@ -71,6 +71,7 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase { AddHandler(0, &TCoExtendBase::Match, HNDL(BuildExtendStage)); AddHandler(0, &TDqJoin::Match, HNDL(RewriteRightJoinToLeft)); AddHandler(0, &TDqJoin::Match, HNDL(RewriteLeftPureJoin)); + AddHandler(0, &TDqJoin::Match, HNDL(RewriteStreamLookupJoin)); AddHandler(0, &TDqJoin::Match, HNDL(BuildJoin)); AddHandler(0, &TDqPrecompute::Match, HNDL(BuildPrecompute)); AddHandler(0, &TCoLMap::Match, HNDL(PushLMapToStage)); @@ -507,6 +508,14 @@ class TKqpPhysicalOptTransformer : public TOptimizeTransformerBase { return output; } + TMaybeNode RewriteStreamLookupJoin(TExprBase node, TExprContext& ctx) { + TMaybeNode output = DqRewriteStreamLookupJoin(node, ctx); + if (output) { + DumpAppliedRule("RewriteStreamLookupJoin", node.Ptr(), output.Cast().Ptr(), ctx); + } + return output; + } + template TMaybeNode BuildJoin(TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, const TGetParents& getParents) diff --git a/ydb/core/kqp/provider/yql_kikimr_datasource.cpp b/ydb/core/kqp/provider/yql_kikimr_datasource.cpp index 87e92b3e3851..f24e34923a9f 100644 --- a/ydb/core/kqp/provider/yql_kikimr_datasource.cpp +++ b/ydb/core/kqp/provider/yql_kikimr_datasource.cpp @@ -666,7 +666,8 @@ class TKikimrDataSource : public TDataProviderBase { node.IsCallable(TDqReadWrap::CallableName()) || node.IsCallable(TDqReadWideWrap::CallableName()) || node.IsCallable(TDqReadBlockWideWrap::CallableName()) || - node.IsCallable(TDqSource::CallableName()) + node.IsCallable(TDqSource::CallableName()) || + node.IsCallable(TDqLookupSourceWrap::CallableName()) ) ) { diff --git a/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp b/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp index 76b7c3a63f43..84fc1a5d2c3d 100644 --- a/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp +++ b/ydb/core/kqp/query_compiler/kqp_query_compiler.cpp @@ -1,5 +1,6 @@ #include "kqp_query_compiler.h" +#include #include #include #include @@ -8,24 +9,23 @@ #include #include #include -#include - -#include #include +#include #include - -#include #include #include #include -#include -#include -#include -#include +#include #include #include + +#include #include #include +#include +#include +#include +#include namespace NKikimr { @@ -589,6 +589,14 @@ TIssues ApplyOverridePlannerSettings(const TString& overridePlannerJson, NKqpPro return issues; } +TStringBuf RemoveJoinAliases(TStringBuf keyName) { + if (const auto idx = keyName.find_last_of('.'); idx != TString::npos) { + return keyName.substr(idx + 1); + } + + return keyName; +} + class TKqpQueryCompiler : public IKqpQueryCompiler { public: TKqpQueryCompiler(const TString& cluster, const TIntrusivePtr tablesData, @@ -795,7 +803,7 @@ class TKqpQueryCompiler : public IKqpQueryCompiler { auto connection = input.Cast(); auto& protoInput = *stageProto.AddInputs(); - FillConnection(connection, stagesMap, protoInput, ctx, tablesMap, physicalStageByID); + FillConnection(connection, stagesMap, protoInput, ctx, tablesMap, physicalStageByID, &stage, inputIndex); protoInput.SetInputIndex(inputIndex); } } @@ -1017,7 +1025,7 @@ class TKqpQueryCompiler : public IKqpQueryCompiler { auto& resultProto = *txProto.AddResults(); auto& connectionProto = *resultProto.MutableConnection(); - FillConnection(connection, stagesMap, connectionProto, ctx, tablesMap, physicalStageByID); + FillConnection(connection, stagesMap, connectionProto, ctx, tablesMap, physicalStageByID, nullptr, 0); const TTypeAnnotationNode* itemType = nullptr; switch (connectionProto.GetTypeCase()) { @@ -1452,7 +1460,9 @@ class TKqpQueryCompiler : public IKqpQueryCompiler { NKqpProto::TKqpPhyConnection& connectionProto, TExprContext& ctx, THashMap>& tablesMap, - THashMap& physicalStageByID + THashMap& physicalStageByID, + const TDqPhyStage* stage, + ui32 inputIndex ) { auto inputStageIndex = stagesMap.FindPtr(connection.Output().Stage().Ref().UniqueId()); YQL_ENSURE(inputStageIndex, "stage #" << connection.Output().Stage().Ref().UniqueId() << " not found in stages map: " @@ -1819,6 +1829,59 @@ class TKqpQueryCompiler : public IKqpQueryCompiler { return; } + if (auto maybeDqSourceStreamLookup = connection.Maybe()) { + const auto streamLookup = maybeDqSourceStreamLookup.Cast(); + const auto lookupSourceWrap = streamLookup.RightInput().Cast(); + + const TStringBuf dataSourceCategory = lookupSourceWrap.DataSource().Category(); + const auto provider = TypesCtx.DataSourceMap.find(dataSourceCategory); + YQL_ENSURE(provider != TypesCtx.DataSourceMap.end(), "Unsupported data source category: \"" << dataSourceCategory << "\""); + NYql::IDqIntegration* dqIntegration = provider->second->GetDqIntegration(); + YQL_ENSURE(dqIntegration, "Unsupported dq source for provider: \"" << dataSourceCategory << "\""); + + auto& dqSourceLookupCn = *connectionProto.MutableDqSourceStreamLookup(); + auto& lookupSource = *dqSourceLookupCn.MutableLookupSource(); + auto& lookupSourceSettings = *lookupSource.MutableSettings(); + auto& lookupSourceType = *lookupSource.MutableType(); + dqIntegration->FillLookupSourceSettings(lookupSourceWrap.Ref(), lookupSourceSettings, lookupSourceType); + YQL_ENSURE(!lookupSourceSettings.type_url().empty(), "Data source provider \"" << dataSourceCategory << "\" did't fill dq source settings for its dq source node"); + YQL_ENSURE(lookupSourceType, "Data source provider \"" << dataSourceCategory << "\" did't fill dq source settings type for its dq source node"); + + const auto& streamLookupOutput = streamLookup.Output(); + const auto connectionInputRowType = GetSeqItemType(streamLookupOutput.Ref().GetTypeAnn()); + YQL_ENSURE(connectionInputRowType->GetKind() == ETypeAnnotationKind::Struct); + const auto connectionOutputRowType = GetSeqItemType(streamLookup.Ref().GetTypeAnn()); + YQL_ENSURE(connectionOutputRowType->GetKind() == ETypeAnnotationKind::Struct); + YQL_ENSURE(stage); + dqSourceLookupCn.SetConnectionInputRowType(NYql::NCommon::GetSerializedTypeAnnotation(connectionInputRowType)); + dqSourceLookupCn.SetConnectionOutputRowType(NYql::NCommon::GetSerializedTypeAnnotation(connectionOutputRowType)); + dqSourceLookupCn.SetLookupRowType(NYql::NCommon::GetSerializedTypeAnnotation(lookupSourceWrap.RowType().Ref().GetTypeAnn())); + dqSourceLookupCn.SetInputStageRowType(NYql::NCommon::GetSerializedTypeAnnotation(GetSeqItemType(streamLookupOutput.Stage().Program().Ref().GetTypeAnn()))); + dqSourceLookupCn.SetOutputStageRowType(NYql::NCommon::GetSerializedTypeAnnotation(GetSeqItemType(stage->Program().Args().Arg(inputIndex).Ref().GetTypeAnn()))); + + const TString leftLabel(streamLookup.LeftLabel()); + dqSourceLookupCn.SetLeftLabel(leftLabel); + dqSourceLookupCn.SetRightLabel(streamLookup.RightLabel().StringValue()); + dqSourceLookupCn.SetJoinType(streamLookup.JoinType().StringValue()); + dqSourceLookupCn.SetCacheLimit(FromString(streamLookup.MaxCachedRows())); + dqSourceLookupCn.SetCacheTtlSeconds(FromString(streamLookup.TTL())); + dqSourceLookupCn.SetMaxDelayedRows(FromString(streamLookup.MaxDelayedRows())); + + if (const auto maybeMultiget = streamLookup.IsMultiget()) { + dqSourceLookupCn.SetIsMultiGet(FromString(maybeMultiget.Cast())); + } + + for (const auto& key : streamLookup.LeftJoinKeyNames()) { + *dqSourceLookupCn.AddLeftJoinKeyNames() = leftLabel ? RemoveJoinAliases(key) : key; + } + + for (const auto& key : streamLookup.RightJoinKeyNames()) { + *dqSourceLookupCn.AddRightJoinKeyNames() = RemoveJoinAliases(key); + } + + return; + } + YQL_ENSURE(false, "Unexpected connection type: " << connection.CallableName()); } diff --git a/ydb/core/kqp/ut/federated_query/datastreams/datastreams_ut.cpp b/ydb/core/kqp/ut/federated_query/datastreams/datastreams_ut.cpp index e607a32f960f..d39d65700557 100644 --- a/ydb/core/kqp/ut/federated_query/datastreams/datastreams_ut.cpp +++ b/ydb/core/kqp/ut/federated_query/datastreams/datastreams_ut.cpp @@ -34,6 +34,27 @@ struct TScriptQuerySettings { TDuration Timeout = TDuration::Seconds(30); }; +struct TColumn { + TString Name; + Ydb::Type::PrimitiveTypeId Type; +}; + +struct TMockConnectorTableDescriptionSettings { + TString TableName; + std::vector Columns; + ui64 DescribeCount = 1; + ui64 ListSplitsCount = 1; + bool ValidateListSplitsArgs = true; +}; + +struct TMockConnectorReadSplitsSettings { + TString TableName; + std::vector Columns; + ui64 NumberReadSplits; + bool ValidateReadSplitsArgs = true; + std::function()> ResultFactory; +}; + class TStreamingTestFixture : public NUnitTest::TBaseFixture { using TBase = NUnitTest::TBaseFixture; @@ -348,6 +369,24 @@ class TStreamingTestFixture : public NUnitTest::TBaseFixture { )); } + void CreateYdbSource(const TString& ydbSourceName) { + ExecQuery(fmt::format(R"( + UPSERT OBJECT ydb_source_secret (TYPE SECRET) WITH (value = "{token}"); + CREATE EXTERNAL DATA SOURCE `{ydb_source}` WITH ( + SOURCE_TYPE = "Ydb", + LOCATION = "{ydb_location}", + DATABASE_NAME = "{ydb_database_name}", + AUTH_METHOD = "TOKEN", + TOKEN_SECRET_NAME = "ydb_source_secret", + USE_TLS = "FALSE" + );)", + "ydb_source"_a = ydbSourceName, + "ydb_location"_a = YDB_ENDPOINT, + "ydb_database_name"_a = YDB_DATABASE, + "token"_a = BUILTIN_ACL_ROOT + )); + } + // Script executions (using query client SDK) TOperation::TOperationId ExecScript(const TString& query, std::optional settings = std::nullopt, bool waitRunning = true) { @@ -562,7 +601,7 @@ class TStreamingTestFixture : public NUnitTest::TBaseFixture { }); } - // Utils + // Mock PQ utils static IMockPqReadSession::TPtr WaitMockPqReadSession(IMockPqGateway::TPtr gateway, const TString& topic) { return WaitForPqMockSession(TEST_OPERATION_TIMEOUT, "read", [gateway, topic]() { @@ -610,6 +649,93 @@ class TStreamingTestFixture : public NUnitTest::TBaseFixture { ReadMockPqMessages(session, {message}); } + // Mock Connector utils + + static NYql::TGenericDataSourceInstance GetMockConnectorSourceInstance() { + NYql::TGenericDataSourceInstance dataSourceInstance; + dataSourceInstance.set_kind(NYql::YDB); + dataSourceInstance.set_database(YDB_DATABASE); + dataSourceInstance.set_use_tls(false); + dataSourceInstance.set_protocol(NYql::NATIVE); + + auto& endpoint = *dataSourceInstance.mutable_endpoint(); + TIpPort port; + NHttp::CrackAddress(YDB_ENDPOINT, *endpoint.mutable_host(), port); + endpoint.set_port(port); + + auto& iamToken = *dataSourceInstance.mutable_credentials()->mutable_token(); + iamToken.set_type("IAM"); + iamToken.set_value(BUILTIN_ACL_ROOT); + + return dataSourceInstance; + } + + template + static void FillMockConnectorRequestColumns(TRequestBuilder& builder, const std::vector& columns) { + for (const auto& column : columns) { + builder.Column(column.Name, column.Type); + } + } + + // Should be called at most once + static void SetupMockConnectorTableDescription(std::shared_ptr mockClient, const TMockConnectorTableDescriptionSettings& settings) { + TTypeMappingSettings typeMappingSettings; + typeMappingSettings.set_date_time_format(STRING_FORMAT); + + auto describeTableBuilder = mockClient->ExpectDescribeTable(); + describeTableBuilder + .Table(settings.TableName) + .DataSourceInstance(GetMockConnectorSourceInstance()) + .TypeMappingSettings(typeMappingSettings); + + auto listSplitsBuilder = mockClient->ExpectListSplits(); + listSplitsBuilder + .ValidateArgs(settings.ValidateListSplitsArgs) + .Select() + .DataSourceInstance(GetMockConnectorSourceInstance()) + .Table(settings.TableName); + + for (ui64 i = 0; i < settings.DescribeCount; ++i) { + auto responseBuilder = describeTableBuilder.Response(); + FillMockConnectorRequestColumns(responseBuilder, settings.Columns); + } + + for (ui64 i = 0; i < settings.ListSplitsCount; ++i) { + auto responseBuilder = listSplitsBuilder.Result() + .AddResponse(NYql::NConnector::NewSuccess()) + .Description("some binary description") + .Select() + .DataSourceInstance(GetMockConnectorSourceInstance()) + .What(); + FillMockConnectorRequestColumns(responseBuilder, settings.Columns); + } + } + + // Should be called at most once + static void SetupMockConnectorTableData(std::shared_ptr mockClient, const TMockConnectorReadSplitsSettings& settings) { + auto readSplitsBuilder = mockClient->ExpectReadSplits(); + + { + auto columnsBuilder = readSplitsBuilder + .Filtering(TReadSplitsRequest::FILTERING_OPTIONAL) + .ValidateArgs(settings.ValidateReadSplitsArgs) + .Split() + .Description("some binary description") + .Select() + .Table(settings.TableName) + .DataSourceInstance(GetMockConnectorSourceInstance()) + .What(); + FillMockConnectorRequestColumns(columnsBuilder, settings.Columns); + } + + for (ui64 i = 0; i < settings.NumberReadSplits; ++i) { + readSplitsBuilder.Result() + .AddResponse(settings.ResultFactory(), NYql::NConnector::NewSuccess()); + } + } + + // Utils + static void WaitFor(TDuration timeout, const TString& description, std::function callback) { TInstant start = TInstant::Now(); TString errorString; @@ -1112,6 +1238,7 @@ Y_UNIT_TEST_SUITE(KqpFederatedQueryDatastreams) { WaitCheckpointUpdate(executionId); auto readSession = WaitMockPqReadSession(pqGateway, inputTopicName); + auto writeSession = WaitMockPqWriteSession(pqGateway, outputTopicName); readSession->AddDataReceivedEvent(1, R"({"key": "key1", "value": "value1"})"); readSession->AddDataReceivedEvent(2, R"({"key": "key2", "value": "value2"})"); readSession->AddDataReceivedEvent(3, R"({"key": "key3", "value": "value3"})"); @@ -1122,7 +1249,8 @@ Y_UNIT_TEST_SUITE(KqpFederatedQueryDatastreams) { WaitCheckpointUpdate(executionId); WaitMockPqReadSession(pqGateway, inputTopicName)->AddDataReceivedEvent(4, R"({"key": "key4", "value": "value4"})"); - ReadMockPqMessage(WaitMockPqWriteSession(pqGateway, outputTopicName), "key4value4"); + writeSession = WaitMockPqWriteSession(pqGateway, outputTopicName); + ReadMockPqMessage(writeSession, "key4value4"); CancelScriptExecution(operationId); } @@ -1683,24 +1811,9 @@ Y_UNIT_TEST_SUITE(KqpStreamingQueriesDdl) { CreateTopic(outputTopicName); constexpr char pqSourceName[] = "pqSourceName"; - CreatePqSource(pqSourceName); - constexpr char ydbSourceName[] = "ydbSourceName"; - ExecQuery(fmt::format(R"( - CREATE OBJECT secret_name (TYPE SECRET) WITH (value = "{token}"); - CREATE EXTERNAL DATA SOURCE `{ydb_source}` WITH ( - SOURCE_TYPE = "Ydb", - LOCATION = "{ydb_location}", - DATABASE_NAME = "{ydb_database_name}", - AUTH_METHOD = "TOKEN", - TOKEN_SECRET_NAME = "secret_name", - USE_TLS = "FALSE" - );)", - "ydb_source"_a = ydbSourceName, - "ydb_location"_a = YDB_ENDPOINT, - "ydb_database_name"_a = YDB_DATABASE, - "token"_a = BUILTIN_ACL_ROOT - )); + CreatePqSource(pqSourceName); + CreateYdbSource(ydbSourceName); constexpr char ydbTable[] = "lookup"; ExecExternalQuery(fmt::format(R"( @@ -1713,75 +1826,30 @@ Y_UNIT_TEST_SUITE(KqpStreamingQueriesDdl) { )); { // Prepare connector mock - NYql::TGenericDataSourceInstance dataSourceInstance; - dataSourceInstance.set_kind(NYql::YDB); - dataSourceInstance.set_database(YDB_DATABASE); - dataSourceInstance.set_use_tls(false); - dataSourceInstance.set_protocol(NYql::NATIVE); - - auto& endpoint = *dataSourceInstance.mutable_endpoint(); - TIpPort port; - NHttp::CrackAddress(YDB_ENDPOINT, *endpoint.mutable_host(), port); - endpoint.set_port(port); - - auto& iamToken = *dataSourceInstance.mutable_credentials()->mutable_token(); - iamToken.set_type("IAM"); - iamToken.set_value(BUILTIN_ACL_ROOT); - - TTypeMappingSettings typeMappingSettings; - typeMappingSettings.set_date_time_format(STRING_FORMAT); - - auto describeTableBuilder = connectorClient->ExpectDescribeTable(); - describeTableBuilder - .Table(ydbTable) - .DataSourceInstance(dataSourceInstance) - .TypeMappingSettings(typeMappingSettings); - - auto listSplitsBuilder = connectorClient->ExpectListSplits(); - listSplitsBuilder.Select() - .DataSourceInstance(dataSourceInstance) - .Table(ydbTable); + const std::vector columns = { + {"fqdn", Ydb::Type::STRING}, + {"payload", Ydb::Type::STRING} + }; + SetupMockConnectorTableDescription(connectorClient, { + .TableName = ydbTable, + .Columns = columns, + .DescribeCount = 2, + .ListSplitsCount = 2 + }); const std::vector fqdnColumn = {"host1.example.com", "host2.example.com", "host3.example.com"}; const std::vector payloadColumn = {"P1", "P2", "P3"}; - auto readSplitsBuilder = connectorClient->ExpectReadSplits(); - readSplitsBuilder - .Filtering(TReadSplitsRequest::FILTERING_OPTIONAL) - .Split() - .Description("some binary description") - .Select() - .Table(ydbTable) - .DataSourceInstance(dataSourceInstance) - .What() - .Column("fqdn", Ydb::Type::STRING) - .Column("payload", Ydb::Type::STRING); - - const auto builtResults = [&]() { - describeTableBuilder.Response() - .Column("fqdn", Ydb::Type::STRING) - .Column("payload", Ydb::Type::STRING); - - listSplitsBuilder.Result() - .AddResponse(NYql::NConnector::NewSuccess()) - .Description("some binary description") - .Select() - .DataSourceInstance(dataSourceInstance) - .What() - .Column("fqdn", Ydb::Type::STRING) - .Column("payload", Ydb::Type::STRING); - - readSplitsBuilder.Result() - .AddResponse( - MakeRecordBatch( - MakeArray("fqdn", fqdnColumn, arrow::binary()), - MakeArray("payload", payloadColumn, arrow::binary()) - ), - NYql::NConnector::NewSuccess() + SetupMockConnectorTableData(connectorClient, { + .TableName = ydbTable, + .Columns = columns, + .NumberReadSplits = 2, + .ResultFactory = [&]() { + return MakeRecordBatch( + MakeArray("fqdn", fqdnColumn, arrow::binary()), + MakeArray("payload", payloadColumn, arrow::binary()) ); - }; - - builtResults(); - builtResults(); // Streaming queries compiled twice, also in test results requested twice due to retry + } + }); } constexpr char queryName[] = "streamingQuery"; @@ -1834,6 +1902,131 @@ Y_UNIT_TEST_SUITE(KqpStreamingQueriesDdl) { WaitMockPqReadSession(pqGateway, inputTopicName)->AddDataReceivedEvent(sampleMessages); ReadMockPqMessages(WaitMockPqWriteSession(pqGateway, outputTopicName), sampleResult); } + + Y_UNIT_TEST_F(StreamingQueryWithStreamLookupJoin, TStreamingTestFixture) { + const auto connectorClient = SetupMockConnectorClient(); + const auto pqGateway = SetupMockPqGateway(); + + constexpr char inputTopicName[] = "sljInputTopicName"; + constexpr char outputTopicName[] = "sljOutputTopicName"; + CreateTopic(inputTopicName); + CreateTopic(outputTopicName); + + constexpr char pqSourceName[] = "pqSourceName"; + constexpr char ydbSourceName[] = "ydbSourceName"; + CreatePqSource(pqSourceName); + CreateYdbSource(ydbSourceName); + + constexpr char ydbTable[] = "lookup"; + ExecExternalQuery(fmt::format(R"( + CREATE TABLE `{table}` ( + fqdn String, + payload String, + PRIMARY KEY (fqdn) + ))", + "table"_a = ydbTable + )); + + { // Prepare connector mock + const std::vector columns = { + {"fqdn", Ydb::Type::STRING}, + {"payload", Ydb::Type::STRING} + }; + SetupMockConnectorTableDescription(connectorClient, { + .TableName = ydbTable, + .Columns = columns, + .DescribeCount = 2, + .ListSplitsCount = 5, + .ValidateListSplitsArgs = false + }); + + ui64 readSplitsCount = 0; + const std::vector fqdnColumn = {"host1.example.com", "host2.example.com", "host3.example.com"}; + SetupMockConnectorTableData(connectorClient, { + .TableName = ydbTable, + .Columns = columns, + .NumberReadSplits = 3, + .ValidateReadSplitsArgs = false, + .ResultFactory = [&]() { + readSplitsCount += 1; + const auto payloadColumn = readSplitsCount < 3 + ? std::vector{"P1", "P2", "P3"} + : std::vector{"P4", "P5", "P6"}; + + return MakeRecordBatch( + MakeArray("fqdn", fqdnColumn, arrow::binary()), + MakeArray("payload", payloadColumn, arrow::binary()) + ); + } + }); + } + + constexpr char queryName[] = "streamingQuery"; + ExecQuery(fmt::format(R"( + CREATE STREAMING QUERY `{query_name}` AS + DO BEGIN + $ydb_lookup = SELECT * FROM `{ydb_source}`.`{ydb_table}`; + + $pq_source = SELECT * FROM `{pq_source}`.`{input_topic}` WITH ( + FORMAT = "json_each_row", + SCHEMA ( + time Int32 NOT NULL, + event String, + host String + ) + ); + + $joined = SELECT l.payload AS payload, p.* FROM $pq_source AS p + LEFT JOIN /*+ streamlookup(TTL 1) */ ANY $ydb_lookup AS l + ON (l.fqdn = p.host); + + INSERT INTO `{pq_source}`.`{output_topic}` + SELECT Unwrap(event || "-" || payload) FROM $joined + END DO;)", + "query_name"_a = queryName, + "pq_source"_a = pqSourceName, + "ydb_source"_a = ydbSourceName, + "ydb_table"_a = ydbTable, + "input_topic"_a = inputTopicName, + "output_topic"_a = outputTopicName + )); + + CheckScriptExecutionsCount(1, 1); + + auto readSession = WaitMockPqReadSession(pqGateway, inputTopicName); + const std::vector sampleMessages = { + {0, R"({"time": 0, "event": "A", "host": "host1.example.com"})"}, + {1, R"({"time": 1, "event": "B", "host": "host3.example.com"})"}, + {2, R"({"time": 2, "event": "A", "host": "host1.example.com"})"}, + }; + readSession->AddDataReceivedEvent(sampleMessages); + + auto writeSession = WaitMockPqWriteSession(pqGateway, outputTopicName); + const std::vector sampleResult = {"A-P1", "B-P3", "A-P1"}; + ReadMockPqMessages(writeSession, sampleResult); + + readSession->AddCloseSessionEvent(EStatus::UNAVAILABLE, {NIssue::TIssue("Test pq session failure")}); + + readSession = WaitMockPqReadSession(pqGateway, inputTopicName); + readSession->AddDataReceivedEvent(sampleMessages); + writeSession = WaitMockPqWriteSession(pqGateway, outputTopicName); + ReadMockPqMessages(writeSession, sampleResult); + + Sleep(TDuration::Seconds(2)); + readSession->AddDataReceivedEvent(sampleMessages); + ReadMockPqMessages(writeSession, {"A-P4", "B-P6", "A-P4"}); + + CheckScriptExecutionsCount(1, 1); + const auto results = ExecQuery( + "SELECT ast_compressed FROM `.metadata/script_executions`;" + ); + UNIT_ASSERT_VALUES_EQUAL(results.size(), 1); + CheckScriptResult(results[0], 1, 1, [](TResultSetParser& result) { + const auto& ast = result.ColumnParser(0).GetOptionalString(); + UNIT_ASSERT(ast); + UNIT_ASSERT_STRING_CONTAINS(*ast, "DqCnStreamLookup"); + }); + } } } // namespace NKikimr::NKqp diff --git a/ydb/core/protos/kqp_physical.proto b/ydb/core/protos/kqp_physical.proto index 2573510b6d70..6f4c37f84a47 100644 --- a/ydb/core/protos/kqp_physical.proto +++ b/ydb/core/protos/kqp_physical.proto @@ -335,6 +335,34 @@ message TKqpPhyCnSequencer { bytes OutputType = 5; } +message TKqpPhyCnDqSourceStreamLookup { + // + // |<- InputStageRowType + // [maybe wide DQ channel] + // |<- ConnectionInputRowType |<- LookupRowType + // ------------------+ + // |<- ConnectionOutputRowType + // [maybe wide DQ channel] + // |<- OutputStageRowType + // + + bytes InputStageRowType = 1; + bytes OutputStageRowType = 2; + bytes LookupRowType = 3; + bytes ConnectionInputRowType = 4; + bytes ConnectionOutputRowType = 5; + TKqpExternalSource LookupSource = 6; + string LeftLabel = 7; + string RightLabel = 8; + string JoinType = 9; + repeated string LeftJoinKeyNames = 10; + repeated string RightJoinKeyNames = 11; + uint64 CacheLimit = 12; + uint64 CacheTtlSeconds = 13; + uint64 MaxDelayedRows = 14; + bool IsMultiGet = 15; +} + message TKqpPhyConnection { uint32 StageIndex = 1; uint32 OutputIndex = 2; @@ -354,6 +382,7 @@ message TKqpPhyConnection { TKqpPhyCnSequencer Sequencer = 14; TKqpPhyCnParallelUnionAll ParallelUnionAll = 15; TKqpPhyCnVectorResolve VectorResolve = 16; + TKqpPhyCnDqSourceStreamLookup DqSourceStreamLookup = 17; }; } diff --git a/ydb/library/yql/dq/opt/dq_opt_join.cpp b/ydb/library/yql/dq/opt/dq_opt_join.cpp index 8b4f8856b748..63fbb13b38ea 100644 --- a/ydb/library/yql/dq/opt/dq_opt_join.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_join.cpp @@ -1,12 +1,14 @@ #include "dq_opt_join.h" #include "dq_opt_phy.h" +#include +#include +#include #include #include -#include -#include -#include #include +#include +#include namespace NYql::NDq { @@ -2067,4 +2069,119 @@ TExprBase DqBuildHashJoin( .Done(); } +namespace { + +bool IsStreamLookup(const TCoEquiJoinTuple& joinTuple) { + for (const auto& outer : joinTuple.Options()) { + for (const auto& inner : outer.Cast()) { + if (auto maybeForceStreamLookupOption = inner.Maybe()) { + if (maybeForceStreamLookupOption.Cast().StringValue() == "forceStreamLookup") { + return true; + } + } + } + } + return false; +} + +IDqOptimization* GetDqOptCallback(const TExprBase& providerRead, TTypeAnnotationContext& typeCtx) { + if (providerRead.Ref().ChildrenSize() > 1 && TCoDataSource::Match(providerRead.Ref().Child(1))) { + auto dataSourceName = providerRead.Ref().Child(1)->Child(0)->Content(); + auto datasource = typeCtx.DataSourceMap.FindPtr(dataSourceName); + YQL_ENSURE(datasource); + return (*datasource)->GetDqOptimization(); + } + return nullptr; +} + +TDqLookupSourceWrap LookupSourceFromSource(TDqSourceWrap source, TExprContext& ctx) { + return Build(ctx, source.Pos()) + .Input(source.Input()) + .DataSource(source.DataSource()) + .RowType(source.RowType()) + .Settings(source.Settings()) + .Done(); +} + +TDqLookupSourceWrap LookupSourceFromRead(TDqReadWrap read, TExprContext& ctx, TTypeAnnotationContext& typeCtx) { // temp replace with yt source + IDqOptimization* dqOptimization = GetDqOptCallback(read.Input(), typeCtx); + YQL_ENSURE(dqOptimization); + auto lookupSourceWrap = dqOptimization->RewriteLookupRead(read.Input().Ptr(), ctx); + YQL_ENSURE(lookupSourceWrap, "Lookup read is not supported"); + return TDqLookupSourceWrap(lookupSourceWrap); +} + +// Recursively walk join tree and replace right-side of StreamLookupJoin +ui32 RewriteStreamJoinTuple(ui32 idx, const TCoEquiJoin& equiJoin, const TCoEquiJoinTuple& joinTuple, std::vector& args, TExprContext& ctx, TTypeAnnotationContext& typeCtx, bool& changed) { + // recursion depth O(args.size()) + Y_ENSURE(idx < args.size()); + + // handle left side + if (!joinTuple.LeftScope().Maybe()) { + idx = RewriteStreamJoinTuple(idx, equiJoin, joinTuple.LeftScope().Cast(), args, ctx, typeCtx, changed); + } else { + ++idx; + } + + // handle right side + if (!joinTuple.RightScope().Maybe()) { + return RewriteStreamJoinTuple(idx, equiJoin, joinTuple.RightScope().Cast(), args, ctx, typeCtx, changed); + } + + Y_ENSURE(idx < args.size()); + + if (!IsStreamLookup(joinTuple)) { + return idx + 1; + } + + auto right = equiJoin.Arg(idx).Cast(); + auto rightList = right.List(); + if (auto maybeExtractMembers = rightList.Maybe()) { + rightList = maybeExtractMembers.Cast().Input(); + } + + TExprNode::TPtr lookupSourceWrap; + if (auto maybeSource = rightList.Maybe()) { + lookupSourceWrap = LookupSourceFromSource(maybeSource.Cast(), ctx).Ptr(); + } else if (auto maybeRead = rightList.Maybe()) { + lookupSourceWrap = LookupSourceFromRead(maybeRead.Cast(), ctx, typeCtx).Ptr(); + } else { + return idx + 1; + } + + changed = true; + args[idx] = + Build(ctx, joinTuple.Pos()) + .List(lookupSourceWrap) + .Scope(right.Scope()) + .Done().Ptr(); + + return idx + 1; +} + +} // anonymous namespace + +TExprBase DqRewriteStreamEquiJoinWithLookup(const TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typeCtx) { + const auto equiJoin = node.Cast(); + auto argCount = equiJoin.ArgCount(); + const auto joinTuple = equiJoin.Arg(argCount - 2).Cast(); + std::vector args(argCount); + bool changed = false; + auto rightIdx = RewriteStreamJoinTuple(0u, equiJoin, joinTuple, args, ctx, typeCtx, changed); + Y_ENSURE(rightIdx + 2 == argCount); + + if (!changed) { + return node; + } + + // fill copies of remaining args + for (ui32 i = 0; i < argCount; ++i) { + if (!args[i]) { + args[i] = equiJoin.Arg(i).Ptr(); + } + } + + return Build(ctx, node.Pos()).Add(std::move(args)).Done(); +} + } // namespace NYql::NDq diff --git a/ydb/library/yql/dq/opt/dq_opt_join.h b/ydb/library/yql/dq/opt/dq_opt_join.h index b3eea0b87f15..57047bad61f1 100644 --- a/ydb/library/yql/dq/opt/dq_opt_join.h +++ b/ydb/library/yql/dq/opt/dq_opt_join.h @@ -49,5 +49,7 @@ bool DqCollectJoinRelationsWithStats( const NNodes::TCoEquiJoin& equiJoin, const std::function>&, TStringBuf, const TExprNode::TPtr, const std::shared_ptr&)>& collector); +NNodes::TExprBase DqRewriteStreamEquiJoinWithLookup(const NNodes::TExprBase& node, TExprContext& ctx, TTypeAnnotationContext& typeCtx); + } // namespace NDq } // namespace NYql diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.cpp b/ydb/library/yql/dq/opt/dq_opt_phy.cpp index eddd9b265120..9874fc525681 100644 --- a/ydb/library/yql/dq/opt/dq_opt_phy.cpp +++ b/ydb/library/yql/dq/opt/dq_opt_phy.cpp @@ -3463,5 +3463,132 @@ TMaybeNode DqUnorderedOverStageInput(TExprBase node, TExprContext& ct return TExprBase(res); } +namespace { + +bool ValidateStreamLookupJoinFlags(const TDqJoin& join, TExprContext& ctx) { + bool leftAny = false; + bool rightAny = false; + if (const auto maybeFlags = join.Flags()) { + for (auto&& flag: maybeFlags.Cast()) { + auto&& name = flag.StringValue(); + if (name == "LeftAny"sv) { + leftAny = true; + continue; + } else if (name == "RightAny"sv) { + rightAny = true; + continue; + } + } + if (leftAny) { + ctx.AddError(TIssue(ctx.GetPosition(maybeFlags.Cast().Pos()), "Streamlookup ANY LEFT join is not implemented")); + return false; + } + } + + if (!rightAny) { + if (false) { // Temporary change to waring to allow for smooth transition + ctx.AddError(TIssue(ctx.GetPosition(join.Pos()), "Streamlookup: must be LEFT JOIN /*+streamlookup(...)*/ ANY")); + return false; + } else { + ctx.AddWarning(TIssue(ctx.GetPosition(join.Pos()), "(Deprecation) Streamlookup: must be LEFT JOIN /*+streamlookup(...)*/ ANY")); + } + } + + return true; +} + +} // anonymous namespace + +TMaybeNode DqRewriteStreamLookupJoin(TExprBase node, TExprContext& ctx) { + const auto join = node.Cast(); + if (join.JoinAlgo().StringValue() != "StreamLookupJoin") { + return node; + } + + const auto left = join.LeftInput().Maybe(); + if (!left) { + return node; + } + + if (!ValidateStreamLookupJoinFlags(join, ctx)) { + return {}; + } + + TExprNode::TPtr ttl; + TExprNode::TPtr maxCachedRows; + TExprNode::TPtr maxDelayedRows; + TExprNode::TPtr isMultiget; + if (const auto maybeOptions = join.JoinAlgoOptions()) { + for (auto&& option: maybeOptions.Cast()) { + auto&& name = option.Name().Value(); + if (name == "TTL"sv) { + ttl = option.Value().Cast().Ptr(); + } else if (name == "MaxCachedRows"sv) { + maxCachedRows = option.Value().Cast().Ptr(); + } else if (name == "MaxDelayedRows"sv) { + maxDelayedRows = option.Value().Cast().Ptr(); + } else if (name == "MultiGet"sv) { + isMultiget = option.Value().Cast().Ptr(); + } + } + } + + const auto pos = node.Pos(); + + if (!ttl) { + ttl = ctx.NewAtom(pos, 300); + } + + if (!maxCachedRows) { + maxCachedRows = ctx.NewAtom(pos, 1'000'000); + } + + if (!maxDelayedRows) { + maxDelayedRows = ctx.NewAtom(pos, 1'000'000); + } + + auto rightInput = join.RightInput().Ptr(); + if (auto maybe = TExprBase(rightInput).Maybe()) { + rightInput = maybe.Cast().Input().Ptr(); + } + + auto leftLabel = join.LeftLabel().Maybe() ? join.LeftLabel().Cast().Ptr() : ctx.NewAtom(pos, ""); + Y_ENSURE(join.RightLabel().Maybe()); + auto cn = Build(ctx, pos) + .Output(left.Output().Cast()) + .LeftLabel(leftLabel) + .RightInput(rightInput) + .RightLabel(join.RightLabel().Cast()) + .JoinKeys(join.JoinKeys()) + .JoinType(join.JoinType()) + .LeftJoinKeyNames(join.LeftJoinKeyNames()) + .RightJoinKeyNames(join.RightJoinKeyNames()) + .TTL(ttl) + .MaxCachedRows(maxCachedRows) + .MaxDelayedRows(maxDelayedRows); + + if (isMultiget) { + cn.IsMultiget(isMultiget); + } + + auto lambda = Build(ctx, pos) + .Args({"stream"}) + .Body("stream") + .Done(); + const auto stage = Build(ctx, pos) + .Inputs() + .Add(cn.Done()) + .Build() + .Program(lambda) + .Settings(TDqStageSettings().BuildNode(ctx, pos)) + .Done(); + + return Build(ctx, pos) + .Output() + .Stage(stage) + .Index().Build("0") + .Build() + .Done(); +} } // namespace NYql::NDq diff --git a/ydb/library/yql/dq/opt/dq_opt_phy.h b/ydb/library/yql/dq/opt/dq_opt_phy.h index 70fea30826d3..aa1d63366866 100644 --- a/ydb/library/yql/dq/opt/dq_opt_phy.h +++ b/ydb/library/yql/dq/opt/dq_opt_phy.h @@ -180,5 +180,6 @@ NNodes::TExprBase DqPushUnorderedToStage(NNodes::TExprBase node, TExprContext& c NNodes::TMaybeNode DqUnorderedOverStageInput(NNodes::TExprBase node, TExprContext& ctx, IOptimizationContext& optCtx, const TTypeAnnotationContext& typeAnnCtx, const TParentsMap& parentsMap, bool allowStageMultiUsage); +NNodes::TMaybeNode DqRewriteStreamLookupJoin(NNodes::TExprBase node, TExprContext& ctx); } // namespace NYql::NDq diff --git a/ydb/library/yql/providers/dq/opt/logical_optimize.cpp b/ydb/library/yql/providers/dq/opt/logical_optimize.cpp index 1457a800c673..70a5f6e7c4fb 100644 --- a/ydb/library/yql/providers/dq/opt/logical_optimize.cpp +++ b/ydb/library/yql/providers/dq/opt/logical_optimize.cpp @@ -26,23 +26,6 @@ using namespace NYql; using namespace NYql::NDq; using namespace NYql::NNodes; -namespace { - -bool IsStreamLookup(const TCoEquiJoinTuple& joinTuple) { - for (const auto& outer: joinTuple.Options()) { - for (const auto& inner: outer.Cast()) { - if (auto maybeForceStreamLookupOption = inner.Maybe()) { - if (maybeForceStreamLookupOption.Cast().StringValue() == "forceStreamLookup") { - return true; - } - } - } - } - return false; -} - -} - /** * DQ Specific cost function and join applicability cost function */ @@ -223,81 +206,8 @@ class TDqsLogicalOptProposalTransformer : public TOptimizeTransformerBase { return node; } - TDqLookupSourceWrap LookupSourceFromSource(TDqSourceWrap source, TExprContext& ctx) { - return Build(ctx, source.Pos()) - .Input(source.Input()) - .DataSource(source.DataSource()) - .RowType(source.RowType()) - .Settings(source.Settings()) - .Done(); - } - - TDqLookupSourceWrap LookupSourceFromRead(TDqReadWrap read, TExprContext& ctx){ //temp replace with yt source - IDqOptimization* dqOptimization = GetDqOptCallback(read.Input()); - YQL_ENSURE(dqOptimization); - auto lookupSourceWrap = dqOptimization->RewriteLookupRead(read.Input().Ptr(), ctx); - YQL_ENSURE(lookupSourceWrap, "Lookup read is not supported"); - return TDqLookupSourceWrap(lookupSourceWrap); - } - - // Recursively walk join tree and replace right-side of StreamLookupJoin - ui32 RewriteStreamJoinTuple(ui32 idx, const TCoEquiJoin& equiJoin, const TCoEquiJoinTuple& joinTuple, std::vector& args, TExprContext& ctx, bool& changed) { - // recursion depth O(args.size()) - Y_ENSURE(idx < args.size()); - // handle left side - if (!joinTuple.LeftScope().Maybe()) { - idx = RewriteStreamJoinTuple(idx, equiJoin, joinTuple.LeftScope().Cast(), args, ctx, changed); - } else { - ++idx; - } - // handle right side - if (!joinTuple.RightScope().Maybe()) { - return RewriteStreamJoinTuple(idx, equiJoin, joinTuple.RightScope().Cast(), args, ctx, changed); - } - Y_ENSURE(idx < args.size()); - if (!IsStreamLookup(joinTuple)) { - return idx + 1; - } - auto right = equiJoin.Arg(idx).Cast(); - auto rightList = right.List(); - if (auto maybeExtractMembers = rightList.Maybe()) { - rightList = maybeExtractMembers.Cast().Input(); - } - TExprNode::TPtr lookupSourceWrap; - if (auto maybeSource = rightList.Maybe()) { - lookupSourceWrap = LookupSourceFromSource(maybeSource.Cast(), ctx).Ptr(); - } else if (auto maybeRead = rightList.Maybe()) { - lookupSourceWrap = LookupSourceFromRead(maybeRead.Cast(), ctx).Ptr(); - } else { - return idx + 1; - } - changed = true; - args[idx] = - Build(ctx, joinTuple.Pos()) - .List(lookupSourceWrap) - .Scope(right.Scope()) - .Done().Ptr(); - return idx + 1; - } - TMaybeNode RewriteStreamEquiJoinWithLookup(TExprBase node, TExprContext& ctx) { - const auto equiJoin = node.Cast(); - auto argCount = equiJoin.ArgCount(); - const auto joinTuple = equiJoin.Arg(argCount - 2).Cast(); - std::vector args(argCount); - bool changed = false; - auto rightIdx = RewriteStreamJoinTuple(0u, equiJoin, joinTuple, args, ctx, changed); - Y_ENSURE(rightIdx + 2 == argCount); - if (!changed) { - return node; - } - // fill copies of remaining args - for (ui32 i = 0; i < argCount; ++i) { - if (!args[i]) { - args[i] = equiJoin.Arg(i).Ptr(); - } - } - return Build(ctx, node.Pos()).Add(std::move(args)).Done(); + return DqRewriteStreamEquiJoinWithLookup(node, ctx, TypesCtx); } TMaybeNode OptimizeEquiJoinWithCosts(TExprBase node, TExprContext& ctx) { @@ -446,16 +356,6 @@ class TDqsLogicalOptProposalTransformer : public TOptimizeTransformerBase { "Distinct is not supported for aggregation with hop"); } - IDqOptimization* GetDqOptCallback(const TExprBase& providerRead) const { - if (providerRead.Ref().ChildrenSize() > 1 && TCoDataSource::Match(providerRead.Ref().Child(1))) { - auto dataSourceName = providerRead.Ref().Child(1)->Child(0)->Content(); - auto datasource = TypesCtx.DataSourceMap.FindPtr(dataSourceName); - YQL_ENSURE(datasource); - return (*datasource)->GetDqOptimization(); - } - return nullptr; - } - private: TDqConfiguration::TPtr Config; TTypeAnnotationContext& TypesCtx; diff --git a/ydb/library/yql/providers/dq/opt/physical_optimize.cpp b/ydb/library/yql/providers/dq/opt/physical_optimize.cpp index 8792d4979e25..f305be8c72e3 100644 --- a/ydb/library/yql/providers/dq/opt/physical_optimize.cpp +++ b/ydb/library/yql/providers/dq/opt/physical_optimize.cpp @@ -264,121 +264,8 @@ class TDqsPhysicalOptProposalTransformer : public TOptimizeTransformerBase { return DqRewriteLeftPureJoin(node, ctx, *getParents(), IsGlobal); } - bool ValidateStreamLookupJoinFlags(const TDqJoin& join, TExprContext& ctx) { - bool leftAny = false; - bool rightAny = false; - if (const auto maybeFlags = join.Flags()) { - for (auto&& flag: maybeFlags.Cast()) { - auto&& name = flag.StringValue(); - if (name == "LeftAny"sv) { - leftAny = true; - continue; - } else if (name == "RightAny"sv) { - rightAny = true; - continue; - } - } - if (leftAny) { - ctx.AddError(TIssue(ctx.GetPosition(maybeFlags.Cast().Pos()), "Streamlookup ANY LEFT join is not implemented")); - return false; - } - } - if (!rightAny) { - if (false) { // Tempoarily change to waring to allow for smooth transition - ctx.AddError(TIssue(ctx.GetPosition(join.Pos()), "Streamlookup: must be LEFT JOIN /*+streamlookup(...)*/ ANY")); - return false; - } else { - ctx.AddWarning(TIssue(ctx.GetPosition(join.Pos()), "(Deprecation) Streamlookup: must be LEFT JOIN /*+streamlookup(...)*/ ANY")); - } - } - return true; - } - TMaybeNode RewriteStreamLookupJoin(TExprBase node, TExprContext& ctx) { - const auto join = node.Cast(); - if (join.JoinAlgo().StringValue() != "StreamLookupJoin") { - return node; - } - - const auto pos = node.Pos(); - const auto left = join.LeftInput().Maybe(); - if (!left) { - return node; - } - - if (!ValidateStreamLookupJoinFlags(join, ctx)) { - return {}; - } - - TExprNode::TPtr ttl; - TExprNode::TPtr maxCachedRows; - TExprNode::TPtr maxDelayedRows; - TExprNode::TPtr isMultiget; - if (const auto maybeOptions = join.JoinAlgoOptions()) { - for (auto&& option: maybeOptions.Cast()) { - auto&& name = option.Name().Value(); - if (name == "TTL"sv) { - ttl = option.Value().Cast().Ptr(); - } else if (name == "MaxCachedRows"sv) { - maxCachedRows = option.Value().Cast().Ptr(); - } else if (name == "MaxDelayedRows"sv) { - maxDelayedRows = option.Value().Cast().Ptr(); - } else if (name == "MultiGet"sv) { - isMultiget = option.Value().Cast().Ptr(); - } - } - } - - if (!ttl) { - ttl = ctx.NewAtom(pos, 300); - } - if (!maxCachedRows) { - maxCachedRows = ctx.NewAtom(pos, 1'000'000); - } - if (!maxDelayedRows) { - maxDelayedRows = ctx.NewAtom(pos, 1'000'000); - } - auto rightInput = join.RightInput().Ptr(); - if (auto maybe = TExprBase(rightInput).Maybe()) { - rightInput = maybe.Cast().Input().Ptr(); - } - auto leftLabel = join.LeftLabel().Maybe() ? join.LeftLabel().Cast().Ptr() : ctx.NewAtom(pos, ""); - Y_ENSURE(join.RightLabel().Maybe()); - auto cn = Build(ctx, pos) - .Output(left.Output().Cast()) - .LeftLabel(leftLabel) - .RightInput(rightInput) - .RightLabel(join.RightLabel().Cast()) - .JoinKeys(join.JoinKeys()) - .JoinType(join.JoinType()) - .LeftJoinKeyNames(join.LeftJoinKeyNames()) - .RightJoinKeyNames(join.RightJoinKeyNames()) - .TTL(ttl) - .MaxCachedRows(maxCachedRows) - .MaxDelayedRows(maxDelayedRows); - - if (isMultiget) { - cn.IsMultiget(isMultiget); - } - - auto lambda = Build(ctx, pos) - .Args({"stream"}) - .Body("stream") - .Done(); - const auto stage = Build(ctx, pos) - .Inputs() - .Add(cn.Done()) - .Build() - .Program(lambda) - .Settings(TDqStageSettings().BuildNode(ctx, pos)) - .Done(); - - return Build(ctx, pos) - .Output() - .Stage(stage) - .Index().Build("0") - .Build() - .Done(); + return DqRewriteStreamLookupJoin(node, ctx); } template diff --git a/ydb/library/yql/providers/generic/connector/libcpp/ut_helpers/connector_client_mock.h b/ydb/library/yql/providers/generic/connector/libcpp/ut_helpers/connector_client_mock.h index 152125db4e27..897a4a156aba 100644 --- a/ydb/library/yql/providers/generic/connector/libcpp/ut_helpers/connector_client_mock.h +++ b/ydb/library/yql/providers/generic/connector/libcpp/ut_helpers/connector_client_mock.h @@ -68,6 +68,11 @@ namespace NYql::NConnector::NTest { return google::protobuf::util::MessageDifferencer::Equals(arg, expected); } + MATCHER_P(RequestRelaxedMatcher, expected, "") { + Y_UNUSED(arg); + return true; + } + #define MATCH_RESULT_WITH_INPUT(INPUT, RESULT_SET, GETTER) \ { \ for (const auto& val : INPUT) { \ @@ -689,13 +694,21 @@ namespace NYql::NConnector::NTest { return *this; } + auto& ValidateArgs(bool validate) { + ValidateArgs_ = validate; + return *this; + } + private: void SetExpectation() { if (ResponseResults_.empty()) { Result(); } - auto& expectBuilder = EXPECT_CALL(*Mock_, ListSplitsImpl(ProtobufRequestMatcher(*Result_))); + auto& expectBuilder = ValidateArgs_ + ? EXPECT_CALL(*Mock_, ListSplitsImpl(ProtobufRequestMatcher(*Result_))) + : EXPECT_CALL(*Mock_, ListSplitsImpl(RequestRelaxedMatcher(*Result_))); + for (auto response : ResponseResults_) { expectBuilder.WillOnce(Return(TIteratorResult{ResponseStatus_, response})); } @@ -705,6 +718,7 @@ namespace NYql::NConnector::NTest { TConnectorClientMock* Mock_ = nullptr; std::vector ResponseResults_; NYdbGrpc::TGrpcStatus ResponseStatus_ {}; + bool ValidateArgs_ = true; }; template @@ -767,6 +781,11 @@ namespace NYql::NConnector::NTest { return *this; } + auto& ValidateArgs(bool validate) { + ValidateArgs_ = validate; + return *this; + } + void FillWithDefaults() { Format(NApi::TReadSplitsRequest::ARROW_IPC_STREAMING); } @@ -777,7 +796,10 @@ namespace NYql::NConnector::NTest { Result(); } - auto& expectBuilder = EXPECT_CALL(*Mock_, ReadSplitsImpl(ProtobufRequestMatcher(*Result_))); + auto& expectBuilder = ValidateArgs_ + ? EXPECT_CALL(*Mock_, ReadSplitsImpl(ProtobufRequestMatcher(*Result_))) + : EXPECT_CALL(*Mock_, ReadSplitsImpl(RequestRelaxedMatcher(*Result_))); + for (auto response : ResponseResults_) { expectBuilder.WillOnce(Return(TIteratorResult{ResponseStatus_, response})); } @@ -787,6 +809,7 @@ namespace NYql::NConnector::NTest { TConnectorClientMock* Mock_ = nullptr; std::vector ResponseResults_; NYdbGrpc::TGrpcStatus ResponseStatus_ {}; + bool ValidateArgs_ = true; }; TDescribeTableExpectationBuilder ExpectDescribeTable() { diff --git a/ydb/tests/fq/generic/streaming/test_join.py b/ydb/tests/fq/generic/streaming/test_join.py index 78fdbba9038a..d86034c65a46 100644 --- a/ydb/tests/fq/generic/streaming/test_join.py +++ b/ydb/tests/fq/generic/streaming/test_join.py @@ -539,7 +539,7 @@ def freeze(json): e.Data as data, u.id as lookup from $input as e - left join {streamlookup} ydb_conn_{table_name}.{table_name} as u + left join {streamlookup} any ydb_conn_{table_name}.{table_name} as u on(AsList(e.Data) = u.data) -- MultiGet true ; @@ -582,7 +582,7 @@ def freeze(json): u.data as lookup from $input as e - left join {streamlookup} ydb_conn_{table_name}.{table_name} as u + left join {streamlookup} any ydb_conn_{table_name}.{table_name} as u on(e.user = u.id) -- MultiGet true ; @@ -656,7 +656,7 @@ def freeze(json): u.data as lookup from $input as e - left join {streamlookup} ydb_conn_{table_name}.{table_name} as u + left join {streamlookup} any ydb_conn_{table_name}.{table_name} as u on(e.user = u.id) -- MultiGet true ; @@ -713,7 +713,7 @@ def freeze(json): $enriched = select a, b, c, d, e, f, za, yb, yc, zd from $input as e - left join {streamlookup} $listified as u + left join {streamlookup} any $listified as u on(e.za = u.a AND e.yb = u.b) -- MultiGet true ; @@ -760,9 +760,9 @@ def freeze(json): $enriched = select u.a as la, u.b as lb, u.c as lc, u2.a as sa, u2.b as sb, u2.c as sc, lza, lyb, sza, syb, yc from $input as e - left join {streamlookup} $listified as u + left join {streamlookup} any $listified as u on(e.lza = u.a AND e.lyb = u.b) - left join /*+streamlookup()*/ $listified as u2 + left join /*+streamlookup()*/ any $listified as u2 on(e.sza = u2.a AND e.syb = u2.b) -- MultiGet true ;