Skip to content

Commit

Permalink
removed unnecessary aggregation for certain cases of window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
westandskif committed Jun 30, 2024
1 parent aa3133e commit 8940873
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 32 deletions.
19 changes: 12 additions & 7 deletions src/convtools/_ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
GetAttr,
GetItem,
InputArg,
LabelConversion,
NaiveConversion,
)
from ._utils import Code
Expand Down Expand Up @@ -43,12 +44,16 @@ class SortingKeyConversion(BaseConversion):

def __init__(self, keys, common_conv=None):
super().__init__()
self.common_conv = (
None
if common_conv is None
else self.ensure_conversion(common_conv)
)
self.keys = [self.ensure_conversion(key) for key in keys]
if common_conv is not None and len(keys) == 1:
self.keys = [self.ensure_conversion(common_conv).pipe(keys[0])]
self.common_conv = None
else:
self.keys = [self.ensure_conversion(key) for key in keys]
self.common_conv = (
None
if common_conv is None
else self.ensure_conversion(common_conv)
)

_any_ordering_hints = (
BaseConversion.OutputHints.ORDERING_NONE_FIRST
Expand Down Expand Up @@ -197,7 +202,7 @@ def __init__(self, key=None, reverse=False):
super().__init__()
self.sorted_kwargs = {}
if key is not None:
if callable(key):
if callable(key) or isinstance(key, LabelConversion):
self.sorted_kwargs["key"] = self.ensure_conversion(key)
else:
self.sorted_kwargs["key"] = self.ensure_conversion(
Expand Down
85 changes: 60 additions & 25 deletions src/convtools/_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,39 +279,49 @@ def _gen_code_and_update_ctx(self, code_input, ctx):
if self.order_by is None:
if self.partition_by is None:
ordering_preservation_needed = False
c_reduce_clause = (
ReduceFuncs.Array(This, default=NaiveConversion(()))
c_agg = (
If(
CallFunc(isinstance, This, list),
This,
CallFunc(list, This),
)
.pipe(frames_finder)
.iter(c_frame_data_handler)
)
else:
ordering_preservation_needed = True
c_reduce_clause = CallFunc(
zip,
ReduceFuncs.Array(
LabelConversion(self._label_next).call(),
default=NaiveConversion(()),
),
ReduceFuncs.Array(This, default=NaiveConversion(()))
.pipe(frames_finder)
.iter(c_frame_data_handler),
c_agg = (
GroupBy(self.partition_by)
.aggregate(
CallFunc(
zip,
ReduceFuncs.Array(
LabelConversion(self._label_next).call(),
default=NaiveConversion(()),
),
ReduceFuncs.Array(
This, default=NaiveConversion(())
)
.pipe(frames_finder)
.iter(c_frame_data_handler),
)
)
.flatten()
)
else:
elif self.partition_by is None:
ordering_preservation_needed = True
label_sorting_key = self.gen_random_name("sorting_key", ctx)
labels[label_sorting_key] = SortingKeyConversion(
self.order_by, common_conv=GetItem(1)
)
c_reduce_clause = (
ReduceFuncs.Array(
(LabelConversion(self._label_next).call(), This),
default=list,
)
.pipe(
This.call_method(
"sort", key=LabelConversion(label_sorting_key)
).or_(This)
c_agg = (
CallFunc(
zip,
CallFunc(count),
This,
)
.as_type(list)
.sort(key=LabelConversion(label_sorting_key))
.pipe(
CallFunc(
zip,
Expand All @@ -323,12 +333,36 @@ def _gen_code_and_update_ctx(self, code_input, ctx):
)
)
)

if self.partition_by is None:
c_agg = Aggregate(c_reduce_clause)
else:
ordering_preservation_needed = True
label_sorting_key = self.gen_random_name("sorting_key", ctx)
labels[label_sorting_key] = SortingKeyConversion(
self.order_by, common_conv=GetItem(1)
)
c_agg = (
GroupBy(self.partition_by).aggregate(c_reduce_clause).flatten()
GroupBy(self.partition_by)
.aggregate(
ReduceFuncs.Array(
(LabelConversion(self._label_next).call(), This),
default=list,
)
.pipe(
This.call_method(
"sort", key=LabelConversion(label_sorting_key)
).or_(This)
)
.pipe(
CallFunc(
zip,
This.iter(GetItem(0)),
This.iter(GetItem(1))
.as_type(list)
.pipe(frames_finder)
.iter(c_frame_data_handler),
)
)
)
.flatten()
)

conv = self.conv.add_label(labels)
Expand Down Expand Up @@ -757,6 +791,7 @@ class FrameData:


def row_index():
"""Row index within the partition."""
return FrameData.ROW_INDEX


Expand Down

0 comments on commit 8940873

Please sign in to comment.