Skip to content

Commit

Permalink
[SPARK-24073][SQL] Rename DataReaderFactory to InputPartition.
Browse files Browse the repository at this point in the history
Renames:
* `DataReaderFactory` to `InputPartition`
* `DataReader` to `InputPartitionReader`
* `createDataReaderFactories` to `planInputPartitions`
* `createUnsafeDataReaderFactories` to `planUnsafeInputPartitions`
* `createBatchDataReaderFactories` to `planBatchInputPartitions`

This fixes the changes in SPARK-23219, which renamed ReadTask to
DataReaderFactory. The intent of that change was to make the read and
write API match (write side uses DataWriterFactory), but the underlying
problem is that the two classes are not equivalent.

ReadTask/DataReader function as Iterable/Iterator. One InputPartition is
a specific partition of the data to be read, in contrast to
DataWriterFactory where the same factory instance is used in all write
tasks. InputPartition's purpose is to manage the lifecycle of the
associated reader, which is now called InputPartitionReader, with an
explicit create operation to mirror the close operation. This was no
longer clear from the API because DataReaderFactory appeared to be more
generic than it is and it isn't clear why a set of them is produced for
a read.

Existing tests, which have been updated to use the new name.

Author: Ryan Blue <blue@apache.org>

Closes apache#21145 from rdblue/SPARK-24073-revert-data-reader-factory-rename.
  • Loading branch information
