Skip to content

Commit

Permalink
planner: fix bug that planner generates wrong 2 phase aggregate plan …
Browse files Browse the repository at this point in the history
…for TiFlash (#34779) (#34932)

close #34682
  • Loading branch information
ti-srebot committed May 25, 2022
1 parent f876236 commit 3ed672f
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 1 deletion.
27 changes: 27 additions & 0 deletions executor/tiflash_test.go
Expand Up @@ -1199,3 +1199,30 @@ func TestTiflashPartitionTableScan(t *testing.T) {
tk.MustQuery("select count(*) from t where a < 12;").Check(testkit.Rows("2"))
wg.Wait()
}

func TestAggPushDownCountStar(t *testing.T) {
store, clean := testkit.CreateMockStore(t, withMockTiFlash(2))
defer clean()
tk := testkit.NewTestKit(t, store)

tk.MustExec("use test")
tk.MustExec("drop table if exists c")
tk.MustExec("drop table if exists o")
tk.MustExec("create table c(c_id bigint primary key)")
tk.MustExec("create table o(o_id bigint primary key, c_id bigint not null)")
tk.MustExec("alter table c set tiflash replica 1")
tb := external.GetTableByName(t, tk, "test", "c")
err := domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)
tk.MustExec("alter table o set tiflash replica 1")
tb = external.GetTableByName(t, tk, "test", "o")
err = domain.GetDomain(tk.Session()).DDL().UpdateTableReplicaInfo(tk.Session(), tb.Meta().ID, true)
require.NoError(t, err)
tk.MustExec("insert into c values(1),(2),(3),(4),(5)")
tk.MustExec("insert into o values(1,1),(2,1),(3,2),(4,2),(5,2)")

tk.MustExec("set @@tidb_enforce_mpp=1")
tk.MustExec("set @@tidb_opt_agg_push_down=1")

tk.MustQuery("select count(*) from c, o where c.c_id=o.c_id").Check(testkit.Rows("5"))
}
54 changes: 54 additions & 0 deletions planner/core/enforce_mpp_test.go
Expand Up @@ -384,3 +384,57 @@ func TestEnforceMPPWarning4(t *testing.T) {
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
}

// Test agg push down for MPP mode
func TestMPP2PhaseAggPushDown(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
tk := testkit.NewTestKit(t, store)

// test table
tk.MustExec("use test")
tk.MustExec("drop table if exists c")
tk.MustExec("drop table if exists o")
tk.MustExec("create table c(c_id bigint)")
tk.MustExec("create table o(o_id bigint, c_id bigint not null)")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "c" || tblInfo.Name.L == "o" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
}
}
}

var input []string
var output []struct {
SQL string
Plan []string
Warn []string
}
enforceMPPSuiteData := plannercore.GetEnforceMPPSuiteData()
enforceMPPSuiteData.GetTestCases(t, &input, &output)
for i, tt := range input {
testdata.OnRecord(func() {
output[i].SQL = tt
})
if strings.HasPrefix(tt, "set") || strings.HasPrefix(tt, "UPDATE") {
tk.MustExec(tt)
continue
}
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
output[i].Warn = testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings())
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
require.Equal(t, output[i].Warn, testdata.ConvertSQLWarnToStrings(tk.Session().GetSessionVars().StmtCtx.GetWarnings()))
}
}
13 changes: 12 additions & 1 deletion planner/core/exhaust_physical_plans.go
Expand Up @@ -2589,6 +2589,11 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
if prop.MPPPartitionTp == property.BroadcastType {
return nil
}

// Is this aggregate a final stage aggregate?
// Final agg can't be split into multi-stage aggregate
hasFinalAgg := len(la.AggFuncs) > 0 && la.AggFuncs[0].Mode == aggregation.FinalMode

if len(la.GroupByItems) > 0 {
partitionCols := la.GetPotentialPartitionKeys()
// trying to match the required parititions.
Expand All @@ -2612,6 +2617,11 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
hashAggs = append(hashAggs, agg)
}

// Final agg can't be split into multi-stage aggregate, so exit early
if hasFinalAgg {
return
}

// 2-phase agg
childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, MPPPartitionTp: property.AnyType, RejectSort: true}
agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp)
Expand All @@ -2628,7 +2638,7 @@ func (la *LogicalAggregation) tryToGetMppHashAggs(prop *property.PhysicalPropert
agg.MppRunMode = MppTiDB
hashAggs = append(hashAggs, agg)
}
} else {
} else if !hasFinalAgg {
// TODO: support scalar agg in MPP, merge the final result to one node
childProp := &property.PhysicalProperty{TaskTp: property.MppTaskType, ExpectedCnt: math.MaxFloat64, RejectSort: true}
agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), childProp)
Expand Down Expand Up @@ -2671,6 +2681,7 @@ func (la *LogicalAggregation) getHashAggs(prop *property.PhysicalProperty) []Phy
if prop.IsFlashProp() {
taskTypes = []property.TaskType{prop.TaskTp}
}

