Skip to content

Commit

Permalink
Feat: add expressions for CORR, COVAR_SAMP, COVAR_POP (#3193)
Browse files Browse the repository at this point in the history
* Add more Snowflake aggregate functions

* MEDIAN -> PERCENTILE_CONT
* CORR
* COVAR_POP
* COVAR_SAMP

* Make test more concise; fix lints

* Address review comments: shorten class names

* Keep forgetting to setup pre-commit
  • Loading branch information
ttzhou committed Mar 22, 2024
1 parent 0dd9ba5 commit a18444d
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sqlglot/dataframe/sql/functions.py
Expand Up @@ -356,15 +356,15 @@ def coalesce(*cols: ColumnOrName) -> Column:


def corr(col1: ColumnOrName, col2: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col1, "CORR", col2)
return Column.invoke_expression_over_column(col1, expression.Corr, expression=col2)


def covar_pop(col1: ColumnOrName, col2: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col1, "COVAR_POP", col2)
return Column.invoke_expression_over_column(col1, expression.CovarPop, expression=col2)


def covar_samp(col1: ColumnOrName, col2: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col1, "COVAR_SAMP", col2)
return Column.invoke_expression_over_column(col1, expression.CovarSamp, expression=col2)


def first(col: ColumnOrName, ignorenulls: t.Optional[bool] = None) -> Column:
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/snowflake.py
Expand Up @@ -368,6 +368,9 @@ class Parser(parser.Parser):
this=seq_get(args, 0), unit=_map_date_part(seq_get(args, 1))
),
"LISTAGG": exp.GroupConcat.from_arg_list,
"MEDIAN": lambda args: exp.PercentileCont(
this=seq_get(args, 0), expression=exp.Literal.number(0.5)
),
"NULLIFZERO": _build_if_from_nullifzero,
"OBJECT_CONSTRUCT": _build_object_construct,
"REGEXP_REPLACE": _build_regexp_replace,
Expand Down
12 changes: 12 additions & 0 deletions sqlglot/expressions.py
Expand Up @@ -5745,6 +5745,10 @@ class Upper(Func):
_sql_names = ["UPPER", "UCASE"]


class Corr(Binary, AggFunc):
pass


class Variance(AggFunc):
_sql_names = ["VARIANCE", "VARIANCE_SAMP", "VAR_SAMP"]

Expand All @@ -5753,6 +5757,14 @@ class VariancePop(AggFunc):
_sql_names = ["VARIANCE_POP", "VAR_POP"]


class CovarSamp(Binary, AggFunc):
pass


class CovarPop(Binary, AggFunc):
pass


class Week(Func):
arg_types = {"this": True, "mode": False}

Expand Down
40 changes: 40 additions & 0 deletions tests/dialects/test_snowflake.py
Expand Up @@ -435,6 +435,46 @@ def test_snowflake(self):
"sqlite": "SELECT MIN(c1), MIN(c2) FROM test",
},
)
for suffix in (
"",
" OVER ()",
):
self.validate_all(
f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
read={
"snowflake": f"SELECT MEDIAN(x){suffix}",
"postgres": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
},
write={
"": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x NULLS LAST){suffix}",
"duckdb": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
"postgres": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
"snowflake": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
},
)
self.validate_all(
f"SELECT MEDIAN(x){suffix}",
write={
"": f"SELECT PERCENTILE_CONT(x, 0.5){suffix}",
"duckdb": f"SELECT QUANTILE_CONT(x, 0.5){suffix}",
"postgres": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
"snowflake": f"SELECT PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY x){suffix}",
},
)
for func in (
"CORR",
"COVAR_POP",
"COVAR_SAMP",
):
self.validate_all(
f"SELECT {func}(y, x){suffix}",
write={
"": f"SELECT {func}(y, x){suffix}",
"duckdb": f"SELECT {func}(y, x){suffix}",
"postgres": f"SELECT {func}(y, x){suffix}",
"snowflake": f"SELECT {func}(y, x){suffix}",
},
)
self.validate_all(
"TO_CHAR(x, y)",
read={
Expand Down

0 comments on commit a18444d

Please sign in to comment.