Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sortedTake in beam runner #1949

Merged
merged 3 commits into from
Sep 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import com.twitter.scalding.typed._
import com.twitter.scalding.typed.functions.{
FilterKeysToFilter,
FlatMapValuesToFlatMap,
MapValuesToMap
MapValuesToMap,
ScaldingPriorityQueueMonoid
}

object BeamPlanner {
Expand Down Expand Up @@ -65,7 +66,12 @@ object BeamPlanner {
config.getMapSideAggregationThreshold match {
case None => op
case Some(count) =>
op.mapSideAggregator(count, sg)
// Semigroup is invariant on T. We cannot pattern match as it is a Semigroup[PriorityQueue[T]]
if (sg.isInstanceOf[ScaldingPriorityQueueMonoid[_]]) {
op
} else {
op.mapSideAggregator(count, sg)
}
}
case (ReduceStepPipe(ir @ IdentityReduce(_, _, _, _, _)), rec) =>
def go[K, V1, V2](ir: IdentityReduce[K, V1, V2]): BeamOp[(K, V2)] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,21 @@ package com.twitter.scalding.beam_backend
import com.twitter.algebird.Semigroup
import com.twitter.scalding.Config
import com.twitter.scalding.beam_backend.BeamFunctions._
import com.twitter.scalding.serialization.Externalizer
import com.twitter.scalding.typed.functions.ComposedFunctions.ComposedMapGroup
import com.twitter.scalding.typed.functions.{EmptyGuard, MapValueStream, SumAll}
import com.twitter.scalding.typed.functions.{
EmptyGuard,
MapValueStream,
ScaldingPriorityQueueMonoid,
SumAll
}
import com.twitter.scalding.typed.{CoGrouped, TypedSource}
import java.lang
import java.util.{Comparator, PriorityQueue}
import org.apache.beam.sdk.Pipeline
import org.apache.beam.sdk.coders.{Coder, IterableCoder, KvCoder}
import org.apache.beam.sdk.transforms.DoFn.ProcessElement
import org.apache.beam.sdk.transforms.Top.TopCombineFn
import org.apache.beam.sdk.transforms._
import org.apache.beam.sdk.transforms.join.{
CoGbkResult,
Expand Down Expand Up @@ -50,6 +58,11 @@ sealed abstract class BeamOp[+A] {
parDo(FlatMapFn(f))
}

private final case class SerializableComparator[T](comp: Comparator[T]) extends Comparator[T] {
nownikhil marked this conversation as resolved.
Show resolved Hide resolved
private[this] val extCmp = Externalizer(comp)
override def compare(o1: T, o2: T): Int = extCmp.get.compare(o1, o2)
}

object BeamOp extends Serializable {
implicit private def fakeClassTag[A]: ClassTag[A] = ClassTag(classOf[AnyRef]).asInstanceOf[ClassTag[A]]

Expand All @@ -59,6 +72,30 @@ object BeamOp extends Serializable {
)(implicit ordK: Ordering[K], kryoCoder: KryoCoder): PCollection[KV[K, java.lang.Iterable[U]]] = {
reduceFn match {
case ComposedMapGroup(f, g) => planMapGroup(planMapGroup(pcoll, f), g)
case EmptyGuard(MapValueStream(SumAll(pqm: ScaldingPriorityQueueMonoid[v]))) =>
val vCollection = pcoll.asInstanceOf[PCollection[KV[K, java.lang.Iterable[PriorityQueue[v]]]]]

vCollection.apply(MapElements.via(
new SimpleFunction[KV[K, java.lang.Iterable[PriorityQueue[v]]], KV[K, java.lang.Iterable[U]]]() {
private final val topCombineFn = new TopCombineFn[v, SerializableComparator[v]](
pqm.count,
SerializableComparator[v](pqm.ordering.reverse)
)

override def apply(input: KV[K, lang.Iterable[PriorityQueue[v]]]): KV[K, java.lang.Iterable[U]] = {
@inline def flattenedValues: Stream[v] =
input.getValue.asScala.toStream.flatMap(_.asScala.toStream)

val outputs: java.util.List[v] = topCombineFn.apply(flattenedValues.asJava)
// We are building the PriorityQueue back as output type U is PriorityQueue[v]
val pqs = pqm.build(outputs.asScala)
nownikhil marked this conversation as resolved.
Show resolved Hide resolved
KV.of(input.getKey, Iterable(pqs.asInstanceOf[U]).asJava)
}
})
).setCoder(KvCoder.of(
OrderedSerializationCoder(ordK, kryoCoder),
IterableCoder.of(kryoCoder))
)
case EmptyGuard(MapValueStream(sa: SumAll[V])) =>
pcoll
.apply(Combine.groupedValues(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
package com.twitter.scalding.beam_backend

import com.twitter.algebird.mutable.PriorityQueueMonoid
import com.twitter.algebird.{AveragedValue, Semigroup}
import com.twitter.scalding.{Config, TextLine, TypedPipe}
import java.io.File
import java.nio.file.Paths
import java.util.PriorityQueue
import org.apache.beam.sdk.options.{PipelineOptions, PipelineOptionsFactory}
import org.scalatest.{BeforeAndAfter, FunSuite}
import scala.io.Source
import scala.util.Try

class BeamBackendTests extends FunSuite with BeforeAndAfter {

Expand Down Expand Up @@ -102,23 +103,28 @@ class BeamBackendTests extends FunSuite with BeforeAndAfter {
)
}

test("priorityQueue operations"){
/**
* @note we are not extending support for `sortedTake` and `sortedReverseTake`, since both of them uses
* [[com.twitter.algebird.mutable.PriorityQueueMonoid.plus]] which mutates input value in pipeline
* and Beam does not allow mutations to input during transformation
*/
val test = Try {
beamMatchesSeq(
TypedPipe
.from(Seq(5, 3, 2, 0, 1, 4))
.map(x => x.toDouble)
.groupAll
.sortedReverseTake(3),
Seq(5, 4, 3)
test("sortedTake"){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that bufferedTake also uses the problematic monoid internally. Worth testing:

implicit val mon: PriorityQueueMonoid[V1] = new PriorityQueueMonoid[V1](n)(fakeOrdering)

beamMatchesSeq(
TypedPipe
.from(Seq(5, 3, 2, 0, 1, 4))
.map(x => x.toDouble)
.groupAll
.sortedReverseTake(3)
.flatMap(_._2),
Seq(5.0, 4.0, 3.0)
)
}
assert(test.isFailure)
}

test("bufferedTake"){
beamMatchesSeq(
TypedPipe
.from(1 to 50)
.groupAll
.bufferedTake(100)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a note: this is going to be really bad in a real job without map-side aggregation: the key is Unit so there is only one key, so this would have each mapper send 100, then have the reducers pick 100 of those.

But with no mapside aggregation, all the data will be sent to the reducers, and they will throw away all but 100.

But we can add an issue and come back and address this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened a ticket for this.
#1952

.map(_._2),
1 to 50,
Config(Map("cascading.aggregateby.threshold" -> "100"))
)
}

test("SumByLocalKeys"){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import com.twitter.algebird.{
Aggregator
}

import com.twitter.algebird.mutable.PriorityQueueMonoid
import com.twitter.scalding.typed.functions.ScaldingPriorityQueueMonoid

import java.util.PriorityQueue

Expand Down Expand Up @@ -391,7 +391,7 @@ trait ReduceOperations[+Self <: ReduceOperations[Self]] extends java.io.Serializ
def sortedTake[T](f: (Fields, Fields), k: Int)(implicit conv: TupleConverter[T], ord: Ordering[T]): Self = {

assert(f._2.size == 1, "output field size must be 1")
implicit val mon: PriorityQueueMonoid[T] = new PriorityQueueMonoid[T](k)
implicit val mon: ScaldingPriorityQueueMonoid[T] = new ScaldingPriorityQueueMonoid[T](k)
mapPlusMap(f) { (tup: T) => mon.build(tup) } {
(lout: PriorityQueue[T]) => lout.iterator.asScala.toList.sorted
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License.
package com.twitter.scalding.typed

import com.twitter.algebird.Semigroup
import com.twitter.algebird.mutable.PriorityQueueMonoid
import com.twitter.scalding.typed.functions._
import com.twitter.scalding.typed.functions.ComposedFunctions.ComposedMapGroup
import scala.collection.JavaConverters._
Expand Down Expand Up @@ -659,7 +658,7 @@ final case class UnsortedIdentityReduce[K, V1, V2](
// If you care which items you take, you should sort by a random number
// or the value itself.
val fakeOrdering: Ordering[V1] = Ordering.by { v: V1 => v.hashCode }
implicit val mon: PriorityQueueMonoid[V1] = new PriorityQueueMonoid[V1](n)(fakeOrdering)
implicit val mon: ScaldingPriorityQueueMonoid[V1] = new ScaldingPriorityQueueMonoid[V1](n)(fakeOrdering)
// Do the heap-sort on the mappers:
val pretake: TypedPipe[(K, V1)] = mapped.mapValues { v: V1 => mon.build(v) }
.sumByLocalKeys
Expand Down Expand Up @@ -745,7 +744,7 @@ final case class IdentityValueSortedReduce[K, V1, V2](
// This means don't take anything, which is legal, but strange
filterKeys(Constant(false))
} else {
implicit val mon: PriorityQueueMonoid[V1] = new PriorityQueueMonoid[V1](n)(valueSort)
implicit val mon: ScaldingPriorityQueueMonoid[V1] = new ScaldingPriorityQueueMonoid[V1](n)(valueSort)
// Do the heap-sort on the mappers:
val pretake: TypedPipe[(K, V1)] = mapped.mapValues { v: V1 => mon.build(v) }
.sumByLocalKeys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import java.io.Serializable
import scala.collection.JavaConverters._

import com.twitter.algebird.{ Fold, Semigroup, Ring, Aggregator }
import com.twitter.algebird.mutable.PriorityQueueMonoid

import com.twitter.scalding.typed.functions._

Expand Down Expand Up @@ -79,7 +78,7 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] extends Se
// If you care which items you take, you should sort by a random number
// or the value itself.
val fakeOrdering: Ordering[T] = Ordering.by { v: T => v.hashCode }
implicit val mon = new PriorityQueueMonoid(n)(fakeOrdering)
implicit val mon = new ScaldingPriorityQueueMonoid(n)(fakeOrdering)
mapValues(mon.build(_))
// Do the heap-sort on the mappers:
.sum
Expand Down Expand Up @@ -213,7 +212,7 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] extends Se
* to fit in memory.
*/
def sortedTake[U >: T](k: Int)(implicit ord: Ordering[U]): This[K, Seq[U]] = {
val mon = new PriorityQueueMonoid[U](k)(ord)
val mon = new ScaldingPriorityQueueMonoid[U](k)(ord)
mapValues(mon.build(_))
.sum(mon) // results in a PriorityQueue
// scala can't infer the type, possibly due to the view bound on TypedPipe
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.twitter.scalding.typed.functions

import com.twitter.algebird.mutable.PriorityQueueMonoid

class ScaldingPriorityQueueMonoid[K](
val count: Int
)(implicit val ordering: Ordering[K]) extends PriorityQueueMonoid[K](count)(ordering)