Skip to content

Commit

Permalink
Make partition metrics accessible via V1WriteCommand
Browse files Browse the repository at this point in the history
Introduce `driverSidePartitionMetrics` and drop event firing.
  • Loading branch information
Steve Vaughan Jr committed Feb 20, 2024
1 parent ddd5dbb commit ff3fcec
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.command

import java.net.URI

import scala.collection.mutable

import org.apache.hadoop.conf.Configuration

import org.apache.spark.SparkContext
Expand All @@ -27,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{CTEInChildren, LogicalPlan, UnaryCommand}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, PartitionTaskStats}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.SerializableConfiguration
Expand All @@ -53,10 +55,12 @@ trait DataWritingCommand extends UnaryCommand with CTEInChildren {
DataWritingCommand.logicalPlanOutputWithNames(query, outputColumnNames)

lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics
lazy val partitionMetrics: mutable.Map[String, PartitionTaskStats] =
BasicWriteJobStatsTracker.partitionMetrics

def basicWriteJobStatsTracker(hadoopConf: Configuration): BasicWriteJobStatsTracker = {
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
new BasicWriteJobStatsTracker(serializableHadoopConf, metrics)
new BasicWriteJobStatsTracker(serializableHadoopConf, metrics, partitionMetrics)
}

def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.SparkListenerEvent
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker._
Expand Down Expand Up @@ -56,12 +54,6 @@ case class BasicWritePartitionTaskStats(
numRows: Long)
extends PartitionTaskStats


@DeveloperApi
case class SparkListenerPartitionTaskEvent(stats: Map[String, PartitionTaskStats])
extends SparkListenerEvent


/**
* Simple [[WriteTaskStatsTracker]] implementation that produces [[BasicWriteTaskStats]].
*/
Expand Down Expand Up @@ -236,13 +228,16 @@ class BasicWriteTaskStatsTracker(
class BasicWriteJobStatsTracker(
serializableHadoopConf: SerializableConfiguration,
@transient val driverSideMetrics: Map[String, SQLMetric],
@transient val driverSidePartitionMetrics: mutable.Map[String, PartitionTaskStats],
taskCommitTimeMetric: SQLMetric)
extends WriteJobStatsTracker {

def this(
serializableHadoopConf: SerializableConfiguration,
metrics: Map[String, SQLMetric]) = {
this(serializableHadoopConf, metrics - TASK_COMMIT_TIME, metrics(TASK_COMMIT_TIME))
metrics: Map[String, SQLMetric],
partitionMetrics: mutable.Map[String, PartitionTaskStats]) = {
this(serializableHadoopConf, metrics - TASK_COMMIT_TIME, partitionMetrics,
metrics(TASK_COMMIT_TIME))
}

override def newTaskInstance(): WriteTaskStatsTracker = {
Expand All @@ -256,8 +251,6 @@ class BasicWriteJobStatsTracker(
var numFiles: Long = 0L
var totalNumBytes: Long = 0L
var totalNumOutput: Long = 0L
val partitionsStats: mutable.Map[String, BasicWritePartitionTaskStats]
= mutable.Map.empty.withDefaultValue(BasicWritePartitionTaskStats(0, 0L, 0L))

val basicStats = stats.map(_.asInstanceOf[BasicWriteTaskStats])

Expand All @@ -269,8 +262,8 @@ class BasicWriteJobStatsTracker(

summary.partitionsStats.foreach(s => {
val path = partitionsMap.getOrElse(s._1, "")
val current = partitionsStats(path)
partitionsStats(path) = BasicWritePartitionTaskStats(
val current = partitionMetrics(path)
driverSidePartitionMetrics(path) = BasicWritePartitionTaskStats(
current.numFiles + s._2.numFiles,
current.numBytes + s._2.numBytes,
current.numRows + s._2.numRows
Expand All @@ -284,8 +277,6 @@ class BasicWriteJobStatsTracker(
driverSideMetrics(NUM_OUTPUT_ROWS_KEY).add(totalNumOutput)
driverSideMetrics(NUM_PARTS_KEY).add(partitionsSet.size)

sparkContext.listenerBus.post(SparkListenerPartitionTaskEvent(partitionsStats.toMap))

val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, driverSideMetrics.values.toList)
}
Expand All @@ -312,4 +303,7 @@ object BasicWriteJobStatsTracker {
JOB_COMMIT_TIME -> SQLMetrics.createTimingMetric(sparkContext, "job commit time")
)
}

def partitionMetrics: mutable.Map[String, PartitionTaskStats] =
mutable.Map.empty.withDefaultValue(BasicWritePartitionTaskStats(0, 0L, 0L))
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ case class InsertIntoHadoopFsRelationCommand(
qualifiedOutputPath
}

val statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf))
val updatedPartitionPaths =
FileFormatWriter.write(
sparkSession = sparkSession,
Expand All @@ -185,11 +186,10 @@ case class InsertIntoHadoopFsRelationCommand(
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
statsTrackers = statsTrackers,
options = options,
numStaticPartitionCols = staticPartitions.size)


// update metastore partition metadata
if (updatedPartitionPaths.isEmpty && staticPartitions.nonEmpty
&& partitionColumns.length == staticPartitions.size) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2

import java.util.UUID

import scala.collection.mutable
import scala.jdk.CollectionConverters._

import org.apache.hadoop.conf.Configuration
Expand All @@ -32,7 +33,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.connector.write.{BatchWrite, LogicalWriteInfo, Write}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, DataSource, OutputWriterFactory, WriteJobDescription}
import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, DataSource, OutputWriterFactory, PartitionTaskStats, WriteJobDescription}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
Expand Down Expand Up @@ -124,8 +125,11 @@ trait FileWrite extends Write {
prepareWrite(sparkSession.sessionState.conf, job, caseInsensitiveOptions, schema)
val allColumns = toAttributes(schema)
val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics
val partitionMetrics: mutable.Map[String, PartitionTaskStats]
= BasicWriteJobStatsTracker.partitionMetrics
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
val statsTracker = new BasicWriteJobStatsTracker(serializableHadoopConf, metrics)
val statsTracker = new BasicWriteJobStatsTracker(serializableHadoopConf, metrics,
partitionMetrics)
// TODO: after partitioning is supported in V2:
// 1. filter out partition columns in `dataColumns`.
// 2. Don't use Seq.empty for `partitionColumns`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ class FileStreamSink(

private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = {
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics)
new BasicWriteJobStatsTracker(serializableHadoopConf, BasicWriteJobStatsTracker.metrics,
BasicWriteJobStatsTracker.partitionMetrics)
}

override def addBatch(batchId: Long, data: DataFrame): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@ package org.apache.spark.sql.execution.metric

import java.io.File

import scala.collection.mutable
import scala.collection.mutable.HashMap
import scala.collection.mutable.ListBuffer

import org.apache.spark.TestUtils
import org.apache.spark.TestUtils.withListener
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerTaskEnd}
import org.apache.spark.sql.DataFrame
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.{SparkPlan, SparkPlanInfo}
import org.apache.spark.sql.execution.datasources.SparkListenerPartitionTaskEvent
import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanInfo}
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.V1WriteCommand
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SQLAppStatusStore}
import org.apache.spark.sql.internal.SQLConf.WHOLESTAGE_CODEGEN_ENABLED
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.util.QueryExecutionListener


trait SQLMetricsTestUtils extends SQLTestUtils {
Expand Down Expand Up @@ -107,43 +108,78 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
assert(totalNumBytes > 0)
}

private class CaptureWriteCommand extends QueryExecutionListener {

val v1WriteCommands: mutable.Buffer[V1WriteCommand] = mutable.Buffer[V1WriteCommand]()

override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
if (qe.executedPlan.isInstanceOf[ExecutedCommandExec] ||
qe.executedPlan.isInstanceOf[DataWritingCommandExec]) {
qe.optimizedPlan match {
case _: V1WriteCommand =>
val executedPlanCmd = qe.executedPlan.asInstanceOf[DataWritingCommandExec].cmd
v1WriteCommands += executedPlanCmd.asInstanceOf[V1WriteCommand]

// All other commands
case _ =>
logDebug(f"Query execution data is not currently supported for query: " +
f"${qe.toString} with plan class ${qe.executedPlan.getClass.getName} " +
f" and executed plan : ${qe.executedPlan}")
}
}
}

override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}

}

