diff --git a/python/tests/ts/flint/test_dataframe.py b/python/tests/ts/flint/test_dataframe.py index 17c36f9b..00fd4c87 100644 --- a/python/tests/ts/flint/test_dataframe.py +++ b/python/tests/ts/flint/test_dataframe.py @@ -2402,4 +2402,4 @@ def foo(df): tests_utils.assert_same(result5, expected5) def test_preview(tests_utils, price): - tests_utils.assert_same(price.limit(10).toPandas(), price.preview()) \ No newline at end of file + tests_utils.assert_same(price.limit(10).toPandas(), price.preview()) diff --git a/python/ts/flint/dataframe.py b/python/ts/flint/dataframe.py index d76427c7..ed124ca0 100644 --- a/python/ts/flint/dataframe.py +++ b/python/ts/flint/dataframe.py @@ -1454,6 +1454,21 @@ def summarize(self, summarizer, key=None): tsrdd = self.timeSeriesRDD.summarize(composed_summarizer._jsummarizer(self._sc), scala_key) return TimeSeriesDataFrame._from_tsrdd(tsrdd, self.sql_ctx) + def summarizeState(self, summarizer, key=None): + """ + Undocumented function for the bravest. + + Returns a Java map from key to summarize state (also Java object). + This function can be changed/removed/broken without notice. + + Use at your own risk. + """ + scala_key = utils.list_to_seq(self._sc, key) + composed_summarizer = summarizers.compose(self._sc, summarizer) + with traceback_utils.SCCallSiteSync(self._sc) as css: + result = self.timeSeriesRDD.summarizeState(composed_summarizer._jsummarizer(self._sc), scala_key) + return result + def addSummaryColumns(self, summarizer, key=None): """Computes the running aggregate statistics of a table. For a given row R, the new columns will be the summarization of all diff --git a/src/main/scala/com/twosigma/flint/rdd/OrderedRDD.scala b/src/main/scala/com/twosigma/flint/rdd/OrderedRDD.scala index b14c97e4..3d92b01b 100644 --- a/src/main/scala/com/twosigma/flint/rdd/OrderedRDD.scala +++ b/src/main/scala/com/twosigma/flint/rdd/OrderedRDD.scala @@ -426,6 +426,16 @@ class OrderedRDD[K: ClassTag, V: ClassTag]( depth: Int ): Map[SK, V2] = Summarize(self, summarizer, windowFn, skFn, depth) + /** + * Summarize and return both state and output + */ + def summarizeState[SK, U, V2]( + summarizer: OverlappableSummarizer[V, U, V2], + windowFn: K => (K, K), + skFn: V => SK, + depth: Int + ): Map[SK, (U, V2)] = Summarize.summarizeState(self, summarizer, windowFn, skFn, depth) + /** * Similar to [[org.apache.spark.rdd.RDD.zipWithIndex]], it zips values of this [[OrderedRDD]] with * its element indices. The ordering of this [[OrderedRDD]] will be preserved. diff --git a/src/main/scala/com/twosigma/flint/rdd/function/summarize/Summarize.scala b/src/main/scala/com/twosigma/flint/rdd/function/summarize/Summarize.scala index 3ca7592b..d4d29b65 100644 --- a/src/main/scala/com/twosigma/flint/rdd/function/summarize/Summarize.scala +++ b/src/main/scala/com/twosigma/flint/rdd/function/summarize/Summarize.scala @@ -122,28 +122,15 @@ protected[flint] object Summarize { } /** - * Apply an [[OverlappableSummarizer]] to an [[OrderedRDD]]. - * - * @param rdd An [[OrderedRDD]] of tuples (K, V) - * @param summarizer An [[OverlappableSummarizer]] expected to apply - * @param windowFn A function expected to expand the range of a partition. - * Consider a partition of `rdd` with a range [b, e). The function expands - * the range to [b1, e1) where b1 is the left windowFn(b) and e1 is the right - * of windowFn(e). The `summarizer` will be applied to an expanded partition - * that includes all rows failing into [b1, e1). - * @param skFn A function that extracts the secondary keys from V such that the summarizer will be - * applied per secondary key level in the order of K. - * @param depth The depth of tree for merging partial summarized results across different partitions - * in a a multi-level tree aggregation fashion. - * @return the summarized results. + * Return the un rendered state of summarization. */ - def apply[K: ClassTag: Ordering, SK, V: ClassTag, U, V2]( + def summarizeStateInternal[K: ClassTag: Ordering, SK, V: ClassTag, U, V2]( rdd: OrderedRDD[K, V], summarizer: OverlappableSummarizer[V, U, V2], windowFn: K => (K, K), skFn: V => SK, depth: Int - ): Map[SK, V2] = { + ): Map[SK, U] = { if (rdd.getNumPartitions == 0) { Map.empty } else { @@ -167,10 +154,50 @@ protected[flint] object Summarize { (sk, summarizer.merge(u1.getOrElse(sk, summarizer.zero()), u2.getOrElse(sk, summarizer.zero()))) } - TreeReduce(partiallySummarized)(mergeOp, depth).map { - case (sk, v) => (sk, summarizer.render(v)) - } + TreeReduce(partiallySummarized)(mergeOp, depth) + } + } + + /** + * Apply an [[OverlappableSummarizer]] to an [[OrderedRDD]]. + * + * @param rdd An [[OrderedRDD]] of tuples (K, V) + * @param summarizer An [[OverlappableSummarizer]] expected to apply + * @param windowFn A function expected to expand the range of a partition. + * Consider a partition of `rdd` with a range [b, e). The function expands + * the range to [b1, e1) where b1 is the left windowFn(b) and e1 is the right + * of windowFn(e). The `summarizer` will be applied to an expanded partition + * that includes all rows failing into [b1, e1). + * @param skFn A function that extracts the secondary keys from V such that the summarizer will be + * applied per secondary key level in the order of K. + * @param depth The depth of tree for merging partial summarized results across different partitions + * in a a multi-level tree aggregation fashion. + * @return the summarized results. + */ + def apply[K: ClassTag: Ordering, SK, V: ClassTag, U, V2]( + rdd: OrderedRDD[K, V], + summarizer: OverlappableSummarizer[V, U, V2], + windowFn: K => (K, K), + skFn: V => SK, + depth: Int + ): Map[SK, V2] = { + summarizeStateInternal(rdd, summarizer, windowFn, skFn, depth).map { + case (sk, v) => (sk, summarizer.render(v)) } } + /** + * Return both state and rendered output of summarization. + */ + def summarizeState[K: ClassTag: Ordering, SK, V: ClassTag, U, V2]( + rdd: OrderedRDD[K, V], + summarizer: OverlappableSummarizer[V, U, V2], + windowFn: K => (K, K), + skFn: V => SK, + depth: Int + ): Map[SK, (U, V2)] = { + summarizeStateInternal(rdd, summarizer, windowFn, skFn, depth).map { + case (sk, v) => (sk, (v, summarizer.render(v))) + } + } } diff --git a/src/main/scala/com/twosigma/flint/timeseries/TimeSeriesRDD.scala b/src/main/scala/com/twosigma/flint/timeseries/TimeSeriesRDD.scala index 4011956e..7cfb8a37 100644 --- a/src/main/scala/com/twosigma/flint/timeseries/TimeSeriesRDD.scala +++ b/src/main/scala/com/twosigma/flint/timeseries/TimeSeriesRDD.scala @@ -1163,6 +1163,16 @@ trait TimeSeriesRDD extends Serializable { */ def summarize(summarizer: SummarizerFactory, key: Seq[String] = Seq.empty): TimeSeriesRDD + /** + * Undocumented function for the bravest. + * + * Returns a Java map from key to summarize state (also Java object). + * This function can be changed/removed/broken without notice. + * + * Use at your own risk. + */ + def summarizeState(summarizerFactory: SummarizerFactory, key: Seq[String] = Seq.empty): Map[Seq[Any], (Any, Any)] + private[flint] def summarize(summarizer: SummarizerFactory, key: String): TimeSeriesRDD = summarize(summarizer, Option(key).toSeq) @@ -1766,6 +1776,20 @@ class TimeSeriesRDDImpl private[timeseries] ( def summarize(summarizerFactory: SummarizerFactory, key: Seq[String] = Seq.empty): TimeSeriesRDD = summarizeInternal(summarizerFactory, key, 2) + def summarizeState(summarizerFactory: SummarizerFactory, key: Seq[String] = Seq.empty): Map[Seq[Any], (Any, Any)] = { + val depth = 2 + val pruned = TimeSeriesRDD.pruneColumns(this, summarizerFactory.requiredColumns, key) + val summarizer = summarizerFactory(pruned.schema) + val keyGetter = pruned.safeGetAsAny(key) + val summarized = summarizerFactory match { + case factory: OverlappableSummarizerFactory => + pruned.orderedRdd.summarizeState( + summarizer.asInstanceOf[OverlappableSummarizer], factory.window.of, keyGetter, depth + ) + } + summarized + } + def addSummaryColumns(summarizer: SummarizerFactory, key: Seq[String] = Seq.empty): TimeSeriesRDD = { val sum = summarizer(schema) val reductionsRdd = orderedRdd.summarizations(sum, safeGetAsAny(key))