Skip to content

Commit

Permalink
[SPARK-19085][SQL] cleanup OutputWriterFactory and OutputWriter
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

`OutputWriterFactory`/`OutputWriter` are internal interfaces and we can remove some unnecessary APIs:
1. `OutputWriterFactory.newWriter(path: String)`: no one calls it and no one implements it.
2. `OutputWriter.write(row: Row)`: during execution we only call `writeInternal`, which is weird as `OutputWriter` is already an internal interface. We should rename `writeInternal` to `write` and remove `def write(row: Row)` and it's related converter code. All implementations should just implement `def write(row: InternalRow)`

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes apache#16479 from cloud-fan/hive-writer.
  • Loading branch information
cloud-fan authored and uzadude committed Jan 27, 2017
1 parent 30f7bd2 commit 6e99290
Show file tree
Hide file tree
Showing 13 changed files with 37 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ private[libsvm] class LibSVMOutputWriter(

private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))

override def write(row: Row): Unit = {
val label = row.get(0)
val vector = row.get(1).asInstanceOf[Vector]
// This `asInstanceOf` is safe because it's guaranteed by `LibSVMFileFormat.verifySchema`
private val udt = dataSchema(1).dataType.asInstanceOf[VectorUDT]

override def write(row: InternalRow): Unit = {
val label = row.getDouble(0)
val vector = udt.deserialize(row.getStruct(1, udt.sqlType.length))
writer.write(label.toString)
vector.foreachActive { case (i, v) =>
writer.write(s" ${i + 1}:$v")
Expand Down Expand Up @@ -115,6 +118,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
verifySchema(dataSchema)
new OutputWriterFactory {
override def newInstance(
path: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark.ml.source.libsvm

import java.io.File
import java.io.{File, IOException}
import java.nio.charset.StandardCharsets

import com.google.common.io.Files

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SaveMode}
Expand Down Expand Up @@ -100,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {

test("write libsvm data failed due to invalid schema") {
val df = spark.read.format("text").load(path)
intercept[SparkException] {
intercept[IOException] {
df.write.format("libsvm").save(path + "_2")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ case class DataSource(
// SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does
// not need to have the query as child, to avoid to analyze an optimized query,
// because InsertIntoHadoopFsRelationCommand will be optimized first.
val columns = partitionColumns.map { name =>
val partitionAttributes = partitionColumns.map { name =>
val plan = data.logicalPlan
plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse {
throw new AnalysisException(
Expand All @@ -485,7 +485,7 @@ case class DataSource(
InsertIntoHadoopFsRelationCommand(
outputPath = outputPath,
staticPartitions = Map.empty,
partitionColumns = columns,
partitionColumns = partitionAttributes,
bucketSpec = bucketSpec,
fileFormat = format,
options = options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ object FileFormatWriter extends Logging {
val outputWriterFactory: OutputWriterFactory,
val allColumns: Seq[Attribute],
val partitionColumns: Seq[Attribute],
val nonPartitionColumns: Seq[Attribute],
val dataColumns: Seq[Attribute],
val bucketSpec: Option[BucketSpec],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long)
extends Serializable {

assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns),
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
s"""
|All columns: ${allColumns.mkString(", ")}
|Partition columns: ${partitionColumns.mkString(", ")}
|Non-partition columns: ${nonPartitionColumns.mkString(", ")}
|Data columns: ${dataColumns.mkString(", ")}
""".stripMargin)
}

Expand Down Expand Up @@ -120,7 +120,7 @@ object FileFormatWriter extends Logging {
outputWriterFactory = outputWriterFactory,
allColumns = queryExecution.logical.output,
partitionColumns = partitionColumns,
nonPartitionColumns = dataColumns,
dataColumns = dataColumns,
bucketSpec = bucketSpec,
path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations,
Expand Down Expand Up @@ -246,9 +246,8 @@ object FileFormatWriter extends Logging {

currentWriter = description.outputWriterFactory.newInstance(
path = tmpFilePath,
dataSchema = description.nonPartitionColumns.toStructType,
dataSchema = description.dataColumns.toStructType,
context = taskAttemptContext)
currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType)
}

override def execute(iter: Iterator[InternalRow]): Set[String] = {
Expand All @@ -267,7 +266,7 @@ object FileFormatWriter extends Logging {
}

val internalRow = iter.next()
currentWriter.writeInternal(internalRow)
currentWriter.write(internalRow)
recordsInFile += 1
}
releaseResources()
Expand Down Expand Up @@ -364,9 +363,8 @@ object FileFormatWriter extends Logging {

currentWriter = description.outputWriterFactory.newInstance(
path = path,
dataSchema = description.nonPartitionColumns.toStructType,
dataSchema = description.dataColumns.toStructType,
context = taskAttemptContext)
currentWriter.initConverter(description.nonPartitionColumns.toStructType)
}

override def execute(iter: Iterator[InternalRow]): Set[String] = {
Expand All @@ -383,7 +381,7 @@ object FileFormatWriter extends Logging {

// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(
description.nonPartitionColumns, description.allColumns)
description.dataColumns, description.allColumns)

// Returns the partition path given a partition key.
val getPartitionStringFunc = UnsafeProjection.create(
Expand All @@ -392,7 +390,7 @@ object FileFormatWriter extends Logging {
// Sorts the data before write, so that we only need one writer at the same time.
val sorter = new UnsafeKVExternalSorter(
sortingKeySchema,
StructType.fromAttributes(description.nonPartitionColumns),
StructType.fromAttributes(description.dataColumns),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes,
Expand Down Expand Up @@ -448,7 +446,7 @@ object FileFormatWriter extends Logging {
newOutputWriter(currentKey, getPartitionStringFunc, fileCounter)
}

currentWriter.writeInternal(sortedIterator.getValue)
currentWriter.write(sortedIterator.getValue)
recordsInFile += 1
}
releaseResources()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ case class InsertIntoHadoopFsRelationCommand(
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
options: Map[String, String],
@transient query: LogicalPlan,
query: LogicalPlan,
mode: SaveMode,
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,6 @@ abstract class OutputWriterFactory extends Serializable {
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter

/**
* Returns a new instance of [[OutputWriter]] that will write data to the given path.
* This method gets called by each task on executor to write InternalRows to
* format-specific files. Compared to the other `newInstance()`, this is a newer API that
* passes only the path that the writer must write to. The writer must write to the exact path
* and not modify it (do not add subdirectories, extensions, etc.). All other
* file-format-specific information needed to create the writer must be passed
* through the [[OutputWriterFactory]] implementation.
*/
def newWriter(path: String): OutputWriter = {
throw new UnsupportedOperationException("newInstance with just path not supported")
}
}


Expand All @@ -74,22 +61,11 @@ abstract class OutputWriter {
* Persists a single row. Invoked on the executor side. When writing to dynamically partitioned
* tables, dynamic partition columns are not included in rows to be written.
*/
def write(row: Row): Unit
def write(row: InternalRow): Unit

/**
* Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before
* the task output is committed.
*/
def close(): Unit

private var converter: InternalRow => Row = _

protected[sql] def initConverter(dataSchema: StructType) = {
converter =
CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row]
}

protected[sql] def writeInternal(row: InternalRow): Unit = {
write(converter(row))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,7 @@ private[csv] class CsvOutputWriter(
row.get(ordinal, dt).toString
}

override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")

override protected[sql] def writeInternal(row: InternalRow): Unit = {
override def write(row: InternalRow): Unit = {
csvWriter.writeRow(rowToString(row), printHeader)
printHeader = false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ private[json] class JsonOutputWriter(
// create the Generator without separator inserted between 2 records
private[this] val gen = new JacksonGenerator(dataSchema, writer, options)

override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")

override protected[sql] def writeInternal(row: InternalRow): Unit = {
override def write(row: InternalRow): Unit = {
gen.write(row)
gen.writeLineEnding()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ private[parquet] class ParquetOutputWriter(path: String, context: TaskAttemptCon
}.getRecordWriter(context)
}

override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")

override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row)
override def write(row: InternalRow): Unit = recordWriter.write(null, row)

override def close(): Unit = recordWriter.close(context)
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ class TextOutputWriter(

private val writer = CodecStreams.createOutputStream(context, new Path(path))

override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal")

override protected[sql] def writeInternal(row: InternalRow): Unit = {
override def write(row: InternalRow): Unit = {
if (!row.isNullAt(0)) {
val utf8string = row.getUTF8String(0)
utf8string.writeTo(writer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,7 @@ private[orc] class OrcOutputWriter(
).asInstanceOf[RecordWriter[NullWritable, Writable]]
}

override def write(row: Row): Unit =
throw new UnsupportedOperationException("call writeInternal")

override protected[sql] def writeInternal(row: InternalRow): Unit = {
override def write(row: InternalRow): Unit = {
recordWriter.write(NullWritable.get(), serializer.serialize(row))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.sql.sources
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.TaskContext
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType

Expand All @@ -42,14 +43,14 @@ class CommitFailureTestSource extends SimpleTextSource {
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new SimpleTextOutputWriter(path, context) {
new SimpleTextOutputWriter(path, dataSchema, context) {
var failed = false
TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) =>
failed = true
SimpleTextRelation.callbackCalled = true
}

override def write(row: Row): Unit = {
override def write(row: InternalRow): Unit = {
if (SimpleTextRelation.failWriter) {
sys.error("Intentional task writer failure for testing purpose.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new SimpleTextOutputWriter(path, context)
new SimpleTextOutputWriter(path, dataSchema, context)
}

override def getFileExtension(context: TaskAttemptContext): String = ""
Expand Down Expand Up @@ -117,13 +117,13 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister {
}
}

class SimpleTextOutputWriter(path: String, context: TaskAttemptContext)
class SimpleTextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext)
extends OutputWriter {

private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path))

override def write(row: Row): Unit = {
val serialized = row.toSeq.map { v =>
override def write(row: InternalRow): Unit = {
val serialized = row.toSeq(dataSchema).map { v =>
if (v == null) "" else v.toString
}.mkString(",")

Expand Down

0 comments on commit 6e99290

Please sign in to comment.