diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index b5aa7ce4e115b..89bbc1556c2d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -45,9 +45,12 @@ private[libsvm] class LibSVMOutputWriter( private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) - override def write(row: Row): Unit = { - val label = row.get(0) - val vector = row.get(1).asInstanceOf[Vector] + // This `asInstanceOf` is safe because it's guaranteed by `LibSVMFileFormat.verifySchema` + private val udt = dataSchema(1).dataType.asInstanceOf[VectorUDT] + + override def write(row: InternalRow): Unit = { + val label = row.getDouble(0) + val vector = udt.deserialize(row.getStruct(1, udt.sqlType.length)) writer.write(label.toString) vector.foreachActive { case (i, v) => writer.write(s" ${i + 1}:$v") @@ -115,6 +118,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { + verifySchema(dataSchema) new OutputWriterFactory { override def newInstance( path: String, diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 2517de59fed63..c701f3823842c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.source.libsvm -import java.io.File +import java.io.{File, IOException} import java.nio.charset.StandardCharsets import com.google.common.io.Files -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SaveMode} @@ -100,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("write libsvm data failed due to invalid schema") { val df = spark.read.format("text").load(path) - intercept[SparkException] { + intercept[IOException] { df.write.format("libsvm").save(path + "_2") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 7e23260e65aaf..b7f3559b6559b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -466,7 +466,7 @@ case class DataSource( // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does // not need to have the query as child, to avoid to analyze an optimized query, // because InsertIntoHadoopFsRelationCommand will be optimized first. - val columns = partitionColumns.map { name => + val partitionAttributes = partitionColumns.map { name => val plan = data.logicalPlan plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse { throw new AnalysisException( @@ -485,7 +485,7 @@ case class DataSource( InsertIntoHadoopFsRelationCommand( outputPath = outputPath, staticPartitions = Map.empty, - partitionColumns = columns, + partitionColumns = partitionAttributes, bucketSpec = bucketSpec, fileFormat = format, options = options, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 1eb4541e2c103..16c5193eda8df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -64,18 +64,18 @@ object FileFormatWriter extends Logging { val outputWriterFactory: OutputWriterFactory, val allColumns: Seq[Attribute], val partitionColumns: Seq[Attribute], - val nonPartitionColumns: Seq[Attribute], + val dataColumns: Seq[Attribute], val bucketSpec: Option[BucketSpec], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], val maxRecordsPerFile: Long) extends Serializable { - assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ nonPartitionColumns), + assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), s""" |All columns: ${allColumns.mkString(", ")} |Partition columns: ${partitionColumns.mkString(", ")} - |Non-partition columns: ${nonPartitionColumns.mkString(", ")} + |Data columns: ${dataColumns.mkString(", ")} """.stripMargin) } @@ -120,7 +120,7 @@ object FileFormatWriter extends Logging { outputWriterFactory = outputWriterFactory, allColumns = queryExecution.logical.output, partitionColumns = partitionColumns, - nonPartitionColumns = dataColumns, + dataColumns = dataColumns, bucketSpec = bucketSpec, path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, @@ -246,9 +246,8 @@ object FileFormatWriter extends Logging { currentWriter = description.outputWriterFactory.newInstance( path = tmpFilePath, - dataSchema = description.nonPartitionColumns.toStructType, + dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - currentWriter.initConverter(dataSchema = description.nonPartitionColumns.toStructType) } override def execute(iter: Iterator[InternalRow]): Set[String] = { @@ -267,7 +266,7 @@ object FileFormatWriter extends Logging { } val internalRow = iter.next() - currentWriter.writeInternal(internalRow) + currentWriter.write(internalRow) recordsInFile += 1 } releaseResources() @@ -364,9 +363,8 @@ object FileFormatWriter extends Logging { currentWriter = description.outputWriterFactory.newInstance( path = path, - dataSchema = description.nonPartitionColumns.toStructType, + dataSchema = description.dataColumns.toStructType, context = taskAttemptContext) - currentWriter.initConverter(description.nonPartitionColumns.toStructType) } override def execute(iter: Iterator[InternalRow]): Set[String] = { @@ -383,7 +381,7 @@ object FileFormatWriter extends Logging { // Returns the data columns to be written given an input row val getOutputRow = UnsafeProjection.create( - description.nonPartitionColumns, description.allColumns) + description.dataColumns, description.allColumns) // Returns the partition path given a partition key. val getPartitionStringFunc = UnsafeProjection.create( @@ -392,7 +390,7 @@ object FileFormatWriter extends Logging { // Sorts the data before write, so that we only need one writer at the same time. val sorter = new UnsafeKVExternalSorter( sortingKeySchema, - StructType.fromAttributes(description.nonPartitionColumns), + StructType.fromAttributes(description.dataColumns), SparkEnv.get.blockManager, SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, @@ -448,7 +446,7 @@ object FileFormatWriter extends Logging { newOutputWriter(currentKey, getPartitionStringFunc, fileCounter) } - currentWriter.writeInternal(sortedIterator.getValue) + currentWriter.write(sortedIterator.getValue) recordsInFile += 1 } releaseResources() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 84ea58b68a936..423009e4eccea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -45,7 +45,7 @@ case class InsertIntoHadoopFsRelationCommand( bucketSpec: Option[BucketSpec], fileFormat: FileFormat, options: Map[String, String], - @transient query: LogicalPlan, + query: LogicalPlan, mode: SaveMode, catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala index a73c8146c1b0d..868e5371426c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala @@ -47,19 +47,6 @@ abstract class OutputWriterFactory extends Serializable { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter - - /** - * Returns a new instance of [[OutputWriter]] that will write data to the given path. - * This method gets called by each task on executor to write InternalRows to - * format-specific files. Compared to the other `newInstance()`, this is a newer API that - * passes only the path that the writer must write to. The writer must write to the exact path - * and not modify it (do not add subdirectories, extensions, etc.). All other - * file-format-specific information needed to create the writer must be passed - * through the [[OutputWriterFactory]] implementation. - */ - def newWriter(path: String): OutputWriter = { - throw new UnsupportedOperationException("newInstance with just path not supported") - } } @@ -74,22 +61,11 @@ abstract class OutputWriter { * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned * tables, dynamic partition columns are not included in rows to be written. */ - def write(row: Row): Unit + def write(row: InternalRow): Unit /** * Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before * the task output is committed. */ def close(): Unit - - private var converter: InternalRow => Row = _ - - protected[sql] def initConverter(dataSchema: StructType) = { - converter = - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] - } - - protected[sql] def writeInternal(row: InternalRow): Unit = { - write(converter(row)) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 23c07eb630d31..8c19be48c2118 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -221,9 +221,7 @@ private[csv] class CsvOutputWriter( row.get(ordinal, dt).toString } - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { + override def write(row: InternalRow): Unit = { csvWriter.writeRow(rowToString(row), printHeader) printHeader = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index a9d8ddfe9d805..be1f94dbad912 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -159,9 +159,7 @@ private[json] class JsonOutputWriter( // create the Generator without separator inserted between 2 records private[this] val gen = new JacksonGenerator(dataSchema, writer, options) - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { + override def write(row: InternalRow): Unit = { gen.write(row) gen.writeLineEnding() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index 5c0f8af17a232..8361762b09703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -37,9 +37,7 @@ private[parquet] class ParquetOutputWriter(path: String, context: TaskAttemptCon }.getRecordWriter(context) } - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + override def write(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 897e535953331..6f6e3016864b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -132,9 +132,7 @@ class TextOutputWriter( private val writer = CodecStreams.createOutputStream(context, new Path(path)) - override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { + override def write(row: InternalRow): Unit = { if (!row.isNullAt(0)) { val utf8string = row.getUTF8String(0) utf8string.writeTo(writer) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 0a7631f782193..f496c01ce9ff7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -239,10 +239,7 @@ private[orc] class OrcOutputWriter( ).asInstanceOf[RecordWriter[NullWritable, Writable]] } - override def write(row: Row): Unit = - throw new UnsupportedOperationException("call writeInternal") - - override protected[sql] def writeInternal(row: InternalRow): Unit = { + override def write(row: InternalRow): Unit = { recordWriter.write(NullWritable.get(), serializer.serialize(row)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala index abc7c8cc4db89..7501334f94dd2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestSource.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.sources import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.TaskContext -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.StructType @@ -42,14 +43,14 @@ class CommitFailureTestSource extends SimpleTextSource { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) { + new SimpleTextOutputWriter(path, dataSchema, context) { var failed = false TaskContext.get().addTaskFailureListener { (t: TaskContext, e: Throwable) => failed = true SimpleTextRelation.callbackCalled = true } - override def write(row: Row): Unit = { + override def write(row: InternalRow): Unit = { if (SimpleTextRelation.failWriter) { sys.error("Intentional task writer failure for testing purpose.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 5fdf6152590ce..1607c97cd6acb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -50,7 +50,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new SimpleTextOutputWriter(path, context) + new SimpleTextOutputWriter(path, dataSchema, context) } override def getFileExtension(context: TaskAttemptContext): String = "" @@ -117,13 +117,13 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { } } -class SimpleTextOutputWriter(path: String, context: TaskAttemptContext) +class SimpleTextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) - override def write(row: Row): Unit = { - val serialized = row.toSeq.map { v => + override def write(row: InternalRow): Unit = { + val serialized = row.toSeq(dataSchema).map { v => if (v == null) "" else v.toString }.mkString(",")