Skip to content

Commit

Permalink
Merge pull request #1816 from twitter/oscar/improve_memory_backend
Browse files Browse the repository at this point in the history
Improve the memory backend usability and testing
  • Loading branch information
johnynek committed Feb 20, 2018
2 parents cc25318 + a64332a commit fe12eb6
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ class AtomicBox[T <: AnyRef](init: T) {
def lazySet(t: T): Unit =
ref.lazySet(t)

def set(t: T): Unit =
ref.set(t)

def swap(t: T): T =
ref.getAndSet(t)

/**
* use a pure function to update the state.
* fn may be called more than once
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,104 @@ import scala.util.{Failure, Success, Try}

import Execution.{ ToWrite, Writer }

final class MemoryMode private (srcs: HMap[TypedSource, Iterable], sinks: HMap[TypedSink, ({type A[T]=AtomicBox[Option[Iterable[T]]]})#A]) extends Mode {
final class MemoryMode private (srcs: HMap[TypedSource, MemorySource], sinks: HMap[TypedSink, MemorySink]) extends Mode {

def newWriter(): Writer =
new MemoryWriter(this)

def addSource[T](src: TypedSource[T], ts: Iterable[T]): MemoryMode =
def addSource[T](src: TypedSource[T], ts: MemorySource[T]): MemoryMode =
new MemoryMode(srcs + (src -> ts), sinks)

def addSink[T](sink: TypedSink[T]): MemoryMode =
new MemoryMode(srcs, sinks + (sink -> new AtomicBox[Option[Iterable[T]]](None)))
def addSourceFn[T](src: TypedSource[T])(fn: ConcurrentExecutionContext => Future[Iterator[T]]): MemoryMode =
new MemoryMode(srcs + (src -> MemorySource.Fn(fn)), sinks)

def addSourceIterable[T](src: TypedSource[T], iter: Iterable[T]): MemoryMode =
new MemoryMode(srcs + (src -> MemorySource.FromIterable(iter)), sinks)

def addSink[T](sink: TypedSink[T], msink: MemorySink[T]): MemoryMode =
new MemoryMode(srcs, sinks + (sink -> msink))

/**
* This has a side effect of mutating this MemoryMode
*/
def writeSink[T](t: TypedSink[T], iter: Iterable[T]): Unit =
sinks(t).lazySet(Some(iter))

def readSink[T](t: TypedSink[T]): Option[Iterable[T]] =
sinks.get(t).flatMap(_.get)
def writeSink[T](t: TypedSink[T], iter: Iterable[T])(implicit ec: ConcurrentExecutionContext): Future[Unit] =
sinks.get(t) match {
case Some(sink) => sink.write(iter)
case None => Future.failed(new Exception(s"missing sink for $t, with first 10 values to write: ${iter.take(10).toList.toString}..."))
}

def readSource[T](t: TypedSource[T]): Option[Iterable[T]] =
srcs.get(t)
def readSource[T](t: TypedSource[T])(implicit ec: ConcurrentExecutionContext): Future[Iterator[T]] =
srcs.get(t) match {
case Some(src) => src.read()
case None => Future.failed(new Exception(s"Source: $t not wired. Please provide an input with MemoryMode.addSource"))
}
}

