From a18444dbd7ccfc05b189dcb2005c85a1048cc8de Mon Sep 17 00:00:00 2001 From: Tim Zhou <5866950+ttzhou@users.noreply.github.com> Date: Fri, 22 Mar 2024 08:14:13 -0400 Subject: [PATCH] Feat: add expressions for CORR, COVAR_SAMP, COVAR_POP (#3193) * Add more Snowflake aggregate functions * MEDIAN -> PERCENTILE_CONT * CORR * COVAR_POP * COVAR_SAMP * Make test more concise; fix lints * Address review comments: shorten class names * Keep forgetting to setup pre-commit --- sqlglot/dataframe/sql/functions.py | 6 ++--- sqlglot/dialects/snowflake.py | 3 +++ sqlglot/expressions.py | 12 +++++++++ tests/dialects/test_snowflake.py | 40 ++++++++++++++++++++++++++++++ 4 files changed, 58 insertions(+), 3 deletions(-) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index ca117e135..b4dd2c6d9 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -356,15 +356,15 @@ def coalesce(*cols: ColumnOrName) -> Column: def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "CORR", col2) + return Column.invoke_expression_over_column(col1, expression.Corr, expression=col2) def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "COVAR_POP", col2) + return Column.invoke_expression_over_column(col1, expression.CovarPop, expression=col2) def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column: - return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2) + return Column.invoke_expression_over_column(col1, expression.CovarSamp, expression=col2) def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column: diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index c13c15ab2..a138eb003 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -368,6 +368,9 @@ class Parser(parser.Parser): this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1)) ), "LISTAGG": exp.GroupConcat.from_arg_list, + "MEDIAN": lambda args: exp.PercentileCont( + this=seq_get(args, 0), expression=exp.Literal.number(0.5) + ), "NULLIFZERO": _build_if_from_nullifzero, "OBJECT_CONSTRUCT": _build_object_construct, "REGEXP_REPLACE": _build_regexp_replace, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 883ab41ac..6c914f7d5 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -5745,6 +5745,10 @@ class Upper(Func): _sql_names = ["UPPER", "UCASE"] +class Corr(Binary, AggFunc): + pass + + class Variance(AggFunc): _sql_names = ["VARIANCE", "VARIANCE_SAMP", "VAR_SAMP"] @@ -5753,6 +5757,14 @@ class VariancePop(AggFunc): _sql_names = ["VARIANCE_POP", "VAR_POP"] +class CovarSamp(Binary, AggFunc): + pass + + +class CovarPop(Binary, AggFunc): + pass + + class Week(Func): arg_types = {"this": True, "mode": False} diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 4d7d97c5a..495159c97 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -435,6 +435,46 @@ def test_snowflake(self): "sqlite": "SELECT MIN(c1), MIN(c2) FROM test", }, ) + for suffix in ( + "", + " OVER ()", + ): + self.validate_all( + f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", + read={ + "snowflake": f"SELECT MEDIAN(x){suffix}", + "postgres": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", + }, + write={ + "": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x NULLS LAST){suffix}", + "duckdb": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", + "postgres": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", + "snowflake": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", + }, + ) + self.validate_all( + f"SELECT MEDIAN(x){suffix}", + write={ + "": f"SELECT PERCENTILE_CONT(x, 0.5){suffix}", + "duckdb": f"SELECT QUANTILE_CONT(x, 0.5){suffix}", + "postgres": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", + "snowflake": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}", + }, + ) + for func in ( + "CORR", + "COVAR_POP", + "COVAR_SAMP", + ): + self.validate_all( + f"SELECT {func}(y, x){suffix}", + write={ + "": f"SELECT {func}(y, x){suffix}", + "duckdb": f"SELECT {func}(y, x){suffix}", + "postgres": f"SELECT {func}(y, x){suffix}", + "snowflake": f"SELECT {func}(y, x){suffix}", + }, + ) self.validate_all( "TO_CHAR(x, y)", read={