Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into unsafe-shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 17, 2015
2 parents cbea80b + 031d7d4 commit 7876f31
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ case class UnresolvedRelation(
/**
* Holds the name of an attribute that has yet to be resolved.
*/
case class UnresolvedAttribute(nameParts: Seq[String])
extends Attribute with trees.LeafNode[Expression] {
case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute {

def name: String =
nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
Expand Down Expand Up @@ -96,7 +95,7 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E
* Represents all of the input attributes to a given relational operator, for example in
* "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis.
*/
trait Star extends NamedExpression with trees.LeafNode[Expression] {
abstract class Star extends LeafExpression with NamedExpression {
self: Product =>

override def name: String = throw new UnresolvedException(this, "name")
Expand Down Expand Up @@ -151,7 +150,7 @@ case class UnresolvedStar(table: Option[String]) extends Star {
* @param names the names to be associated with each output of computing [[child]].
*/
case class MultiAlias(child: Expression, names: Seq[String])
extends NamedExpression with trees.UnaryNode[Expression] {
extends UnaryExpression with NamedExpression {

override def name: String = throw new UnresolvedException(this, "name")

Expand Down Expand Up @@ -210,8 +209,7 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
/**
* Holds the expression that has yet to be aliased.
*/
case class UnresolvedAlias(child: Expression) extends NamedExpression
with trees.UnaryNode[Expression] {
case class UnresolvedAlias(child: Expression) extends UnaryExpression with NamedExpression {

override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._
* the layout of intermediate tuples, BindReferences should be run after all such transformations.
*/
case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
extends NamedExpression with trees.LeafNode[Expression] {
extends LeafExpression with NamedExpression {

override def toString: String = s"input[$ordinal]"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ case object Descending extends SortDirection
* An expression that can be used to sort a tuple. This class extends expression primarily so that
* transformations over expression will descend into its child.
*/
case class SortOrder(child: Expression, direction: SortDirection) extends Expression
with trees.UnaryNode[Expression] {
case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression {

/** Sort order is not foldable because we don't have an eval for it. */
override def foldable: Boolean = false

override def dataType: DataType = child.dataType
override def nullable: Boolean = child.nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@ package org.apache.spark.sql.catalyst.expressions
import com.clearspring.analytics.stream.cardinality.HyperLogLog

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet

abstract class AggregateExpression extends Expression {
trait AggregateExpression extends Expression {
self: Product =>

/**
* Aggregate expressions should not be foldable.
*/
override def foldable: Boolean = false

/**
* Creates a new instance that can be used to compute this aggregate expression for a group
* of input rows/
Expand Down Expand Up @@ -60,7 +64,7 @@ case class SplitEvaluation(
* An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples.
* These partial evaluations can then be combined to compute the actual answer.
*/
abstract class PartialAggregate extends AggregateExpression {
trait PartialAggregate extends AggregateExpression {
self: Product =>

/**
Expand All @@ -74,7 +78,7 @@ abstract class PartialAggregate extends AggregateExpression {
* [[AggregateExpression]] with an algorithm that will be used to compute one specific result.
*/
abstract class AggregateFunction
extends AggregateExpression with Serializable with trees.LeafNode[Expression] {
extends LeafExpression with AggregateExpression with Serializable {
self: Product =>

/** Base should return the generic aggregate expression that this function is computing */
Expand All @@ -91,7 +95,7 @@ abstract class AggregateFunction
}
}

case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
case class Min(child: Expression) extends UnaryExpression with PartialAggregate {

override def nullable: Boolean = true
override def dataType: DataType = child.dataType
Expand Down Expand Up @@ -124,7 +128,7 @@ case class MinFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def eval(input: InternalRow): Any = currentMin.value
}

case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
case class Max(child: Expression) extends UnaryExpression with PartialAggregate {

override def nullable: Boolean = true
override def dataType: DataType = child.dataType
Expand Down Expand Up @@ -157,7 +161,7 @@ case class MaxFunction(expr: Expression, base: AggregateExpression) extends Aggr
override def eval(input: InternalRow): Any = currentMax.value
}

case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
case class Count(child: Expression) extends UnaryExpression with PartialAggregate {

override def nullable: Boolean = false
override def dataType: LongType.type = LongType
Expand Down Expand Up @@ -310,7 +314,7 @@ private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] {
}

case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
extends UnaryExpression with AggregateExpression {

override def nullable: Boolean = false
override def dataType: DataType = HyperLogLogUDT
Expand Down Expand Up @@ -340,7 +344,7 @@ case class ApproxCountDistinctPartitionFunction(
}

case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
extends UnaryExpression with AggregateExpression {

override def nullable: Boolean = false
override def dataType: LongType.type = LongType
Expand Down Expand Up @@ -368,7 +372,7 @@ case class ApproxCountDistinctMergeFunction(
}

case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
extends PartialAggregate with trees.UnaryNode[Expression] {
extends UnaryExpression with PartialAggregate {

override def nullable: Boolean = false
override def dataType: LongType.type = LongType
Expand All @@ -386,7 +390,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this)
}

case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
case class Average(child: Expression) extends UnaryExpression with PartialAggregate {

override def prettyName: String = "avg"

Expand Down Expand Up @@ -479,7 +483,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
}
}

case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
case class Sum(child: Expression) extends UnaryExpression with PartialAggregate {

override def nullable: Boolean = true

Expand Down Expand Up @@ -606,8 +610,7 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression)
}
}

case class SumDistinct(child: Expression)
extends PartialAggregate with trees.UnaryNode[Expression] {
case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate {

def this() = this(null)
override def nullable: Boolean = true
Expand Down Expand Up @@ -701,7 +704,7 @@ case class CombineSetsAndSumFunction(
}
}

case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
case class First(child: Expression) extends UnaryExpression with PartialAggregate {
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
override def toString: String = s"FIRST($child)"
Expand Down Expand Up @@ -729,7 +732,7 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
override def eval(input: InternalRow): Any = result
}

case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
case class Last(child: Expression) extends UnaryExpression with PartialAggregate {
override def references: AttributeSet = child.references
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@ import org.apache.spark.sql.types._
* requested. The attributes produced by this function will be automatically copied anytime rules
* result in changes to the Generator or its children.
*/
abstract class Generator extends Expression {
self: Product =>
trait Generator extends Expression { self: Product =>

// TODO ideally we should return the type of ArrayType(StructType),
// however, we don't keep the output field names in the Generator.
override def dataType: DataType = throw new UnsupportedOperationException

override def foldable: Boolean = false

override def nullable: Boolean = false

/**
Expand Down Expand Up @@ -99,8 +100,9 @@ case class UserDefinedGenerator(
/**
* Given an input array produces a sequence of rows for each value in the array.
*/
case class Explode(child: Expression)
extends Generator with trees.UnaryNode[Expression] {
case class Explode(child: Expression) extends UnaryExpression with Generator {

override def children: Seq[Expression] = child :: Nil

override def checkInputDataTypes(): TypeCheckResult = {
if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
Expand All @@ -127,6 +129,4 @@ case class Explode(child: Expression)
else inputMap.map { case (k, v) => InternalRow(k, v) }
}
}

override def toString: String = s"explode($child)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@ object NamedExpression {
*/
case class ExprId(id: Long)

abstract class NamedExpression extends Expression {
self: Product =>
/**
* An [[Expression]] that is named.
*/
trait NamedExpression extends Expression { self: Product =>

/** We should never fold named expressions in order to not remove the alias. */
override def foldable: Boolean = false

def name: String
def exprId: ExprId
Expand Down Expand Up @@ -78,8 +83,7 @@ abstract class NamedExpression extends Expression {
}
}

abstract class Attribute extends NamedExpression {
self: Product =>
abstract class Attribute extends LeafExpression with NamedExpression { self: Product =>

override def references: AttributeSet = AttributeSet(this)

Expand Down Expand Up @@ -110,7 +114,7 @@ case class Alias(child: Expression, name: String)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifiers: Seq[String] = Nil,
val explicitMetadata: Option[Metadata] = None)
extends NamedExpression with trees.UnaryNode[Expression] {
extends UnaryExpression with NamedExpression {

// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
override lazy val resolved =
Expand Down Expand Up @@ -172,7 +176,8 @@ case class AttributeReference(
nullable: Boolean = true,
override val metadata: Metadata = Metadata.empty)(
val exprId: ExprId = NamedExpression.newExprId,
val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] {
val qualifiers: Seq[String] = Nil)
extends Attribute {

/**
* Returns true iff the expression id is the same for both attributes.
Expand Down Expand Up @@ -242,7 +247,7 @@ case class AttributeReference(
* A place holder used when printing expressions without debugging information such as the
* expression id or the unresolved indicator.
*/
case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
case class PrettyAttribute(name: String) extends Attribute {

override def toString: String = name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
case l: Literal => l

// Fold expressions that are foldable.
case e if e.foldable => Literal.create(e.eval(null), e.dataType)
case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType)

// Fold "literal in (item1, item2, ..., literal, ...)" into true directly.
case In(Literal(v, _), list) if list.exists {
Expand All @@ -361,7 +361,7 @@ object OptimizeIn extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
val hSet = list.map(e => e.eval(null))
val hSet = list.map(e => e.eval(EmptyRow))
InSet(v, HashSet() ++ hSet)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.trees


abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
Expand Down Expand Up @@ -277,15 +276,21 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
/**
* A logical plan node with no children.
*/
abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
abstract class LeafNode extends LogicalPlan {
self: Product =>

override def children: Seq[LogicalPlan] = Nil
}

/**
* A logical plan node with single child.
*/
abstract class UnaryNode extends LogicalPlan with trees.UnaryNode[LogicalPlan] {
abstract class UnaryNode extends LogicalPlan {
self: Product =>

def child: LogicalPlan

override def children: Seq[LogicalPlan] = child :: Nil
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,19 +452,3 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
s"$nodeName(${args.mkString(",")})"
}
}


/**
* A [[TreeNode]] with no children.
*/
trait LeafNode[BaseType <: TreeNode[BaseType]] {
def children: Seq[BaseType] = Nil
}

/**
* A [[TreeNode]] with a single [[child]].
*/
trait UnaryNode[BaseType <: TreeNode[BaseType]] {
def child: BaseType
def children: Seq[BaseType] = child :: Nil
}
Loading

0 comments on commit 7876f31

Please sign in to comment.