Skip to content
This repository was archived by the owner on Aug 29, 2024. It is now read-only.
23 changes: 22 additions & 1 deletion dask_sql/physical/utils/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import dask.dataframe as dd
import pandas as pd

from dask_sql.utils import make_pickable_without_dask_sql, new_temporary_column
from dask_sql.utils import make_pickable_without_dask_sql

try:
import dask_cudf
except ImportError:
dask_cudf = None


def apply_sort(
Expand All @@ -12,6 +17,22 @@ def apply_sort(
sort_ascending: List[bool],
sort_null_first: List[bool],
) -> dd.DataFrame:
# Try fast path for multi-column sorting before falling back to
# sort_partition_func. Tools like dask-cudf have a limited but fast
# multi-column sort implementation. We check if any sorting/null sorting
# is required. If so, we fall back to default sorting implementation
if (
dask_cudf is not None
and isinstance(df, dask_cudf.DataFrame)
and all(sort_ascending)
and not any(sort_null_first)
):
try:
df = df.sort_values(sort_columns, ignore_index=True)
return df.persist()
Copy link
Member

@VibhuJawa VibhuJawa Sep 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We should not call .persist() on single patition frames .

  2. Just curious , Does .persist() ensure we dont trigger duplicate computations as IIRC, .sort_values() is not lazy.

I wonder if this is a better patten

df = df.persist()
df = df.sort_values(sort_columns, ignore_index=True).persist()

Copy link
Member Author

@charlesbluca charlesbluca Sep 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Agreed, I can add in a call to map_partitions in the single partition case.
  2. @quasiben might know better than me the implications of calling persist here; I would assume this is here mostly to match up with the persist call happening in the workaround:

return df.persist()

EDIT:

Just saw your edit - knowing that, it looks like the current pattern should be good (once we account for the single partition case) - should we still opt to persist before running sort_values?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just saw your edit - knowing that, it looks like the current pattern should be good (once we account for the single partition case) - should we still opt to persist before running sort_values?

Testing it again now., will update here. Sorry for the edit and confusion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So i tested an example workflow with and without persisting first, and persisting before sorting indeed prevents duplicate computation.

Without Persisting (DASK PROFILE):

st = time.time()
with performance_report(filename="sort-without-persist.html"):
    df =  dask_cudf.read_parquet(get_fp("web_sales"),columns= columns).shuffle(['ws_sold_date_sk','ws_ship_date_sk'])
    df = df.sort_values(by=['ws_bill_cdemo_sk'],ignore_index=True).persist()
    df = wait(df);
    del df
print(f"et -st = {et-st}")
et -st = 23.0989 

With Persisting (DASK PROFILE):

st = time.time()
with performance_report(filename="sort-with-persist.html"):
    df =  dask_cudf.read_parquet(get_fp("web_sales"),columns= columns).shuffle(['ws_sold_date_sk','ws_ship_date_sk'])
    df = df.persist().sort_values(by=['ws_bill_cdemo_sk'],ignore_index=True).persist()
    df = wait(df);
    del df
    
et = time.time()
print(f"et -st = {et-st}")
et -st = 16.24

The trade of here is memory vs duplicate computation. I think we might want to think more about this .

I wonder if a version of in-place sorting might prevent some memory overheads.

Anyways, we should think deeply about this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the persist calls here are due to handling the multi-col sort on CPU. Once pandas-dev/pandas#43881 is resolved and Dask has a native multi-col sort we can probably remove them entirely. @charlesbluca is correct that I was originally intending to match the the case when native mult-col sorting is not supported.

I think it's ok to safely remove persist in the initial try state and return the dataframe directly

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed these changes to the original PR:

dask-contrib#229

except ValueError:
pass

# Split the first column. We need to handle this one with set_index
first_sort_column = sort_columns[0]
first_sort_ascending = sort_ascending[0]
Expand Down
8 changes: 8 additions & 0 deletions tests/integration/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from dask.distributed import Client
from pandas.testing import assert_frame_equal

try:
import dask_cudf
except ImportError:
dask_cudf = None


@pytest.fixture()
def timeseries_df(c):
Expand Down Expand Up @@ -117,6 +122,9 @@ def c(
for df_name, df in dfs.items():
dask_df = dd.from_pandas(df, npartitions=3)
c.create_table(df_name, dask_df)
if dask_cudf is not None:
cudf_df = dask_cudf.from_dask_dataframe(dask_df)
c.create_table("cudf_" + df_name, cudf_df)

yield c

Expand Down
30 changes: 30 additions & 0 deletions tests/integration/test_dask_cudf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

pytest.importorskip("dask_cudf")

from cudf.testing._utils import assert_eq


def test_cudf_order_by(c):
df = c.sql(
"""
SELECT
*
FROM cudf_user_table_1
ORDER BY user_id
"""
).compute()

expected_df = (
c.sql(
"""
SELECT
*
FROM cudf_user_table_1
"""
)
.sort_values(by="user_id", ignore_index=True)
.compute()
)

assert_eq(df, expected_df)
26 changes: 26 additions & 0 deletions tests/integration/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import pytest
from pandas.testing import assert_frame_equal

try:
import dask_cudf
except ImportError:
dask_cudf = None


def test_schemas(c):
df = c.sql("SHOW SCHEMAS")
Expand Down Expand Up @@ -36,6 +41,27 @@ def test_tables(c):
"string_table",
"datetime_table",
]
if dask_cudf is None
else [
"df_simple",
"cudf_df_simple",
"df",
"cudf_df",
"user_table_1",
"cudf_user_table_1",
"user_table_2",
"cudf_user_table_2",
"long_table",
"cudf_long_table",
"user_table_inf",
"cudf_user_table_inf",
"user_table_nan",
"cudf_user_table_nan",
"string_table",
"cudf_string_table",
"datetime_table",
"cudf_datetime_table",
]
}
)

Expand Down