Skip to content

Commit

Permalink
Expose state in summarize
Browse files Browse the repository at this point in the history
  • Loading branch information
Li Jin authored and icexelloss committed Jul 2, 2018
1 parent 42ef0ec commit 5b49da1
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 20 deletions.
2 changes: 1 addition & 1 deletion python/tests/ts/flint/test_dataframe.py
Expand Up @@ -2402,4 +2402,4 @@ def foo(df):
tests_utils.assert_same(result5, expected5)

def test_preview(tests_utils, price):
tests_utils.assert_same(price.limit(10).toPandas(), price.preview())
tests_utils.assert_same(price.limit(10).toPandas(), price.preview())
15 changes: 15 additions & 0 deletions python/ts/flint/dataframe.py
Expand Up @@ -1454,6 +1454,21 @@ def summarize(self, summarizer, key=None):
tsrdd = self.timeSeriesRDD.summarize(composed_summarizer._jsummarizer(self._sc), scala_key)
return TimeSeriesDataFrame._from_tsrdd(tsrdd, self.sql_ctx)

def summarizeState(self, summarizer, key=None):
"""
Undocumented function for the bravest.
Returns a Java map from key to summarize state (also Java object).
This function can be changed/removed/broken without notice.
Use at your own risk.
"""
scala_key = utils.list_to_seq(self._sc, key)
composed_summarizer = summarizers.compose(self._sc, summarizer)
with traceback_utils.SCCallSiteSync(self._sc) as css:
result = self.timeSeriesRDD.summarizeState(composed_summarizer._jsummarizer(self._sc), scala_key)
return result

