Skip to content

Commit

Permalink
Pushing comparison containing argument on join in a different way to …
Browse files Browse the repository at this point in the history
…ensure dependencies

Signed-off-by: Florent Poinsard <florent.poinsard@outlook.fr>
  • Loading branch information
frouioui committed Sep 15, 2021
1 parent 10e4478 commit d1bff26
Showing 1 changed file with 138 additions and 58 deletions.
196 changes: 138 additions & 58 deletions go/vt/vtgate/planbuilder/route_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,82 +405,162 @@ func stripDownQuery(from, to sqlparser.SelectStatement) error {
}

func pushJoinPredicate(ctx planningContext, exprs []sqlparser.Expr, tree queryTree) (queryTree, error) {
if len(exprs) == 0 {
return tree, nil
}
switch node := tree.(type) {
case *routeTree:
plan := node.clone().(*routeTree)
err := plan.addPredicate(ctx, exprs...)
return pushJoinPredicateOnRoute(ctx, exprs, node)
case *joinTree:
return pushJoinPredicateOnJoin(ctx, exprs, node)
case *derivedTree:
return pushJoinPredicateOnDerived(ctx, exprs, node)
case *vindexTree:
// vindexFunc cannot accept predicates from the other side of a join
return node, nil
default:
panic(fmt.Sprintf("BUG: unknown type %T", node))
}
}

func pushJoinPredicateOnRoute(ctx planningContext, exprs []sqlparser.Expr, node *routeTree) (queryTree, error) {
plan := node.clone().(*routeTree)
err := plan.addPredicate(ctx, exprs...)
if err != nil {
return nil, err
}
return plan, nil
}

