Skip to content

Commit

Permalink
fix(sqlglot): Convert metric syntax back to dialect-specific after pa…
Browse files Browse the repository at this point in the history
…rsing it (#273)
  • Loading branch information
Vitor-Avila committed Mar 20, 2024
1 parent ed91b13 commit 8a459f6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/preset_cli/cli/superset/sync/dbt/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_metric_expression(metric_name: str, metrics: Dict[str, MetricSchema]) ->
)
token.replace(parent_expression)

return expression.sql()
return expression.sql(dialect=metric["dialect"])

sorted_metric = dict(sorted(metric.items()))
raise Exception(f"Unable to generate metric expression from: {sorted_metric}")
Expand Down Expand Up @@ -285,7 +285,7 @@ def convert_query_to_projection(sql: str, dialect: MFSQLEngine) -> str:
)
metric_expression.set("this", case_expression)

return metric_expression.sql()
return metric_expression.sql(dialect=DIALECT_MAP.get(dialect))


def convert_metric_flow_to_superset(
Expand Down
16 changes: 14 additions & 2 deletions tests/cli/superset/sync/dbt/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_get_metric_expression_derived_legacy() -> None:
result = get_metric_expression(unique_id, metrics)
assert (
result
== "SAFE_DIVIDE(SUM(CASE WHEN \"product_line\" = 'Classic Cars' THEN price_each * 0.80 ELSE price_each * 0.70 END), SUM(price_each))"
== "SAFE_DIVIDE(SUM(IF(`product_line` = 'Classic Cars', price_each * 0.80, price_each * 0.70)), SUM(price_each))"
)


Expand Down Expand Up @@ -683,7 +683,19 @@ def test_convert_query_to_projection() -> None:
""",
MFSQLEngine.BIGQUERY,
)
== "CAST(SUM(CASE WHEN is_food_item = 1 THEN product_price ELSE 0 END) AS DOUBLE) / CAST(NULLIF(SUM(product_price), 0) AS DOUBLE)"
== "CAST(SUM(CASE WHEN is_food_item = 1 THEN product_price ELSE 0 END) AS FLOAT64) / CAST(NULLIF(SUM(product_price), 0) AS FLOAT64)"
)

assert (
convert_query_to_projection(
"""
SELECT
AVG(DATE_DIFF(start_date, end_date, DAY)) AS avg_time_diff
FROM `dbt-tutorial-347100`.`dbt_beto`.`order_items` order_item_src_98
""",
MFSQLEngine.BIGQUERY,
)
== "AVG(DATE_DIFF(start_date, end_date, DAY))"
)

with pytest.raises(ValueError) as excinfo:
Expand Down

0 comments on commit 8a459f6

Please sign in to comment.