From 3ed672faa82612d75f28bc73117febe1b53f605a Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Thu, 26 May 2022 01:02:46 +0800 Subject: [PATCH] planner: fix bug that planner generates wrong 2 phase aggregate plan for TiFlash (#34779) (#34932) close pingcap/tidb#34682 --- executor/tiflash_test.go | 27 +++++++ planner/core/enforce_mpp_test.go | 54 +++++++++++++ planner/core/exhaust_physical_plans.go | 13 +++- .../core/testdata/enforce_mpp_suite_in.json | 9 +++ .../core/testdata/enforce_mpp_suite_out.json | 75 +++++++++++++++++++ 5 files changed, 177 insertions(+), 1 deletion(-) diff --git a/executor/tiflash_test.go b/executor/tiflash_test.go index 63d6ec5353c1..dc5ac512c3d5 100644 --- a/executor/tiflash_test.go +++ b/executor/tiflash_test.go @@ -1199,3 +1199,30 @@ func TestTiflashPartitionTableScan(t *testing.T) { tk.MustQuery("select count(*) from t where a < 12;").Check(testkit.Rows("2")) wg.Wait() } + +func TestAggPushDownCountStar(t *testing.T) { + store, clean := testkit.CreateMockStore(t, withMockTiFlash(2)) + defer clean() + tk := testkit.NewTestKit(t, store) + + tk.MustExec("use test") + tk.MustExec("drop table if exists c") + tk.MustExec("drop table if exists o") + tk.MustExec("create table c(c_id bigint primary key)") + tk.MustExec("create table o(o_id bigint primary key, c_id bigint not null)") + tk.MustExec("alter table c set tiflash replica 1") + tb := external.GetTableByName(t, tk, "test", "c") + err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + tk.MustExec("alter table o set tiflash replica 1") + tb = external.GetTableByName(t, tk, "test", "o") + err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true) + require.NoError(t, err) + tk.MustExec("insert into c values(1),(2),(3),(4),(5)") + tk.MustExec("insert into o values(1,1),(2,1),(3,2),(4,2),(5,2)") + + tk.MustExec("set @@tidb_enforce_mpp=1") + tk.MustExec("set @@tidb_opt_agg_push_down=1") + + tk.MustQuery("select count(*) from c, o where c.c_id=o.c_id").Check(testkit.Rows("5")) +} diff --git a/planner/core/enforce_mpp_test.go b/planner/core/enforce_mpp_test.go index f9034f0ea5d6..1b7f1792ea60 100644 --- a/planner/core/enforce_mpp_test.go +++ b/planner/core/enforce_mpp_test.go @@ -384,3 +384,57 @@ func TestEnforceMPPWarning4(t *testing.T) { require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())) } } + +// Test agg push down for MPP mode +func TestMPP2PhaseAggPushDown(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + + // test table + tk.MustExec("use test") + tk.MustExec("drop table if exists c") + tk.MustExec("drop table if exists o") + tk.MustExec("create table c(c_id bigint)") + tk.MustExec("create table o(o_id bigint, c_id bigint not null)") + + // Create virtual tiflash replica info. + dom := domain.GetDomain(tk.Session()) + is := dom.InfoSchema() + db, exists := is.SchemaByName(model.NewCIStr("test")) + require.True(t, exists) + for _, tblInfo := range db.Tables { + if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + } + } + + var input []string + var output []struct { + SQL string + Plan []string + Warn []string + } + enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData() + enforceMPPSuiteData.GetTestCases(t, &input, &output) + for i, tt := range input { + testdata.OnRecord(func() { + output[i].SQL = tt + }) + if strings.HasPrefix(tt, "set") || strings.HasPrefix(tt, "UPDATE") { + tk.MustExec(tt) + continue + } + testdata.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()) + }) + res := tk.MustQuery(tt) + res.Check(testkit.Rows(output[i].Plan...)) + require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())) + } +} diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index 191aaa57b2b2..b4670f1146c2 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -2589,6 +2589,11 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert if prop.MPPPartitionTp == property.BroadcastType { return nil } + + // Is this aggregate a final stage aggregate? + // Final agg can't be split into multi-stage aggregate + hasFinalAgg := len(la.AggFuncs) > 0 && la.AggFuncs[0].Mode == aggregation.FinalMode + if len(la.GroupByItems) > 0 { partitionCols := la.GetPotentialPartitionKeys() // trying to match the required parititions. @@ -2612,6 +2617,11 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert hashAggs = append(hashAggs, agg) } + // Final agg can't be split into multi-stage aggregate, so exit early + if hasFinalAgg { + return + } + // 2-phase agg childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.AnyType, RejectSort: true} agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp) @@ -2628,7 +2638,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert agg.MppRunMode = MppTiDB hashAggs = append(hashAggs, agg) } - } else { + } else if !hasFinalAgg { // TODO: support scalar agg in MPP, merge the final result to one node childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, RejectSort: true} agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp) @@ -2671,6 +2681,7 @@ func (la *LogicalAggregation) getHashAggs(prop *property.PhysicalProperty) []Phy if prop.IsFlashProp() { taskTypes = []property.TaskType{prop.TaskTp} } + for _, taskTp := range taskTypes { if taskTp == property.MppTaskType { mppAggs := la.tryToGetMppHashAggs(prop) diff --git a/planner/core/testdata/enforce_mpp_suite_in.json b/planner/core/testdata/enforce_mpp_suite_in.json index 41b3aac9920a..3c70fa18e5a5 100644 --- a/planner/core/testdata/enforce_mpp_suite_in.json +++ b/planner/core/testdata/enforce_mpp_suite_in.json @@ -85,5 +85,14 @@ "explain select a from t where t.a>1 or t.a not in (select a from t); -- now it's supported -- 8. anti left outer semi join", "explain select a from t where t.a not in (select a from s where t.a<1); -- 9. non left join has left conditions" ] + }, + { + "name": "TestMPP2PhaseAggPushDown", + "cases": [ + "set @@tidb_allow_mpp=1;set @@tidb_enforce_mpp=1;set @@tidb_opt_agg_push_down=1;", + "EXPLAIN select count(*) from c, o where c.c_id=o.c_id; -- 1. test agg push down, scalar aggregate", + "EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column", + "EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column" + ] } ] diff --git a/planner/core/testdata/enforce_mpp_suite_out.json b/planner/core/testdata/enforce_mpp_suite_out.json index 744398186067..4ef7f843bf65 100644 --- a/planner/core/testdata/enforce_mpp_suite_out.json +++ b/planner/core/testdata/enforce_mpp_suite_out.json @@ -634,5 +634,80 @@ ] } ] + }, + { + "Name": "TestMPP2PhaseAggPushDown", + "Cases": [ + { + "SQL": "set @@tidb_allow_mpp=1;set @@tidb_enforce_mpp=1;set @@tidb_opt_agg_push_down=1;", + "Plan": null, + "Warn": null + }, + { + "SQL": "EXPLAIN select count(*) from c, o where c.c_id=o.c_id; -- 1. test agg push down, scalar aggregate", + "Plan": [ + "HashAgg_13 1.00 root funcs:count(Column#7)->Column#6", + "└─TableReader_35 9990.00 root data:ExchangeSender_34", + " └─ExchangeSender_34 9990.00 mpp[tiflash] ExchangeType: PassThrough", + " └─HashJoin_14 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", + " ├─ExchangeReceiver_26(Build) 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender_25 8000.00 mpp[tiflash] ExchangeType: Broadcast", + " │ └─Projection_24 8000.00 mpp[tiflash] Column#7, test.o.c_id", + " │ └─HashAgg_19 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#7, funcs:firstrow(test.o.c_id)->test.o.c_id", + " │ └─ExchangeReceiver_23 10000.00 mpp[tiflash] ", + " │ └─ExchangeSender_22 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.c_id, collate: binary]", + " │ └─TableFullScan_21 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo", + " └─Selection_18(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))", + " └─TableFullScan_17 10000.00 mpp[tiflash] table:c keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column", + "Plan": [ + "TableReader_78 8000.00 root data:ExchangeSender_77", + "└─ExchangeSender_77 8000.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_10 8000.00 mpp[tiflash] test.o.o_id, Column#6", + " └─Projection_76 8000.00 mpp[tiflash] Column#6, test.o.o_id", + " └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.o_id", + " └─ExchangeReceiver_71 9990.00 mpp[tiflash] ", + " └─ExchangeSender_70 9990.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.o_id, collate: binary]", + " └─HashJoin_69 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", + " ├─ExchangeReceiver_27(Build) 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender_26 8000.00 mpp[tiflash] ExchangeType: Broadcast", + " │ └─Projection_25 8000.00 mpp[tiflash] Column#7, Column#8, test.o.o_id, test.o.c_id", + " │ └─HashAgg_20 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:count(1)->Column#7, funcs:firstrow(test.o.o_id)->Column#8, funcs:firstrow(test.o.o_id)->test.o.o_id, funcs:firstrow(test.o.c_id)->test.o.c_id", + " │ └─ExchangeReceiver_24 10000.00 mpp[tiflash] ", + " │ └─ExchangeSender_23 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.o_id, collate: binary], [name: test.o.c_id, collate: binary]", + " │ └─TableFullScan_22 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo", + " └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))", + " └─TableFullScan_18 10000.00 mpp[tiflash] table:c keep order:false, stats:pseudo" + ], + "Warn": null + }, + { + "SQL": "EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column", + "Plan": [ + "TableReader_78 8000.00 root data:ExchangeSender_77", + "└─ExchangeSender_77 8000.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Projection_10 8000.00 mpp[tiflash] test.o.c_id, Column#6", + " └─Projection_76 8000.00 mpp[tiflash] Column#6, test.o.c_id", + " └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.c_id", + " └─ExchangeReceiver_71 9990.00 mpp[tiflash] ", + " └─ExchangeSender_70 9990.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.c_id, collate: binary]", + " └─HashJoin_69 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]", + " ├─ExchangeReceiver_27(Build) 8000.00 mpp[tiflash] ", + " │ └─ExchangeSender_26 8000.00 mpp[tiflash] ExchangeType: Broadcast", + " │ └─Projection_25 8000.00 mpp[tiflash] Column#7, Column#8, test.o.c_id", + " │ └─HashAgg_20 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#7, funcs:firstrow(test.o.c_id)->Column#8, funcs:firstrow(test.o.c_id)->test.o.c_id", + " │ └─ExchangeReceiver_24 10000.00 mpp[tiflash] ", + " │ └─ExchangeSender_23 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.c_id, collate: binary]", + " │ └─TableFullScan_22 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo", + " └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))", + " └─TableFullScan_18 10000.00 mpp[tiflash] table:c keep order:false, stats:pseudo" + ], + "Warn": null + } + ] } ]