Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save metrics on close + close only once + check acc registration #296

Merged
merged 12 commits into from
Oct 18, 2016
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.trueaccord.scalapb.{ScalaPbPlugin => PB}

val dataflowSdkVersion = "1.8.0"
val algebirdVersion = "0.12.2"
val autoServiceVersion = "1.0-rc2"
val avroVersion = "1.7.7"
val bigQueryVersion = "v2-rev317-1.22.0"
val bigtableVersion = "0.9.3"
Expand Down Expand Up @@ -200,7 +201,8 @@ lazy val scioCore: Project = Project(
"com.twitter" % "chill-protobuf" % chillVersion,
"commons-io" % "commons-io" % commonsIoVersion,
"org.apache.commons" % "commons-math3" % commonsMath3Version,
"com.fasterxml.jackson.module" %% "jackson-module-scala" % jacksonScalaModuleVersion
"com.fasterxml.jackson.module" %% "jackson-module-scala" % jacksonScalaModuleVersion,
"com.google.auto.service" % "auto-service" % autoServiceVersion
)
).dependsOn(
scioBigQuery
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ package object bigtable {
def bigTable(projectId: String,
instanceId: String,
tableId: String,
scan: Scan = null): SCollection[Result] = self.pipelineOp {
scan: Scan = null): SCollection[Result] = self.requireNotClosed {
val _scan: Scan = if (scan != null) scan else new Scan()
val config = new bt.CloudBigtableScanConfiguration.Builder()
.withProjectId(projectId)
Expand All @@ -61,7 +61,8 @@ package object bigtable {
}

/** Get an SCollection for a Bigtable table. */
def bigTable(config: bt.CloudBigtableScanConfiguration): SCollection[Result] = self.pipelineOp {
def bigTable(config: bt.CloudBigtableScanConfiguration): SCollection[Result] =
self.requireNotClosed {
if (self.isTest) {
val input = BigtableInput(
config.getProjectId, config.getInstanceId, config.getTableId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,8 @@ public interface ScioOptions extends PipelineOptions {
@Description("Scala version")
String getScalaVersion();
void setScalaVersion(String version);

@Description("Filename to save metrics to.")
String getMetricsLocation();
void setMetricsLocation(String version);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright 2016 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.options;

import com.google.auto.service.AutoService;
import com.google.cloud.dataflow.sdk.options.PipelineOptions;
import com.google.cloud.dataflow.sdk.options.PipelineOptionsRegistrar;
import com.google.common.collect.ImmutableList;

@AutoService(PipelineOptionsRegistrar.class)
public class ScioOptionsRegistrar implements PipelineOptionsRegistrar {
@Override
public Iterable<Class<? extends PipelineOptions>> getPipelineOptions() {
return ImmutableList.<Class<? extends PipelineOptions>>of(ScioOptions.class);
}
}
61 changes: 40 additions & 21 deletions scio-core/src/main/scala/com/spotify/scio/ScioContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class ScioContext private[scio] (val options: PipelineOptions,
}

/** Close the context. No operation can be performed once the context is closed. */
def close(): ScioResult = {
def close(): ScioResult = requireNotClosed {
if (_queryJobs.nonEmpty) {
bigQueryClient.waitForJobs(_queryJobs: _*)
}
Expand All @@ -321,14 +321,24 @@ class ScioContext private[scio] (val options: PipelineOptions,
Future.successful(result.getState)
}

new ScioResult(result, finalState, _accumulators.values.toSeq, pipeline)
val scioResult = new ScioResult(result, finalState, _accumulators.values.toSeq, pipeline)
val metricsLocation = optionsAs[ScioOptions].getMetricsLocation
if (metricsLocation != null) {
import scala.concurrent.ExecutionContext.Implicits.global
// force immediate execution on completed pipeline
finalState.value match {
case Some(_) => scioResult.saveMetrics(metricsLocation)
case None => finalState.onComplete(_ => scioResult.saveMetrics(metricsLocation))
}
}
scioResult
}

/** Whether the context is closed. */
def isClosed: Boolean = _isClosed

/** Ensure an operation is called before the pipeline is closed. */
private[scio] def pipelineOp[T](body: => T): T = {
private[scio] def requireNotClosed[T](body: => T): T = {
require(!this.isClosed, "ScioContext already closed")
body
}
Expand Down Expand Up @@ -384,7 +394,7 @@ class ScioContext private[scio] (val options: PipelineOptions,
* Get an SCollection for an object file.
* @group input
*/
def objectFile[T: ClassTag](path: String): SCollection[T] = pipelineOp {
def objectFile[T: ClassTag](path: String): SCollection[T] = requireNotClosed {
if (this.isTest) {
this.getTestInput(ObjectFileIO[T](path))
} else {
Expand All @@ -403,7 +413,8 @@ class ScioContext private[scio] (val options: PipelineOptions,
* Get an SCollection for an Avro file.
* @group input
*/
def avroFile[T: ClassTag](path: String, schema: Schema = null): SCollection[T] = pipelineOp {
def avroFile[T: ClassTag](path: String, schema: Schema = null): SCollection[T] =
requireNotClosed {
if (this.isTest) {
this.getTestInput(AvroIO[T](path))
} else {
Expand All @@ -430,7 +441,7 @@ class ScioContext private[scio] (val options: PipelineOptions,
* @group input
*/
def bigQuerySelect(sqlQuery: String,
flattenResults: Boolean = false): SCollection[TableRow] = pipelineOp {
flattenResults: Boolean = false): SCollection[TableRow] = requireNotClosed {
if (this.isTest) {
this.getTestInput(BigQueryIO(sqlQuery))
} else {
Expand All @@ -445,7 +456,7 @@ class ScioContext private[scio] (val options: PipelineOptions,
* Get an SCollection for a BigQuery table.
* @group input
*/
def bigQueryTable(table: TableReference): SCollection[TableRow] = pipelineOp {
def bigQueryTable(table: TableReference): SCollection[TableRow] = requireNotClosed {
val tableSpec: String = gio.BigQueryIO.toTableSpec(table)
if (this.isTest) {
this.getTestInput(BigQueryIO(tableSpec))
Expand All @@ -466,7 +477,7 @@ class ScioContext private[scio] (val options: PipelineOptions,
* @group input
*/
def datastore(projectId: String, query: Query, namespace: String = null): SCollection[Entity] =
pipelineOp {
requireNotClosed {
if (this.isTest) {
this.getTestInput(DatastoreIO(projectId, query, namespace))
} else {
Expand All @@ -484,7 +495,7 @@ class ScioContext private[scio] (val options: PipelineOptions,
*/
def pubsubSubscription(sub: String,
idLabel: String = null,
timestampLabel: String = null): SCollection[String] = pipelineOp {
timestampLabel: String = null): SCollection[String] = requireNotClosed {
if (this.isTest) {
this.getTestInput(PubsubIO(sub))
} else {
Expand All @@ -505,7 +516,7 @@ class ScioContext private[scio] (val options: PipelineOptions,
*/
def pubsubTopic(topic: String,
idLabel: String = null,
timestampLabel: String = null): SCollection[String] = pipelineOp {
timestampLabel: String = null): SCollection[String] = requireNotClosed {
if (this.isTest) {
this.getTestInput(PubsubIO(topic))
} else {
Expand All @@ -524,7 +535,7 @@ class ScioContext private[scio] (val options: PipelineOptions,
* Get an SCollection of TableRow for a JSON file.
* @group input
*/
def tableRowJsonFile(path: String): SCollection[TableRow] = pipelineOp {
def tableRowJsonFile(path: String): SCollection[TableRow] = requireNotClosed {
if (this.isTest) {
this.getTestInput(TableRowJsonIO(path))
} else {
Expand All @@ -539,7 +550,7 @@ class ScioContext private[scio] (val options: PipelineOptions,
*/
def textFile(path: String,
compressionType: gio.TextIO.CompressionType = gio.TextIO.CompressionType.AUTO)
: SCollection[String] = pipelineOp {
: SCollection[String] = requireNotClosed {
if (this.isTest) {
this.getTestInput(TextIO(path))
} else {
Expand All @@ -559,7 +570,8 @@ class ScioContext private[scio] (val options: PipelineOptions,
* for examples.
* @group accumulator
*/
def maxAccumulator[T](n: String)(implicit at: AccumulatorType[T]): Accumulator[T] = pipelineOp {
def maxAccumulator[T](n: String)(implicit at: AccumulatorType[T]): Accumulator[T] =
requireNotClosed {
require(!_accumulators.contains(n), s"Accumulator '$n' already exists")
val acc = new Accumulator[T] {
override val name: String = n
Expand All @@ -576,7 +588,8 @@ class ScioContext private[scio] (val options: PipelineOptions,
* for examples.
* @group accumulator
*/
def minAccumulator[T](n: String)(implicit at: AccumulatorType[T]): Accumulator[T] = pipelineOp {
def minAccumulator[T](n: String)(implicit at: AccumulatorType[T]): Accumulator[T] =
requireNotClosed {
require(!_accumulators.contains(n), s"Accumulator '$n' already exists")
val acc = new Accumulator[T] {
override val name: String = n
Expand All @@ -593,7 +606,8 @@ class ScioContext private[scio] (val options: PipelineOptions,
* for examples.
* @group accumulator
*/
def sumAccumulator[T](n: String)(implicit at: AccumulatorType[T]): Accumulator[T] = pipelineOp {
def sumAccumulator[T](n: String)(implicit at: AccumulatorType[T]): Accumulator[T] =
requireNotClosed {
require(!_accumulators.contains(n), s"Accumulator '$n' already exists")
val acc = new Accumulator[T] {
override val name: String = n
Expand All @@ -603,6 +617,9 @@ class ScioContext private[scio] (val options: PipelineOptions,
acc
}

private[scio] def containsAccumulator(acc: Accumulator[_]): Boolean =
_accumulators.contains(acc.name)

// =======================================================================
// In-memory collections
// =======================================================================
Expand All @@ -616,7 +633,7 @@ class ScioContext private[scio] (val options: PipelineOptions,
* Distribute a local Scala Iterable to form an SCollection.
* @group in_memory
*/
def parallelize[T: ClassTag](elems: Iterable[T]): SCollection[T] = pipelineOp {
def parallelize[T: ClassTag](elems: Iterable[T]): SCollection[T] = requireNotClosed {
val coder = pipeline.getCoderRegistry.getScalaCoder[T]
wrap(this.applyInternal(Create.of(elems.asJava).withCoder(coder)))
.setName(truncate(elems.toString()))
Expand All @@ -626,7 +643,8 @@ class ScioContext private[scio] (val options: PipelineOptions,
* Distribute a local Scala Map to form an SCollection.
* @group in_memory
*/
def parallelize[K: ClassTag, V: ClassTag](elems: Map[K, V]): SCollection[(K, V)] = pipelineOp {
def parallelize[K: ClassTag, V: ClassTag](elems: Map[K, V]): SCollection[(K, V)] =
requireNotClosed {
val coder = pipeline.getCoderRegistry.getScalaKvCoder[K, V]
wrap(this.applyInternal(Create.of(elems.asJava).withCoder(coder)))
.map(kv => (kv.getKey, kv.getValue))
Expand All @@ -638,7 +656,7 @@ class ScioContext private[scio] (val options: PipelineOptions,
* @group in_memory
*/
def parallelizeTimestamped[T: ClassTag](elems: Iterable[(T, Instant)])
: SCollection[T] = pipelineOp {
: SCollection[T] = requireNotClosed {
val coder = pipeline.getCoderRegistry.getScalaCoder[T]
val v = elems.map(t => TimestampedValue.of(t._1, t._2))
wrap(this.applyInternal(Create.timestamped(v.asJava).withCoder(coder)))
Expand All @@ -650,7 +668,7 @@ class ScioContext private[scio] (val options: PipelineOptions,
* @group in_memory
*/
def parallelizeTimestamped[T: ClassTag](elems: Iterable[T], timestamps: Iterable[Instant])
: SCollection[T] = pipelineOp {
: SCollection[T] = requireNotClosed {
val coder = pipeline.getCoderRegistry.getScalaCoder[T]
val v = elems.zip(timestamps).map(t => TimestampedValue.of(t._1, t._2))
wrap(this.applyInternal(Create.timestamped(v.asJava).withCoder(coder)))
Expand Down Expand Up @@ -685,7 +703,7 @@ class DistCacheScioContext private[scio] (self: ScioContext) {
* }}}
* @group dist_cache
*/
def distCache[F](uri: String)(initFn: File => F): DistCache[F] = self.pipelineOp {
def distCache[F](uri: String)(initFn: File => F): DistCache[F] = self.requireNotClosed {
if (self.isTest) {
new MockDistCache(testDistCache(DistCacheIO(uri)))
} else {
Expand All @@ -699,7 +717,8 @@ class DistCacheScioContext private[scio] (self: ScioContext) {
* @param initFn function to initialized the distributed files
* @group dist_cache
*/
def distCache[F](uris: Seq[String])(initFn: Seq[File] => F): DistCache[F] = self.pipelineOp {
def distCache[F](uris: Seq[String])(initFn: Seq[File] => F): DistCache[F] =
self.requireNotClosed {
if (self.isTest) {
new MockDistCache(testDistCache(DistCacheIO(uris.mkString("\t"))))
} else {
Expand Down
68 changes: 37 additions & 31 deletions scio-core/src/main/scala/com/spotify/scio/ScioResult.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ package com.spotify.scio

import java.nio.ByteBuffer

import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import com.google.cloud.dataflow.sdk.PipelineResult.State
import com.google.cloud.dataflow.sdk.options.{ApplicationNameOptions, DataflowPipelineOptions}
import com.google.cloud.dataflow.sdk.runners.{AggregatorPipelineExtractor, AggregatorValues}
import com.google.cloud.dataflow.sdk.transforms.Aggregator
import com.google.cloud.dataflow.sdk.util.{IOChannelUtils, MimeTypes}
import com.google.cloud.dataflow.sdk.{Pipeline, PipelineResult}
import com.spotify.scio.util.ScioUtil
import com.spotify.scio.values.Accumulator

import scala.collection.JavaConverters._
Expand All @@ -53,41 +52,45 @@ class ScioResult private[scio] (val internal: PipelineResult,

/** Get the total value of an accumulator. */
def accumulatorTotalValue[T](acc: Accumulator[T]): T = {
require(accumulators.contains(acc), "Accumulator not present in the result")
acc.combineFn(getAggregatorValues(acc).map(_.getTotalValue(acc.combineFn)).asJava)
}

/** Get the values of an accumulator at each step it was used. */
def accumulatorValuesAtSteps[T](acc: Accumulator[T]): Map[String, T] =
def accumulatorValuesAtSteps[T](acc: Accumulator[T]): Map[String, T] = {
require(accumulators.contains(acc), "Accumulator not present in the result")
getAggregatorValues(acc).flatMap(_.getValuesAtSteps.asScala).toMap
}

/** Save metrics of the finished pipeline to a file. */
def saveMetrics(filename: String): Unit = {
require(isCompleted, "Pipeline has to be finished to save metrics.")

val mapper = new ObjectMapper()
mapper.registerModule(DefaultScalaModule)

val mapper = ScioUtil.getScalaJsonMapper
val out = IOChannelUtils.create(filename, MimeTypes.TEXT)

try {
out.write(ByteBuffer.wrap(mapper.writeValueAsBytes(getMetrics)))
} finally {
if (out != null) {
out.close()
}
}

def getMetrics: MetricSchema.Metrics = {
import MetricSchema._

val totalValues = accumulators
.map(acc => AccumulatorValue(acc.name, accumulatorTotalValue(acc)))

val stepsValues = accumulators
.map(acc => AccumulatorStepsValue(acc.name,
accumulatorValuesAtSteps(acc).map(a => AccumulatorStepValue(a._1, a._2))))
val stepsValues = accumulators.map(acc => AccumulatorStepsValue(acc.name,
accumulatorValuesAtSteps(acc).map(a => AccumulatorStepValue(a._1, a._2))))

val options = this.pipeline.getOptions
val metrics = Metrics(scioVersion,
scalaVersion,
options.as(classOf[ApplicationNameOptions]).getAppName,
options.as(classOf[DataflowPipelineOptions]).getJobName,
AccumulatorMetrics(totalValues, stepsValues))
out.write(ByteBuffer.wrap(mapper.writeValueAsBytes(metrics)))
} finally {
if (out != null) {
out.close()
}
Metrics(scioVersion,
scalaVersion,
options.as(classOf[ApplicationNameOptions]).getAppName,
options.as(classOf[DataflowPipelineOptions]).getJobName,
this.state.toString,
AccumulatorMetrics(totalValues, stepsValues))
}
}

Expand All @@ -96,13 +99,16 @@ class ScioResult private[scio] (val internal: PipelineResult,

}

private[scio] case class Metrics(version: String,
scalaVersion: String,
jobName: String,
jobId: String,
accumulators: AccumulatorMetrics)
private[scio] case class AccumulatorMetrics(total: Iterable[AccumulatorValue],
steps: Iterable[AccumulatorStepsValue])
private[scio] case class AccumulatorValue(name: String, value: Any)
private[scio] case class AccumulatorStepValue(name: String, value: Any)
private[scio] case class AccumulatorStepsValue(name: String, steps: Iterable[AccumulatorStepValue])
private[scio] object MetricSchema {
case class Metrics(version: String,
scalaVersion: String,
jobName: String,
jobId: String,
state: String,
accumulators: AccumulatorMetrics)
case class AccumulatorMetrics(total: Iterable[AccumulatorValue],
steps: Iterable[AccumulatorStepsValue])
case class AccumulatorValue(name: String, value: Any)
case class AccumulatorStepValue(name: String, value: Any)
case class AccumulatorStepsValue(name: String, steps: Iterable[AccumulatorStepValue])
}
Loading