From 4797f835cb2db99e3a03a602e2c79621cbe6ed0c Mon Sep 17 00:00:00 2001 From: Andrei Pechkurov <37772591+puzpuzpuz@users.noreply.github.com> Date: Fri, 29 Sep 2023 16:57:46 +0300 Subject: [PATCH] fix(sql): fix wrong results returned from union queries with similar joins (#3768) --- .../io/questdb/griffin/SqlCodeGenerator.java | 1 - .../java/io/questdb/griffin/SqlOptimiser.java | 25 ++++--- .../io/questdb/griffin/WhereClauseParser.java | 20 +++--- .../engine/union/IntersectRecordCursor.java | 4 +- .../io/questdb/griffin/model/JoinContext.java | 2 +- .../griffin/model/RuntimeIntervalModel.java | 2 - .../model/RuntimeIntervalModelBuilder.java | 64 +++++++++++++---- .../io/questdb/test/griffin/JoinTest.java | 72 +++++++++++++++++++ .../questdb/test/griffin/SqlParserTest.java | 40 ++++++++++- 9 files changed, 191 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/io/questdb/griffin/SqlCodeGenerator.java b/core/src/main/java/io/questdb/griffin/SqlCodeGenerator.java index dd444b3c628c..34f3174a8ac7 100644 --- a/core/src/main/java/io/questdb/griffin/SqlCodeGenerator.java +++ b/core/src/main/java/io/questdb/griffin/SqlCodeGenerator.java @@ -3930,7 +3930,6 @@ private RecordCursorFactory generateTableQuery0( boolean orderDescendingByDesignatedTimestampOnly = isOrderDescendingByDesignatedTimestampOnly(model); if (withinExtracted != null) { - CharSequence preferredKeyColumn = null; if (latestByColumnCount == 1) { diff --git a/core/src/main/java/io/questdb/griffin/SqlOptimiser.java b/core/src/main/java/io/questdb/griffin/SqlOptimiser.java index 01aeb15f8e67..682cf2483770 100644 --- a/core/src/main/java/io/questdb/griffin/SqlOptimiser.java +++ b/core/src/main/java/io/questdb/griffin/SqlOptimiser.java @@ -133,13 +133,11 @@ public SqlOptimiser( } public void clear() { + clearForUnionModelInJoin(); contextPool.clear(); intHashSetPool.clear(); joinClausesSwap1.clear(); joinClausesSwap2.clear(); - constNameToIndex.clear(); - constNameToNode.clear(); - constNameToToken.clear(); literalCollectorAIndexes.clear(); literalCollectorBIndexes.clear(); literalCollectorANames.clear(); @@ -158,6 +156,12 @@ public void clear() { tempQueryModel = null; } + public void clearForUnionModelInJoin() { + constNameToIndex.clear(); + constNameToNode.clear(); + constNameToToken.clear(); + } + public CharSequence findColumnByAst(ObjList groupByNodes, ObjList groupByAlises, ExpressionNode node) { for (int i = 0, max = groupByNodes.size(); i < max; i++) { ExpressionNode n = groupByNodes.getQuick(i); @@ -179,8 +183,8 @@ public int findColumnIdxByAst(ObjList groupByNodes, ExpressionNo } private static boolean isOrderedByDesignatedTimestamp(QueryModel baseModel) { - return baseModel.getTimestamp() != null && baseModel.getOrderBy().size() == 1 && - Chars.equals(baseModel.getOrderBy().getQuick(0).token, baseModel.getTimestamp().token); + return baseModel.getTimestamp() != null && baseModel.getOrderBy().size() == 1 + && Chars.equals(baseModel.getOrderBy().getQuick(0).token, baseModel.getTimestamp().token); } private static void linkDependencies(QueryModel model, int parent, int child) { @@ -1648,7 +1652,6 @@ private void emitLiterals( QueryModel validatingModel, boolean analyticCall ) throws SqlException { - sqlNodeStack.clear(); // pre-order iterative tree traversal @@ -1951,7 +1954,6 @@ private boolean hasAggregateQueryColumn(QueryModel model) { } private boolean hasAggregates(ExpressionNode node) { - sqlNodeStack.clear(); // pre-order iterative tree traversal @@ -2702,6 +2704,7 @@ private void optimiseJoins(QueryModel model) throws SqlException { m = model.getJoinModels().getQuick(i).getUnionModel(); if (m != null) { + clearForUnionModelInJoin(); optimiseJoins(m); } } @@ -2806,7 +2809,13 @@ private void processEmittedJoinClauses(QueryModel model) { * * @param node expression n */ - private void processJoinConditions(QueryModel parent, ExpressionNode node, boolean innerPredicate, QueryModel joinModel, int joinIndex) throws SqlException { + private void processJoinConditions( + QueryModel parent, + ExpressionNode node, + boolean innerPredicate, + QueryModel joinModel, + int joinIndex + ) throws SqlException { ExpressionNode n = node; // pre-order traversal sqlNodeStack.clear(); diff --git a/core/src/main/java/io/questdb/griffin/WhereClauseParser.java b/core/src/main/java/io/questdb/griffin/WhereClauseParser.java index 62676f0206ec..4c14a6b81fd9 100644 --- a/core/src/main/java/io/questdb/griffin/WhereClauseParser.java +++ b/core/src/main/java/io/questdb/griffin/WhereClauseParser.java @@ -1814,15 +1814,17 @@ private void processArgument( } } - private boolean removeAndIntrinsics(AliasTranslator translator, - IntrinsicModel model, - ExpressionNode node, - RecordMetadata m, - FunctionParser functionParser, - RecordMetadata metadata, - SqlExecutionContext executionContext, - boolean latestByMultiColumn, - TableReader reader) throws SqlException { + private boolean removeAndIntrinsics( + AliasTranslator translator, + IntrinsicModel model, + ExpressionNode node, + RecordMetadata m, + FunctionParser functionParser, + RecordMetadata metadata, + SqlExecutionContext executionContext, + boolean latestByMultiColumn, + TableReader reader + ) throws SqlException { switch (intrinsicOps.get(node.token)) { case INTRINSIC_OP_IN: return analyzeIn(translator, model, node, m, functionParser, executionContext, latestByMultiColumn, reader); diff --git a/core/src/main/java/io/questdb/griffin/engine/union/IntersectRecordCursor.java b/core/src/main/java/io/questdb/griffin/engine/union/IntersectRecordCursor.java index af385eba13a7..34aca6e2cdac 100644 --- a/core/src/main/java/io/questdb/griffin/engine/union/IntersectRecordCursor.java +++ b/core/src/main/java/io/questdb/griffin/engine/union/IntersectRecordCursor.java @@ -127,8 +127,8 @@ private void hashCursorB() { } // this is an optimisation to release TableReader in case "this" // cursor lingers around. If there is exception or circuit breaker fault - // we will rely on close() method to release reader. - this.cursorB = Misc.free(this.cursorB); + // we will rely on close() method to release the reader. + cursorB = Misc.free(cursorB); } void of(RecordCursor cursorA, RecordCursor cursorB, SqlExecutionCircuitBreaker circuitBreaker) throws SqlException { diff --git a/core/src/main/java/io/questdb/griffin/model/JoinContext.java b/core/src/main/java/io/questdb/griffin/model/JoinContext.java index 4ac07ef38b0a..7695e2ae85e0 100644 --- a/core/src/main/java/io/questdb/griffin/model/JoinContext.java +++ b/core/src/main/java/io/questdb/griffin/model/JoinContext.java @@ -29,7 +29,7 @@ import io.questdb.std.*; -//Metadata describing join conditions +// Metadata describing join conditions public class JoinContext implements Mutable, Plannable { public static final ObjectFactory FACTORY = JoinContext::new; private static final int TYPICAL_NUMBER_OF_JOIN_COLUMNS = 4; diff --git a/core/src/main/java/io/questdb/griffin/model/RuntimeIntervalModel.java b/core/src/main/java/io/questdb/griffin/model/RuntimeIntervalModel.java index 5797fb24b62c..cbe4c93b6224 100644 --- a/core/src/main/java/io/questdb/griffin/model/RuntimeIntervalModel.java +++ b/core/src/main/java/io/questdb/griffin/model/RuntimeIntervalModel.java @@ -51,7 +51,6 @@ public RuntimeIntervalModel(LongList intervals) { public RuntimeIntervalModel(LongList staticIntervals, ObjList dynamicRangeList) { this.intervals = staticIntervals; - this.dynamicRangeList = dynamicRangeList; } @@ -102,7 +101,6 @@ public void toPlan(PlanSink sink) { valTs(sink, intervals.getQuick(i + 1)); sink.val("\")"); } - } catch (SqlException e) { LOG.error().$("Can't calculate intervals: ").$(e.getMessage()).$(); } diff --git a/core/src/main/java/io/questdb/griffin/model/RuntimeIntervalModelBuilder.java b/core/src/main/java/io/questdb/griffin/model/RuntimeIntervalModelBuilder.java index 1a56d7d38f54..0934d9d6d076 100644 --- a/core/src/main/java/io/questdb/griffin/model/RuntimeIntervalModelBuilder.java +++ b/core/src/main/java/io/questdb/griffin/model/RuntimeIntervalModelBuilder.java @@ -81,7 +81,9 @@ public boolean hasIntervalFilters() { } public void intersect(long lo, Function hi, short adjustment) { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } IntervalUtils.addHiLoInterval(lo, 0, adjustment, IntervalDynamicIndicator.IS_HI_DYNAMIC, IntervalOperation.INTERSECT, staticIntervals); dynamicRangeList.add(hi); @@ -89,7 +91,9 @@ public void intersect(long lo, Function hi, short adjustment) { } public void intersect(Function lo, long hi, short adjustment) { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } IntervalUtils.addHiLoInterval(0, hi, adjustment, IntervalDynamicIndicator.IS_LO_DYNAMIC, IntervalOperation.INTERSECT, staticIntervals); dynamicRangeList.add(lo); @@ -97,7 +101,10 @@ public void intersect(Function lo, long hi, short adjustment) { } public void intersect(long lo, long hi) { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } + if (dynamicRangeList.size() == 0) { staticIntervals.add(lo, hi); if (intervalApplied) { @@ -111,7 +118,10 @@ public void intersect(long lo, long hi) { } public void intersectDynamicInterval(Function intervalStrFunction) { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } + IntervalUtils.addHiLoInterval(0L, 0L, IntervalOperation.INTERSECT_INTERVALS, staticIntervals); dynamicRangeList.add(intervalStrFunction); intervalApplied = true; @@ -123,7 +133,9 @@ public void intersectEmpty() { } public void intersectEquals(Function function) { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } IntervalUtils.addHiLoInterval(0, 0, (short) 0, IntervalDynamicIndicator.IS_LO_HI_DYNAMIC, IntervalOperation.INTERSECT, staticIntervals); dynamicRangeList.add(function); @@ -131,7 +143,10 @@ public void intersectEquals(Function function) { } public void intersectIntervals(CharSequence seq, int lo, int lim, int position) throws SqlException { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } + int size = staticIntervals.size(); IntervalUtils.parseIntervalEx(seq, lo, lim, position, staticIntervals, IntervalOperation.INTERSECT); if (dynamicRangeList.size() == 0) { @@ -147,7 +162,10 @@ public void intersectIntervals(CharSequence seq, int lo, int lim, int position) } public void intersectTimestamp(CharSequence seq, int lo, int lim, int position) throws SqlException { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } + int size = staticIntervals.size(); IntervalUtils.parseSingleTimestamp(seq, lo, lim, position, staticIntervals, IntervalOperation.INTERSECT); if (dynamicRangeList.size() == 0) { @@ -217,7 +235,9 @@ public void setBetweenNegated(boolean isNegated) { } public void subtractEquals(Function function) { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } IntervalUtils.addHiLoInterval(0, 0, (short) 0, IntervalDynamicIndicator.IS_LO_HI_DYNAMIC, IntervalOperation.SUBTRACT, staticIntervals); dynamicRangeList.add(function); @@ -225,7 +245,10 @@ public void subtractEquals(Function function) { } public void subtractInterval(long lo, long hi) { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } + if (dynamicRangeList.size() == 0) { int size = staticIntervals.size(); staticIntervals.add(lo, hi); @@ -241,7 +264,10 @@ public void subtractInterval(long lo, long hi) { } public void subtractIntervals(CharSequence seq, int lo, int lim, int position) throws SqlException { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } + int size = staticIntervals.size(); IntervalUtils.parseIntervalEx(seq, lo, lim, position, staticIntervals, IntervalOperation.SUBTRACT); if (dynamicRangeList.size() == 0) { @@ -258,14 +284,20 @@ public void subtractIntervals(CharSequence seq, int lo, int lim, int position) t } public void subtractRuntimeInterval(Function intervalStrFunction) { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } + IntervalUtils.addHiLoInterval(0L, 0L, IntervalOperation.SUBTRACT_INTERVALS, staticIntervals); dynamicRangeList.add(intervalStrFunction); intervalApplied = true; } public void union(long lo, long hi) { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } + if (dynamicRangeList.size() == 0) { staticIntervals.add(lo, hi); if (intervalApplied) { @@ -278,7 +310,9 @@ public void union(long lo, long hi) { } private void intersectBetweenDynamic(Function funcValue1, Function funcValue2) { - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } short operation = betweenNegated ? IntervalOperation.SUBTRACT_BETWEEN : IntervalOperation.INTERSECT_BETWEEN; IntervalUtils.addHiLoInterval(0, 0, (short) 0, IntervalDynamicIndicator.IS_LO_SEPARATE_DYNAMIC, operation, staticIntervals); @@ -301,7 +335,9 @@ private void intersectBetweenSemiDynamic(Function funcValue, long constValue) { return; } - if (isEmptySet()) return; + if (isEmptySet()) { + return; + } short operation = betweenNegated ? IntervalOperation.SUBTRACT_BETWEEN : IntervalOperation.INTERSECT_BETWEEN; IntervalUtils.addHiLoInterval(constValue, 0, (short) 0, IntervalDynamicIndicator.IS_HI_DYNAMIC, operation, staticIntervals); diff --git a/core/src/test/java/io/questdb/test/griffin/JoinTest.java b/core/src/test/java/io/questdb/test/griffin/JoinTest.java index 0fa5fba4eb11..d451ace4b93c 100644 --- a/core/src/test/java/io/questdb/test/griffin/JoinTest.java +++ b/core/src/test/java/io/questdb/test/griffin/JoinTest.java @@ -1731,6 +1731,78 @@ public void testJoinConstantTrueFF() throws Exception { testFullFat(this::testJoinConstantTrue0); } + @Test + public void testJoinContextIsolationInIntersect() throws Exception { + assertMemoryLeak(() -> { + ddl( + "CREATE TABLE t (\n" + + " created timestamp,\n" + + " event short,\n" + + " origin short\n" + + ") TIMESTAMP(created) PARTITION BY DAY;" + ); + insert("INSERT INTO t VALUES ('2023-09-21T10:00:00.000000Z', 1, 1);"); + + // The important aspects here are T2.created = '2003-09-21T10:00:00.000000Z' + // in the first query and T2.created = T3.created in the second one. Due to this, + // transitive filters pass was mistakenly mutating where clause in the second query. + final String query1 = "SELECT count(1)\n" + + "FROM t as T1 CROSS JOIN t as T2\n" + + "WHERE T2.created > now() and T2.created = '2003-09-21T10:00:00.000000Z'"; + final String query2 = "SELECT count(1)\n" + + "FROM t as T1 JOIN t as T2 on T1.created = T2.created JOIN t as T3 ON T2.created = T3.created\n" + + "WHERE T3.created < now()"; + + assertQuery("count\n0\n", query1, null, false, true); + assertQuery("count\n1\n", query2, null, false, true); + + assertQuery( + "count\n", + query1 + " INTERSECT " + query2, + null, + false, + false + ); + }); + } + + @Test + public void testJoinContextIsolationInUnion() throws Exception { + assertMemoryLeak(() -> { + ddl( + "CREATE TABLE t (\n" + + " created timestamp,\n" + + " event short,\n" + + " origin short\n" + + ") TIMESTAMP(created) PARTITION BY DAY;" + ); + insert("INSERT INTO t VALUES ('2023-09-21T10:00:00.000000Z', 1, 1);"); + insert("INSERT INTO t VALUES ('2023-09-21T11:00:00.000000Z', 1, 1);"); + + // The important aspects here are T1.event = 0.0 + // in the first query and T1.event = T2.event in the second one. Due to this, + // transitive filters pass was mistakenly mutating where clause in the second query. + final String query1 = "SELECT count(1)\n" + + "FROM t as T1 JOIN t as T2 ON T1.created = T2.created\n" + + "WHERE T1.event = 1.0"; + final String query2 = "SELECT count(1)\n" + + "FROM t as T1 JOIN t as T2 ON T1.event = T2.event"; + + assertQuery("count\n2\n", query1, null, false, true); + assertQuery("count\n4\n", query2, null, false, true); + + assertQuery( + "count\n" + + "2\n" + + "4\n", + query1 + " UNION " + query2, + null, + false, + false + ); + }); + } + @Test public void testJoinInner() throws Exception { assertMemoryLeak(() -> { diff --git a/core/src/test/java/io/questdb/test/griffin/SqlParserTest.java b/core/src/test/java/io/questdb/test/griffin/SqlParserTest.java index 657552e32da7..2c5b57962c2e 100644 --- a/core/src/test/java/io/questdb/test/griffin/SqlParserTest.java +++ b/core/src/test/java/io/questdb/test/griffin/SqlParserTest.java @@ -2991,6 +2991,42 @@ public void testEraseColumnPrefixInJoin() throws Exception { ); } + @Test + public void testEraseColumnPrefixInJoinWithNestedUnion() throws Exception { + assertQuery( + "select-choose c.customerId customerId, o.customerId customerId1, o.x x from (select [customerId] from customers c left join select [customerId, x] from (select-choose [customerId, x] customerId, x from (select [customerId, x] from (select-choose [customerId, x] customerId, x from (select [customerId, x] from orders) union select-choose [customerId, x] customerId, x from (select [customerId, x] from orders)) o where x = 10 and customerId = 100) o) o on customerId = c.customerId where customerId = 100) c", + "customers c" + + " left join ((orders union orders) o where o.x = 10) o on c.customerId = o.customerId" + + " where c.customerId = 100", + modelOf("customers").col("customerId", ColumnType.INT), + modelOf("orders") + .col("customerId", ColumnType.INT) + .col("x", ColumnType.INT) + ); + } + + @Test + public void testEraseColumnPrefixInJoinWithOuterUnion() throws Exception { + assertQuery( + "select-choose customerId from (select-choose [c.customerId customerId] c.customerId customerId from (select [customerId] from customers c left join select [customerId] from (select-choose [customerId] customerId, x from (select [customerId, x] from orders o where x = 10 and customerId = 100) o) o on customerId = c.customerId where customerId = 100) c)" + + " union all" + + " select-choose customerId from (select-choose [c.customerId customerId] c.customerId customerId from (select [customerId] from customers c left join (select [customerId] from orders o where customerId = 100) o on o.customerId = c.customerId where customerId = 100) c)", + "(select c.customerId" + + " from customers c" + + " left join (orders o where o.x = 10) o on c.customerId = o.customerId" + + " where c.customerId = 100)" + + " union all" + + " (select c.customerId " + + " from customers c" + + " left join orders o on c.customerId = o.customerId" + + " where c.customerId = 100)", + modelOf("customers").col("customerId", ColumnType.INT), + modelOf("orders") + .col("customerId", ColumnType.INT) + .col("x", ColumnType.INT) + ); + } + @Test public void testExcelODBCQ2() throws Exception { assertQuery( @@ -4395,8 +4431,8 @@ public void testJoinGroupByFilter() throws Exception { "(select country, sum(quantity) sum " + "from orders o " + "join customers c on c.customerId = o.customerId " + - "join orderDetails d on o.orderId = d.orderId" + - " where country ~ '^Z') where sum > 2", + "join orderDetails d on o.orderId = d.orderId " + + "where country ~ '^Z') where sum > 2", modelOf("orders").col("customerId", ColumnType.INT).col("orderId", ColumnType.INT).col("quantity", ColumnType.DOUBLE), modelOf("customers").col("customerId", ColumnType.INT).col("country", ColumnType.SYMBOL), modelOf("orderDetails").col("orderId", ColumnType.INT).col("comment", ColumnType.STRING)