protected def withQueryExecutionListener[L <: QueryExecutionListener]
(spark: SparkSession, listener: L)
(body: L => Unit): Unit = {
spark.listenerManager.register(listener)
try {
body(listener)
}
finally {
spark.listenerManager.unregister(listener)
}
}


protected def testMetricsNonDynamicPartition(
dataFormat: String,
tableName: String): Unit = {
withTable(tableName) {
Seq((1, 2)).toDF("i", "j")
.write.format(dataFormat).mode("overwrite").saveAsTable(tableName)
val listener = new CaptureWriteCommand()
withQueryExecutionListener(spark, listener) { _ =>
withTable(tableName) {
Seq((1, 2)).toDF("i", "j")
.write.format(dataFormat).mode("overwrite").saveAsTable(tableName)

val tableLocation =
new File(spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).location)
val tableLocation =
new File(spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).location)

// 2 files, 100 rows, 0 dynamic partition.
verifyWriteDataMetrics(Seq(2, 0, 100)) {
(0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2)
.write.format(dataFormat).mode("overwrite").insertInto(tableName)
// 2 files, 100 rows, 0 dynamic partition.
verifyWriteDataMetrics(Seq(2, 0, 100)) {
(0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2)
.write.format(dataFormat).mode("overwrite").insertInto(tableName)
}
assert(TestUtils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2)
}
assert(TestUtils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2)
}

// Verify that there were 2 write command for the entire write process. This test creates the
// table and performs a repartitioning
assert(listener.v1WriteCommands.length == 2)
assert(listener.v1WriteCommands.forall(
v1WriteCommand => v1WriteCommand.partitionMetrics.isEmpty))
}

