Skip to content

Commit

Permalink
[SPARK-39259][SQL][3.2] Evaluate timestamps consistently in subqueries
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Apply the optimizer rule ComputeCurrentTime consistently across subqueries.

This is a backport of apache#36654.

### Why are the changes needed?

At the moment timestamp functions like now() can return different values within a query if subqueries are involved

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

A new unit test was added

Closes apache#36753 from olaky/SPARK-39259-spark_3_2.

Lead-authored-by: Ole Sasse <ole.sasse@databricks.com>
Co-authored-by: Josh Rosen <joshrosen@databricks.com>
Co-authored-by: Dongjoon Hyun <dongjoon@apache.org>
Signed-off-by: Max Gekk <max.gekk@gmail.com>
  • Loading branch information
3 people committed Jun 9, 2022
1 parent cea3f3e commit 8361e79
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@

package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable
import java.time.{Instant, LocalDateTime}

import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ}
import org.apache.spark.sql.catalyst.trees.TreePatternBits
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, instantToMicros, localDateTimeToMicros}
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -76,29 +78,30 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] {
*/
object ComputeCurrentTime extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val currentDates = mutable.Map.empty[String, Literal]
val timeExpr = CurrentTimestamp()
val timestamp = timeExpr.eval(EmptyRow).asInstanceOf[Long]
val currentTime = Literal.create(timestamp, timeExpr.dataType)
val instant = Instant.now()
val currentTimestampMicros = instantToMicros(instant)
val currentTime = Literal.create(currentTimestampMicros, TimestampType)
val timezone = Literal.create(conf.sessionLocalTimeZone, StringType)
val localTimestamps = mutable.Map.empty[String, Literal]

plan.transformAllExpressionsWithPruning(_.containsPattern(CURRENT_LIKE)) {
case currentDate @ CurrentDate(Some(timeZoneId)) =>
currentDates.getOrElseUpdate(timeZoneId, {
Literal.create(currentDate.eval().asInstanceOf[Int], DateType)
})
case CurrentTimestamp() | Now() => currentTime
case CurrentTimeZone() => timezone
case localTimestamp @ LocalTimestamp(Some(timeZoneId)) =>
localTimestamps.getOrElseUpdate(timeZoneId, {
Literal.create(localTimestamp.eval().asInstanceOf[Long], TimestampNTZType)
})
def transformCondition(treePatternbits: TreePatternBits): Boolean = {
treePatternbits.containsPattern(CURRENT_LIKE)
}

plan.transformDownWithSubqueriesAndPruning(transformCondition) {
case subQuery =>
subQuery.transformAllExpressionsWithPruning(transformCondition) {
case cd: CurrentDate =>
Literal.create(DateTimeUtils.microsToDays(currentTimestampMicros, cd.zoneId), DateType)
case CurrentTimestamp() | Now() => currentTime
case CurrentTimeZone() => timezone
case localTimestamp: LocalTimestamp =>
val asDateTime = LocalDateTime.ofInstant(instant, localTimestamp.zoneId)
Literal.create(localDateTimeToMicros(asDateTime), TimestampNTZType)
}
}
}
}


/**
* Replaces the expression of CurrentDatabase with the current database name.
* Replaces the expression of CurrentCatalog with the current catalog name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,20 +473,33 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
* When the partial function does not apply to a given node, it is left unchanged.
*/
def transformDownWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = {
transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f)
}

/**
* This method is the top-down (pre-order) counterpart of transformUpWithSubqueries.
* Returns a copy of this node where the given partial function has been recursively applied
* first to this node, then this node's subqueries and finally this node's children.
* When the partial function does not apply to a given node, it is left unchanged.
*/
def transformDownWithSubqueriesAndPruning(
cond: TreePatternBits => Boolean,
ruleId: RuleId = UnknownRuleId)
(f: PartialFunction[PlanType, PlanType]): PlanType = {
val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] {
override def isDefinedAt(x: PlanType): Boolean = true

override def apply(plan: PlanType): PlanType = {
val transformed = f.applyOrElse[PlanType, PlanType](plan, identity)
transformed transformExpressionsDown {
case planExpression: PlanExpression[PlanType] =>
val newPlan = planExpression.plan.transformDownWithSubqueries(f)
val newPlan = planExpression.plan.transformDownWithSubqueriesAndPruning(cond, ruleId)(f)
planExpression.withNewPlan(newPlan)
}
}
}

transformDown(g)
transformDownWithPruning(cond, ruleId)(g)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer

import java.time.{LocalDateTime, ZoneId}

import scala.collection.JavaConverters.mapAsScalaMap
import scala.concurrent.duration._

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, Literal, LocalTimestamp}
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, CurrentTimeZone, InSubquery, ListQuery, Literal, LocalTimestamp, Now}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -41,11 +44,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
val max = (System.currentTimeMillis() + 1) * 1000

