Skip to content

Commit 80c079f

Browse files
committed
Replace date_trunc in comparison with a range predicate
Range predicates (ComparisonExpression, BetweenPredicate) can be transformed into a `TupleDomain` and thus help with predicate pushdown. Range-based `TupleDomain` representation is critical for connectors which have min/max-based metadata (like Iceberg data files and manifest lists), as ranges allow for intersection tests, something that is hard to do in a generic manner for `ConnectorExpression`.
1 parent f661c52 commit 80c079f

File tree

4 files changed

+343
-13
lines changed

4 files changed

+343
-13
lines changed

core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@
238238
import io.trino.sql.planner.iterative.rule.TransformUncorrelatedInPredicateSubqueryToSemiJoin;
239239
import io.trino.sql.planner.iterative.rule.TransformUncorrelatedSubqueryToJoin;
240240
import io.trino.sql.planner.iterative.rule.UnwrapCastInComparison;
241+
import io.trino.sql.planner.iterative.rule.UnwrapDateTruncInComparison;
241242
import io.trino.sql.planner.iterative.rule.UnwrapRowSubscript;
242243
import io.trino.sql.planner.iterative.rule.UnwrapSingleColumnRowInApply;
243244
import io.trino.sql.planner.iterative.rule.UseNonPartitionedJoinLookupSource;
@@ -370,6 +371,7 @@ public PlanOptimizers(
370371
.addAll(new UnwrapRowSubscript().rules())
371372
.addAll(new PushCastIntoRow().rules())
372373
.addAll(new UnwrapCastInComparison(plannerContext, typeAnalyzer).rules())
374+
.addAll(new UnwrapDateTruncInComparison(plannerContext, typeAnalyzer).rules())
373375
.addAll(new RemoveDuplicateConditions(metadata).rules())
374376
.addAll(new CanonicalizeExpressions(plannerContext, typeAnalyzer).rules())
375377
.addAll(new RemoveRedundantDateTrunc(plannerContext, typeAnalyzer).rules())

core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/UnwrapCastInComparison.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,12 @@ private static Instant getInstantWithTruncation(TimestampWithTimeZoneType type,
545545
.plus(longTimestampWithTimeZone.getPicosOfMilli() / PICOSECONDS_PER_NANOSECOND, ChronoUnit.NANOS);
546546
}
547547

548-
private static Expression falseIfNotNull(Expression argument)
548+
public static Expression falseIfNotNull(Expression argument)
549549
{
550550
return and(new IsNullPredicate(argument), new NullLiteral());
551551
}
552552

553-
private static Expression trueIfNotNull(Expression argument)
553+
public static Expression trueIfNotNull(Expression argument)
554554
{
555555
return or(new IsNotNullPredicate(argument), new NullLiteral());
556556
}
Lines changed: 328 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,328 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.sql.planner.iterative.rule;
15+
16+
import com.google.common.base.Enums;
17+
import com.google.common.base.Throwables;
18+
import com.google.common.collect.ImmutableList;
19+
import com.google.common.math.LongMath;
20+
import io.airlift.slice.Slice;
21+
import io.trino.Session;
22+
import io.trino.metadata.ResolvedFunction;
23+
import io.trino.spi.TrinoException;
24+
import io.trino.spi.function.InvocationConvention;
25+
import io.trino.spi.type.LongTimestamp;
26+
import io.trino.spi.type.TimestampType;
27+
import io.trino.spi.type.TimestampWithTimeZoneType;
28+
import io.trino.spi.type.Type;
29+
import io.trino.spi.type.VarcharType;
30+
import io.trino.sql.InterpretedFunctionInvoker;
31+
import io.trino.sql.PlannerContext;
32+
import io.trino.sql.planner.ExpressionInterpreter;
33+
import io.trino.sql.planner.LiteralEncoder;
34+
import io.trino.sql.planner.NoOpSymbolResolver;
35+
import io.trino.sql.planner.TypeAnalyzer;
36+
import io.trino.sql.planner.TypeProvider;
37+
import io.trino.sql.tree.BetweenPredicate;
38+
import io.trino.sql.tree.Cast;
39+
import io.trino.sql.tree.ComparisonExpression;
40+
import io.trino.sql.tree.Expression;
41+
import io.trino.sql.tree.ExpressionTreeRewriter;
42+
import io.trino.sql.tree.FunctionCall;
43+
import io.trino.sql.tree.IsNotNullPredicate;
44+
import io.trino.sql.tree.IsNullPredicate;
45+
import io.trino.sql.tree.NodeRef;
46+
import io.trino.sql.tree.NotExpression;
47+
import io.trino.sql.tree.NullLiteral;
48+
49+
import java.lang.invoke.MethodHandle;
50+
import java.time.LocalDate;
51+
import java.time.LocalDateTime;
52+
import java.time.ZoneOffset;
53+
import java.util.Locale;
54+
import java.util.Map;
55+
import java.util.Optional;
56+
57+
import static com.google.common.base.Verify.verify;
58+
import static io.trino.metadata.ResolvedFunction.extractFunctionName;
59+
import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
60+
import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.NEVER_NULL;
61+
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
62+
import static io.trino.spi.type.BooleanType.BOOLEAN;
63+
import static io.trino.spi.type.DateType.DATE;
64+
import static io.trino.spi.type.TimestampType.createTimestampType;
65+
import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND;
66+
import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND;
67+
import static io.trino.sql.ExpressionUtils.isEffectivelyLiteral;
68+
import static io.trino.sql.ExpressionUtils.or;
69+
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
70+
import static io.trino.sql.planner.iterative.rule.UnwrapCastInComparison.falseIfNotNull;
71+
import static io.trino.sql.planner.iterative.rule.UnwrapCastInComparison.trueIfNotNull;
72+
import static io.trino.sql.tree.BooleanLiteral.TRUE_LITERAL;
73+
import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN;
74+
import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL;
75+
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN;
76+
import static io.trino.sql.tree.ComparisonExpression.Operator.LESS_THAN_OR_EQUAL;
77+
import static io.trino.type.DateTimes.PICOSECONDS_PER_MICROSECOND;
78+
import static io.trino.type.DateTimes.scaleFactor;
79+
import static java.lang.Math.floorDiv;
80+
import static java.lang.Math.floorMod;
81+
import static java.lang.Math.toIntExact;
82+
import static java.math.RoundingMode.UNNECESSARY;
83+
import static java.util.Objects.requireNonNull;
84+
85+
/**
86+
* Given constant temporal unit U and a constant date/time expression t that's rounded to unit,
87+
* converts expression of the form
88+
* <pre>
89+
* date_trunc(unit, date_time) = t
90+
* </pre>
91+
* <p>
92+
* into
93+
* <pre>
94+
* date_time BETWEEN t AND (t + unit)
95+
* </pre>
96+
* <p>
97+
* It also applies to comparison operators other than equality and detects expressions that
98+
* are known to be true or false, e.g. {@code date_trunc('month', ...) = DATE '2005-09-10'}
99+
* is known to be false.
100+
*
101+
* @see UnwrapCastInComparison
102+
*/
103+
public class UnwrapDateTruncInComparison
104+
extends ExpressionRewriteRuleSet
105+
{
106+
public UnwrapDateTruncInComparison(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer)
107+
{
108+
super(createRewrite(plannerContext, typeAnalyzer));
109+
}
110+
111+
private static ExpressionRewriter createRewrite(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer)
112+
{
113+
requireNonNull(plannerContext, "plannerContext is null");
114+
requireNonNull(typeAnalyzer, "typeAnalyzer is null");
115+
116+
return (expression, context) -> unwrapDateTrunc(context.getSession(), plannerContext, typeAnalyzer, context.getSymbolAllocator().getTypes(), expression);
117+
}
118+
119+
private static Expression unwrapDateTrunc(Session session,
120+
PlannerContext plannerContext,
121+
TypeAnalyzer typeAnalyzer,
122+
TypeProvider types,
123+
Expression expression)
124+
{
125+
return ExpressionTreeRewriter.rewriteWith(new Visitor(plannerContext, typeAnalyzer, session, types), expression);
126+
}
127+
128+
private static class Visitor
129+
extends io.trino.sql.tree.ExpressionRewriter<Void>
130+
{
131+
private final PlannerContext plannerContext;
132+
private final TypeAnalyzer typeAnalyzer;
133+
private final Session session;
134+
private final TypeProvider types;
135+
private final InterpretedFunctionInvoker functionInvoker;
136+
private final LiteralEncoder literalEncoder;
137+
138+
public Visitor(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, Session session, TypeProvider types)
139+
{
140+
this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
141+
this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null");
142+
this.session = requireNonNull(session, "session is null");
143+
this.types = requireNonNull(types, "types is null");
144+
this.functionInvoker = new InterpretedFunctionInvoker(plannerContext.getFunctionManager());
145+
this.literalEncoder = new LiteralEncoder(plannerContext);
146+
}
147+
148+
@Override
149+
public Expression rewriteComparisonExpression(ComparisonExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
150+
{
151+
ComparisonExpression expression = (ComparisonExpression) treeRewriter.defaultRewrite((Expression) node, null);
152+
return unwrapDateTrunc(expression);
153+
}
154+
155+
// Simplify `date_trunc(unit, d) ? value`
156+
private Expression unwrapDateTrunc(ComparisonExpression expression)
157+
{
158+
// Expect date_trunc on the left side and value on the right side of the comparison.
159+
// This is provided by CanonicalizeExpressionRewriter.
160+
161+
if (!(expression.getLeft() instanceof FunctionCall call) ||
162+
!extractFunctionName(call.getName()).equals("date_trunc") ||
163+
call.getArguments().size() != 2) {
164+
return expression;
165+
}
166+
167+
Map<NodeRef<Expression>, Type> expressionTypes = typeAnalyzer.getTypes(session, types, expression);
168+
Expression unitExpression = call.getArguments().get(0);
169+
if (!(expressionTypes.get(NodeRef.of(unitExpression)) instanceof VarcharType) || !isEffectivelyLiteral(plannerContext, session, unitExpression)) {
170+
return expression;
171+
}
172+
Slice unitName = (Slice) new ExpressionInterpreter(unitExpression, plannerContext, session, expressionTypes)
173+
.optimize(NoOpSymbolResolver.INSTANCE);
174+
if (unitName == null) {
175+
return expression;
176+
}
177+
178+
Expression argument = call.getArguments().get(1);
179+
Type argumentType = expressionTypes.get(NodeRef.of(argument));
180+
181+
Type rightType = expressionTypes.get(NodeRef.of(expression.getRight()));
182+
verify(argumentType.equals(rightType), "Mismatched types: %s and %s", argumentType, rightType);
183+
184+
Object right = new ExpressionInterpreter(expression.getRight(), plannerContext, session, expressionTypes)
185+
.optimize(NoOpSymbolResolver.INSTANCE);
186+
187+
if (right == null || right instanceof NullLiteral) {
188+
return switch (expression.getOperator()) {
189+
case EQUAL, NOT_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL, GREATER_THAN, GREATER_THAN_OR_EQUAL -> new Cast(new NullLiteral(), toSqlType(BOOLEAN));
190+
case IS_DISTINCT_FROM -> new IsNotNullPredicate(argument);
191+
};
192+
}
193+
194+
if (right instanceof Expression) {
195+
return expression;
196+
}
197+
if (rightType instanceof TimestampWithTimeZoneType) {
198+
// Cannot replace with a range due to how date_trunc operates on value's local date/time.
199+
// I.e. unwrapping is possible only when values are all of some fixed zone and the zone is known.
200+
return expression;
201+
}
202+
203+
ResolvedFunction resolvedFunction = plannerContext.getMetadata().decodeFunction(call.getName());
204+
205+
Optional<SupportedUnit> unitIfSupported = Enums.getIfPresent(SupportedUnit.class, unitName.toStringUtf8().toUpperCase(Locale.ENGLISH)).toJavaUtil();
206+
if (unitIfSupported.isEmpty()) {
207+
return expression;
208+
}
209+
SupportedUnit unit = unitIfSupported.get();
210+
if (unit == SupportedUnit.DAY && rightType == DATE) {
211+
// case handled by CanonicalizeExpressionRewriter
212+
return expression;
213+
}
214+
215+
Object rangeLow = functionInvoker.invoke(resolvedFunction, session.toConnectorSession(), ImmutableList.of(unitName, right));
216+
int compare = compare(rightType, rangeLow, right);
217+
verify(compare <= 0, "Truncation of %s value %s resulted in a bigger value %s", rightType, right, rangeLow);
218+
boolean rightValueAtRangeLow = compare == 0;
219+
220+
return switch (expression.getOperator()) {
221+
case EQUAL -> {
222+
if (!rightValueAtRangeLow) {
223+
yield falseIfNotNull(argument);
224+
}
225+
yield between(argument, rightType, rangeLow, calculateRangeEndInclusive(rangeLow, rightType, unit));
226+
}
227+
case NOT_EQUAL -> {
228+
if (!rightValueAtRangeLow) {
229+
yield trueIfNotNull(argument);
230+
}
231+
yield new NotExpression(between(argument, rightType, rangeLow, calculateRangeEndInclusive(rangeLow, rightType, unit)));
232+
}
233+
case IS_DISTINCT_FROM -> {
234+
if (!rightValueAtRangeLow) {
235+
yield TRUE_LITERAL;
236+
}
237+
yield or(
238+
new IsNullPredicate(argument),
239+
new NotExpression(between(argument, rightType, rangeLow, calculateRangeEndInclusive(rangeLow, rightType, unit))));
240+
}
241+
case LESS_THAN -> {
242+
if (rightValueAtRangeLow) {
243+
yield new ComparisonExpression(LESS_THAN, argument, toExpression(rangeLow, rightType));
244+
}
245+
yield new ComparisonExpression(LESS_THAN_OR_EQUAL, argument, toExpression(calculateRangeEndInclusive(rangeLow, rightType, unit), rightType));
246+
}
247+
case LESS_THAN_OR_EQUAL -> new ComparisonExpression(LESS_THAN_OR_EQUAL, argument, toExpression(calculateRangeEndInclusive(rangeLow, rightType, unit), rightType));
248+
case GREATER_THAN -> new ComparisonExpression(GREATER_THAN, argument, toExpression(calculateRangeEndInclusive(rangeLow, rightType, unit), rightType));
249+
case GREATER_THAN_OR_EQUAL -> {
250+
if (rightValueAtRangeLow) {
251+
yield new ComparisonExpression(GREATER_THAN_OR_EQUAL, argument, toExpression(rangeLow, rightType));
252+
}
253+
yield new ComparisonExpression(GREATER_THAN, argument, toExpression(calculateRangeEndInclusive(rangeLow, rightType, unit), rightType));
254+
}
255+
};
256+
}
257+
258+
private Object calculateRangeEndInclusive(Object rangeStart, Type type, SupportedUnit rangeUnit)
259+
{
260+
if (type == DATE) {
261+
LocalDate date = LocalDate.ofEpochDay((long) rangeStart);
262+
LocalDate endExclusive = switch (rangeUnit) {
263+
case DAY -> throw new UnsupportedOperationException("Unsupported type and unit: %s, %s".formatted(type, rangeUnit));
264+
case MONTH -> date.plusMonths(1);
265+
case YEAR -> date.plusYears(1);
266+
};
267+
return endExclusive.toEpochDay() - 1;
268+
}
269+
if (type instanceof TimestampType timestampType) {
270+
if (timestampType.isShort()) {
271+
long epochMicros = (long) rangeStart;
272+
long epochSecond = floorDiv(epochMicros, MICROSECONDS_PER_SECOND);
273+
int microOfSecond = floorMod(epochMicros, MICROSECONDS_PER_SECOND);
274+
verify(microOfSecond == 0, "Unexpected micros, value should be rounded to %s: %s", rangeUnit, microOfSecond);
275+
LocalDateTime dateTime = LocalDateTime.ofEpochSecond(epochSecond, 0, ZoneOffset.UTC);
276+
LocalDateTime endExclusive = switch (rangeUnit) {
277+
case DAY -> dateTime.plusDays(1);
278+
case MONTH -> dateTime.plusMonths(1);
279+
case YEAR -> dateTime.plusYears(1);
280+
};
281+
long endExclusiveMicros = endExclusive.toEpochSecond(ZoneOffset.UTC) * MICROSECONDS_PER_SECOND
282+
+ LongMath.divide(endExclusive.getNano(), NANOSECONDS_PER_MICROSECOND, UNNECESSARY);
283+
return endExclusiveMicros - scaleFactor(timestampType.getPrecision(), 6);
284+
}
285+
LongTimestamp longTimestamp = (LongTimestamp) rangeStart;
286+
verify(longTimestamp.getPicosOfMicro() == 0, "Unexpected picos in %s, value not rounded to %s", rangeStart, rangeUnit);
287+
long endInclusiveMicros = (long) calculateRangeEndInclusive(longTimestamp.getEpochMicros(), createTimestampType(6), rangeUnit);
288+
return new LongTimestamp(endInclusiveMicros, toIntExact(PICOSECONDS_PER_MICROSECOND - scaleFactor(timestampType.getPrecision(), 12)));
289+
}
290+
throw new UnsupportedOperationException("Unsupported type: " + type);
291+
}
292+
293+
private BetweenPredicate between(Expression argument, Type type, Object minInclusive, Object maxInclusive)
294+
{
295+
return new BetweenPredicate(
296+
argument,
297+
toExpression(minInclusive, type),
298+
toExpression(maxInclusive, type));
299+
}
300+
301+
private Expression toExpression(Object value, Type type)
302+
{
303+
return literalEncoder.toExpression(session, value, type);
304+
}
305+
306+
private int compare(Type type, Object first, Object second)
307+
{
308+
requireNonNull(first, "first is null");
309+
requireNonNull(second, "second is null");
310+
// choice of placing unordered values first or last does not matter for this code
311+
MethodHandle comparisonOperator = plannerContext.getTypeOperators().getComparisonUnorderedLastOperator(type, InvocationConvention.simpleConvention(FAIL_ON_NULL, NEVER_NULL, NEVER_NULL));
312+
try {
313+
return toIntExact((long) comparisonOperator.invoke(first, second));
314+
}
315+
catch (Throwable throwable) {
316+
Throwables.throwIfUnchecked(throwable);
317+
throw new TrinoException(GENERIC_INTERNAL_ERROR, throwable);
318+
}
319+
}
320+
}
321+
322+
private enum SupportedUnit
323+
{
324+
DAY,
325+
MONTH,
326+
YEAR,
327+
}
328+
}

0 commit comments

Comments
 (0)