object MemoryMode {
def empty: MemoryMode = new MemoryMode(HMap.empty, HMap.empty[TypedSink, ({type A[T]=AtomicBox[Option[Iterable[T]]]})#A])
def empty: MemoryMode = new MemoryMode(HMap.empty, HMap.empty)
}

trait MemorySource[A] {
def read()(implicit ec: ConcurrentExecutionContext): Future[Iterator[A]]
}

object MemorySource {
case class FromIterable[A](iter: Iterable[A]) extends MemorySource[A] {
def read()(implicit ec: ConcurrentExecutionContext) = Future.successful(iter.iterator)
}
case class Fn[A](toFn: ConcurrentExecutionContext => Future[Iterator[A]]) extends MemorySource[A] {
def read()(implicit ec: ConcurrentExecutionContext) = toFn(ec)
}
}

trait MemorySink[A] {
def write(data: Iterable[A])(implicit ec: ConcurrentExecutionContext): Future[Unit]
}

object MemorySink {
/**
* This is a sink that writes into local memory which you can read out
* by a future
*
* this needs to be reset between each write (so it only works for a single
* write per Execution)
*/
class LocalVar[A] extends MemorySink[A] {
private[this] val box: AtomicBox[Promise[Iterable[A]]] = new AtomicBox(Promise[Iterable[A]]())

/**
* This is a future that completes when a write comes. If no write
* happens before a reset, the future fails
*/
def read(): Future[Iterable[A]] = box.get().future

/**
* This takes the current future and resets the promise
* making it safe for another write.
*/
def reset(): Option[Iterable[A]] = {
val current = box.swap(Promise[Iterable[A]]())
// if the promise is not set, it never will be, so
// go ahead and poll now
//
// also note we never set this future to failed
current.future.value match {
case Some(Success(res)) =>
Some(res)
case Some(Failure(err)) =>
throw new IllegalStateException("We should never reach this because, we only complete with failure below", err)
case None =>
// make sure we complete the original future so readers don't block forever
current.failure(new Exception(s"sink never written to before reset() called $this"))
None
}
}

def write(data: Iterable[A])(implicit ec: ConcurrentExecutionContext): Future[Unit] =
Future {
box.update { p => (p.success(data), ()) }
}
}
}

object MemoryPlanner {
Expand Down Expand Up @@ -73,13 +145,13 @@ object MemoryPlanner {
Op.Materialize(this)
}
object Op {
def source[I](i: Iterable[I]): Op[I] = Source(Try(i))
def source[I](i: Iterable[I]): Op[I] = Source(_ => Future.successful(i.iterator))
def empty[I]: Op[I] = source(Nil)

final case class Source[I](input: Try[Iterable[I]]) extends Op[I] {
final case class Source[I](input: ConcurrentExecutionContext => Future[Iterator[I]]) extends Op[I] {

def result(implicit cec: ConcurrentExecutionContext): Future[ArrayBuffer[I]] =
Future.fromTry(input).map(ArrayBuffer.concat(_))
input(cec).map(ArrayBuffer.empty[I] ++= _)
}

// Here we need to make a copy on each result
Expand Down Expand Up @@ -253,7 +325,7 @@ object MemoryPlanner {
}

object Memo {
def empty: Memo = Memo(HMap.empty)
val empty: Memo = Memo(HMap.empty)
}

}
Expand Down Expand Up @@ -390,12 +462,7 @@ class MemoryWriter(mem: MemoryMode) extends Writer {
(m2, Op.Concat(op1, op2))

case SourcePipe(src) =>
(m, Op.Source(
mem.readSource(src) match {
case Some(iter) => Success(iter)
case None => Failure(new Exception(s"Source: $src not wired. Please provide an input with MemoryMode.addSource"))
}
))
(m, Op.Source({ cec => mem.readSource(src)(cec) }))

case slk@SumByLocalKeys(_, _) =>
def sum[K, V](sblk: SumByLocalKeys[K, V]) = {
Expand Down Expand Up @@ -598,16 +665,14 @@ class MemoryWriter(mem: MemoryMode) extends Writer {
state.forced.get(opt) match {
case Some(iterf) =>
val action = () => {
iterf.foreach(mem.writeSink(sink, _))
iterf.map(_ => ())
iterf.flatMap(mem.writeSink(sink, _))
}
(state, action :: acts)
case None =>
val (nextM, op) = plan(state.memo, opt) // linter:disable:UndesirableTypeInference
val action = () => {
val arrayBufferF = op.result
arrayBufferF.foreach { mem.writeSink(sink, _) }
arrayBufferF.map(_ => ())
arrayBufferF.flatMap(mem.writeSink(sink, _))
}
(state.copy(memo = nextM), action :: acts)
}
Expand All @@ -632,11 +697,8 @@ class MemoryWriter(mem: MemoryMode) extends Writer {
case Some(f) => f.map(TypedPipe.from(_))
}

def getSource[A](src: TypedSource[A]): Future[Iterable[A]] =
mem.readSource(src) match {
case Some(iter) => Future.successful(iter)
case None => Future.failed(new Exception(s"Source: $src not connected"))
}
private def getSource[A](src: TypedSource[A])(implicit cec: ConcurrentExecutionContext): Future[Iterable[A]] =
mem.readSource(src).map(_.toList)

/**
* This should only be called after a call to execute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,24 @@ class MemoryTest extends FunSuite with PropertyChecks {
assert(mkv.get.toMap == lkv.get.toMap)
}

private def timeit[A](msg: String, a: => A): A = {
val start = System.nanoTime()
val res = a
val diff = System.nanoTime() - start
val ms = diff / 1e6
// uncomment this for some poor version of benchmarking,
// but scalding in-memory mode seems about 3-100x faster
//
// println(s"$msg: $ms ms")
res
}

private def sortMatch[A: Ordering](ex: Execution[Iterable[A]]) = {
val mm = MemoryMode.empty

val mkv = ex.waitFor(Config.empty, mm)
val mkv = timeit("scalding", ex.waitFor(Config.empty, mm))

val lkv = ex.waitFor(Config.empty, Local(true))
val lkv = timeit("cascading", ex.waitFor(Config.empty, Local(true)))
assert(mkv.get.toList.sorted == lkv.get.toList.sorted)
}

Expand Down Expand Up @@ -64,4 +76,34 @@ class MemoryTest extends FunSuite with PropertyChecks {
implicit val generatorDrivenConfig: PropertyCheckConfiguration = PropertyCheckConfiguration(minSuccessful = 50)
forAll(genWithIterableSources) { pipe => sortMatch(pipe.toIterableExecution) }
}

test("writing gives the same result as toIterableExecution") {
import TypedPipeGen.genWithIterableSources
// we can afford to test a lot more in just memory mode because it is faster than cascading
implicit val generatorDrivenConfig: PropertyCheckConfiguration = PropertyCheckConfiguration(minSuccessful = 500)
forAll(genWithIterableSources) { pipe =>
val sink = new MemorySink.LocalVar[Int]

val ex1 = pipe.writeExecution(SinkT("my_sink"))
val ex2 = pipe.toIterableExecution

val mm = MemoryMode.empty.addSink(SinkT("my_sink"), sink)
val res1 = ex1.waitFor(Config.empty, mm)
val res2 = ex2.waitFor(Config.empty, MemoryMode.empty)

assert(sink.reset().get.toList.sorted == res2.get.toList.sorted)

}
}

test("using sources work") {
val srctag = SourceT[Int]("some_source")

val job = TypedPipe.from(srctag).map { i => (i % 31, i) }.sumByKey.toIterableExecution

val jobRes = job.waitFor(Config.empty, MemoryMode.empty.addSourceIterable(srctag, (0 to 10000)))

val expected = (0 to 10000).groupBy(_ % 31).mapValues(_.sum).toList.sorted
assert(jobRes.get.toList.sorted == expected)
}
}

0 comments on commit fe12eb6

Please sign in to comment.