Skip to content

Commit

Permalink
Support accessing all counters within JobTest (fix #2007)
Browse files Browse the repository at this point in the history
  • Loading branch information
clairemcginty committed Sep 16, 2019
1 parent c073a8a commit 204aa2a
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 16 deletions.
118 changes: 106 additions & 12 deletions scio-test/src/main/scala/com/spotify/scio/testing/JobTest.scala
Expand Up @@ -77,12 +77,22 @@ object JobTest {
input: Map[String, JobInputSource[_]] = Map.empty,
output: Map[String, SCollection[_] => Any] = Map.empty,
distCaches: Map[DistCacheIO[_], _] = Map.empty,
counters: Map[beam.Counter, Long => Any] = Map.empty,
distributions: Map[beam.Distribution, beam.DistributionResult => Any] = Map.empty,
gauges: Map[beam.Gauge, beam.GaugeResult => Any] = Map.empty,
counters: Set[MetricsAssertion[beam.Counter, Long]] = Set.empty,
distributions: Set[MetricsAssertion[beam.Distribution, beam.DistributionResult]] = Set.empty,
gauges: Set[MetricsAssertion[beam.Gauge, beam.GaugeResult]] = Set.empty,
wasRunInvoked: Boolean = false
)

private sealed trait MetricsAssertion[M <: beam.Metric, V]
private final case class SingleMetricAssertion[M <: beam.Metric, V](
metric: M,
assert: V => Any
) extends MetricsAssertion[M, V]

private final case class AllMetricsAssertion[M <: beam.Metric, V](
assert: Map[beam.MetricName, V] => Any
) extends MetricsAssertion[M, V]

class Builder(private var state: BuilderState) {

/** Test ID for input and output wiring. */
Expand Down Expand Up @@ -165,8 +175,31 @@ object JobTest {
* @param assertion assertion for the counter result's committed value
*/
def counter(counter: beam.Counter)(assertion: Long => Any): Builder = {
require(!state.counters.contains(counter), "Duplicate test counter: " + counter.getName)
state = state.copy(counters = state.counters + (counter -> assertion))
require(
!state.counters.exists {
case a: SingleMetricAssertion[beam.Counter, Long] => a.metric == counter
case _ => false
},
"Duplicate test counter: " + counter.getName
)

state = state.copy(
counters = state.counters +
SingleMetricAssertion(counter, assertion)
)
this
}

/**
* Evaluate all [[org.apache.beam.sdk.metrics.Counter Counters]] in the pipeline being tested.
* @param assertion assertion on the collection of all job counters' committed values
*/
def counters(assertion: Map[beam.MetricName, Long] => Any): Builder = {
state = state.copy(
counters = state.counters +
AllMetricsAssertion[beam.Counter, Long](assertion)
)

this
}

Expand All @@ -180,10 +213,33 @@ object JobTest {
distribution: beam.Distribution
)(assertion: beam.DistributionResult => Any): Builder = {
require(
!state.distributions.contains(distribution),
!state.distributions.exists {
case a: SingleMetricAssertion[beam.Distribution, beam.DistributionResult] =>
a.metric == distribution
case _ => false
},
"Duplicate test distribution: " + distribution.getName
)
state = state.copy(distributions = state.distributions + (distribution -> assertion))

state = state.copy(
distributions = state.distributions +
SingleMetricAssertion(distribution, assertion)
)
this
}

/**
* Evaluate all [[org.apache.beam.sdk.metrics.Distribution Distributions]] in the
* pipeline being tested.
* @param assertion assertion on the collection of all job distribution results'
* committed values
*/
def distributions(assertion: Map[beam.MetricName, beam.DistributionResult] => Any): Builder = {
state = state.copy(
distributions = state.distributions +
AllMetricsAssertion[beam.Distribution, beam.DistributionResult](assertion)
)

this
}

Expand All @@ -193,8 +249,29 @@ object JobTest {
* @param assertion assertion for the gauge result's committed value
*/
def gauge(gauge: beam.Gauge)(assertion: beam.GaugeResult => Any): Builder = {
require(!state.gauges.contains(gauge), "Duplicate test gauge: " + gauge.getName)
state = state.copy(gauges = state.gauges + (gauge -> assertion))
require(
!state.gauges.exists {
case a: SingleMetricAssertion[beam.Gauge, beam.GaugeResult] =>
a.metric == gauge
case _ => false
},
"Duplicate test gauge: " + gauge.getName
)

state = state.copy(gauges = state.gauges + SingleMetricAssertion(gauge, assertion))
this
}

/**
* Evaluate all [[org.apache.beam.sdk.metrics.Gauge Gauges]] in the pipeline being tested.
* @param assertion assertion on the collection of all job gauge results' committed values
*/
def gauges(assertion: Map[beam.MetricName, beam.GaugeResult] => Any): Builder = {
state = state.copy(
gauges = state.gauges +
AllMetricsAssertion[beam.Gauge, beam.GaugeResult](assertion)
)

this
}

Expand All @@ -217,12 +294,29 @@ object JobTest {
def tearDown(): Unit = {
val metricsFn = (result: ScioResult) => {
state.counters.foreach {
case (k, v) => v(result.counter(k).committed.get)
case a: SingleMetricAssertion[beam.Counter, Long] =>
a.assert(result.counter(a.metric).committed.get)
case a: AllMetricsAssertion[beam.Counter, Long] =>
a.assert(result.allCounters.map { c =>
c._1 -> c._2.committed.get
})
}
state.gauges.foreach {
case a: SingleMetricAssertion[beam.Gauge, beam.GaugeResult] =>
a.assert(result.gauge(a.metric).committed.get)
case a: AllMetricsAssertion[beam.Gauge, beam.GaugeResult] =>
a.assert(result.allGauges.map { c =>
c._1 -> c._2.committed.get
})
}
state.distributions.foreach {
case (k, v) => v(result.distribution(k).committed.get)
case a: SingleMetricAssertion[beam.Distribution, beam.DistributionResult] =>
a.assert(result.distribution(a.metric).committed.get)
case a: AllMetricsAssertion[beam.Distribution, beam.DistributionResult] =>
a.assert(result.allDistributions.map { c =>
c._1 -> c._2.committed.get
})
}
state.gauges.foreach { case (k, v) => v(result.gauge(k).committed.get) }
}
TestDataManager.tearDown(testId, metricsFn)
}
Expand Down
Expand Up @@ -30,15 +30,16 @@ import com.spotify.scio.io._
import com.spotify.scio.util.MockedPrintStream
import org.apache.avro.generic.GenericRecord
import org.apache.beam.sdk.Pipeline.PipelineExecutionException
import org.apache.beam.sdk.io.FileIO
import org.apache.beam.sdk.io.FileIO.ReadMatches.DirectoryTreatment
import org.apache.beam.sdk.metrics.DistributionResult
import org.apache.beam.sdk.transforms.PTransform
import org.apache.beam.sdk.values.PCollection
import org.apache.beam.sdk.{io => beam}
import org.joda.time.Instant
import org.scalatest.exceptions.TestFailedException

import scala.io.Source
import org.apache.beam.sdk.transforms.PTransform
import org.apache.beam.sdk.values.PCollection
import org.apache.beam.sdk.io.FileIO
import org.apache.beam.sdk.io.FileIO.ReadMatches.DirectoryTreatment

// scalastyle:off file.size.limit

Expand Down Expand Up @@ -1089,17 +1090,29 @@ class JobTestTest extends PipelineSpec {
.counter(MetricsJob.counter) { x =>
x shouldBe 10
}
.counters(_ should contain(MetricsJob.counter.getName -> 10))
.distribution(MetricsJob.distribution) { d =>
d.getCount shouldBe 10
d.getMin shouldBe 1
d.getMax shouldBe 10
d.getSum shouldBe 55
d.getMean shouldBe 5.5
}
.distributions(
_ should contain(
MetricsJob.distribution.getName ->
DistributionResult.create(55, 10, 1, 10)
)
)
.gauge(MetricsJob.gauge) { g =>
g.getValue should be >= 1L
g.getValue should be <= 10L
}
.gauges(_.map {
case (_, result) =>
result.getValue should be >= 1L
result.getValue should be <= 10L
})
.run()
}

Expand Down

0 comments on commit 204aa2a

Please sign in to comment.