Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 4 additions & 38 deletions sidemantic/adapters/rill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from pathlib import Path
from typing import Any

import sqlglot
import yaml
from sqlglot import expressions as exp

from sidemantic.core.dimension import Dimension
from sidemantic.core.metric import Metric
Expand Down Expand Up @@ -203,46 +201,14 @@ def _parse_measure(self, measure_def: dict[str, Any]) -> Metric | None:
# "simple" = basic aggregation (None type), "derived" = calculation using other measures
metric_type = "derived"

# Use sqlglot to detect simple aggregations
agg_type = None
agg_sql = None
try:
parsed = sqlglot.parse_one(expression, read="duckdb")

# Check if this is a simple aggregation function
if isinstance(parsed, (exp.Sum, exp.Avg, exp.Count, exp.Min, exp.Max)):
# Map sqlglot aggregation types to Sidemantic agg types
if isinstance(parsed, exp.Sum):
agg_type = "sum"
elif isinstance(parsed, exp.Avg):
agg_type = "avg"
elif isinstance(parsed, exp.Count):
if parsed.args.get("distinct"):
agg_type = "count_distinct"
else:
agg_type = "count"
elif isinstance(parsed, exp.Min):
agg_type = "min"
elif isinstance(parsed, exp.Max):
agg_type = "max"

# Extract the aggregated column/expression
agg_arg = parsed.this
if agg_arg:
agg_sql = agg_arg.sql(dialect="duckdb")
elif isinstance(parsed, exp.Count):
# COUNT(*) case
agg_sql = None
except Exception:
# If parsing fails, treat as custom SQL expression
pass

