Skip to content

Commit

Permalink
[SPARK-18711][SQL] should disable subexpression elimination for Lambd…
Browse files Browse the repository at this point in the history
…aVariable

## What changes were proposed in this pull request?

This is kind of a long-standing bug, it's hidden until apache#15780 , which may add `AssertNotNull` on top of `LambdaVariable` and thus enables subexpression elimination.

However, subexpression elimination will evaluate the common expressions at the beginning, which is invalid for `LambdaVariable`. `LambdaVariable` usually represents loop variable, which can't be evaluated ahead of the loop.

This PR skips expressions containing `LambdaVariable` when doing subexpression elimination.

## How was this patch tested?

updated test in `DatasetAggregatorSuite`

Author: Wenchen Fan <wenchen@databricks.com>

Closes apache#16143 from cloud-fan/aggregator.
  • Loading branch information
cloud-fan authored and uzadude committed Jan 27, 2017
1 parent 7afdae9 commit 9c62d18
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable

/**
* This class is used to compute equality of (sub)expression trees. Expressions can be added
Expand Down Expand Up @@ -72,7 +73,10 @@ class EquivalentExpressions {
root: Expression,
ignoreLeaf: Boolean = true,
skipReferenceToExpressions: Boolean = true): Unit = {
val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
val skip = (root.isInstanceOf[LeafExpression] && ignoreLeaf) ||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
root.find(_.isInstanceOf[LambdaVariable]).isDefined
// There are some special expressions that we should not recurse into children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.
Expand Down
Expand Up @@ -92,13 +92,13 @@ object NameAgg extends Aggregator[AggData, String, String] {
}


object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[Int]] {
object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[(Int, Int)]] {
def zero: Seq[Int] = Nil
def reduce(b: Seq[Int], a: AggData): Seq[Int] = a.a +: b
def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1 ++ b2
def finish(r: Seq[Int]): Seq[Int] = r
def finish(r: Seq[Int]): Seq[(Int, Int)] = r.map(i => i -> i)
override def bufferEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
override def outputEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
override def outputEncoder: Encoder[Seq[(Int, Int)]] = ExpressionEncoder()
}


Expand Down Expand Up @@ -281,7 +281,7 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {

checkDataset(
ds.groupByKey(_.b).agg(SeqAgg.toColumn),
"a" -> Seq(1, 2)
"a" -> Seq(1 -> 1, 2 -> 2)
)
}

Expand Down

0 comments on commit 9c62d18

Please sign in to comment.