Skip to content

Commit

Permalink
[SPARK-23315][SQL] failed to get output from canonicalized data sourc…
Browse files Browse the repository at this point in the history
…e v2 related plans

## What changes were proposed in this pull request?

`DataSourceV2Relation`  keeps a `fullOutput` and resolves the real output on demand by column name lookup. i.e.
```
lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name =>
  fullOutput.find(_.name == name).get
}
```

This will be broken after we canonicalize the plan, because all attribute names become "None", see https://github.com/apache/spark/blob/v2.3.0-rc1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala#L42

To fix this, `DataSourceV2Relation` should just keep `output`, and update the `output` when doing column pruning.

## How was this patch tested?

a new test case

Author: Wenchen Fan <wenchen@databricks.com>

Closes apache#20485 from cloud-fan/canonicalize.
  • Loading branch information
cloud-fan authored and Robert Kruszewski committed Feb 12, 2018
1 parent 9ebaef0 commit 1cb0993
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2

import java.util.Objects

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.sources.v2.reader._

/**
Expand All @@ -28,9 +28,9 @@ import org.apache.spark.sql.sources.v2.reader._
trait DataSourceReaderHolder {

/**
* The full output of the data source reader, without column pruning.
* The output of the data source reader, w.r.t. column pruning.
*/
def fullOutput: Seq[AttributeReference]
def output: Seq[Attribute]

/**
* The held data source reader.
Expand All @@ -46,7 +46,7 @@ trait DataSourceReaderHolder {
case s: SupportsPushDownFilters => s.pushedFilters().toSet
case _ => Nil
}
Seq(fullOutput, reader.getClass, reader.readSchema(), filters)
Seq(output, reader.getClass, filters)
}

def canEqual(other: Any): Boolean
Expand All @@ -61,8 +61,4 @@ trait DataSourceReaderHolder {
override def hashCode(): Int = {
metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
}

lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name =>
fullOutput.find(_.name == name).get
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.sources.v2.reader._

case class DataSourceV2Relation(
fullOutput: Seq[AttributeReference],
output: Seq[AttributeReference],
reader: DataSourceReader)
extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder {

Expand All @@ -37,7 +37,7 @@ case class DataSourceV2Relation(
}

override def newInstance(): DataSourceV2Relation = {
copy(fullOutput = fullOutput.map(_.newInstance()))
copy(output = output.map(_.newInstance()))
}
}

Expand All @@ -46,8 +46,8 @@ case class DataSourceV2Relation(
* to the non-streaming relation.
*/
class StreamingDataSourceV2Relation(
fullOutput: Seq[AttributeReference],
reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) {
output: Seq[AttributeReference],
reader: DataSourceReader) extends DataSourceV2Relation(output, reader) {
override def isStreaming: Boolean = true
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,12 @@ import org.apache.spark.sql.types.StructType
* Physical plan node for scanning data from a data source.
*/
case class DataSourceV2ScanExec(
fullOutput: Seq[AttributeReference],
output: Seq[AttributeReference],
@transient reader: DataSourceReader)
extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan {

override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]

override def producedAttributes: AttributeSet = AttributeSet(fullOutput)

override def outputPartitioning: physical.Partitioning = reader match {
case s: SupportsReportPartitioning =>
new DataSourcePartitioning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,33 +81,44 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel

// TODO: add more push down rules.

pushDownRequiredColumns(filterPushed, filterPushed.outputSet)
val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet)
// After column pruning, we may have redundant PROJECT nodes in the query plan, remove them.
RemoveRedundantProject(filterPushed)
RemoveRedundantProject(columnPruned)
}

// TODO: nested fields pruning
private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = {
private def pushDownRequiredColumns(
plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = {
plan match {
case Project(projectList, child) =>
case p @ Project(projectList, child) =>
val required = projectList.flatMap(_.references)
pushDownRequiredColumns(child, AttributeSet(required))
p.copy(child = pushDownRequiredColumns(child, AttributeSet(required)))

case Filter(condition, child) =>
case f @ Filter(condition, child) =>
val required = requiredByParent ++ condition.references
pushDownRequiredColumns(child, required)
f.copy(child = pushDownRequiredColumns(child, required))

case relation: DataSourceV2Relation => relation.reader match {
case reader: SupportsPushDownRequiredColumns =>
// TODO: Enable the below assert after we make `DataSourceV2Relation` immutable. Fow now
// it's possible that the mutable reader being updated by someone else, and we need to
// always call `reader.pruneColumns` here to correct it.
// assert(relation.output.toStructType == reader.readSchema(),
// "Schema of data source reader does not match the relation plan.")

val requiredColumns = relation.output.filter(requiredByParent.contains)
reader.pruneColumns(requiredColumns.toStructType)

case _ =>
val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
val newOutput = reader.readSchema().map(_.name).map(nameToAttr)
relation.copy(output = newOutput)

case _ => relation
}

// TODO: there may be more operators that can be used to calculate the required columns. We
// can add more and more in the future.
case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet))
case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import test.org.apache.spark.sql.sources.v2._
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -316,6 +316,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
val reader4 = getReader(q4)
assert(reader4.requiredSchema.fieldNames === Seq("i"))
}

test("SPARK-23315: get output from canonicalized data source v2 related plans") {
def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = {
val logical = df.queryExecution.optimizedPlan.collect {
case d: DataSourceV2Relation => d
}.head
assert(logical.canonicalized.output.length == numOutput)

val physical = df.queryExecution.executedPlan.collect {
case d: DataSourceV2ScanExec => d
}.head
assert(physical.canonicalized.output.length == numOutput)
}

val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
checkCanonicalizedOutput(df, 2)
checkCanonicalizedOutput(df.select('i), 1)
}
}

class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
Expand Down

0 comments on commit 1cb0993

Please sign in to comment.