# Let the Metric class handle aggregation parsing via its model_validator.
# This properly handles complex expressions like SUM(x) / SUM(y) and
# COUNT(DISTINCT col) using sqlglot.
return Metric(
name=name,
label=label,
description=description,
agg=agg_type,
sql=agg_sql if agg_type else expression,
sql=expression, # Pass full expression, Metric will parse aggregations
type=metric_type,
value_format_name=value_format_name,
window_order=window_order,
Expand Down
14 changes: 14 additions & 0 deletions sidemantic/core/dependency_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ def extract_metric_dependencies(metric_obj, graph=None, model_context=None) -> s
deps.add(metric_obj.sql)
return deps

# Check if this is an expression metric with inline aggregations
# (e.g., SUM(x) / SUM(y), COUNT(DISTINCT col) * 1.0)
# These don't have measure dependencies - the aggregations are inline
try:
parsed = sqlglot.parse_one(metric_obj.sql)
# Check if the expression contains any aggregation functions
agg_types = (exp.Sum, exp.Avg, exp.Count, exp.Min, exp.Max, exp.Median)
has_inline_agg = any(parsed.find_all(*agg_types))
if has_inline_agg and not metric_obj.type:
# Expression metric with inline aggregations - no measure dependencies
return deps
except Exception:
pass

# Extract column references from expression
refs = extract_column_references(metric_obj.sql)

Expand Down
76 changes: 58 additions & 18 deletions sidemantic/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ def handle_expr_and_parse_agg(cls, data):

1. Converts expr= to sql= for backwards compatibility
2. Parses aggregation functions from SQL (e.g., SUM(amount) -> agg=sum, sql=amount)
"""
import re

Uses sqlglot to properly parse expressions and handle nested parentheses.
Only extracts aggregation from SIMPLE expressions (single aggregation function).
Complex expressions like SUM(x) / SUM(y) are preserved as-is.
"""
if isinstance(data, dict):
# Step 1: Handle expr alias
expr_val = data.get("expr")
Expand All @@ -72,23 +74,61 @@ def handle_expr_and_parse_agg(cls, data):
# Parse if sql is provided and agg is not set
# Allow parsing for simple metrics (no type) OR cumulative metrics (to support AVG/COUNT windows)
if sql_val and not agg_val and (not type_val or type_val == "cumulative"):
# Match aggregation functions at the start: SUM(expr), COUNT(expr), etc.
agg_pattern = r"^\s*(SUM|COUNT|AVG|MIN|MAX|MEDIAN|COUNT_DISTINCT)\s*\((.*)\)\s*$"
match = re.match(agg_pattern, sql_val, re.IGNORECASE)

if match:
agg_func = match.group(1).lower()
inner_expr = match.group(2).strip()

# Extract DISTINCT for COUNT(DISTINCT col)
if agg_func == "count":
distinct_match = re.match(r"^\s*DISTINCT\s+(.+)$", inner_expr, re.IGNORECASE)
if distinct_match:
try:
import sqlglot
from sqlglot import expressions as exp

parsed = sqlglot.parse_one(sql_val, read="duckdb")

# Only extract if the TOP-LEVEL expression is a simple aggregation
# This prevents breaking expressions like SUM(x) / SUM(y)
agg_map = {
exp.Sum: "sum",
exp.Avg: "avg",
exp.Min: "min",
exp.Max: "max",
exp.Median: "median",
}

agg_func = None
inner_expr = None

# Check for standard aggregations
for agg_class, agg_name in agg_map.items():
if isinstance(parsed, agg_class):
agg_func = agg_name
if parsed.this:
inner_expr = parsed.this.sql(dialect="duckdb")
break

# Handle COUNT specially (need to detect DISTINCT)
if isinstance(parsed, exp.Count):
# Check if the argument is a Distinct expression
if isinstance(parsed.this, exp.Distinct):
agg_func = "count_distinct"
inner_expr = distinct_match.group(1).strip()

data["agg"] = agg_func
data["sql"] = inner_expr
# Extract all expressions from inside Distinct
# e.g., COUNT(DISTINCT a, b) -> "a, b"
if parsed.this.expressions:
inner_expr = ", ".join(e.sql(dialect="duckdb") for e in parsed.this.expressions)
else:
inner_expr = parsed.this.sql(dialect="duckdb")
else:
agg_func = "count"
if parsed.this:
inner_expr = parsed.this.sql(dialect="duckdb")
# COUNT(*) case - inner_expr stays None

if agg_func:
data["agg"] = agg_func
if inner_expr is not None:
data["sql"] = inner_expr
elif agg_func == "count":
# COUNT(*) - leave sql as None or "*"
data["sql"] = None

except Exception:
# If sqlglot parsing fails, leave the expression as-is
pass

return data

Expand Down
59 changes: 47 additions & 12 deletions sidemantic/sql/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,13 +1398,15 @@ def _build_main_select(
# Complex metric types (derived, ratio) can be built inline
# Note: cumulative, time_comparison, conversion are handled via special query generators
# and won't appear in this code path
if measure.type in ["derived", "ratio"]:
# Also handle "expression metrics" - metrics with inline aggregations like SUM(x)/SUM(y)
is_expression_metric = not measure.type and not measure.agg and measure.sql
if measure.type in ["derived", "ratio"] or is_expression_metric:
# Use complex metric builder
metric_expr = self._build_metric_sql(measure, model_name)
metric_expr = self._wrap_with_fill_nulls(metric_expr, measure)
select_exprs.append(f"{metric_expr} AS {alias}")
elif not measure.agg:
# Complex types that need special handling (shouldn't reach here normally)
# Unknown metric type that needs special handling
raise ValueError(
f"Metric '{measure.name}' with type '{measure.type}' cannot be queried directly. "
f"Use generate() instead of _build_main_select() for this metric type."
Expand Down Expand Up @@ -1736,7 +1738,12 @@ def _build_metric_sql(self, metric, model_context: str | None = None) -> str:

# Check if this is a SQL expression metric (has inline aggregations)
# These metrics already contain complete SQL and shouldn't have dependencies replaced
has_inline_agg = any(agg in formula.upper() for agg in ["COUNT(", "SUM(", "AVG(", "MIN(", "MAX("])
try:
parsed = sqlglot.parse_one(formula, read=self.dialect)
agg_types = (exp.Sum, exp.Avg, exp.Count, exp.Min, exp.Max, exp.Median)
has_inline_agg = any(parsed.find_all(*agg_types))
except Exception:
has_inline_agg = False

if has_inline_agg:
# This is a SQL expression metric with inline aggregations.
Expand Down Expand Up @@ -2009,7 +2016,12 @@ def _generate_with_window_functions(
cumulative_metrics.append(m)
# Add the base measure/metric to base_metrics
if metric.sql:
base_metrics.append(metric.sql)
base_ref = metric.sql
# Qualify unqualified references with the model name
if "." not in base_ref and "." in m:
model_name = m.split(".")[0]
base_ref = f"{model_name}.{base_ref}"
base_metrics.append(base_ref)
elif metric and metric.type == "time_comparison":
# Validate required fields
if not metric.base_metric:
Expand Down Expand Up @@ -2076,7 +2088,16 @@ def _generate_with_window_functions(

# Add cumulative metrics with window functions
for m in cumulative_metrics:
metric = self.graph.get_metric(m)
# Handle both qualified (model.measure) and unqualified references
if "." in m:
model_name, measure_name = m.split(".", 1)
model = self.graph.get_model(model_name)
metric = model.get_metric(measure_name) if model else None
# Use just the measure name as the alias (not model.measure)
metric_alias = measure_name
else:
metric = self.graph.get_metric(m)
metric_alias = m
if not metric or (not metric.sql and not metric.window_expression):
continue

Expand Down Expand Up @@ -2107,7 +2128,7 @@ def _generate_with_window_functions(
if metric.window_expression:
order_col = time_dim
frame = metric.window_frame or "ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"
window_expr = f"{metric.window_expression} OVER (ORDER BY {order_col} {frame}) AS {m}"
window_expr = f"{metric.window_expression} OVER (ORDER BY {order_col} {frame}) AS {metric_alias}"
select_exprs.append(window_expr)
continue

Expand All @@ -2118,8 +2139,22 @@ def _generate_with_window_functions(
# It's a direct measure reference - extract just the measure name
base_alias = base_ref.split(".")[1]
else:
# It's a metric reference - check if it exists and get its underlying measure
base_metric = self.graph.get_metric(base_ref)
# It's an unqualified reference - check model first, then graph-level
base_metric = None
# Get model name from the cumulative metric reference
cum_model_name = m.split(".")[0] if "." in m else None
if cum_model_name:
cum_model = self.graph.get_model(cum_model_name)
if cum_model:
base_metric = cum_model.get_metric(base_ref)

# Fallback to graph-level metric
if not base_metric:
try:
base_metric = self.graph.get_metric(base_ref)
except KeyError:
pass

if base_metric and base_metric.sql:
# Use the underlying measure name
if "." in base_metric.sql:
Expand All @@ -2145,20 +2180,20 @@ def _generate_with_window_functions(
grain = metric.grain_to_date
partition = self._date_trunc(grain, time_dim)

window_expr = f"{agg_func}({base_col}) OVER (PARTITION BY {partition} ORDER BY {time_dim} ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS {m}"
window_expr = f"{agg_func}({base_col}) OVER (PARTITION BY {partition} ORDER BY {time_dim} ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS {metric_alias}"
elif metric.window:
# Parse window (e.g., "7 days")
window_parts = metric.window.split()
if len(window_parts) == 2:
num, unit = window_parts
# For date-based windows, use RANGE
window_expr = f"{agg_func}({base_col}) OVER (ORDER BY {time_dim} RANGE BETWEEN INTERVAL '{num} {unit}' PRECEDING AND CURRENT ROW) AS {m}"
window_expr = f"{agg_func}({base_col}) OVER (ORDER BY {time_dim} RANGE BETWEEN INTERVAL '{num} {unit}' PRECEDING AND CURRENT ROW) AS {metric_alias}"
else:
# Fallback to rows
window_expr = f"{agg_func}({base_col}) OVER (ORDER BY {time_dim} ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS {m}"
window_expr = f"{agg_func}({base_col}) OVER (ORDER BY {time_dim} ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS {metric_alias}"
else:
# Running total (unbounded window)
window_expr = f"{agg_func}({base_col}) OVER (ORDER BY {time_dim} ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS {m}"
window_expr = f"{agg_func}({base_col}) OVER (ORDER BY {time_dim} ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS {metric_alias}"

select_exprs.append(window_expr)

Expand Down
Loading
Loading