val lits = new scala.collection.mutable.ArrayBuffer[Long]
plan.transformAllExpressions { case e: Literal =>
lits += e.value.asInstanceOf[Long]
e
}
val lits = literals[Long](plan)
assert(lits.size == 2)
assert(lits(0) >= min && lits(0) <= max)
assert(lits(1) >= min && lits(1) <= max)
Expand All @@ -59,11 +58,7 @@ class ComputeCurrentTimeSuite extends PlanTest {
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
val max = DateTimeUtils.currentDate(ZoneId.systemDefault())

val lits = new scala.collection.mutable.ArrayBuffer[Int]
plan.transformAllExpressions { case e: Literal =>
lits += e.value.asInstanceOf[Int]
e
}
val lits = literals[Int](plan)
assert(lits.size == 2)
assert(lits(0) >= min && lits(0) <= max)
assert(lits(1) >= min && lits(1) <= max)
Expand All @@ -73,13 +68,9 @@ class ComputeCurrentTimeSuite extends PlanTest {
test("SPARK-33469: Add current_timezone function") {
val in = Project(Seq(Alias(CurrentTimeZone(), "c")()), LocalRelation())
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
val lits = new scala.collection.mutable.ArrayBuffer[String]
plan.transformAllExpressions { case e: Literal =>
lits += e.value.asInstanceOf[UTF8String].toString
e
}
val lits = literals[UTF8String](plan)
assert(lits.size == 1)
assert(lits.head == SQLConf.get.sessionLocalTimeZone)
assert(lits.head == UTF8String.fromString(SQLConf.get.sessionLocalTimeZone))
}

test("analyzer should replace localtimestamp with literals") {
Expand All @@ -92,14 +83,66 @@ class ComputeCurrentTimeSuite extends PlanTest {
val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
val max = DateTimeUtils.localDateTimeToMicros(LocalDateTime.now(zoneId))

val lits = new scala.collection.mutable.ArrayBuffer[Long]
plan.transformAllExpressions { case e: Literal =>
lits += e.value.asInstanceOf[Long]
e
}
val lits = literals[Long](plan)
assert(lits.size == 2)
assert(lits(0) >= min && lits(0) <= max)
assert(lits(1) >= min && lits(1) <= max)
assert(lits(0) == lits(1))
}

test("analyzer should use equal timestamps across subqueries") {
val timestampInSubQuery = Project(Seq(Alias(LocalTimestamp(), "timestamp1")()), LocalRelation())
val listSubQuery = ListQuery(timestampInSubQuery)
val valueSearchedInSubQuery = Seq(Alias(LocalTimestamp(), "timestamp2")())
val inFilterWithSubQuery = InSubquery(valueSearchedInSubQuery, listSubQuery)
val input = Project(Nil, Filter(inFilterWithSubQuery, LocalRelation()))

val plan = Optimize.execute(input.analyze).asInstanceOf[Project]

val lits = literals[Long](plan)
assert(lits.size == 3) // transformDownWithSubqueries covers the inner timestamp twice
assert(lits.toSet.size == 1)
}

test("analyzer should use consistent timestamps for different timezones") {
val localTimestamps = mapAsScalaMap(ZoneId.SHORT_IDS)
.map { case (zoneId, _) => Alias(LocalTimestamp(Some(zoneId)), zoneId)() }.toSeq
val input = Project(localTimestamps, LocalRelation())

val plan = Optimize.execute(input).asInstanceOf[Project]

val lits = literals[Long](plan)
assert(lits.size === localTimestamps.size)
// there are timezones with a 30 or 45 minute offset
val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet
assert(offsetsFromQuarterHour.size == 1)
}

test("analyzer should use consistent timestamps for different timestamp functions") {
val differentTimestamps = Seq(
Alias(CurrentTimestamp(), "currentTimestamp")(),
Alias(Now(), "now")(),
Alias(LocalTimestamp(Some("PLT")), "localTimestampWithTimezone")()
)
val input = Project(differentTimestamps, LocalRelation())

val plan = Optimize.execute(input).asInstanceOf[Project]

val lits = literals[Long](plan)
assert(lits.size === differentTimestamps.size)
// there are timezones with a 30 or 45 minute offset
val offsetsFromQuarterHour = lits.map( _ % Duration(15, MINUTES).toMicros).toSet
assert(offsetsFromQuarterHour.size == 1)
}

private def literals[T](plan: LogicalPlan): scala.collection.mutable.ArrayBuffer[T] = {
val literals = new scala.collection.mutable.ArrayBuffer[T]
plan.transformWithSubqueries { case subQuery =>
subQuery.transformAllExpressions { case expression: Literal =>
literals += expression.value.asInstanceOf[T]
expression
}
}
literals
}
}

0 comments on commit 8361e79

Please sign in to comment.