diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/CoGrouped.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/CoGrouped.scala index 37473c06f0..5a7191f538 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/CoGrouped.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/CoGrouped.scala @@ -74,18 +74,31 @@ trait CoGroupable[K, +R] extends HasReducers with java.io.Serializable { def cogroup[R1, R2](smaller: CoGroupable[K, R1])(fn: (K, Iterator[R], Iterable[R1]) => Iterator[R2]): CoGrouped[K, R2] = { val self = this val leftSeqCount = self.inputs.size - 1 + val jf = joinFunction // avoid capturing `this` in the closure below + val smallerJf = smaller.joinFunction new CoGrouped[K, R2] { val inputs = self.inputs ++ smaller.inputs val reducers = (self.reducers.toIterable ++ smaller.reducers.toIterable).reduceOption(_ max _) def keyOrdering = smaller.keyOrdering + /** + * Avoid capturing anything below as it will need to be serialized and sent to + * all the reducers. + */ def joinFunction = { (k: K, leftMost: Iterator[CTuple], joins: Seq[Iterable[CTuple]]) => - val joinedLeft = self.joinFunction(k, leftMost, joins.take(leftSeqCount)) - - val smallerIns = joins.drop(leftSeqCount) + val (leftSeq, rightSeq) = joins.splitAt(leftSeqCount) + val joinedLeft = jf(k, leftMost, leftSeq) + + // Only do this once, for all calls to iterator below + val smallerHead = rightSeq.head + val smallerTail = rightSeq.tail + // TODO: it might make sense to cache this in memory as an IndexedSeq and not + // recompute it on every value for the left if the smallerJf is non-trivial + // we could see how long it is, and possible switch to a cached version the + // second time through if it is small enough val joinedRight = new Iterable[R1] { - def iterator = smaller.joinFunction(k, smallerIns.head.iterator, smallerIns.tail) + def iterator = smallerJf(k, smallerHead.iterator, smallerTail) } fn(k, joinedLeft, joinedRight)