Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion sidemantic/sql/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,17 +242,22 @@ def metric_needs_window(m):
except Exception:
pass

# Extract columns needed for metric-level filters (before building CTEs)
metric_filter_cols_by_model = self._extract_metric_filter_columns(metrics)

# Build CTEs for all models with pushed-down filters
cte_sqls = []
for model_name in all_models:
model_filters = pushdown_filters.get(model_name, [])
metric_filter_cols = metric_filter_cols_by_model.get(model_name)
cte_sql = self._build_model_cte(
model_name,
parsed_dims,
metrics,
model_filters if model_filters else None,
order_by=order_by,
all_models=all_models,
metric_filter_columns=metric_filter_cols,
)
cte_sqls.append(cte_sql)

Expand Down Expand Up @@ -472,12 +477,70 @@ def _classify_filters_for_pushdown(

return pushdown_filters, main_query_filters

def _extract_metric_filter_columns(self, metrics: list[str]) -> dict[str, set[str]]:
"""Extract columns referenced in metric-level filters.

Args:
metrics: List of metric references (e.g., ["orders.revenue", "bookings.gross_value"])

Returns:
Dict mapping model_name -> set of column names needed for metric filters
"""
columns_by_model: dict[str, set[str]] = {}

for metric_ref in metrics:
if "." in metric_ref:
# model.measure format
model_name, measure_name = metric_ref.split(".")
model = self.graph.get_model(model_name)
if model:
measure = model.get_metric(measure_name)
if measure and measure.filters:
if model_name not in columns_by_model:
columns_by_model[model_name] = set()
for f in measure.filters:
# Replace {model} placeholder for parsing
aliased_filter = f.replace("{model}", f"{model_name}_cte")
try:
parsed = sqlglot.parse_one(aliased_filter, dialect=self.dialect)
for col in parsed.find_all(exp.Column):
if col.table and col.table.replace("_cte", "") == model_name:
columns_by_model[model_name].add(col.name)
except Exception:
pass
else:
# Just metric name - try graph-level metric
try:
metric = self.graph.get_metric(metric_ref)
if metric and metric.filters:
deps = metric.get_dependencies(self.graph)
for dep in deps:
if "." in dep:
dep_model_name = dep.split(".")[0]
if dep_model_name not in columns_by_model:
columns_by_model[dep_model_name] = set()
for f in metric.filters:
aliased_filter = f.replace("{model}", f"{dep_model_name}_cte")
try:
parsed = sqlglot.parse_one(aliased_filter, dialect=self.dialect)
for col in parsed.find_all(exp.Column):
if col.table and col.table.replace("_cte", "") == dep_model_name:
columns_by_model[dep_model_name].add(col.name)
except Exception:
pass
break # Only use first dependency's model
except KeyError:
pass

return columns_by_model

def _find_needed_dimensions(
self,
model_name: str,
dimensions: list[tuple[str, str | None]],
filters: list[str] | None,
order_by: list[str] | None,
metric_filter_columns: set[str] | None = None,
) -> set[str]:
"""Find which dimensions from this model are actually needed.

Expand All @@ -486,6 +549,7 @@ def _find_needed_dimensions(
dimensions: Parsed dimension references from query
filters: Filter expressions
order_by: Order by fields
metric_filter_columns: Columns needed for metric-level filters

Returns:
Set of dimension names needed for this model
Expand Down Expand Up @@ -519,6 +583,10 @@ def _find_needed_dimensions(
if model_part == model_name:
needed.add(dim_part)

# Columns needed for metric-level filters
if metric_filter_columns:
needed.update(metric_filter_columns)

return needed

def _build_model_cte(
Expand All @@ -529,6 +597,7 @@ def _build_model_cte(
filters: list[str] | None = None,
order_by: list[str] | None = None,
all_models: set[str] | None = None,
metric_filter_columns: set[str] | None = None,
) -> str:
"""Build CTE SQL for a model with optional filter pushdown.

Expand All @@ -539,6 +608,7 @@ def _build_model_cte(
filters: Filters to push down into this CTE (optional)
order_by: Order by fields (for determining needed dimensions)
all_models: All models in query (for determining if joins needed)
metric_filter_columns: Columns needed for metric-level filters

Returns:
CTE SQL string
Expand All @@ -548,7 +618,9 @@ def _build_model_cte(
needs_joins = len(all_models) > 1

# Find which dimensions are actually needed
needed_dimensions = self._find_needed_dimensions(model_name, dimensions, filters, order_by)
needed_dimensions = self._find_needed_dimensions(
model_name, dimensions, filters, order_by, metric_filter_columns
)

# Build SELECT columns
select_cols = []
Expand Down
84 changes: 84 additions & 0 deletions tests/metrics/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,87 @@ def test_metric_filter_with_time_dimension(layer):
# Should contain both filters
assert "orders_cte.status = 'completed'" in sql
assert "CURRENT_DATE - 30" in sql or "CURRENT_DATE-30" in sql # SQLGlot might format differently


def test_metric_filter_column_not_in_query_dimensions(layer):
"""Test that metric filter columns are included in CTE even when not in query dimensions.

This is a regression test for the bug where filters like:
filters: ["state IN ('confirmed', 'completed')"]
would fail because the 'state' column wasn't added to the CTE SELECT list
when 'state' wasn't explicitly requested as a dimension in the query.
"""
bookings = Model(
name="bookings",
table="wide_bookings",
primary_key="booking_id",
dimensions=[
Dimension(name="state", type="categorical"),
Dimension(name="region", type="categorical"),
],
metrics=[
Metric(
name="gross_booking_value",
agg="sum",
sql="gross_booking_value",
filters=["{model}.state IN ('confirmed', 'completed', 'cancelled')"],
),
],
)

layer.add_model(bookings)

# Query the metric WITHOUT including 'state' in dimensions
# This should still work because the filter needs 'state' in the CTE
sql = layer.compile(metrics=["bookings.gross_booking_value"], dimensions=["bookings.region"])

print("Generated SQL:")
print(sql)

# The 'state' column must be in the CTE for the filter to work
# Check that state appears in the CTE SELECT (before FROM)
cte_match = sql.split("FROM")[0] # Get the CTE SELECT part
assert "state" in cte_match, f"'state' column should be in CTE SELECT for filter to work. CTE: {cte_match}"

# The filter should be in the WHERE clause
assert (
"state IN ('confirmed', 'completed', 'cancelled')" in sql
or "state IN ('cancelled', 'completed', 'confirmed')" in sql
) # Order might vary


def test_metric_filter_multiple_columns_not_in_dimensions(layer):
"""Test multiple filter columns are included in CTE when not in query dimensions."""
orders = Model(
name="orders",
table="orders_table",
primary_key="order_id",
dimensions=[
Dimension(name="status", type="categorical"),
Dimension(name="payment_method", type="categorical"),
Dimension(name="region", type="categorical"),
],
metrics=[
Metric(
name="card_completed_revenue",
agg="sum",
sql="amount",
filters=[
"{model}.status = 'completed'",
"{model}.payment_method IN ('visa', 'mastercard')",
],
),
],
)

layer.add_model(orders)

# Query with only 'region' dimension - both status and payment_method need to be in CTE
sql = layer.compile(metrics=["orders.card_completed_revenue"], dimensions=["orders.region"])

print("Generated SQL:")
print(sql)

cte_match = sql.split("FROM")[0]
assert "status" in cte_match, f"'status' should be in CTE SELECT. CTE: {cte_match}"
assert "payment_method" in cte_match, f"'payment_method' should be in CTE SELECT. CTE: {cte_match}"