Skip to content

Commit

Permalink
planner: Add HashJoin<-Receiver specific physicalPlan column pruner (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
yibin87 committed Nov 2, 2022
1 parent 5b0be9a commit e245b84
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 22 deletions.
46 changes: 46 additions & 0 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4657,6 +4657,52 @@ func TestMppJoinDecimal(t *testing.T) {
}
}

func TestMppJoinExchangeColumnPrune(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("drop table if exists tt")
tk.MustExec("create table t (c1 int, c2 int, c3 int NOT NULL, c4 int NOT NULL, c5 int)")
tk.MustExec("create table tt (b1 int)")
tk.MustExec("analyze table t")
tk.MustExec("analyze table tt")

// Create virtual tiflash replica info.
dom := domain.GetDomain(tk.Session())
is := dom.InfoSchema()
db, exists := is.SchemaByName(model.NewCIStr("test"))
require.True(t, exists)
for _, tblInfo := range db.Tables {
if tblInfo.Name.L == "t" || tblInfo.Name.L == "tt" {
tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{
Count: 1,
Available: true,
}
}
}

tk.MustExec("set @@tidb_allow_mpp=1;")
tk.MustExec("set @@session.tidb_broadcast_join_threshold_size = 1")
tk.MustExec("set @@session.tidb_broadcast_join_threshold_count = 1")

var input []string
var output []struct {
SQL string
Plan []string
}
integrationSuiteData := core.GetIntegrationSuiteData()
integrationSuiteData.LoadTestCases(t, &input, &output)
for i, tt := range input {
testdata.OnRecord(func() {
output[i].SQL = tt
output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
})
res := tk.MustQuery(tt)
res.Check(testkit.Rows(output[i].Plan...))
}
}

