18
18
19
19
package org .apache .flink .table .planner .plan .rules .logical ;
20
20
21
+ import org .apache .flink .api .java .tuple .Tuple2 ;
21
22
import org .apache .flink .table .api .TableException ;
22
23
import org .apache .flink .table .planner .plan .nodes .logical .FlinkLogicalJoin ;
23
24
import org .apache .flink .table .planner .plan .nodes .physical .stream .StreamPhysicalMultiJoin ;
24
25
import org .apache .flink .table .planner .plan .utils .IntervalJoinUtil ;
25
26
26
27
import org .apache .calcite .plan .RelOptRuleCall ;
28
+ import org .apache .calcite .plan .RelOptTable ;
27
29
import org .apache .calcite .plan .RelOptUtil ;
28
30
import org .apache .calcite .plan .RelRule ;
29
31
import org .apache .calcite .plan .hep .HepRelVertex ;
35
37
import org .apache .calcite .rel .core .JoinRelType ;
36
38
import org .apache .calcite .rel .logical .LogicalJoin ;
37
39
import org .apache .calcite .rel .logical .LogicalSnapshot ;
40
+ import org .apache .calcite .rel .logical .LogicalTableScan ;
41
+ import org .apache .calcite .rel .metadata .RelColumnOrigin ;
42
+ import org .apache .calcite .rel .metadata .RelMetadataQuery ;
38
43
import org .apache .calcite .rel .rules .CoreRules ;
39
44
import org .apache .calcite .rel .rules .FilterMultiJoinMergeRule ;
40
45
import org .apache .calcite .rel .rules .MultiJoin ;
41
46
import org .apache .calcite .rel .rules .ProjectMultiJoinMergeRule ;
42
47
import org .apache .calcite .rel .rules .TransformationRule ;
43
- import org .apache .calcite .rel .type .RelDataTypeField ;
44
48
import org .apache .calcite .rex .RexBuilder ;
49
+ import org .apache .calcite .rex .RexCall ;
45
50
import org .apache .calcite .rex .RexInputRef ;
46
51
import org .apache .calcite .rex .RexNode ;
47
52
import org .apache .calcite .rex .RexUtil ;
52
57
import org .apache .calcite .util .Pair ;
53
58
import org .immutables .value .Value ;
54
59
55
- import javax .annotation .Nullable ;
56
-
57
60
import java .util .ArrayList ;
58
61
import java .util .Collections ;
59
62
import java .util .HashMap ;
63
+ import java .util .HashSet ;
60
64
import java .util .List ;
61
65
import java .util .Map ;
66
+ import java .util .Set ;
67
+ import java .util .stream .Collectors ;
62
68
63
69
/**
64
70
* Flink Planner rule to flatten a tree of {@link Join}s into a single {@link MultiJoin} with N
@@ -300,14 +306,7 @@ private List<RelNode> combineInputs(
300
306
ImmutableIntList leftKeys = joinInfo .leftKeys ;
301
307
ImmutableIntList rightKeys = joinInfo .rightKeys ;
302
308
303
- if (canCombine (
304
- left ,
305
- leftKeys ,
306
- join .getJoinType (),
307
- join .getJoinType ().generatesNullsOnLeft (),
308
- true ,
309
- inputNullGenFieldList ,
310
- 0 )) {
309
+ if (canCombine (left , join )) {
311
310
final MultiJoin leftMultiJoin = (MultiJoin ) left ;
312
311
for (int i = 0 ; i < leftMultiJoin .getInputs ().size (); i ++) {
313
312
newInputs .add (leftMultiJoin .getInput (i ));
@@ -322,14 +321,7 @@ private List<RelNode> combineInputs(
322
321
joinFieldRefCountsList .add (new int [left .getRowType ().getFieldCount ()]);
323
322
}
324
323
325
- if (canCombine (
326
- right ,
327
- rightKeys ,
328
- join .getJoinType (),
329
- join .getJoinType ().generatesNullsOnRight (),
330
- false ,
331
- inputNullGenFieldList ,
332
- left .getRowType ().getFieldCount ())) {
324
+ if (canCombine (right , join )) {
333
325
final MultiJoin rightMultiJoin = (MultiJoin ) right ;
334
326
for (int i = 0 ; i < rightMultiJoin .getInputs ().size (); i ++) {
335
327
newInputs .add (rightMultiJoin .getInput (i ));
@@ -369,27 +361,19 @@ private void combineJoinInfo(
369
361
JoinInfo joinInfo = joinRel .analyzeCondition ();
370
362
ImmutableIntList leftKeys = joinInfo .leftKeys ;
371
363
final RexBuilder rexBuilder = joinRel .getCluster ().getRexBuilder ();
372
- boolean leftCombined =
373
- canCombine (
374
- left ,
375
- leftKeys ,
376
- joinType ,
377
- joinType .generatesNullsOnLeft (),
378
- true ,
379
- inputNullGenFieldList ,
380
- 0 );
364
+ boolean leftCombined = canCombine (left , joinRel );
381
365
switch (joinType ) {
382
366
case LEFT :
383
367
if (leftCombined ) {
384
- copyJoinInfo ((MultiJoin ) left , joinSpecs , 0 , null , null );
368
+ copyJoinInfo ((MultiJoin ) left , joinSpecs );
385
369
} else {
386
370
joinSpecs .add (Pair .of (JoinRelType .INNER , rexBuilder .makeLiteral (true )));
387
371
}
388
372
joinSpecs .add (Pair .of (joinType , joinRel .getCondition ()));
389
373
break ;
390
374
case INNER :
391
375
if (leftCombined ) {
392
- copyJoinInfo ((MultiJoin ) left , joinSpecs , 0 , null , null );
376
+ copyJoinInfo ((MultiJoin ) left , joinSpecs );
393
377
} else {
394
378
joinSpecs .add (Pair .of (JoinRelType .INNER , rexBuilder .makeLiteral (true )));
395
379
}
@@ -408,45 +392,13 @@ private void combineJoinInfo(
408
392
*
409
393
* @param multiJoin the source MultiJoin
410
394
* @param destJoinSpecs the list where the join types and conditions will be copied
411
- * @param adjustmentAmount if > 0, the amount the RexInputRefs in the join conditions need to
412
- * be adjusted by
413
- * @param srcFields the source fields that the original join conditions are referencing
414
- * @param destFields the destination fields that the new join conditions
415
395
*/
416
- private void copyJoinInfo (
417
- MultiJoin multiJoin ,
418
- List <Pair <JoinRelType , RexNode >> destJoinSpecs ,
419
- int adjustmentAmount ,
420
- @ Nullable List <RelDataTypeField > srcFields ,
421
- @ Nullable List <RelDataTypeField > destFields ) {
396
+ private void copyJoinInfo (MultiJoin multiJoin , List <Pair <JoinRelType , RexNode >> destJoinSpecs ) {
422
397
// getOuterJoinConditions are return all join conditions since that's how we use it
423
398
final List <Pair <JoinRelType , RexNode >> srcJoinSpecs =
424
399
Pair .zip (multiJoin .getJoinTypes (), multiJoin .getOuterJoinConditions ());
425
400
426
- if (adjustmentAmount == 0 ) {
427
- destJoinSpecs .addAll (srcJoinSpecs );
428
- } else {
429
- assert srcFields != null ;
430
- assert destFields != null ;
431
- int nFields = srcFields .size ();
432
- int [] adjustments = new int [nFields ];
433
- for (int idx = 0 ; idx < nFields ; idx ++) {
434
- adjustments [idx ] = adjustmentAmount ;
435
- }
436
- for (Pair <JoinRelType , RexNode > src : srcJoinSpecs ) {
437
- destJoinSpecs .add (
438
- Pair .of (
439
- src .left ,
440
- src .right == null
441
- ? null
442
- : src .right .accept (
443
- new RelOptUtil .RexInputConverter (
444
- multiJoin .getCluster ().getRexBuilder (),
445
- srcFields ,
446
- destFields ,
447
- adjustments ))));
448
- }
449
- }
401
+ destJoinSpecs .addAll (srcJoinSpecs );
450
402
}
451
403
452
404
/**
@@ -474,14 +426,7 @@ private List<RexNode> combineJoinFilters(
474
426
if ((joinType != JoinRelType .LEFT )) {
475
427
filters .add (join .getCondition ());
476
428
}
477
- if (canCombine (
478
- left ,
479
- leftKeys ,
480
- joinType ,
481
- joinType .generatesNullsOnLeft (),
482
- true ,
483
- inputNullGenFieldList ,
484
- 0 )) {
429
+ if (canCombine (left , join )) {
485
430
filters .add (((MultiJoin ) left ).getJoinFilter ());
486
431
}
487
432
@@ -497,55 +442,148 @@ private List<RexNode> combineJoinFilters(
497
442
* href="https://issues.apache.org/jira/browse/FLINK-37890">FLINK-37890</a>.
498
443
*
499
444
* @param input input into a join
500
- * @param nullGenerating true if the input is null generating
501
445
* @return true if the input can be combined into a parent MultiJoin
502
446
*/
503
- private boolean canCombine (
504
- RelNode input ,
505
- ImmutableIntList joinKeys ,
506
- JoinRelType joinType ,
507
- boolean nullGenerating ,
508
- boolean isLeft ,
509
- List <Boolean > inputNullGenFieldList ,
510
- int beginIndex ) {
447
+ private boolean canCombine (RelNode input , Join origJoin ) {
511
448
if (input instanceof MultiJoin ) {
512
449
MultiJoin join = (MultiJoin ) input ;
513
- if (join .isFullOuterJoin () || nullGenerating ) {
450
+
451
+ if (join .isFullOuterJoin ()) {
514
452
return false ;
515
453
}
516
454
517
- if (joinType == JoinRelType .LEFT ) {
518
- if (!isLeft ) {
519
- return false ;
520
- } else {
521
- for (int joinKey : joinKeys ) {
522
- if (inputNullGenFieldList .get (joinKey + beginIndex )) {
523
- return false ;
524
- }
525
- }
526
- }
527
- } else if (joinType == JoinRelType .RIGHT ) {
528
- if (isLeft ) {
529
- return false ;
530
- } else {
531
- for (int joinKey : joinKeys ) {
532
- if (inputNullGenFieldList .get (joinKey + beginIndex )) {
533
- return false ;
534
- }
535
- }
536
- }
537
- } else if (joinType == JoinRelType .INNER ) {
538
- for (int joinKey : joinKeys ) {
539
- if (inputNullGenFieldList .get (joinKey + beginIndex )) {
540
- return false ;
541
- }
455
+ return haveCommonJoinKey (origJoin , join );
456
+ } else {
457
+ return false ;
458
+ }
459
+ }
460
+
461
+ /**
462
+ * Checks if original join and child multi-join have common join keys to decide if we can merge
463
+ * them into a single MultiJoin with one more input.
464
+ *
465
+ * @param origJoin original Join
466
+ * @param otherJoin child MultiJoin
467
+ * @return true if original Join and child multi-join have at least one common JoinKey
468
+ */
469
+ private boolean haveCommonJoinKey (Join origJoin , MultiJoin otherJoin ) {
470
+ Set <String > origJoinKeys = getJoinKeys (origJoin );
471
+ Set <String > otherJoinKeys = getJoinKeys (otherJoin );
472
+
473
+ origJoinKeys .retainAll (otherJoinKeys );
474
+
475
+ return !origJoinKeys .isEmpty ();
476
+ }
477
+
478
+ /**
479
+ * Returns a set of join keys as strings following this format [table_name.field_name].
480
+ *
481
+ * @param join Join or MultiJoin node
482
+ * @return set of all the join keys (keys from join conditions)
483
+ */
484
+ public Set <String > getJoinKeys (RelNode join ) {
485
+ Set <String > joinKeys = new HashSet <>();
486
+ List <RexCall > conditions = Collections .emptyList ();
487
+ List <RelNode > inputs = join .getInputs ();
488
+
489
+ if (join instanceof Join ) {
490
+ conditions = collectConjunctions (((Join ) join ).getCondition ());
491
+ } else if (join instanceof MultiJoin ) {
492
+ conditions =
493
+ ((MultiJoin ) join )
494
+ .getOuterJoinConditions ().stream ()
495
+ .flatMap (cond -> collectConjunctions (cond ).stream ())
496
+ .collect (Collectors .toList ());
497
+ }
498
+
499
+ RelMetadataQuery mq = join .getCluster ().getMetadataQuery ();
500
+
501
+ for (RexCall condition : conditions ) {
502
+ for (RexNode operand : condition .getOperands ()) {
503
+ if (operand instanceof RexInputRef ) {
504
+ addJoinKeysByOperand ((RexInputRef ) operand , inputs , mq , joinKeys );
542
505
}
543
- } else {
544
- return false ;
545
506
}
546
- return true ;
507
+ }
508
+
509
+ return joinKeys ;
510
+ }
511
+
512
+ /**
513
+ * Retrieves conjunctions from joinCondition.
514
+ *
515
+ * @param joinCondition join condition
516
+ * @return List of RexCalls representing conditions
517
+ */
518
+ private List <RexCall > collectConjunctions (RexNode joinCondition ) {
519
+ return RelOptUtil .conjunctions (joinCondition ).stream ()
520
+ .map (rexNode -> (RexCall ) rexNode )
521
+ .collect (Collectors .toList ());
522
+ }
523
+
524
+ /**
525
+ * Appends join key's string representation to the set of join keys.
526
+ *
527
+ * @param ref input ref to the operand
528
+ * @param inputs List of node's inputs
529
+ * @param mq RelMetadataQuery needed to retrieve column origins
530
+ * @param joinKeys Set of join keys to be added
531
+ */
532
+ private void addJoinKeysByOperand (
533
+ RexInputRef ref , List <RelNode > inputs , RelMetadataQuery mq , Set <String > joinKeys ) {
534
+ int inputRefIndex = ref .getIndex ();
535
+ Tuple2 <RelNode , Integer > targetInputAndIdx = getTargetInputAndIdx (inputRefIndex , inputs );
536
+ RelNode targetInput = targetInputAndIdx .f0 ;
537
+ int idxInTargetInput = targetInputAndIdx .f1 ;
538
+
539
+ Set <RelColumnOrigin > origins = mq .getColumnOrigins (targetInput , idxInTargetInput );
540
+ if (origins != null ) {
541
+ for (RelColumnOrigin origin : origins ) {
542
+ RelOptTable originTable = origin .getOriginTable ();
543
+ List <String > qualifiedName = originTable .getQualifiedName ();
544
+ String fieldName =
545
+ originTable
546
+ .getRowType ()
547
+ .getFieldList ()
548
+ .get (origin .getOriginColumnOrdinal ())
549
+ .getName ();
550
+ joinKeys .add (qualifiedName .get (qualifiedName .size () - 1 ) + "." + fieldName );
551
+ }
552
+ }
553
+ }
554
+
555
+ /**
556
+ * Get real table that contains needed input ref (join key).
557
+ *
558
+ * @param inputRefIndex index of the required field
559
+ * @param inputs inputs of the node
560
+ * @return target input + idx of the required field as target input's
561
+ */
562
+ private Tuple2 <RelNode , Integer > getTargetInputAndIdx (int inputRefIndex , List <RelNode > inputs ) {
563
+ RelNode targetInput = null ;
564
+ int idxInTargetInput = 0 ;
565
+ int inputFieldEnd = 0 ;
566
+ for (RelNode input : inputs ) {
567
+ inputFieldEnd += input .getRowType ().getFieldCount ();
568
+ if (inputRefIndex < inputFieldEnd ) {
569
+ targetInput = input ;
570
+ int targetInputStartIdx = inputFieldEnd - input .getRowType ().getFieldCount ();
571
+ idxInTargetInput = inputRefIndex - targetInputStartIdx ;
572
+ break ;
573
+ }
574
+ }
575
+
576
+ targetInput =
577
+ (targetInput instanceof HepRelVertex )
578
+ ? ((HepRelVertex ) targetInput ).getCurrentRel ()
579
+ : targetInput ;
580
+
581
+ assert targetInput != null ;
582
+
583
+ if (targetInput instanceof LogicalTableScan ) {
584
+ return new Tuple2 <>(targetInput , idxInTargetInput );
547
585
} else {
548
- return false ;
586
+ return getTargetInputAndIdx ( idxInTargetInput , targetInput . getInputs ()) ;
549
587
}
550
588
}
551
589
0 commit comments