def addSummaryColumns(self, summarizer, key=None):
"""Computes the running aggregate statistics of a table. For a
given row R, the new columns will be the summarization of all
Expand Down
10 changes: 10 additions & 0 deletions src/main/scala/com/twosigma/flint/rdd/OrderedRDD.scala
Expand Up @@ -426,6 +426,16 @@ class OrderedRDD[K: ClassTag, V: ClassTag](
depth: Int
): Map[SK, V2] = Summarize(self, summarizer, windowFn, skFn, depth)

/**
* Summarize and return both state and output
*/
def summarizeState[SK, U, V2](
summarizer: OverlappableSummarizer[V, U, V2],
windowFn: K => (K, K),
skFn: V => SK,
depth: Int
): Map[SK, (U, V2)] = Summarize.summarizeState(self, summarizer, windowFn, skFn, depth)

/**
* Similar to [[org.apache.spark.rdd.RDD.zipWithIndex]], it zips values of this [[OrderedRDD]] with
* its element indices. The ordering of this [[OrderedRDD]] will be preserved.
Expand Down
Expand Up @@ -122,28 +122,15 @@ protected[flint] object Summarize {
}

/**
* Apply an [[OverlappableSummarizer]] to an [[OrderedRDD]].
*
* @param rdd An [[OrderedRDD]] of tuples (K, V)
* @param summarizer An [[OverlappableSummarizer]] expected to apply
* @param windowFn A function expected to expand the range of a partition.
* Consider a partition of `rdd` with a range [b, e). The function expands
* the range to [b1, e1) where b1 is the left windowFn(b) and e1 is the right
* of windowFn(e). The `summarizer` will be applied to an expanded partition
* that includes all rows failing into [b1, e1).
* @param skFn A function that extracts the secondary keys from V such that the summarizer will be
* applied per secondary key level in the order of K.
* @param depth The depth of tree for merging partial summarized results across different partitions
* in a a multi-level tree aggregation fashion.
* @return the summarized results.
* Return the un rendered state of summarization.
*/
def apply[K: ClassTag: Ordering, SK, V: ClassTag, U, V2](
def summarizeStateInternal[K: ClassTag: Ordering, SK, V: ClassTag, U, V2](
rdd: OrderedRDD[K, V],
summarizer: OverlappableSummarizer[V, U, V2],
windowFn: K => (K, K),
skFn: V => SK,
depth: Int
): Map[SK, V2] = {
): Map[SK, U] = {
if (rdd.getNumPartitions == 0) {
Map.empty
} else {
Expand All @@ -167,10 +154,50 @@ protected[flint] object Summarize {
(sk, summarizer.merge(u1.getOrElse(sk, summarizer.zero()), u2.getOrElse(sk, summarizer.zero())))
}

TreeReduce(partiallySummarized)(mergeOp, depth).map {
case (sk, v) => (sk, summarizer.render(v))
}
TreeReduce(partiallySummarized)(mergeOp, depth)
}
}

/**
* Apply an [[OverlappableSummarizer]] to an [[OrderedRDD]].
*
* @param rdd An [[OrderedRDD]] of tuples (K, V)
* @param summarizer An [[OverlappableSummarizer]] expected to apply
* @param windowFn A function expected to expand the range of a partition.
* Consider a partition of `rdd` with a range [b, e). The function expands
* the range to [b1, e1) where b1 is the left windowFn(b) and e1 is the right
* of windowFn(e). The `summarizer` will be applied to an expanded partition
* that includes all rows failing into [b1, e1).
* @param skFn A function that extracts the secondary keys from V such that the summarizer will be
* applied per secondary key level in the order of K.
* @param depth The depth of tree for merging partial summarized results across different partitions
* in a a multi-level tree aggregation fashion.
* @return the summarized results.
*/
def apply[K: ClassTag: Ordering, SK, V: ClassTag, U, V2](
rdd: OrderedRDD[K, V],
summarizer: OverlappableSummarizer[V, U, V2],
windowFn: K => (K, K),
skFn: V => SK,
depth: Int
): Map[SK, V2] = {
summarizeStateInternal(rdd, summarizer, windowFn, skFn, depth).map {
case (sk, v) => (sk, summarizer.render(v))
}
}

/**
* Return both state and rendered output of summarization.
*/
def summarizeState[K: ClassTag: Ordering, SK, V: ClassTag, U, V2](
rdd: OrderedRDD[K, V],
summarizer: OverlappableSummarizer[V, U, V2],
windowFn: K => (K, K),
skFn: V => SK,
depth: Int
): Map[SK, (U, V2)] = {
summarizeStateInternal(rdd, summarizer, windowFn, skFn, depth).map {
case (sk, v) => (sk, (v, summarizer.render(v)))
}
}
}
24 changes: 24 additions & 0 deletions src/main/scala/com/twosigma/flint/timeseries/TimeSeriesRDD.scala
Expand Up @@ -1163,6 +1163,16 @@ trait TimeSeriesRDD extends Serializable {
*/
def summarize(summarizer: SummarizerFactory, key: Seq[String] = Seq.empty): TimeSeriesRDD

/**
* Undocumented function for the bravest.
*
* Returns a Java map from key to summarize state (also Java object).
* This function can be changed/removed/broken without notice.
*
* Use at your own risk.
*/
def summarizeState(summarizerFactory: SummarizerFactory, key: Seq[String] = Seq.empty): Map[Seq[Any], (Any, Any)]

private[flint] def summarize(summarizer: SummarizerFactory, key: String): TimeSeriesRDD =
summarize(summarizer, Option(key).toSeq)

Expand Down Expand Up @@ -1766,6 +1776,20 @@ class TimeSeriesRDDImpl private[timeseries] (
def summarize(summarizerFactory: SummarizerFactory, key: Seq[String] = Seq.empty): TimeSeriesRDD =
summarizeInternal(summarizerFactory, key, 2)

def summarizeState(summarizerFactory: SummarizerFactory, key: Seq[String] = Seq.empty): Map[Seq[Any], (Any, Any)] = {
val depth = 2
val pruned = TimeSeriesRDD.pruneColumns(this, summarizerFactory.requiredColumns, key)
val summarizer = summarizerFactory(pruned.schema)
val keyGetter = pruned.safeGetAsAny(key)
val summarized = summarizerFactory match {
case factory: OverlappableSummarizerFactory =>
pruned.orderedRdd.summarizeState(
summarizer.asInstanceOf[OverlappableSummarizer], factory.window.of, keyGetter, depth
)
}
summarized
}

def addSummaryColumns(summarizer: SummarizerFactory, key: Seq[String] = Seq.empty): TimeSeriesRDD = {
val sum = summarizer(schema)
val reductionsRdd = orderedRdd.summarizations(sum, safeGetAsAny(key))
Expand Down

0 comments on commit 5b49da1

Please sign in to comment.