Skip to content

Commit 615e19b

Browse files
[FLINK-37890] [table-planner] Add commonJoinKey check inside JoinToMultiJoinRule (#26702)
* [FLINK-37890][table-planner] Add restore and semantic tests [FLINK-37890][table-planner] canCombine and tests refactor [FLINK-37890][table-planner] Add commonJoinKey check inside JoinToMultiJoinRule * [FLINK-37890][table-planner] Refactor JoinToMultiJoinRule
1 parent 3aee6ac commit 615e19b

File tree

8 files changed

+928
-112
lines changed

8 files changed

+928
-112
lines changed

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/JoinToMultiJoinRule.java

Lines changed: 149 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818

1919
package org.apache.flink.table.planner.plan.rules.logical;
2020

21+
import org.apache.flink.api.java.tuple.Tuple2;
2122
import org.apache.flink.table.api.TableException;
2223
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalJoin;
2324
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalMultiJoin;
2425
import org.apache.flink.table.planner.plan.utils.IntervalJoinUtil;
2526

2627
import org.apache.calcite.plan.RelOptRuleCall;
28+
import org.apache.calcite.plan.RelOptTable;
2729
import org.apache.calcite.plan.RelOptUtil;
2830
import org.apache.calcite.plan.RelRule;
2931
import org.apache.calcite.plan.hep.HepRelVertex;
@@ -35,13 +37,16 @@
3537
import org.apache.calcite.rel.core.JoinRelType;
3638
import org.apache.calcite.rel.logical.LogicalJoin;
3739
import org.apache.calcite.rel.logical.LogicalSnapshot;
40+
import org.apache.calcite.rel.logical.LogicalTableScan;
41+
import org.apache.calcite.rel.metadata.RelColumnOrigin;
42+
import org.apache.calcite.rel.metadata.RelMetadataQuery;
3843
import org.apache.calcite.rel.rules.CoreRules;
3944
import org.apache.calcite.rel.rules.FilterMultiJoinMergeRule;
4045
import org.apache.calcite.rel.rules.MultiJoin;
4146
import org.apache.calcite.rel.rules.ProjectMultiJoinMergeRule;
4247
import org.apache.calcite.rel.rules.TransformationRule;
43-
import org.apache.calcite.rel.type.RelDataTypeField;
4448
import org.apache.calcite.rex.RexBuilder;
49+
import org.apache.calcite.rex.RexCall;
4550
import org.apache.calcite.rex.RexInputRef;
4651
import org.apache.calcite.rex.RexNode;
4752
import org.apache.calcite.rex.RexUtil;
@@ -52,13 +57,14 @@
5257
import org.apache.calcite.util.Pair;
5358
import org.immutables.value.Value;
5459

55-
import javax.annotation.Nullable;
56-
5760
import java.util.ArrayList;
5861
import java.util.Collections;
5962
import java.util.HashMap;
63+
import java.util.HashSet;
6064
import java.util.List;
6165
import java.util.Map;
66+
import java.util.Set;
67+
import java.util.stream.Collectors;
6268

6369
/**
6470
* Flink Planner rule to flatten a tree of {@link Join}s into a single {@link MultiJoin} with N
@@ -300,14 +306,7 @@ private List<RelNode> combineInputs(
300306
ImmutableIntList leftKeys = joinInfo.leftKeys;
301307
ImmutableIntList rightKeys = joinInfo.rightKeys;
302308

303-
if (canCombine(
304-
left,
305-
leftKeys,
306-
join.getJoinType(),
307-
join.getJoinType().generatesNullsOnLeft(),
308-
true,
309-
inputNullGenFieldList,
310-
0)) {
309+
if (canCombine(left, join)) {
311310
final MultiJoin leftMultiJoin = (MultiJoin) left;
312311
for (int i = 0; i < leftMultiJoin.getInputs().size(); i++) {
313312
newInputs.add(leftMultiJoin.getInput(i));
@@ -322,14 +321,7 @@ private List<RelNode> combineInputs(
322321
joinFieldRefCountsList.add(new int[left.getRowType().getFieldCount()]);
323322
}
324323

325-
if (canCombine(
326-
right,
327-
rightKeys,
328-
join.getJoinType(),
329-
join.getJoinType().generatesNullsOnRight(),
330-
false,
331-
inputNullGenFieldList,
332-
left.getRowType().getFieldCount())) {
324+
if (canCombine(right, join)) {
333325
final MultiJoin rightMultiJoin = (MultiJoin) right;
334326
for (int i = 0; i < rightMultiJoin.getInputs().size(); i++) {
335327
newInputs.add(rightMultiJoin.getInput(i));
@@ -369,27 +361,19 @@ private void combineJoinInfo(
369361
JoinInfo joinInfo = joinRel.analyzeCondition();
370362
ImmutableIntList leftKeys = joinInfo.leftKeys;
371363
final RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder();
372-
boolean leftCombined =
373-
canCombine(
374-
left,
375-
leftKeys,
376-
joinType,
377-
joinType.generatesNullsOnLeft(),
378-
true,
379-
inputNullGenFieldList,
380-
0);
364+
boolean leftCombined = canCombine(left, joinRel);
381365
switch (joinType) {
382366
case LEFT:
383367
if (leftCombined) {
384-
copyJoinInfo((MultiJoin) left, joinSpecs, 0, null, null);
368+
copyJoinInfo((MultiJoin) left, joinSpecs);
385369
} else {
386370
joinSpecs.add(Pair.of(JoinRelType.INNER, rexBuilder.makeLiteral(true)));
387371
}
388372
joinSpecs.add(Pair.of(joinType, joinRel.getCondition()));
389373
break;
390374
case INNER:
391375
if (leftCombined) {
392-
copyJoinInfo((MultiJoin) left, joinSpecs, 0, null, null);
376+
copyJoinInfo((MultiJoin) left, joinSpecs);
393377
} else {
394378
joinSpecs.add(Pair.of(JoinRelType.INNER, rexBuilder.makeLiteral(true)));
395379
}
@@ -408,45 +392,13 @@ private void combineJoinInfo(
408392
*
409393
* @param multiJoin the source MultiJoin
410394
* @param destJoinSpecs the list where the join types and conditions will be copied
411-
* @param adjustmentAmount if &gt; 0, the amount the RexInputRefs in the join conditions need to
412-
* be adjusted by
413-
* @param srcFields the source fields that the original join conditions are referencing
414-
* @param destFields the destination fields that the new join conditions
415395
*/
416-
private void copyJoinInfo(
417-
MultiJoin multiJoin,
418-
List<Pair<JoinRelType, RexNode>> destJoinSpecs,
419-
int adjustmentAmount,
420-
@Nullable List<RelDataTypeField> srcFields,
421-
@Nullable List<RelDataTypeField> destFields) {
396+
private void copyJoinInfo(MultiJoin multiJoin, List<Pair<JoinRelType, RexNode>> destJoinSpecs) {
422397
// getOuterJoinConditions are return all join conditions since that's how we use it
423398
final List<Pair<JoinRelType, RexNode>> srcJoinSpecs =
424399
Pair.zip(multiJoin.getJoinTypes(), multiJoin.getOuterJoinConditions());
425400

426-
if (adjustmentAmount == 0) {
427-
destJoinSpecs.addAll(srcJoinSpecs);
428-
} else {
429-
assert srcFields != null;
430-
assert destFields != null;
431-
int nFields = srcFields.size();
432-
int[] adjustments = new int[nFields];
433-
for (int idx = 0; idx < nFields; idx++) {
434-
adjustments[idx] = adjustmentAmount;
435-
}
436-
for (Pair<JoinRelType, RexNode> src : srcJoinSpecs) {
437-
destJoinSpecs.add(
438-
Pair.of(
439-
src.left,
440-
src.right == null
441-
? null
442-
: src.right.accept(
443-
new RelOptUtil.RexInputConverter(
444-
multiJoin.getCluster().getRexBuilder(),
445-
srcFields,
446-
destFields,
447-
adjustments))));
448-
}
449-
}
401+
destJoinSpecs.addAll(srcJoinSpecs);
450402
}
451403

452404
/**
@@ -474,14 +426,7 @@ private List<RexNode> combineJoinFilters(
474426
if ((joinType != JoinRelType.LEFT)) {
475427
filters.add(join.getCondition());
476428
}
477-
if (canCombine(
478-
left,
479-
leftKeys,
480-
joinType,
481-
joinType.generatesNullsOnLeft(),
482-
true,
483-
inputNullGenFieldList,
484-
0)) {
429+
if (canCombine(left, join)) {
485430
filters.add(((MultiJoin) left).getJoinFilter());
486431
}
487432

@@ -497,55 +442,148 @@ private List<RexNode> combineJoinFilters(
497442
* href="https://issues.apache.org/jira/browse/FLINK-37890">FLINK-37890</a>.
498443
*
499444
* @param input input into a join
500-
* @param nullGenerating true if the input is null generating
501445
* @return true if the input can be combined into a parent MultiJoin
502446
*/
503-
private boolean canCombine(
504-
RelNode input,
505-
ImmutableIntList joinKeys,
506-
JoinRelType joinType,
507-
boolean nullGenerating,
508-
boolean isLeft,
509-
List<Boolean> inputNullGenFieldList,
510-
int beginIndex) {
447+
private boolean canCombine(RelNode input, Join origJoin) {
511448
if (input instanceof MultiJoin) {
512449
MultiJoin join = (MultiJoin) input;
513-
if (join.isFullOuterJoin() || nullGenerating) {
450+
451+
if (join.isFullOuterJoin()) {
514452
return false;
515453
}
516454

517-
if (joinType == JoinRelType.LEFT) {
518-
if (!isLeft) {
519-
return false;
520-
} else {
521-
for (int joinKey : joinKeys) {
522-
if (inputNullGenFieldList.get(joinKey + beginIndex)) {
523-
return false;
524-
}
525-
}
526-
}
527-
} else if (joinType == JoinRelType.RIGHT) {
528-
if (isLeft) {
529-
return false;
530-
} else {
531-
for (int joinKey : joinKeys) {
532-
if (inputNullGenFieldList.get(joinKey + beginIndex)) {
533-
return false;
534-
}
535-
}
536-
}
537-
} else if (joinType == JoinRelType.INNER) {
538-
for (int joinKey : joinKeys) {
539-
if (inputNullGenFieldList.get(joinKey + beginIndex)) {
540-
return false;
541-
}
455+
return haveCommonJoinKey(origJoin, join);
456+
} else {
457+
return false;
458+
}
459+
}
460+
461+
/**
462+
* Checks if original join and child multi-join have common join keys to decide if we can merge
463+
* them into a single MultiJoin with one more input.
464+
*
465+
* @param origJoin original Join
466+
* @param otherJoin child MultiJoin
467+
* @return true if original Join and child multi-join have at least one common JoinKey
468+
*/
469+
private boolean haveCommonJoinKey(Join origJoin, MultiJoin otherJoin) {
470+
Set<String> origJoinKeys = getJoinKeys(origJoin);
471+
Set<String> otherJoinKeys = getJoinKeys(otherJoin);
472+
473+
origJoinKeys.retainAll(otherJoinKeys);
474+
475+
return !origJoinKeys.isEmpty();
476+
}
477+
478+
/**
479+
* Returns a set of join keys as strings following this format [table_name.field_name].
480+
*
481+
* @param join Join or MultiJoin node
482+
* @return set of all the join keys (keys from join conditions)
483+
*/
484+
public Set<String> getJoinKeys(RelNode join) {
485+
Set<String> joinKeys = new HashSet<>();
486+
List<RexCall> conditions = Collections.emptyList();
487+
List<RelNode> inputs = join.getInputs();
488+
489+
if (join instanceof Join) {
490+
conditions = collectConjunctions(((Join) join).getCondition());
491+
} else if (join instanceof MultiJoin) {
492+
conditions =
493+
((MultiJoin) join)
494+
.getOuterJoinConditions().stream()
495+
.flatMap(cond -> collectConjunctions(cond).stream())
496+
.collect(Collectors.toList());
497+
}
498+
499+
RelMetadataQuery mq = join.getCluster().getMetadataQuery();
500+
501+
for (RexCall condition : conditions) {
502+
for (RexNode operand : condition.getOperands()) {
503+
if (operand instanceof RexInputRef) {
504+
addJoinKeysByOperand((RexInputRef) operand, inputs, mq, joinKeys);
542505
}
543-
} else {
544-
return false;
545506
}
546-
return true;
507+
}
508+
509+
return joinKeys;
510+
}
511+
512+
/**
513+
* Retrieves conjunctions from joinCondition.
514+
*
515+
* @param joinCondition join condition
516+
* @return List of RexCalls representing conditions
517+
*/
518+
private List<RexCall> collectConjunctions(RexNode joinCondition) {
519+
return RelOptUtil.conjunctions(joinCondition).stream()
520+
.map(rexNode -> (RexCall) rexNode)
521+
.collect(Collectors.toList());
522+
}
523+
524+
/**
525+
* Appends join key's string representation to the set of join keys.
526+
*
527+
* @param ref input ref to the operand
528+
* @param inputs List of node's inputs
529+
* @param mq RelMetadataQuery needed to retrieve column origins
530+
* @param joinKeys Set of join keys to be added
531+
*/
532+
private void addJoinKeysByOperand(
533+
RexInputRef ref, List<RelNode> inputs, RelMetadataQuery mq, Set<String> joinKeys) {
534+
int inputRefIndex = ref.getIndex();
535+
Tuple2<RelNode, Integer> targetInputAndIdx = getTargetInputAndIdx(inputRefIndex, inputs);
536+
RelNode targetInput = targetInputAndIdx.f0;
537+
int idxInTargetInput = targetInputAndIdx.f1;
538+
539+
Set<RelColumnOrigin> origins = mq.getColumnOrigins(targetInput, idxInTargetInput);
540+
if (origins != null) {
541+
for (RelColumnOrigin origin : origins) {
542+
RelOptTable originTable = origin.getOriginTable();
543+
List<String> qualifiedName = originTable.getQualifiedName();
544+
String fieldName =
545+
originTable
546+
.getRowType()
547+
.getFieldList()
548+
.get(origin.getOriginColumnOrdinal())
549+
.getName();
550+
joinKeys.add(qualifiedName.get(qualifiedName.size() - 1) + "." + fieldName);
551+
}
552+
}
553+
}
554+
555+
/**
556+
* Get real table that contains needed input ref (join key).
557+
*
558+
* @param inputRefIndex index of the required field
559+
* @param inputs inputs of the node
560+
* @return target input + idx of the required field as target input's
561+
*/
562+
private Tuple2<RelNode, Integer> getTargetInputAndIdx(int inputRefIndex, List<RelNode> inputs) {
563+
RelNode targetInput = null;
564+
int idxInTargetInput = 0;
565+
int inputFieldEnd = 0;
566+
for (RelNode input : inputs) {
567+
inputFieldEnd += input.getRowType().getFieldCount();
568+
if (inputRefIndex < inputFieldEnd) {
569+
targetInput = input;
570+
int targetInputStartIdx = inputFieldEnd - input.getRowType().getFieldCount();
571+
idxInTargetInput = inputRefIndex - targetInputStartIdx;
572+
break;
573+
}
574+
}
575+
576+
targetInput =
577+
(targetInput instanceof HepRelVertex)
578+
? ((HepRelVertex) targetInput).getCurrentRel()
579+
: targetInput;
580+
581+
assert targetInput != null;
582+
583+
if (targetInput instanceof LogicalTableScan) {
584+
return new Tuple2<>(targetInput, idxInTargetInput);
547585
} else {
548-
return false;
586+
return getTargetInputAndIdx(idxInTargetInput, targetInput.getInputs());
549587
}
550588
}
551589

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MultiJoinRestoreTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public List<TableTestProgram> programs() {
3737
MultiJoinTestPrograms.MULTI_JOIN_THREE_WAY_INNER_JOIN_WITH_RESTORE,
3838
MultiJoinTestPrograms.MULTI_JOIN_THREE_WAY_LEFT_OUTER_JOIN_WITH_RESTORE,
3939
MultiJoinTestPrograms.MULTI_JOIN_FOUR_WAY_COMPLEX_WITH_RESTORE,
40+
MultiJoinTestPrograms.MULTI_JOIN_FOUR_WAY_NO_COMMON_JOIN_KEY_RESTORE,
4041
MultiJoinTestPrograms.MULTI_JOIN_WITH_TIME_ATTRIBUTES_MATERIALIZATION_WITH_RESTORE);
4142
}
4243
}

flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/nodes/exec/stream/MultiJoinSemanticTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public List<TableTestProgram> programs() {
3434
MultiJoinTestPrograms.MULTI_JOIN_THREE_WAY_LEFT_OUTER_JOIN_UPDATING,
3535
MultiJoinTestPrograms.MULTI_JOIN_THREE_WAY_LEFT_OUTER_JOIN_WITH_WHERE,
3636
MultiJoinTestPrograms.MULTI_JOIN_FOUR_WAY_COMPLEX,
37-
MultiJoinTestPrograms.MULTI_JOIN_WITH_TIME_ATTRIBUTES_MATERIALIZATION);
37+
MultiJoinTestPrograms.MULTI_JOIN_WITH_TIME_ATTRIBUTES_MATERIALIZATION,
38+
MultiJoinTestPrograms.MULTI_JOIN_FOUR_WAY_NO_COMMON_JOIN_KEY);
3839
}
3940
}

0 commit comments

Comments
 (0)