func pushJoinPredicateOnDerived(ctx planningContext, exprs []sqlparser.Expr, node *derivedTree) (queryTree, error) {
plan := node.clone().(*derivedTree)

newExpressions := make([]sqlparser.Expr, 0, len(exprs))
for _, expr := range exprs {
tblInfo, err := ctx.semTable.TableInfoForExpr(expr)
if err != nil {
return nil, err
}
return plan, nil

case *joinTree:
node = node.clone().(*joinTree)

// we break up the predicates so that colnames from the LHS are replaced by arguments
var rhsPreds []sqlparser.Expr
var lhsColumns []*sqlparser.ColName
var lhsVarsName []string
lhsSolves := node.lhs.tableID()
for _, expr := range exprs {
bvName, cols, predicate, err := breakPredicateInLHSandRHS(expr, ctx.semTable, lhsSolves)
if err != nil {
return nil, err
}
lhsColumns = append(lhsColumns, cols...)
lhsVarsName = append(lhsVarsName, bvName...)
rhsPreds = append(rhsPreds, predicate)
}
if lhsColumns != nil && lhsVarsName != nil {
idxs, err := node.pushOutputColumns(lhsColumns, ctx.semTable)
if err != nil {
return nil, err
}
for i, idx := range idxs {
node.vars[lhsVarsName[i]] = idx
}
rewritten, err := semantics.RewriteDerivedExpression(expr, tblInfo)
if err != nil {
return nil, err
}
newExpressions = append(newExpressions, rewritten)
}

rhsPlan, err := pushJoinPredicate(ctx, rhsPreds, node.rhs)
newInner, err := pushJoinPredicate(ctx, newExpressions, plan.inner)
if err != nil {
return nil, err
}

plan.inner = newInner
return plan, nil
}

func pushJoinPredicateOnJoin(ctx planningContext, exprs []sqlparser.Expr, node *joinTree) (queryTree, error) {
node = node.clone().(*joinTree)

var rhsPreds []sqlparser.Expr
var lhsColumns []*sqlparser.ColName
var lhsVarsName []string

for _, expr := range exprs {
// we are pushing argument expression in a different way, if one side
// of the comparison is an argument (*sqlparser.Argument) coming from
// then we can push the expression to either left or right hand side
// (depending on who solves the expression). such expression do not
// need to be "outputed" and sent to the RHS of the join as we would
// usually do.
newNode, err := pushArgumentsOnJoin(ctx, expr, node)
if err != nil {
return nil, err
}
if newNode != nil {
// we are getting a new node from pushArgumentsOnJoin, meaning
// we do not need to break the predicate between LHS and RHS, thus we
// continue onto the following expression.
node = newNode
continue
}

return &joinTree{
lhs: node.lhs,
rhs: rhsPlan,
outer: node.outer,
vars: node.vars,
}, nil
case *derivedTree:
plan := node.clone().(*derivedTree)

newExpressions := make([]sqlparser.Expr, 0, len(exprs))
for _, expr := range exprs {
tblInfo, err := ctx.semTable.TableInfoForExpr(expr)
if err != nil {
return nil, err
}
rewritten, err := semantics.RewriteDerivedExpression(expr, tblInfo)
if err != nil {
return nil, err
}
newExpressions = append(newExpressions, rewritten)
bvName, cols, predicate, err := breakPredicateInLHSandRHS(expr, ctx.semTable, node.lhs.tableID())
if err != nil {
return nil, err
}
lhsColumns = append(lhsColumns, cols...)
lhsVarsName = append(lhsVarsName, bvName...)
rhsPreds = append(rhsPreds, predicate)
}

newInner, err := pushJoinPredicate(ctx, newExpressions, plan.inner)
if lhsColumns != nil && lhsVarsName != nil {
idxs, err := node.pushOutputColumns(lhsColumns, ctx.semTable)
if err != nil {
return nil, err
}
for i, idx := range idxs {
node.vars[lhsVarsName[i]] = idx
}
}

plan.inner = newInner
return plan, nil
case *vindexTree:
// vindexFunc cannot accept predicates from the other side of a join
return node, nil
default:
panic(fmt.Sprintf("BUG: unknown type %T", node))
rhsPlan, err := pushJoinPredicate(ctx, rhsPreds, node.rhs)
if err != nil {
return nil, err
}
return &joinTree{
lhs: node.lhs,
rhs: rhsPlan,
outer: node.outer,
vars: node.vars,
}, nil
}

func pushArgumentsOnJoin(ctx planningContext, expr sqlparser.Expr, node *joinTree) (*joinTree, error) {
cmp, isCmp := expr.(*sqlparser.ComparisonExpr)
if !isCmp {
return nil, nil
}

solvedByLeft, solvedByRight := isComparisonExprSolvedByJoinTree(ctx, cmp, node)
if !solvedByLeft && !solvedByRight {
return nil, nil
}

var nodeToReplace queryTree
if solvedByLeft {
nodeToReplace = node.lhs
} else if solvedByRight {
nodeToReplace = node.rhs
}

newNode, err := pushJoinPredicate(ctx, []sqlparser.Expr{expr}, nodeToReplace)
if err != nil {
return nil, err
}

if solvedByLeft {
node.lhs = newNode
} else if solvedByRight {
node.rhs = newNode
}
return node, nil
}

func isComparisonExprSolvedByJoinTree(ctx planningContext, cmp *sqlparser.ComparisonExpr, node *joinTree) (bool, bool) {
var argExpr sqlparser.Expr
_, isLeftArg := cmp.Left.(sqlparser.Argument)
_, isRightArg := cmp.Right.(sqlparser.Argument)
if isLeftArg {
argExpr = cmp.Right
} else if isRightArg {
argExpr = cmp.Left
} else {
return false, false
}

argDeps := ctx.semTable.RecursiveDeps(argExpr)
solvedByLeft := argDeps.IsSolvedBy(node.lhs.tableID())
solvedByRight := argDeps.IsSolvedBy(node.rhs.tableID())
return solvedByLeft, solvedByRight
}

func breakPredicateInLHSandRHS(
Expand Down

0 comments on commit d1bff26

Please sign in to comment.