Skip to content

Commit

Permalink
fix(sql): fix wrong results returned from union queries with similar …
Browse files Browse the repository at this point in the history
…joins (#3768)
  • Loading branch information
puzpuzpuz committed Sep 29, 2023
1 parent c1a4b7f commit 4797f83
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3930,7 +3930,6 @@ private RecordCursorFactory generateTableQuery0(

boolean orderDescendingByDesignatedTimestampOnly = isOrderDescendingByDesignatedTimestampOnly(model);
if (withinExtracted != null) {

CharSequence preferredKeyColumn = null;

if (latestByColumnCount == 1) {
Expand Down
25 changes: 17 additions & 8 deletions core/src/main/java/io/questdb/griffin/SqlOptimiser.java
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,11 @@ public SqlOptimiser(
}

public void clear() {
clearForUnionModelInJoin();
contextPool.clear();
intHashSetPool.clear();
joinClausesSwap1.clear();
joinClausesSwap2.clear();
constNameToIndex.clear();
constNameToNode.clear();
constNameToToken.clear();
literalCollectorAIndexes.clear();
literalCollectorBIndexes.clear();
literalCollectorANames.clear();
Expand All @@ -158,6 +156,12 @@ public void clear() {
tempQueryModel = null;
}

public void clearForUnionModelInJoin() {
constNameToIndex.clear();
constNameToNode.clear();
constNameToToken.clear();
}

public CharSequence findColumnByAst(ObjList<ExpressionNode> groupByNodes, ObjList<CharSequence> groupByAlises, ExpressionNode node) {
for (int i = 0, max = groupByNodes.size(); i < max; i++) {
ExpressionNode n = groupByNodes.getQuick(i);
Expand All @@ -179,8 +183,8 @@ public int findColumnIdxByAst(ObjList<ExpressionNode> groupByNodes, ExpressionNo
}

private static boolean isOrderedByDesignatedTimestamp(QueryModel baseModel) {
return baseModel.getTimestamp() != null && baseModel.getOrderBy().size() == 1 &&
Chars.equals(baseModel.getOrderBy().getQuick(0).token, baseModel.getTimestamp().token);
return baseModel.getTimestamp() != null && baseModel.getOrderBy().size() == 1
&& Chars.equals(baseModel.getOrderBy().getQuick(0).token, baseModel.getTimestamp().token);
}

private static void linkDependencies(QueryModel model, int parent, int child) {
Expand Down Expand Up @@ -1648,7 +1652,6 @@ private void emitLiterals(
QueryModel validatingModel,
boolean analyticCall
) throws SqlException {

sqlNodeStack.clear();

// pre-order iterative tree traversal
Expand Down Expand Up @@ -1951,7 +1954,6 @@ private boolean hasAggregateQueryColumn(QueryModel model) {
}

private boolean hasAggregates(ExpressionNode node) {

sqlNodeStack.clear();

// pre-order iterative tree traversal
Expand Down Expand Up @@ -2702,6 +2704,7 @@ private void optimiseJoins(QueryModel model) throws SqlException {

m = model.getJoinModels().getQuick(i).getUnionModel();
if (m != null) {
clearForUnionModelInJoin();
optimiseJoins(m);
}
}
Expand Down Expand Up @@ -2806,7 +2809,13 @@ private void processEmittedJoinClauses(QueryModel model) {
*
* @param node expression n
*/
private void processJoinConditions(QueryModel parent, ExpressionNode node, boolean innerPredicate, QueryModel joinModel, int joinIndex) throws SqlException {
private void processJoinConditions(
QueryModel parent,
ExpressionNode node,
boolean innerPredicate,
QueryModel joinModel,
int joinIndex
) throws SqlException {
ExpressionNode n = node;
// pre-order traversal
sqlNodeStack.clear();
Expand Down
20 changes: 11 additions & 9 deletions core/src/main/java/io/questdb/griffin/WhereClauseParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -1814,15 +1814,17 @@ private void processArgument(
}
}

private boolean removeAndIntrinsics(AliasTranslator translator,
IntrinsicModel model,
ExpressionNode node,
RecordMetadata m,
FunctionParser functionParser,
RecordMetadata metadata,
SqlExecutionContext executionContext,
boolean latestByMultiColumn,
TableReader reader) throws SqlException {
private boolean removeAndIntrinsics(
AliasTranslator translator,
IntrinsicModel model,
ExpressionNode node,
RecordMetadata m,
FunctionParser functionParser,
RecordMetadata metadata,
SqlExecutionContext executionContext,
boolean latestByMultiColumn,
TableReader reader
) throws SqlException {
switch (intrinsicOps.get(node.token)) {
case INTRINSIC_OP_IN:
return analyzeIn(translator, model, node, m, functionParser, executionContext, latestByMultiColumn, reader);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ private void hashCursorB() {
}
// this is an optimisation to release TableReader in case "this"
// cursor lingers around. If there is exception or circuit breaker fault
// we will rely on close() method to release reader.
this.cursorB = Misc.free(this.cursorB);
// we will rely on close() method to release the reader.
cursorB = Misc.free(cursorB);
}

void of(RecordCursor cursorA, RecordCursor cursorB, SqlExecutionCircuitBreaker circuitBreaker) throws SqlException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import io.questdb.std.*;


//Metadata describing join conditions
// Metadata describing join conditions
public class JoinContext implements Mutable, Plannable {
public static final ObjectFactory<JoinContext> FACTORY = JoinContext::new;
private static final int TYPICAL_NUMBER_OF_JOIN_COLUMNS = 4;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ public RuntimeIntervalModel(LongList intervals) {

public RuntimeIntervalModel(LongList staticIntervals, ObjList<Function> dynamicRangeList) {
this.intervals = staticIntervals;

this.dynamicRangeList = dynamicRangeList;
}

Expand Down Expand Up @@ -102,7 +101,6 @@ public void toPlan(PlanSink sink) {
valTs(sink, intervals.getQuick(i + 1));
sink.val("\")");
}

} catch (SqlException e) {
LOG.error().$("Can't calculate intervals: ").$(e.getMessage()).$();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,30 @@ public boolean hasIntervalFilters() {
}

public void intersect(long lo, Function hi, short adjustment) {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

IntervalUtils.addHiLoInterval(lo, 0, adjustment, IntervalDynamicIndicator.IS_HI_DYNAMIC, IntervalOperation.INTERSECT, staticIntervals);
dynamicRangeList.add(hi);
intervalApplied = true;
}

public void intersect(Function lo, long hi, short adjustment) {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

IntervalUtils.addHiLoInterval(0, hi, adjustment, IntervalDynamicIndicator.IS_LO_DYNAMIC, IntervalOperation.INTERSECT, staticIntervals);
dynamicRangeList.add(lo);
intervalApplied = true;
}

public void intersect(long lo, long hi) {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

if (dynamicRangeList.size() == 0) {
staticIntervals.add(lo, hi);
if (intervalApplied) {
Expand All @@ -111,7 +118,10 @@ public void intersect(long lo, long hi) {
}

public void intersectDynamicInterval(Function intervalStrFunction) {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

IntervalUtils.addHiLoInterval(0L, 0L, IntervalOperation.INTERSECT_INTERVALS, staticIntervals);
dynamicRangeList.add(intervalStrFunction);
intervalApplied = true;
Expand All @@ -123,15 +133,20 @@ public void intersectEmpty() {
}

public void intersectEquals(Function function) {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

IntervalUtils.addHiLoInterval(0, 0, (short) 0, IntervalDynamicIndicator.IS_LO_HI_DYNAMIC, IntervalOperation.INTERSECT, staticIntervals);
dynamicRangeList.add(function);
intervalApplied = true;
}

public void intersectIntervals(CharSequence seq, int lo, int lim, int position) throws SqlException {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

int size = staticIntervals.size();
IntervalUtils.parseIntervalEx(seq, lo, lim, position, staticIntervals, IntervalOperation.INTERSECT);
if (dynamicRangeList.size() == 0) {
Expand All @@ -147,7 +162,10 @@ public void intersectIntervals(CharSequence seq, int lo, int lim, int position)
}

public void intersectTimestamp(CharSequence seq, int lo, int lim, int position) throws SqlException {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

int size = staticIntervals.size();
IntervalUtils.parseSingleTimestamp(seq, lo, lim, position, staticIntervals, IntervalOperation.INTERSECT);
if (dynamicRangeList.size() == 0) {
Expand Down Expand Up @@ -217,15 +235,20 @@ public void setBetweenNegated(boolean isNegated) {
}

public void subtractEquals(Function function) {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

IntervalUtils.addHiLoInterval(0, 0, (short) 0, IntervalDynamicIndicator.IS_LO_HI_DYNAMIC, IntervalOperation.SUBTRACT, staticIntervals);
dynamicRangeList.add(function);
intervalApplied = true;
}

public void subtractInterval(long lo, long hi) {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

if (dynamicRangeList.size() == 0) {
int size = staticIntervals.size();
staticIntervals.add(lo, hi);
Expand All @@ -241,7 +264,10 @@ public void subtractInterval(long lo, long hi) {
}

public void subtractIntervals(CharSequence seq, int lo, int lim, int position) throws SqlException {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

int size = staticIntervals.size();
IntervalUtils.parseIntervalEx(seq, lo, lim, position, staticIntervals, IntervalOperation.SUBTRACT);
if (dynamicRangeList.size() == 0) {
Expand All @@ -258,14 +284,20 @@ public void subtractIntervals(CharSequence seq, int lo, int lim, int position) t
}

public void subtractRuntimeInterval(Function intervalStrFunction) {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

IntervalUtils.addHiLoInterval(0L, 0L, IntervalOperation.SUBTRACT_INTERVALS, staticIntervals);
dynamicRangeList.add(intervalStrFunction);
intervalApplied = true;
}

public void union(long lo, long hi) {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

if (dynamicRangeList.size() == 0) {
staticIntervals.add(lo, hi);
if (intervalApplied) {
Expand All @@ -278,7 +310,9 @@ public void union(long lo, long hi) {
}

private void intersectBetweenDynamic(Function funcValue1, Function funcValue2) {
if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

short operation = betweenNegated ? IntervalOperation.SUBTRACT_BETWEEN : IntervalOperation.INTERSECT_BETWEEN;
IntervalUtils.addHiLoInterval(0, 0, (short) 0, IntervalDynamicIndicator.IS_LO_SEPARATE_DYNAMIC, operation, staticIntervals);
Expand All @@ -301,7 +335,9 @@ private void intersectBetweenSemiDynamic(Function funcValue, long constValue) {
return;
}

if (isEmptySet()) return;
if (isEmptySet()) {
return;
}

short operation = betweenNegated ? IntervalOperation.SUBTRACT_BETWEEN : IntervalOperation.INTERSECT_BETWEEN;
IntervalUtils.addHiLoInterval(constValue, 0, (short) 0, IntervalDynamicIndicator.IS_HI_DYNAMIC, operation, staticIntervals);
Expand Down
72 changes: 72 additions & 0 deletions core/src/test/java/io/questdb/test/griffin/JoinTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1731,6 +1731,78 @@ public void testJoinConstantTrueFF() throws Exception {
testFullFat(this::testJoinConstantTrue0);
}

@Test
public void testJoinContextIsolationInIntersect() throws Exception {
assertMemoryLeak(() -> {
ddl(
"CREATE TABLE t (\n" +
" created timestamp,\n" +
" event short,\n" +
" origin short\n" +
") TIMESTAMP(created) PARTITION BY DAY;"
);
insert("INSERT INTO t VALUES ('2023-09-21T10:00:00.000000Z', 1, 1);");

// The important aspects here are T2.created = '2003-09-21T10:00:00.000000Z'
// in the first query and T2.created = T3.created in the second one. Due to this,
// transitive filters pass was mistakenly mutating where clause in the second query.
final String query1 = "SELECT count(1)\n" +
"FROM t as T1 CROSS JOIN t as T2\n" +
"WHERE T2.created > now() and T2.created = '2003-09-21T10:00:00.000000Z'";
final String query2 = "SELECT count(1)\n" +
"FROM t as T1 JOIN t as T2 on T1.created = T2.created JOIN t as T3 ON T2.created = T3.created\n" +
"WHERE T3.created < now()";

assertQuery("count\n0\n", query1, null, false, true);
assertQuery("count\n1\n", query2, null, false, true);

assertQuery(
"count\n",
query1 + " INTERSECT " + query2,
null,
false,
false
);
});
}

@Test
public void testJoinContextIsolationInUnion() throws Exception {
assertMemoryLeak(() -> {
ddl(
"CREATE TABLE t (\n" +
" created timestamp,\n" +
" event short,\n" +
" origin short\n" +
") TIMESTAMP(created) PARTITION BY DAY;"
);
insert("INSERT INTO t VALUES ('2023-09-21T10:00:00.000000Z', 1, 1);");
insert("INSERT INTO t VALUES ('2023-09-21T11:00:00.000000Z', 1, 1);");

// The important aspects here are T1.event = 0.0
// in the first query and T1.event = T2.event in the second one. Due to this,
// transitive filters pass was mistakenly mutating where clause in the second query.
final String query1 = "SELECT count(1)\n" +
"FROM t as T1 JOIN t as T2 ON T1.created = T2.created\n" +
"WHERE T1.event = 1.0";
final String query2 = "SELECT count(1)\n" +
"FROM t as T1 JOIN t as T2 ON T1.event = T2.event";

assertQuery("count\n2\n", query1, null, false, true);
assertQuery("count\n4\n", query2, null, false, true);

assertQuery(
"count\n" +
"2\n" +
"4\n",
query1 + " UNION " + query2,
null,
false,
false
);
});
}

@Test
public void testJoinInner() throws Exception {
assertMemoryLeak(() -> {
Expand Down

0 comments on commit 4797f83

Please sign in to comment.