rdblue committed Aug 29, 2018
1 parent 6fbee72 commit 8f986c3
Show file tree
Hide file tree
Showing 13 changed files with 123 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
* {@link ReadSupport#createReader(DataSourceOptions)} or
* {@link ReadSupportWithSchema#createReader(StructType, DataSourceOptions)}.
* It can mix in various query optimization interfaces to speed up the data scan. The actual scan
* logic is delegated to {@link DataReaderFactory}s that are returned by
* {@link #createDataReaderFactories()}.
* logic is delegated to {@link InputPartition}s that are returned by
* {@link #planInputPartitions()}.
*
* There are mainly 3 kinds of query optimizations:
* 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column
Expand Down Expand Up @@ -62,8 +62,8 @@ public interface DataSourceReader {
StructType readSchema();

/**
* Returns a list of reader factories. Each factory is responsible for creating a data reader to
* output data for one RDD partition. That means the number of factories returned here is same as
* Returns a list of read tasks. Each task is responsible for creating a data reader to
* output data for one RDD partition. That means the number of tasks returned here is same as
* the number of RDD partitions this scan outputs.
*
* Note that, this may not be a full scan if the data source reader mixes in other optimization
Expand All @@ -73,5 +73,5 @@ public interface DataSourceReader {
* If this method fails (by throwing an exception), the action would fail and no Spark job was
* submitted.
*/
List<DataReaderFactory<Row>> createDataReaderFactories();
List<InputPartition<Row>> planInputPartitions();
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@
import org.apache.spark.annotation.InterfaceStability;

/**
* A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is
* An input partition returned by {@link DataSourceReader#planInputPartitions()} and is
* responsible for creating the actual data reader. The relationship between
* {@link DataReaderFactory} and {@link DataReader}
* {@link InputPartition} and {@link InputPartitionReader}
* is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}.
*
* Note that, the reader factory will be serialized and sent to executors, then the data reader
* will be created on executors and do the actual reading. So {@link DataReaderFactory} must be
* serializable and {@link DataReader} doesn't need to be.
* Note that input partitions will be serialized and sent to executors, then the partition reader
* will be created on executors and do the actual reading. So {@link InputPartition} must be
* serializable and {@link InputPartitionReader} doesn't need to be.
*/
@InterfaceStability.Evolving
public interface DataReaderFactory<T> extends Serializable {
public interface InputPartition<T> extends Serializable {

/**
* The preferred locations where the data reader returned by this reader factory can run faster,
* The preferred locations where the data reader returned by this partition can run faster,
* but Spark does not guarantee to run the data reader on these locations.
* The implementations should make sure that it can be run on any location.
* The location is a string representing the host name.
Expand All @@ -57,5 +57,5 @@ default String[] preferredLocations() {
* If this method fails (by throwing an exception), the corresponding Spark task would fail and
* get retried until hitting the maximum retry times.
*/
DataReader<T> createDataReader();
InputPartitionReader<T> createPartitionReader();
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
import org.apache.spark.annotation.InterfaceStability;

/**
* A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for
* A data reader returned by {@link InputPartition#createPartitionReader()} and is responsible for
* outputting data for a RDD partition.
*
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data
* source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source
* readers that mix in {@link SupportsScanUnsafeRow}.
*/
@InterfaceStability.Evolving
public interface DataReader<T> extends Closeable {
public interface InputPartitionReader<T> extends Closeable {

/**
* Proceed to next record, returns false if there is no more records.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@
public interface SupportsScanUnsafeRow extends DataSourceReader {

@Override
default List<DataReaderFactory<Row>> createDataReaderFactories() {
default List<InputPartition<Row>> planInputPartitions() {
throw new IllegalStateException(
"createDataReaderFactories not supported by default within SupportsScanUnsafeRow");
"planInputPartitions not supported by default within SupportsScanUnsafeRow");
}

/**
* Similar to {@link DataSourceV2Reader#createDataReaderFactories()},
* Similar to {@link DataSourceReader#planInputPartitions()},
* but returns data in unsafe row format.
*/
List<DataReaderFactory<UnsafeRow>> createUnsafeRowReaderFactories();
List<InputPartition<UnsafeRow>> planUnsafeInputPartitions();
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,29 @@

package org.apache.spark.sql.execution.datasources.v2

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
import org.apache.spark.sql.sources.v2.reader.InputPartition

class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: DataReaderFactory[T])
class DataSourceRDDPartition[T : ClassTag](val index: Int, val inputPartition: InputPartition[T])
extends Partition with Serializable

class DataSourceRDD[T : ClassTag](
sc: SparkContext,
@transient private val readerFactories: java.util.List[DataReaderFactory[T]])
@transient private val readerFactories: Seq[InputPartition[T]])
extends RDD[T](sc, Nil) {

override protected def getPartitions: Array[Partition] = {
readerFactories.asScala.zipWithIndex.map {
readerFactories.zipWithIndex.map {
case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
}.toArray
}

override def compute(split: Partition, context: TaskContext): Iterator[T] = {
val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.createDataReader()
val reader = split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition
.createPartitionReader()
context.addTaskCompletionListener(_ => reader.close())
val iter = new Iterator[T] {
private[this] var valuePrepared = false
Expand All @@ -63,6 +63,6 @@ class DataSourceRDD[T : ClassTag](
}

override def getPreferredLocations(split: Partition): Seq[String] = {
split.asInstanceOf[DataSourceRDDPartition[T]].readerFactory.preferredLocations()
split.asInstanceOf[DataSourceRDDPartition[T]].inputPartition.preferredLocations()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.spark.sql.execution.datasources.v2

import java.util.Objects

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.sources.v2.reader._

/**
Expand Down Expand Up @@ -48,17 +46,4 @@ trait DataSourceReaderHolder {
}
Seq(output, reader.getClass, reader.readSchema(), filters)
}

def canEqual(other: Any): Boolean

override def equals(other: Any): Boolean = other match {
case other: DataSourceReaderHolder =>
canEqual(other) && metadata.length == other.metadata.length &&
metadata.zip(other.metadata).forall { case (l, r) => l == r }
case _ => false
}

override def hashCode(): Int = {
metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,34 @@ case class DataSourceV2ScanExec(

override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]

override def references: AttributeSet = AttributeSet.empty
// TODO: unify the equal/hashCode implementation for all data source v2 query plans.
override def equals(other: Any): Boolean = other match {
case other: DataSourceV2ScanExec =>
output == other.output && reader.getClass == other.reader.getClass && options == other.options
case _ => false
}

override def hashCode(): Int = {
Seq(output, source, options).hashCode()
}

private lazy val partitions: Seq[InputPartition[UnsafeRow]] = reader match {
case r: SupportsScanUnsafeRow => r.planUnsafeInputPartitions().asScala
case _ =>
reader.planInputPartitions().asScala.map {
new RowToUnsafeRowInputPartition(_, reader.readSchema()): InputPartition[UnsafeRow]
}
}

private lazy val inputRDD: RDD[InternalRow] = reader match {
case _ =>
new DataSourceRDD(sparkContext, partitions).asInstanceOf[RDD[InternalRow]]
}

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))

override protected def doExecute(): RDD[InternalRow] = {
val readTasks: java.util.List[DataReaderFactory[UnsafeRow]] = reader match {
case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories()
case _ =>
reader.createDataReaderFactories().asScala.map {
new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow]
}.asJava
}

val inputRDD = new DataSourceRDD(sparkContext, readTasks)
.asInstanceOf[RDD[InternalRow]]
val numOutputRows = longMetric("numOutputRows")
inputRDD.map { r =>
numOutputRows += 1
Expand All @@ -67,19 +79,22 @@ case class DataSourceV2ScanExec(
}
}

class RowToUnsafeRowDataReaderFactory(rowReaderFactory: DataReaderFactory[Row], schema: StructType)
extends DataReaderFactory[UnsafeRow] {
class RowToUnsafeRowInputPartition(partition: InputPartition[Row], schema: StructType)
extends InputPartition[UnsafeRow] {

override def preferredLocations: Array[String] = rowReaderFactory.preferredLocations
override def preferredLocations: Array[String] = partition.preferredLocations

override def createDataReader: DataReader[UnsafeRow] = {
new RowToUnsafeDataReader(
rowReaderFactory.createDataReader, RowEncoder.apply(schema).resolveAndBind())
override def createPartitionReader: InputPartitionReader[UnsafeRow] = {
new RowToUnsafeInputPartitionReader(
partition.createPartitionReader, RowEncoder.apply(schema).resolveAndBind())
}
}

class RowToUnsafeDataReader(rowReader: DataReader[Row], encoder: ExpressionEncoder[Row])
extends DataReader[UnsafeRow] {
class RowToUnsafeInputPartitionReader(
val rowReader: InputPartitionReader[Row],
encoder: ExpressionEncoder[Row])

extends InputPartitionReader[UnsafeRow] {

override def next: Boolean = rowReader.next

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ public Filter[] pushedFilters() {
}

@Override
public List<DataReaderFactory<Row>> createDataReaderFactories() {
List<DataReaderFactory<Row>> res = new ArrayList<>();
public List<InputPartition<Row>> planInputPartitions() {
List<InputPartition<Row>> res = new ArrayList<>();

Integer lowerBound = null;
for (Filter filter : filters) {
Expand All @@ -94,33 +94,33 @@ public List<DataReaderFactory<Row>> createDataReaderFactories() {
}

if (lowerBound == null) {
res.add(new JavaAdvancedDataReaderFactory(0, 5, requiredSchema));
res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema));
res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema));
res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema));
} else if (lowerBound < 4) {
res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 5, requiredSchema));
res.add(new JavaAdvancedDataReaderFactory(5, 10, requiredSchema));
res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema));
res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema));
} else if (lowerBound < 9) {
res.add(new JavaAdvancedDataReaderFactory(lowerBound + 1, 10, requiredSchema));
res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema));
}

return res;
}
}

static class JavaAdvancedDataReaderFactory implements DataReaderFactory<Row>, DataReader<Row> {
static class JavaAdvancedInputPartition implements InputPartition<Row>, InputPartitionReader<Row> {
private int start;
private int end;
private StructType requiredSchema;

JavaAdvancedDataReaderFactory(int start, int end, StructType requiredSchema) {
JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) {
this.start = start;
this.end = end;
this.requiredSchema = requiredSchema;
}

@Override
public DataReader<Row> createDataReader() {
return new JavaAdvancedDataReaderFactory(start - 1, end, requiredSchema);
public InputPartitionReader<Row> createPartitionReader() {
return new JavaAdvancedInputPartition(start - 1, end, requiredSchema);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.ReadSupportWithSchema;
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.types.StructType;

public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupportWithSchema {
Expand All @@ -42,7 +42,7 @@ public StructType readSchema() {
}

@Override
public List<DataReaderFactory<Row>> createDataReaderFactories() {
public List<InputPartition<Row>> planInputPartitions() {
return java.util.Collections.emptyList();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import org.apache.spark.sql.sources.v2.DataSourceV2;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.ReadSupport;
import org.apache.spark.sql.sources.v2.reader.DataReader;
import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;
import org.apache.spark.sql.sources.v2.reader.InputPartitionReader;
import org.apache.spark.sql.sources.v2.reader.InputPartition;
import org.apache.spark.sql.sources.v2.reader.DataSourceReader;
import org.apache.spark.sql.types.StructType;

Expand All @@ -41,25 +41,25 @@ public StructType readSchema() {
}

@Override
public List<DataReaderFactory<Row>> createDataReaderFactories() {
public List<InputPartition<Row>> planInputPartitions() {
return java.util.Arrays.asList(
new JavaSimpleDataReaderFactory(0, 5),
new JavaSimpleDataReaderFactory(5, 10));
new JavaSimpleInputPartition(0, 5),
new JavaSimpleInputPartition(5, 10));
}
}

static class JavaSimpleDataReaderFactory implements DataReaderFactory<Row>, DataReader<Row> {
static class JavaSimpleInputPartition implements InputPartition<Row>, InputPartitionReader<Row> {
private int start;
private int end;

JavaSimpleDataReaderFactory(int start, int end) {
JavaSimpleInputPartition(int start, int end) {
this.start = start;
this.end = end;
}

@Override
public DataReader<Row> createDataReader() {
return new JavaSimpleDataReaderFactory(start - 1, end);
public InputPartitionReader<Row> createPartitionReader() {
return new JavaSimpleInputPartition(start - 1, end);
}

@Override
Expand Down

0 comments on commit 8f986c3

Please sign in to comment.