func TestMppAggTopNWithJoin(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
Expand Down
151 changes: 148 additions & 3 deletions planner/core/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,10 @@ func DoOptimize(ctx context.Context, sctx sessionctx.Context, flag uint64, logic
if err != nil {
return nil, 0, err
}
finalPlan := postOptimize(sctx, physical)
finalPlan, err := postOptimize(sctx, physical)
if err != nil {
return nil, 0, err
}

if sctx.GetSessionVars().StmtCtx.EnableOptimizerCETrace {
refineCETrace(sctx)
Expand Down Expand Up @@ -372,9 +375,13 @@ func mergeContinuousSelections(p PhysicalPlan) {
}
}

func postOptimize(sctx sessionctx.Context, plan PhysicalPlan) PhysicalPlan {
func postOptimize(sctx sessionctx.Context, plan PhysicalPlan) (PhysicalPlan, error) {
// some cases from update optimize will require avoiding projection elimination.
// see comments ahead of call of DoOptimize in function of buildUpdate().
err := prunePhysicalColumns(sctx, plan)
if err != nil {
return nil, err
}
plan = eliminatePhysicalProjection(plan)
plan = InjectExtraProjection(plan)
mergeContinuousSelections(plan)
Expand All @@ -383,7 +390,145 @@ func postOptimize(sctx sessionctx.Context, plan PhysicalPlan) PhysicalPlan {
handleFineGrainedShuffle(sctx, plan)
checkPlanCacheable(sctx, plan)
propagateProbeParents(plan, nil)
return plan
return plan, nil
}

// prunePhysicalColumns currently only work for MPP(HashJoin<-Exchange).
// Here add projection instead of pruning columns directly for safety considerations.
// And projection is cheap here for it saves the network cost and work in memory.
func prunePhysicalColumns(sctx sessionctx.Context, plan PhysicalPlan) error {
if tableReader, ok := plan.(*PhysicalTableReader); ok {
if _, isExchangeSender := tableReader.tablePlan.(*PhysicalExchangeSender); isExchangeSender {
err := prunePhysicalColumnsInternal(sctx, tableReader.tablePlan)
if err != nil {
return err
}
}
} else {
for _, child := range plan.Children() {
return prunePhysicalColumns(sctx, child)
}
}
return nil
}

func (p *PhysicalHashJoin) extractUsedCols(parentUsedCols []*expression.Column) (leftCols []*expression.Column, rightCols []*expression.Column) {
for _, eqCond := range p.EqualConditions {
parentUsedCols = append(parentUsedCols, expression.ExtractColumns(eqCond)...)
}
for _, neCond := range p.NAEqualConditions {
parentUsedCols = append(parentUsedCols, expression.ExtractColumns(neCond)...)
}
for _, leftCond := range p.LeftConditions {
parentUsedCols = append(parentUsedCols, expression.ExtractColumns(leftCond)...)
}
for _, rightCond := range p.RightConditions {
parentUsedCols = append(parentUsedCols, expression.ExtractColumns(rightCond)...)
}
for _, otherCond := range p.OtherConditions {
parentUsedCols = append(parentUsedCols, expression.ExtractColumns(otherCond)...)
}
lChild := p.children[0]
rChild := p.children[1]
for _, col := range parentUsedCols {
if lChild.Schema().Contains(col) {
leftCols = append(leftCols, col)
} else if rChild.Schema().Contains(col) {
rightCols = append(rightCols, col)
}
}
return leftCols, rightCols
}

func prunePhysicalColumnForHashJoinChild(sctx sessionctx.Context, hashJoin *PhysicalHashJoin, joinUsedCols []*expression.Column, sender *PhysicalExchangeSender) error {
var err error
joinUsed := expression.GetUsedList(joinUsedCols, sender.Schema())
hashCols := make([]*expression.Column, len(sender.HashCols))
for i, mppCol := range sender.HashCols {
hashCols[i] = mppCol.Col
}
hashUsed := expression.GetUsedList(hashCols, sender.Schema())

needPrune := false
usedExprs := make([]expression.Expression, len(sender.Schema().Columns))
prunedSchema := sender.Schema().Clone()
for i := len(joinUsed) - 1; i >= 0; i-- {
usedExprs[i] = sender.Schema().Columns[i]
if !joinUsed[i] && !hashUsed[i] {
needPrune = true
usedExprs = append(usedExprs[:i], usedExprs[i+1:]...)
prunedSchema.Columns = append(prunedSchema.Columns[:i], prunedSchema.Columns[i+1:]...)
}
}

if needPrune && len(sender.children) > 0 {
ch := sender.children[0]
proj := PhysicalProjection{
Exprs: usedExprs,
}.Init(sctx, ch.statsInfo(), ch.SelectBlockOffset())

proj.SetSchema(prunedSchema)
proj.SetChildren(ch)
sender.children[0] = proj

// Resolve Indices from bottom to up
err = proj.ResolveIndicesItself()
if err != nil {
return err
}
err = sender.ResolveIndicesItself()
if err != nil {
return err
}
err = hashJoin.ResolveIndicesItself()
if err != nil {
return err
}
}
return err
}

func prunePhysicalColumnsInternal(sctx sessionctx.Context, plan PhysicalPlan) error {
var err error
switch x := plan.(type) {
case *PhysicalHashJoin:
schemaColumns := x.Schema().Clone().Columns
leftCols, rightCols := x.extractUsedCols(schemaColumns)
matchPattern := false
for i := 0; i <= 1; i++ {
// Pattern: HashJoin <- ExchangeReceiver <- ExchangeSender
matchPattern = false
var exchangeSender *PhysicalExchangeSender
if receiver, ok := x.children[i].(*PhysicalExchangeReceiver); ok {
exchangeSender, matchPattern = receiver.children[0].(*PhysicalExchangeSender)
}

if matchPattern {
if i == 0 {
err = prunePhysicalColumnForHashJoinChild(sctx, x, leftCols, exchangeSender)
} else {
err = prunePhysicalColumnForHashJoinChild(sctx, x, rightCols, exchangeSender)
}
if err != nil {
return nil
}
}

/// recursively travel the physical plan
err = prunePhysicalColumnsInternal(sctx, x.children[i])
if err != nil {
return nil
}
}
default:
for _, child := range x.Children() {
err = prunePhysicalColumnsInternal(sctx, child)
if err != nil {
return err
}
}
}
return nil
}

// Only for MPP(Window<-[Sort]<-ExchangeReceiver<-ExchangeSender).
Expand Down
103 changes: 103 additions & 0 deletions planner/core/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"reflect"
"testing"

"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/planner/property"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -288,3 +290,104 @@ func TestHandleFineGrainedShuffle(t *testing.T) {
hashSender1.children = []PhysicalPlan{tableScan1}
start(partWindow, expStreamCount, 3, 0)
}

// Test for core.prunePhysicalColumns()
func TestPrunePhysicalColumns(t *testing.T) {
sctx := MockContext()
col0 := &expression.Column{
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
RetType: types.NewFieldType(mysql.TypeLonglong),
}
col1 := &expression.Column{
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
RetType: types.NewFieldType(mysql.TypeLonglong),
}
col2 := &expression.Column{
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
RetType: types.NewFieldType(mysql.TypeLonglong),
}
col3 := &expression.Column{
UniqueID: sctx.GetSessionVars().AllocPlanColumnID(),
RetType: types.NewFieldType(mysql.TypeLonglong),
}

// Join[col2, col3; col2==col3] <- ExchangeReceiver[col0, col1, col2] <- ExchangeSender[col0, col1, col2] <- Selection[col0, col1, col2; col0 < col1] <- TableScan[col0, col1, col2]
// <- ExchangeReceiver1[col3] <- ExchangeSender1[col3] <- TableScan1[col3]
tableReader := &PhysicalTableReader{}
passSender := &PhysicalExchangeSender{
ExchangeType: tipb.ExchangeType_PassThrough,
}
hashJoin := &PhysicalHashJoin{}
recv := &PhysicalExchangeReceiver{}
recv1 := &PhysicalExchangeReceiver{}
hashSender := &PhysicalExchangeSender{
ExchangeType: tipb.ExchangeType_Hash,
}
hashSender1 := &PhysicalExchangeSender{
ExchangeType: tipb.ExchangeType_Hash,
}
tableScan := &PhysicalTableScan{}
tableScan1 := &PhysicalTableScan{}

tableReader.tablePlan = passSender
passSender.children = []PhysicalPlan{hashJoin}
hashJoin.children = []PhysicalPlan{recv, recv1}
selection := &PhysicalSelection{}

cond, err := expression.NewFunction(sctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), col2, col3)
require.True(t, err == nil)
sf, isSF := cond.(*expression.ScalarFunction)
require.True(t, isSF)
hashJoin.EqualConditions = append(hashJoin.EqualConditions, sf)
hashJoin.LeftJoinKeys = append(hashJoin.LeftJoinKeys, col2)
hashJoin.RightJoinKeys = append(hashJoin.RightJoinKeys, col3)
hashJoinSchema := make([]*expression.Column, 0)
hashJoinSchema = append(hashJoinSchema, col3)
hashJoin.SetSchema(expression.NewSchema(hashJoinSchema...))

selection.SetChildren(tableScan)
hashSender.SetChildren(selection)
var partitionCols = make([]*property.MPPPartitionColumn, 0, 1)
partitionCols = append(partitionCols, &property.MPPPartitionColumn{
Col: col2,
CollateID: property.GetCollateIDByNameForPartition(col2.GetType().GetCollate()),
})
hashSender.HashCols = partitionCols
recv.SetChildren(hashSender)
tableScan.Schema().Columns = append(tableScan.Schema().Columns, col0, col1, col2)

hashSender1.SetChildren(tableScan1)
recv1.SetChildren(hashSender1)
tableScan1.Schema().Columns = append(tableScan1.Schema().Columns, col3)

prunePhysicalColumns(sctx, tableReader)

// Optimized Plan:
// Join[col2, col3; col2==col3] <- ExchangeReceiver[col2] <- ExchangeSender[col2;col2] <- Projection[col2] <- Selection[col0, col1, col2; col0 < col1] <- TableScan[col0, col1, col2]
// <- ExchangeReceiver1[col3] <- ExchangeSender1[col3] <- TableScan1[col3]
require.True(t, len(recv.Schema().Columns) == 1)
require.True(t, recv.Schema().Contains(col2))
require.False(t, recv.Schema().Contains(col0))
require.False(t, recv.Schema().Contains(col1))
require.True(t, len(recv.children[0].Children()) == 1)
physicalProj := recv.children[0].Children()[0]
switch x := physicalProj.(type) {
case *PhysicalProjection:
require.True(t, x.Schema().Contains(col2))
require.False(t, recv.Schema().Contains(col0))
require.False(t, recv.Schema().Contains(col1))
// Check PhysicalProj resolved index
require.True(t, len(x.Exprs) == 1)
require.True(t, x.Exprs[0].(*expression.Column).Index == 2)
default:
require.True(t, false)
}

// Check resolved indices
require.True(t, hashJoin.LeftJoinKeys[0].Index == 0)
require.True(t, hashSender.HashCols[0].Col.Index == 0)

// Check recv1,no changes
require.True(t, len(recv1.Schema().Columns) == 1)
require.True(t, recv1.Schema().Contains(col3))
}

0 comments on commit e245b84

Please sign in to comment.