for _, taskTp := range taskTypes {
if taskTp == property.MppTaskType {
mppAggs := la.tryToGetMppHashAggs(prop)
Expand Down
9 changes: 9 additions & 0 deletions planner/core/testdata/enforce_mpp_suite_in.json
Expand Up @@ -85,5 +85,14 @@
"explain select a from t where t.a>1 or t.a not in (select a from t); -- now it's supported -- 8. anti left outer semi join",
"explain select a from t where t.a not in (select a from s where t.a<1); -- 9. non left join has left conditions"
]
},
{
"name": "TestMPP2PhaseAggPushDown",
"cases": [
"set @@tidb_allow_mpp=1;set @@tidb_enforce_mpp=1;set @@tidb_opt_agg_push_down=1;",
"EXPLAIN select count(*) from c, o where c.c_id=o.c_id; -- 1. test agg push down, scalar aggregate",
"EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column",
"EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column"
]
}
]
75 changes: 75 additions & 0 deletions planner/core/testdata/enforce_mpp_suite_out.json
Expand Up @@ -634,5 +634,80 @@
]
}
]
},
{
"Name": "TestMPP2PhaseAggPushDown",
"Cases": [
{
"SQL": "set @@tidb_allow_mpp=1;set @@tidb_enforce_mpp=1;set @@tidb_opt_agg_push_down=1;",
"Plan": null,
"Warn": null
},
{
"SQL": "EXPLAIN select count(*) from c, o where c.c_id=o.c_id; -- 1. test agg push down, scalar aggregate",
"Plan": [
"HashAgg_13 1.00 root funcs:count(Column#7)->Column#6",
"└─TableReader_35 9990.00 root data:ExchangeSender_34",
" └─ExchangeSender_34 9990.00 mpp[tiflash] ExchangeType: PassThrough",
" └─HashJoin_14 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
" ├─ExchangeReceiver_26(Build) 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_25 8000.00 mpp[tiflash] ExchangeType: Broadcast",
" │ └─Projection_24 8000.00 mpp[tiflash] Column#7, test.o.c_id",
" │ └─HashAgg_19 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#7, funcs:firstrow(test.o.c_id)->test.o.c_id",
" │ └─ExchangeReceiver_23 10000.00 mpp[tiflash] ",
" │ └─ExchangeSender_22 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.c_id, collate: binary]",
" │ └─TableFullScan_21 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
" └─Selection_18(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
" └─TableFullScan_17 10000.00 mpp[tiflash] table:c keep order:false, stats:pseudo"
],
"Warn": null
},
{
"SQL": "EXPLAIN select o.o_id, count(*) from c, o where c.c_id=o.c_id group by o.o_id; -- 2. test agg push down, group by non-join column",
"Plan": [
"TableReader_78 8000.00 root data:ExchangeSender_77",
"└─ExchangeSender_77 8000.00 mpp[tiflash] ExchangeType: PassThrough",
" └─Projection_10 8000.00 mpp[tiflash] test.o.o_id, Column#6",
" └─Projection_76 8000.00 mpp[tiflash] Column#6, test.o.o_id",
" └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.o_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.o_id",
" └─ExchangeReceiver_71 9990.00 mpp[tiflash] ",
" └─ExchangeSender_70 9990.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.o_id, collate: binary]",
" └─HashJoin_69 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
" ├─ExchangeReceiver_27(Build) 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_26 8000.00 mpp[tiflash] ExchangeType: Broadcast",
" │ └─Projection_25 8000.00 mpp[tiflash] Column#7, Column#8, test.o.o_id, test.o.c_id",
" │ └─HashAgg_20 8000.00 mpp[tiflash] group by:test.o.c_id, test.o.o_id, funcs:count(1)->Column#7, funcs:firstrow(test.o.o_id)->Column#8, funcs:firstrow(test.o.o_id)->test.o.o_id, funcs:firstrow(test.o.c_id)->test.o.c_id",
" │ └─ExchangeReceiver_24 10000.00 mpp[tiflash] ",
" │ └─ExchangeSender_23 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.o_id, collate: binary], [name: test.o.c_id, collate: binary]",
" │ └─TableFullScan_22 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c keep order:false, stats:pseudo"
],
"Warn": null
},
{
"SQL": "EXPLAIN select o.c_id, count(*) from c, o where c.c_id=o.c_id group by o.c_id; -- 3. test agg push down, group by join column",
"Plan": [
"TableReader_78 8000.00 root data:ExchangeSender_77",
"└─ExchangeSender_77 8000.00 mpp[tiflash] ExchangeType: PassThrough",
" └─Projection_10 8000.00 mpp[tiflash] test.o.c_id, Column#6",
" └─Projection_76 8000.00 mpp[tiflash] Column#6, test.o.c_id",
" └─HashAgg_75 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(Column#7)->Column#6, funcs:firstrow(Column#8)->test.o.c_id",
" └─ExchangeReceiver_71 9990.00 mpp[tiflash] ",
" └─ExchangeSender_70 9990.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.c_id, collate: binary]",
" └─HashJoin_69 9990.00 mpp[tiflash] inner join, equal:[eq(test.c.c_id, test.o.c_id)]",
" ├─ExchangeReceiver_27(Build) 8000.00 mpp[tiflash] ",
" │ └─ExchangeSender_26 8000.00 mpp[tiflash] ExchangeType: Broadcast",
" │ └─Projection_25 8000.00 mpp[tiflash] Column#7, Column#8, test.o.c_id",
" │ └─HashAgg_20 8000.00 mpp[tiflash] group by:test.o.c_id, funcs:count(1)->Column#7, funcs:firstrow(test.o.c_id)->Column#8, funcs:firstrow(test.o.c_id)->test.o.c_id",
" │ └─ExchangeReceiver_24 10000.00 mpp[tiflash] ",
" │ └─ExchangeSender_23 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.o.c_id, collate: binary]",
" │ └─TableFullScan_22 10000.00 mpp[tiflash] table:o keep order:false, stats:pseudo",
" └─Selection_19(Probe) 9990.00 mpp[tiflash] not(isnull(test.c.c_id))",
" └─TableFullScan_18 10000.00 mpp[tiflash] table:c keep order:false, stats:pseudo"
],
"Warn": null
}
]
}
]

0 comments on commit 3ed672f

Please sign in to comment.