From 71066530b3d1e50f7e7aa3a9de4bbf0afd80c92f Mon Sep 17 00:00:00 2001 From: Nikhil Goyal Date: Tue, 21 Sep 2021 15:25:46 -0700 Subject: [PATCH 1/3] Support `sortedTake` in beam runner We flatten all the input PriorityQueues and construct a new PQ using the monoid provided. TESTS: Updated failing unit test --- .../scalding/beam_backend/BeamOp.scala | 21 +++++++++++++++ .../beam_backend/BeamBackendTests.scala | 26 +++++++------------ 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala b/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala index 207cd560d..530f4a94f 100644 --- a/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala +++ b/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala @@ -1,12 +1,14 @@ package com.twitter.scalding.beam_backend import com.twitter.algebird.Semigroup +import com.twitter.algebird.mutable.PriorityQueueMonoid import com.twitter.scalding.Config import com.twitter.scalding.beam_backend.BeamFunctions._ import com.twitter.scalding.typed.functions.ComposedFunctions.ComposedMapGroup import com.twitter.scalding.typed.functions.{EmptyGuard, MapValueStream, SumAll} import com.twitter.scalding.typed.{CoGrouped, TypedSource} import java.lang +import java.util.PriorityQueue import org.apache.beam.sdk.Pipeline import org.apache.beam.sdk.coders.{Coder, IterableCoder, KvCoder} import org.apache.beam.sdk.transforms.DoFn.ProcessElement @@ -59,6 +61,25 @@ object BeamOp extends Serializable { )(implicit ordK: Ordering[K], kryoCoder: KryoCoder): PCollection[KV[K, java.lang.Iterable[U]]] = { reduceFn match { case ComposedMapGroup(f, g) => planMapGroup(planMapGroup(pcoll, f), g) + case EmptyGuard(MapValueStream(SumAll(pqm: PriorityQueueMonoid[V]))) => + pcoll.apply(MapElements.via( + new SimpleFunction[KV[K, java.lang.Iterable[V]], KV[K, java.lang.Iterable[U]]]() { + override def apply(input: KV[K, lang.Iterable[V]]): KV[K, java.lang.Iterable[U]] = { + // We are not using plus method defined in PriorityQueueMonoid as it is mutating + // input Priority Queues. We create a new PQ from the individual ones. + // We didn't use Top PTransformation in beam as it is not needed, also + // we cannot access `max` defined in PQ monoid. + val flattenedValues = input.getValue.asScala.flatMap { value => + value.asInstanceOf[PriorityQueue[V]].iterator().asScala + } + val mergedPQ = pqm.build(flattenedValues) + KV.of(input.getKey, Iterable(mergedPQ.asInstanceOf[U]).asJava) + } + }) + ).setCoder(KvCoder.of( + OrderedSerializationCoder(ordK, kryoCoder), + IterableCoder.of(kryoCoder)) + ) case EmptyGuard(MapValueStream(sa: SumAll[V])) => pcoll .apply(Combine.groupedValues( diff --git a/scalding-beam/src/test/scala/com/twitter/scalding/beam_backend/BeamBackendTests.scala b/scalding-beam/src/test/scala/com/twitter/scalding/beam_backend/BeamBackendTests.scala index 5095c0bfd..0cc18da2a 100644 --- a/scalding-beam/src/test/scala/com/twitter/scalding/beam_backend/BeamBackendTests.scala +++ b/scalding-beam/src/test/scala/com/twitter/scalding/beam_backend/BeamBackendTests.scala @@ -7,7 +7,6 @@ import java.nio.file.Paths import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory} import org.scalatest.{BeforeAndAfter, FunSuite} import scala.io.Source -import scala.util.Try class BeamBackendTests extends FunSuite with BeforeAndAfter { @@ -102,23 +101,16 @@ class BeamBackendTests extends FunSuite with BeforeAndAfter { ) } - test("priorityQueue operations"){ - /** - * @note we are not extending support for `sortedTake` and `sortedReverseTake`, since both of them uses - * [[com.twitter.algebird.mutable.PriorityQueueMonoid.plus]] which mutates input value in pipeline - * and Beam does not allow mutations to input during transformation - */ - val test = Try { - beamMatchesSeq( - TypedPipe - .from(Seq(5, 3, 2, 0, 1, 4)) - .map(x => x.toDouble) - .groupAll - .sortedReverseTake(3), - Seq(5, 4, 3) + test("sortedTake"){ + beamMatchesSeq( + TypedPipe + .from(Seq(5, 3, 2, 0, 1, 4)) + .map(x => x.toDouble) + .groupAll + .sortedReverseTake(3) + .flatMap(_._2), + Seq(5.0, 4.0, 3.0) ) - } - assert(test.isFailure) } test("SumByLocalKeys"){ From 86831193a7f0c29798431e0b2f2870d71dd3041b Mon Sep 17 00:00:00 2001 From: Nikhil Goyal Date: Wed, 22 Sep 2021 13:18:45 -0700 Subject: [PATCH 2/3] Added extension of PriorityQueueMonoid Added ScaldingPriorityQueueMonoid which exposes count which we later use in TopCombineFn. Added unit test for bufferedTake Disabled map side aggregation when using ScaldingPriorityQueueMonoid --- .../scalding/beam_backend/BeamBackend.scala | 10 +++- .../scalding/beam_backend/BeamOp.scala | 46 ++++++++++++------- .../beam_backend/BeamBackendTests.scala | 14 ++++++ .../twitter/scalding/ReduceOperations.scala | 4 +- .../com/twitter/scalding/typed/Grouped.scala | 5 +- .../twitter/scalding/typed/KeyedList.scala | 5 +- .../ScaldingPriorityQueueMonoid.scala | 7 +++ 7 files changed, 65 insertions(+), 26 deletions(-) create mode 100644 scalding-core/src/main/scala/com/twitter/scalding/typed/functions/ScaldingPriorityQueueMonoid.scala diff --git a/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamBackend.scala b/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamBackend.scala index cb6d35936..0ff91be69 100644 --- a/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamBackend.scala +++ b/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamBackend.scala @@ -10,7 +10,8 @@ import com.twitter.scalding.typed._ import com.twitter.scalding.typed.functions.{ FilterKeysToFilter, FlatMapValuesToFlatMap, - MapValuesToMap + MapValuesToMap, + ScaldingPriorityQueueMonoid } object BeamPlanner { @@ -65,7 +66,12 @@ object BeamPlanner { config.getMapSideAggregationThreshold match { case None => op case Some(count) => - op.mapSideAggregator(count, sg) + // Semigroup is invariant on T. We cannot pattern match as it is a Semigroup[PriorityQueue[T]] + if (sg.isInstanceOf[ScaldingPriorityQueueMonoid[_]]) { + op + } else { + op.mapSideAggregator(count, sg) + } } case (ReduceStepPipe(ir @ IdentityReduce(_, _, _, _, _)), rec) => def go[K, V1, V2](ir: IdentityReduce[K, V1, V2]): BeamOp[(K, V2)] = { diff --git a/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala b/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala index 530f4a94f..75f932adb 100644 --- a/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala +++ b/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala @@ -1,17 +1,22 @@ package com.twitter.scalding.beam_backend import com.twitter.algebird.Semigroup -import com.twitter.algebird.mutable.PriorityQueueMonoid import com.twitter.scalding.Config import com.twitter.scalding.beam_backend.BeamFunctions._ import com.twitter.scalding.typed.functions.ComposedFunctions.ComposedMapGroup -import com.twitter.scalding.typed.functions.{EmptyGuard, MapValueStream, SumAll} +import com.twitter.scalding.typed.functions.{ + EmptyGuard, + MapValueStream, + ScaldingPriorityQueueMonoid, + SumAll +} import com.twitter.scalding.typed.{CoGrouped, TypedSource} import java.lang -import java.util.PriorityQueue +import java.util.{Comparator, PriorityQueue} import org.apache.beam.sdk.Pipeline import org.apache.beam.sdk.coders.{Coder, IterableCoder, KvCoder} import org.apache.beam.sdk.transforms.DoFn.ProcessElement +import org.apache.beam.sdk.transforms.Top.TopCombineFn import org.apache.beam.sdk.transforms._ import org.apache.beam.sdk.transforms.join.{ CoGbkResult, @@ -52,6 +57,10 @@ sealed abstract class BeamOp[+A] { parDo(FlatMapFn(f)) } +private final case class SerializableComparator[T](comp: Comparator[T]) extends Comparator[T] { + override def compare(o1: T, o2: T): Int = comp.compare(o1, o2) +} + object BeamOp extends Serializable { implicit private def fakeClassTag[A]: ClassTag[A] = ClassTag(classOf[AnyRef]).asInstanceOf[ClassTag[A]] @@ -61,19 +70,24 @@ object BeamOp extends Serializable { )(implicit ordK: Ordering[K], kryoCoder: KryoCoder): PCollection[KV[K, java.lang.Iterable[U]]] = { reduceFn match { case ComposedMapGroup(f, g) => planMapGroup(planMapGroup(pcoll, f), g) - case EmptyGuard(MapValueStream(SumAll(pqm: PriorityQueueMonoid[V]))) => - pcoll.apply(MapElements.via( - new SimpleFunction[KV[K, java.lang.Iterable[V]], KV[K, java.lang.Iterable[U]]]() { - override def apply(input: KV[K, lang.Iterable[V]]): KV[K, java.lang.Iterable[U]] = { - // We are not using plus method defined in PriorityQueueMonoid as it is mutating - // input Priority Queues. We create a new PQ from the individual ones. - // We didn't use Top PTransformation in beam as it is not needed, also - // we cannot access `max` defined in PQ monoid. - val flattenedValues = input.getValue.asScala.flatMap { value => - value.asInstanceOf[PriorityQueue[V]].iterator().asScala - } - val mergedPQ = pqm.build(flattenedValues) - KV.of(input.getKey, Iterable(mergedPQ.asInstanceOf[U]).asJava) + case EmptyGuard(MapValueStream(SumAll(pqm: ScaldingPriorityQueueMonoid[v]))) => + val vCollection = pcoll.asInstanceOf[PCollection[KV[K, java.lang.Iterable[PriorityQueue[v]]]]] + + vCollection.apply(MapElements.via( + new SimpleFunction[KV[K, java.lang.Iterable[PriorityQueue[v]]], KV[K, java.lang.Iterable[U]]]() { + override def apply(input: KV[K, lang.Iterable[PriorityQueue[v]]]): KV[K, java.lang.Iterable[U]] = { + + val topCombineFn = new TopCombineFn[v, SerializableComparator[v]]( + pqm.count, + SerializableComparator[v](pqm.ordering.reverse) + ) + + @inline def flattenedValues: Stream[v] = + input.getValue.asScala.toStream.flatMap(_.asScala.toStream) + + val outputs: java.util.List[v] = topCombineFn.apply(flattenedValues.asJava) + val pqs = pqm.build(outputs.asScala) + KV.of(input.getKey, Iterable(pqs.asInstanceOf[U]).asJava) } }) ).setCoder(KvCoder.of( diff --git a/scalding-beam/src/test/scala/com/twitter/scalding/beam_backend/BeamBackendTests.scala b/scalding-beam/src/test/scala/com/twitter/scalding/beam_backend/BeamBackendTests.scala index 0cc18da2a..515aafcc6 100644 --- a/scalding-beam/src/test/scala/com/twitter/scalding/beam_backend/BeamBackendTests.scala +++ b/scalding-beam/src/test/scala/com/twitter/scalding/beam_backend/BeamBackendTests.scala @@ -1,9 +1,11 @@ package com.twitter.scalding.beam_backend +import com.twitter.algebird.mutable.PriorityQueueMonoid import com.twitter.algebird.{AveragedValue, Semigroup} import com.twitter.scalding.{Config, TextLine, TypedPipe} import java.io.File import java.nio.file.Paths +import java.util.PriorityQueue import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory} import org.scalatest.{BeforeAndAfter, FunSuite} import scala.io.Source @@ -113,6 +115,18 @@ class BeamBackendTests extends FunSuite with BeforeAndAfter { ) } + test("bufferedTake"){ + beamMatchesSeq( + TypedPipe + .from(1 to 50) + .groupAll + .bufferedTake(100) + .map(_._2), + 1 to 50, + Config(Map("cascading.aggregateby.threshold" -> "100")) + ) + } + test("SumByLocalKeys"){ beamMatchesSeq( TypedPipe diff --git a/scalding-core/src/main/scala/com/twitter/scalding/ReduceOperations.scala b/scalding-core/src/main/scala/com/twitter/scalding/ReduceOperations.scala index 7645cac59..1429098df 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/ReduceOperations.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/ReduceOperations.scala @@ -28,7 +28,7 @@ import com.twitter.algebird.{ Aggregator } -import com.twitter.algebird.mutable.PriorityQueueMonoid +import com.twitter.scalding.typed.functions.ScaldingPriorityQueueMonoid import java.util.PriorityQueue @@ -391,7 +391,7 @@ trait ReduceOperations[+Self <: ReduceOperations[Self]] extends java.io.Serializ def sortedTake[T](f: (Fields, Fields), k: Int)(implicit conv: TupleConverter[T], ord: Ordering[T]): Self = { assert(f._2.size == 1, "output field size must be 1") - implicit val mon: PriorityQueueMonoid[T] = new PriorityQueueMonoid[T](k) + implicit val mon: ScaldingPriorityQueueMonoid[T] = new ScaldingPriorityQueueMonoid[T](k) mapPlusMap(f) { (tup: T) => mon.build(tup) } { (lout: PriorityQueue[T]) => lout.iterator.asScala.toList.sorted } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/Grouped.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/Grouped.scala index d8a8eb094..d263b5d5d 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/Grouped.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/Grouped.scala @@ -16,7 +16,6 @@ limitations under the License. package com.twitter.scalding.typed import com.twitter.algebird.Semigroup -import com.twitter.algebird.mutable.PriorityQueueMonoid import com.twitter.scalding.typed.functions._ import com.twitter.scalding.typed.functions.ComposedFunctions.ComposedMapGroup import scala.collection.JavaConverters._ @@ -659,7 +658,7 @@ final case class UnsortedIdentityReduce[K, V1, V2]( // If you care which items you take, you should sort by a random number // or the value itself. val fakeOrdering: Ordering[V1] = Ordering.by { v: V1 => v.hashCode } - implicit val mon: PriorityQueueMonoid[V1] = new PriorityQueueMonoid[V1](n)(fakeOrdering) + implicit val mon: ScaldingPriorityQueueMonoid[V1] = new ScaldingPriorityQueueMonoid[V1](n)(fakeOrdering) // Do the heap-sort on the mappers: val pretake: TypedPipe[(K, V1)] = mapped.mapValues { v: V1 => mon.build(v) } .sumByLocalKeys @@ -745,7 +744,7 @@ final case class IdentityValueSortedReduce[K, V1, V2]( // This means don't take anything, which is legal, but strange filterKeys(Constant(false)) } else { - implicit val mon: PriorityQueueMonoid[V1] = new PriorityQueueMonoid[V1](n)(valueSort) + implicit val mon: ScaldingPriorityQueueMonoid[V1] = new ScaldingPriorityQueueMonoid[V1](n)(valueSort) // Do the heap-sort on the mappers: val pretake: TypedPipe[(K, V1)] = mapped.mapValues { v: V1 => mon.build(v) } .sumByLocalKeys diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/KeyedList.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/KeyedList.scala index cf9f35775..0973e3406 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/KeyedList.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/KeyedList.scala @@ -19,7 +19,6 @@ import java.io.Serializable import scala.collection.JavaConverters._ import com.twitter.algebird.{ Fold, Semigroup, Ring, Aggregator } -import com.twitter.algebird.mutable.PriorityQueueMonoid import com.twitter.scalding.typed.functions._ @@ -79,7 +78,7 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] extends Se // If you care which items you take, you should sort by a random number // or the value itself. val fakeOrdering: Ordering[T] = Ordering.by { v: T => v.hashCode } - implicit val mon = new PriorityQueueMonoid(n)(fakeOrdering) + implicit val mon = new ScaldingPriorityQueueMonoid(n)(fakeOrdering) mapValues(mon.build(_)) // Do the heap-sort on the mappers: .sum @@ -213,7 +212,7 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] extends Se * to fit in memory. */ def sortedTake[U >: T](k: Int)(implicit ord: Ordering[U]): This[K, Seq[U]] = { - val mon = new PriorityQueueMonoid[U](k)(ord) + val mon = new ScaldingPriorityQueueMonoid[U](k)(ord) mapValues(mon.build(_)) .sum(mon) // results in a PriorityQueue // scala can't infer the type, possibly due to the view bound on TypedPipe diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/ScaldingPriorityQueueMonoid.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/ScaldingPriorityQueueMonoid.scala new file mode 100644 index 000000000..51077806b --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/ScaldingPriorityQueueMonoid.scala @@ -0,0 +1,7 @@ +package com.twitter.scalding.typed.functions + +import com.twitter.algebird.mutable.PriorityQueueMonoid + +class ScaldingPriorityQueueMonoid[K]( + val count: Int +)(implicit val ordering: Ordering[K]) extends PriorityQueueMonoid[K](count)(ordering) From 49af384bd38ec02156cbabd7c65174041d52e37f Mon Sep 17 00:00:00 2001 From: Nikhil Goyal Date: Thu, 23 Sep 2021 19:22:21 -0700 Subject: [PATCH 3/3] Use externalizer in SerializableComparator --- .../twitter/scalding/beam_backend/BeamOp.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala b/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala index 75f932adb..bc27d524a 100644 --- a/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala +++ b/scalding-beam/src/main/scala/com/twitter/scalding/beam_backend/BeamOp.scala @@ -3,6 +3,7 @@ package com.twitter.scalding.beam_backend import com.twitter.algebird.Semigroup import com.twitter.scalding.Config import com.twitter.scalding.beam_backend.BeamFunctions._ +import com.twitter.scalding.serialization.Externalizer import com.twitter.scalding.typed.functions.ComposedFunctions.ComposedMapGroup import com.twitter.scalding.typed.functions.{ EmptyGuard, @@ -58,7 +59,8 @@ sealed abstract class BeamOp[+A] { } private final case class SerializableComparator[T](comp: Comparator[T]) extends Comparator[T] { - override def compare(o1: T, o2: T): Int = comp.compare(o1, o2) + private[this] val extCmp = Externalizer(comp) + override def compare(o1: T, o2: T): Int = extCmp.get.compare(o1, o2) } object BeamOp extends Serializable { @@ -75,17 +77,17 @@ object BeamOp extends Serializable { vCollection.apply(MapElements.via( new SimpleFunction[KV[K, java.lang.Iterable[PriorityQueue[v]]], KV[K, java.lang.Iterable[U]]]() { - override def apply(input: KV[K, lang.Iterable[PriorityQueue[v]]]): KV[K, java.lang.Iterable[U]] = { - - val topCombineFn = new TopCombineFn[v, SerializableComparator[v]]( - pqm.count, - SerializableComparator[v](pqm.ordering.reverse) - ) + private final val topCombineFn = new TopCombineFn[v, SerializableComparator[v]]( + pqm.count, + SerializableComparator[v](pqm.ordering.reverse) + ) + override def apply(input: KV[K, lang.Iterable[PriorityQueue[v]]]): KV[K, java.lang.Iterable[U]] = { @inline def flattenedValues: Stream[v] = input.getValue.asScala.toStream.flatMap(_.asScala.toStream) val outputs: java.util.List[v] = topCombineFn.apply(flattenedValues.asJava) + // We are building the PriorityQueue back as output type U is PriorityQueue[v] val pqs = pqm.build(outputs.asScala) KV.of(input.getKey, Iterable(pqs.asInstanceOf[U]).asJava) }