Skip to content

Commit

Permalink
executor,planner: Relax projection column expression push down check …
Browse files Browse the repository at this point in the history
…conditions (#52502)

close #52501
  • Loading branch information
yibin87 committed Apr 15, 2024
1 parent bb84d1f commit 3d82fc5
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 268 deletions.
43 changes: 32 additions & 11 deletions pkg/expression/expr_to_pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ func ExpressionsToPBList(ctx EvalContext, exprs []Expression, client kv.Client)
return
}

// ProjectionExpressionsToPBList converts PhysicalProjection's expressions to tipb.Expr list for new plan.
// It doesn't check type for top level column expression, since top level column expression doesn't imply any calculations
func ProjectionExpressionsToPBList(ctx EvalContext, exprs []Expression, client kv.Client) (pbExpr []*tipb.Expr, err error) {
pc := PbConverter{client: client, ctx: ctx}
for _, expr := range exprs {
var v *tipb.Expr
if column, ok := expr.(*Column); ok {
v = pc.columnToPBExpr(column, false)
} else {
v = pc.ExprToPB(expr)
}
if v == nil {
return nil, plannererrors.ErrInternal.GenWithStack("expression %v cannot be pushed down", expr)
}
pbExpr = append(pbExpr, v)
}
return
}

// PbConverter supplies methods to convert TiDB expressions to TiPB.
type PbConverter struct {
client kv.Client
Expand All @@ -69,7 +88,7 @@ func (pc PbConverter) ExprToPB(expr Expression) *tipb.Expr {
case *CorrelatedColumn:
return pc.conOrCorColToPBExpr(expr)
case *Column:
return pc.columnToPBExpr(x)
return pc.columnToPBExpr(x, true)
case *ScalarFunction:
return pc.scalarFuncToPBExpr(x)
}
Expand Down Expand Up @@ -190,20 +209,22 @@ func FieldTypeFromPB(ft *tipb.FieldType) *types.FieldType {
return ft1
}

func (pc PbConverter) columnToPBExpr(column *Column) *tipb.Expr {
func (pc PbConverter) columnToPBExpr(column *Column, checkType bool) *tipb.Expr {
if !pc.client.IsRequestTypeSupported(kv.ReqTypeSelect, int64(tipb.ExprType_ColumnRef)) {
return nil
}
switch column.GetType().GetType() {
case mysql.TypeBit:
if !IsPushDownEnabled(ast.TypeStr(column.GetType().GetType()), kv.TiKV) {
return nil
}
case mysql.TypeSet, mysql.TypeGeometry, mysql.TypeUnspecified:
return nil
case mysql.TypeEnum:
if !IsPushDownEnabled("enum", kv.UnSpecified) {
if checkType {
switch column.GetType().GetType() {
case mysql.TypeBit:
if !IsPushDownEnabled(ast.TypeStr(mysql.TypeBit), kv.TiKV) {
return nil
}
case mysql.TypeSet, mysql.TypeGeometry, mysql.TypeUnspecified:
return nil
case mysql.TypeEnum:
if !IsPushDownEnabled("enum", kv.UnSpecified) {
return nil
}
}
}

Expand Down
17 changes: 17 additions & 0 deletions pkg/expression/expr_to_pb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1913,3 +1913,20 @@ func TestPanicIfPbCodeUnspecified(t *testing.T) {
pc := PbConverter{client: new(mock.Client), ctx: mock.NewContext()}
require.PanicsWithError(t, "unspecified PbCode: *expression.builtinBitAndSig", func() { pc.ExprToPB(fn) })
}

func TestProjectionColumn2Pb(t *testing.T) {
var colExprs []Expression
ctx := mock.NewContext()
client := new(mock.Client)

colExprs = append(colExprs, genColumn(mysql.TypeSet, 1))
colExprs = append(colExprs, genColumn(mysql.TypeShort, 2))
colExprs = append(colExprs, genColumn(mysql.TypeLong, 3))

// TypeSet column can't be converted to PB by default
_, err := ExpressionsToPBList(ctx, colExprs, client)
require.Error(t, err)

_, err = ProjectionExpressionsToPBList(ctx, colExprs, client)
require.NoError(t, err)
}
4 changes: 2 additions & 2 deletions pkg/expression/infer_pushdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ func canExprPushDown(ctx PushDownContext, expr Expression, storeType kv.StoreTyp
}
switch x := expr.(type) {
case *CorrelatedColumn:
return pc.conOrCorColToPBExpr(expr) != nil && pc.columnToPBExpr(&x.Column) != nil
return pc.conOrCorColToPBExpr(expr) != nil && pc.columnToPBExpr(&x.Column, true) != nil
case *Constant:
return pc.conOrCorColToPBExpr(expr) != nil
case *Column:
return pc.columnToPBExpr(x) != nil
return pc.columnToPBExpr(x, true) != nil
case *ScalarFunction:
return canScalarFuncPushDown(ctx, x, storeType)
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/planner/core/casetest/mpp/mpp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ func TestMppJoinExchangeColumnPrune(t *testing.T) {
}

tk.MustExec("set @@tidb_allow_mpp=1;")
tk.MustExec("set @@tidb_enforce_mpp=1;")
tk.MustExec("set @@session.tidb_broadcast_join_threshold_size = 1")
tk.MustExec("set @@session.tidb_broadcast_join_threshold_count = 1")

Expand Down Expand Up @@ -707,6 +708,7 @@ func TestMppFineGrainedJoinAndAgg(t *testing.T) {
}

tk.MustExec("set @@tidb_allow_mpp=1;")
tk.MustExec("set @@tidb_enforce_mpp=1;")
tk.MustExec("set @@session.tidb_broadcast_join_threshold_size = 1")
tk.MustExec("set @@session.tidb_broadcast_join_threshold_count = 1")

Expand Down
152 changes: 3 additions & 149 deletions pkg/planner/core/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,7 @@ func doOptimize(
if err != nil {
return nil, nil, 0, err
}
finalPlan, err := postOptimize(ctx, sctx, physical)
if err != nil {
return nil, nil, 0, err
}
finalPlan := postOptimize(ctx, sctx, physical)

if sessVars.StmtCtx.EnableOptimizerCETrace {
refineCETrace(sctx)
Expand Down Expand Up @@ -412,13 +409,9 @@ func mergeContinuousSelections(p base.PhysicalPlan) {
}
}

func postOptimize(ctx context.Context, sctx base.PlanContext, plan base.PhysicalPlan) (base.PhysicalPlan, error) {
func postOptimize(ctx context.Context, sctx base.PlanContext, plan base.PhysicalPlan) base.PhysicalPlan {
// some cases from update optimize will require avoiding projection elimination.
// see comments ahead of call of DoOptimize in function of buildUpdate().
err := prunePhysicalColumns(sctx, plan)
if err != nil {
return nil, err
}
plan = eliminatePhysicalProjection(plan)
plan = InjectExtraProjection(plan)
mergeContinuousSelections(plan)
Expand All @@ -430,7 +423,7 @@ func postOptimize(ctx context.Context, sctx base.PlanContext, plan base.Physical
disableReuseChunkIfNeeded(sctx, plan)
tryEnableLateMaterialization(sctx, plan)
generateRuntimeFilter(sctx, plan)
return plan, nil
return plan
}

func generateRuntimeFilter(sctx base.PlanContext, plan base.PhysicalPlan) {
Expand All @@ -449,145 +442,6 @@ func generateRuntimeFilter(sctx base.PlanContext, plan base.PhysicalPlan) {
zap.Duration("Cost", time.Since(startRFGenerator)))
}

// prunePhysicalColumns currently only work for MPP(HashJoin<-Exchange).
// Here add projection instead of pruning columns directly for safety considerations.
// And projection is cheap here for it saves the network cost and work in memory.
func prunePhysicalColumns(sctx base.PlanContext, plan base.PhysicalPlan) error {
if tableReader, ok := plan.(*PhysicalTableReader); ok {
if _, isExchangeSender := tableReader.tablePlan.(*PhysicalExchangeSender); isExchangeSender {
err := prunePhysicalColumnsInternal(sctx, tableReader.tablePlan)
if err != nil {
return err
}
}
} else {
for _, child := range plan.Children() {
return prunePhysicalColumns(sctx, child)
}
}
return nil
}

func (p *PhysicalHashJoin) extractUsedCols(parentUsedCols []*expression.Column) (leftCols []*expression.Column, rightCols []*expression.Column) {
for _, eqCond := range p.EqualConditions {
parentUsedCols = append(parentUsedCols, expression.ExtractColumns(eqCond)...)
}
for _, neCond := range p.NAEqualConditions {
parentUsedCols = append(parentUsedCols, expression.ExtractColumns(neCond)...)
}
for _, leftCond := range p.LeftConditions {
parentUsedCols = append(parentUsedCols, expression.ExtractColumns(leftCond)...)
}
for _, rightCond := range p.RightConditions {
parentUsedCols = append(parentUsedCols, expression.ExtractColumns(rightCond)...)
}
for _, otherCond := range p.OtherConditions {
parentUsedCols = append(parentUsedCols, expression.ExtractColumns(otherCond)...)
}
lChild := p.children[0]
rChild := p.children[1]
for _, col := range parentUsedCols {
if lChild.Schema().Contains(col) {
leftCols = append(leftCols, col)
} else if rChild.Schema().Contains(col) {
rightCols = append(rightCols, col)
}
}
return leftCols, rightCols
}

func prunePhysicalColumnForHashJoinChild(sctx base.PlanContext, hashJoin *PhysicalHashJoin, joinUsedCols []*expression.Column, sender *PhysicalExchangeSender) error {
var err error
evalCtx := sctx.GetExprCtx().GetEvalCtx()
joinUsed := expression.GetUsedList(evalCtx, joinUsedCols, sender.Schema())
hashCols := make([]*expression.Column, len(sender.HashCols))
for i, mppCol := range sender.HashCols {
hashCols[i] = mppCol.Col
}
hashUsed := expression.GetUsedList(evalCtx, hashCols, sender.Schema())

needPrune := false
usedExprs := make([]expression.Expression, len(sender.Schema().Columns))
prunedSchema := sender.Schema().Clone()
for i := len(joinUsed) - 1; i >= 0; i-- {
usedExprs[i] = sender.Schema().Columns[i]
if !joinUsed[i] && !hashUsed[i] {
needPrune = true
usedExprs = append(usedExprs[:i], usedExprs[i+1:]...)
prunedSchema.Columns = append(prunedSchema.Columns[:i], prunedSchema.Columns[i+1:]...)
}
}

if needPrune && len(sender.children) > 0 {
ch := sender.children[0]
proj := PhysicalProjection{
Exprs: usedExprs,
}.Init(sctx, ch.StatsInfo(), ch.QueryBlockOffset())

proj.SetSchema(prunedSchema)
proj.SetChildren(ch)
sender.children[0] = proj

// Resolve Indices from bottom to up
err = proj.ResolveIndicesItself()
if err != nil {
return err
}
err = sender.ResolveIndicesItself()
if err != nil {
return err
}
err = hashJoin.ResolveIndicesItself()
if err != nil {
return err
}
}
return err
}

func prunePhysicalColumnsInternal(sctx base.PlanContext, plan base.PhysicalPlan) error {
var err error
switch x := plan.(type) {
case *PhysicalHashJoin:
schemaColumns := x.Schema().Clone().Columns
leftCols, rightCols := x.extractUsedCols(schemaColumns)
matchPattern := false
for i := 0; i <= 1; i++ {
// Pattern: HashJoin <- ExchangeReceiver <- ExchangeSender
matchPattern = false
var exchangeSender *PhysicalExchangeSender
if receiver, ok := x.children[i].(*PhysicalExchangeReceiver); ok {
exchangeSender, matchPattern = receiver.children[0].(*PhysicalExchangeSender)
}

if matchPattern {
if i == 0 {
err = prunePhysicalColumnForHashJoinChild(sctx, x, leftCols, exchangeSender)
} else {
err = prunePhysicalColumnForHashJoinChild(sctx, x, rightCols, exchangeSender)
}
if err != nil {
return nil
}
}

/// recursively travel the physical plan
err = prunePhysicalColumnsInternal(sctx, x.children[i])
if err != nil {
return nil
}
}
default:
for _, child := range x.Children() {
err = prunePhysicalColumnsInternal(sctx, child)
if err != nil {
return err
}
}
}
return nil
}

// tryEnableLateMaterialization tries to push down some filter conditions to the table scan operator
// @brief: push down some filter conditions to the table scan operator
// @param: sctx: session context
Expand Down

0 comments on commit 3d82fc5

Please sign in to comment.