diff --git a/scio-test/src/main/scala/com/spotify/scio/testing/JobTest.scala b/scio-test/src/main/scala/com/spotify/scio/testing/JobTest.scala index 175d4562be..58ec7b0030 100644 --- a/scio-test/src/main/scala/com/spotify/scio/testing/JobTest.scala +++ b/scio-test/src/main/scala/com/spotify/scio/testing/JobTest.scala @@ -77,12 +77,22 @@ object JobTest { input: Map[String, JobInputSource[_]] = Map.empty, output: Map[String, SCollection[_] => Any] = Map.empty, distCaches: Map[DistCacheIO[_], _] = Map.empty, - counters: Map[beam.Counter, Long => Any] = Map.empty, - distributions: Map[beam.Distribution, beam.DistributionResult => Any] = Map.empty, - gauges: Map[beam.Gauge, beam.GaugeResult => Any] = Map.empty, + counters: Set[MetricsAssertion[beam.Counter, Long]] = Set.empty, + distributions: Set[MetricsAssertion[beam.Distribution, beam.DistributionResult]] = Set.empty, + gauges: Set[MetricsAssertion[beam.Gauge, beam.GaugeResult]] = Set.empty, wasRunInvoked: Boolean = false ) + private sealed trait MetricsAssertion[M <: beam.Metric, V] + private final case class SingleMetricAssertion[M <: beam.Metric, V]( + metric: M, + assert: V => Any + ) extends MetricsAssertion[M, V] + + private final case class AllMetricsAssertion[M <: beam.Metric, V]( + assert: Map[beam.MetricName, V] => Any + ) extends MetricsAssertion[M, V] + class Builder(private var state: BuilderState) { /** Test ID for input and output wiring. */ @@ -165,8 +175,31 @@ object JobTest { * @param assertion assertion for the counter result's committed value */ def counter(counter: beam.Counter)(assertion: Long => Any): Builder = { - require(!state.counters.contains(counter), "Duplicate test counter: " + counter.getName) - state = state.copy(counters = state.counters + (counter -> assertion)) + require( + !state.counters.exists { + case a: SingleMetricAssertion[beam.Counter, Long] => a.metric == counter + case _ => false + }, + "Duplicate test counter: " + counter.getName + ) + + state = state.copy( + counters = state.counters + + SingleMetricAssertion(counter, assertion) + ) + this + } + + /** + * Evaluate all [[org.apache.beam.sdk.metrics.Counter Counters]] in the pipeline being tested. + * @param assertion assertion on the collection of all job counters' committed values + */ + def counters(assertion: Map[beam.MetricName, Long] => Any): Builder = { + state = state.copy( + counters = state.counters + + AllMetricsAssertion[beam.Counter, Long](assertion) + ) + this } @@ -180,10 +213,33 @@ object JobTest { distribution: beam.Distribution )(assertion: beam.DistributionResult => Any): Builder = { require( - !state.distributions.contains(distribution), + !state.distributions.exists { + case a: SingleMetricAssertion[beam.Distribution, beam.DistributionResult] => + a.metric == distribution + case _ => false + }, "Duplicate test distribution: " + distribution.getName ) - state = state.copy(distributions = state.distributions + (distribution -> assertion)) + + state = state.copy( + distributions = state.distributions + + SingleMetricAssertion(distribution, assertion) + ) + this + } + + /** + * Evaluate all [[org.apache.beam.sdk.metrics.Distribution Distributions]] in the + * pipeline being tested. + * @param assertion assertion on the collection of all job distribution results' + * committed values + */ + def distributions(assertion: Map[beam.MetricName, beam.DistributionResult] => Any): Builder = { + state = state.copy( + distributions = state.distributions + + AllMetricsAssertion[beam.Distribution, beam.DistributionResult](assertion) + ) + this } @@ -193,8 +249,29 @@ object JobTest { * @param assertion assertion for the gauge result's committed value */ def gauge(gauge: beam.Gauge)(assertion: beam.GaugeResult => Any): Builder = { - require(!state.gauges.contains(gauge), "Duplicate test gauge: " + gauge.getName) - state = state.copy(gauges = state.gauges + (gauge -> assertion)) + require( + !state.gauges.exists { + case a: SingleMetricAssertion[beam.Gauge, beam.GaugeResult] => + a.metric == gauge + case _ => false + }, + "Duplicate test gauge: " + gauge.getName + ) + + state = state.copy(gauges = state.gauges + SingleMetricAssertion(gauge, assertion)) + this + } + + /** + * Evaluate all [[org.apache.beam.sdk.metrics.Gauge Gauges]] in the pipeline being tested. + * @param assertion assertion on the collection of all job gauge results' committed values + */ + def gauges(assertion: Map[beam.MetricName, beam.GaugeResult] => Any): Builder = { + state = state.copy( + gauges = state.gauges + + AllMetricsAssertion[beam.Gauge, beam.GaugeResult](assertion) + ) + this } @@ -217,12 +294,29 @@ object JobTest { def tearDown(): Unit = { val metricsFn = (result: ScioResult) => { state.counters.foreach { - case (k, v) => v(result.counter(k).committed.get) + case a: SingleMetricAssertion[beam.Counter, Long] => + a.assert(result.counter(a.metric).committed.get) + case a: AllMetricsAssertion[beam.Counter, Long] => + a.assert(result.allCounters.map { c => + c._1 -> c._2.committed.get + }) + } + state.gauges.foreach { + case a: SingleMetricAssertion[beam.Gauge, beam.GaugeResult] => + a.assert(result.gauge(a.metric).committed.get) + case a: AllMetricsAssertion[beam.Gauge, beam.GaugeResult] => + a.assert(result.allGauges.map { c => + c._1 -> c._2.committed.get + }) } state.distributions.foreach { - case (k, v) => v(result.distribution(k).committed.get) + case a: SingleMetricAssertion[beam.Distribution, beam.DistributionResult] => + a.assert(result.distribution(a.metric).committed.get) + case a: AllMetricsAssertion[beam.Distribution, beam.DistributionResult] => + a.assert(result.allDistributions.map { c => + c._1 -> c._2.committed.get + }) } - state.gauges.foreach { case (k, v) => v(result.gauge(k).committed.get) } } TestDataManager.tearDown(testId, metricsFn) } diff --git a/scio-test/src/test/scala/com/spotify/scio/testing/JobTestTest.scala b/scio-test/src/test/scala/com/spotify/scio/testing/JobTestTest.scala index ae5652339e..bab5b7878d 100644 --- a/scio-test/src/test/scala/com/spotify/scio/testing/JobTestTest.scala +++ b/scio-test/src/test/scala/com/spotify/scio/testing/JobTestTest.scala @@ -30,15 +30,16 @@ import com.spotify.scio.io._ import com.spotify.scio.util.MockedPrintStream import org.apache.avro.generic.GenericRecord import org.apache.beam.sdk.Pipeline.PipelineExecutionException +import org.apache.beam.sdk.io.FileIO +import org.apache.beam.sdk.io.FileIO.ReadMatches.DirectoryTreatment +import org.apache.beam.sdk.metrics.DistributionResult +import org.apache.beam.sdk.transforms.PTransform +import org.apache.beam.sdk.values.PCollection import org.apache.beam.sdk.{io => beam} import org.joda.time.Instant import org.scalatest.exceptions.TestFailedException import scala.io.Source -import org.apache.beam.sdk.transforms.PTransform -import org.apache.beam.sdk.values.PCollection -import org.apache.beam.sdk.io.FileIO -import org.apache.beam.sdk.io.FileIO.ReadMatches.DirectoryTreatment // scalastyle:off file.size.limit @@ -1089,6 +1090,7 @@ class JobTestTest extends PipelineSpec { .counter(MetricsJob.counter) { x => x shouldBe 10 } + .counters(_ should contain(MetricsJob.counter.getName -> 10)) .distribution(MetricsJob.distribution) { d => d.getCount shouldBe 10 d.getMin shouldBe 1 @@ -1096,10 +1098,21 @@ class JobTestTest extends PipelineSpec { d.getSum shouldBe 55 d.getMean shouldBe 5.5 } + .distributions( + _ should contain( + MetricsJob.distribution.getName -> + DistributionResult.create(55, 10, 1, 10) + ) + ) .gauge(MetricsJob.gauge) { g => g.getValue should be >= 1L g.getValue should be <= 10L } + .gauges(_.map { + case (_, result) => + result.getValue should be >= 1L + result.getValue should be <= 10L + }) .run() }