protected def testMetricsDynamicPartition(
provider: String,
dataFormat: String,
tableName: String): Unit = {
val events: ListBuffer[SparkListenerPartitionTaskEvent]
= ListBuffer[SparkListenerPartitionTaskEvent]()
val listener: SparkListener = new SparkListener {

override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
case partitionStats: SparkListenerPartitionTaskEvent =>
events += partitionStats
case _ => // ignore other events
}
}

}
withListener(sparkContext, listener) { _ =>
val listener = new CaptureWriteCommand()
withQueryExecutionListener(spark, listener) { _ =>
withTable(tableName) {
withTempPath { dir =>
spark.sql(
Expand Down Expand Up @@ -172,16 +208,18 @@ trait SQLMetricsTestUtils extends SQLTestUtils {
}
}

// Verify that there was a single event for the entire write process
assert(events.length == 1)
val event = events.head
// Verify that there was a single write command for the entire write process
assert(listener.v1WriteCommands.length == 1)
val v1WriteCommand = listener.v1WriteCommands.head

// Verify the number of partitions
assert(event.stats.keySet.size == 40)
assert(v1WriteCommand.partitionMetrics.keySet.size == 40)
// Verify the number of files per partition
assert(event.stats.values.forall(partitionStats => partitionStats.numFiles == 1))
assert(v1WriteCommand.partitionMetrics.values.forall(
partitionStats => partitionStats.numFiles == 1))
// Verify the number of rows per partition
assert(event.stats.values.forall(partitionStats => partitionStats.numRows == 2))
assert(v1WriteCommand.partitionMetrics.values.forall(
partitionStats => partitionStats.numRows == 2))
}

/**
Expand Down

0 comments on commit ff3fcec

Please sign in to comment.