diff --git a/sidemantic/sql/generator.py b/sidemantic/sql/generator.py index 90c7d7e2..83a14224 100644 --- a/sidemantic/sql/generator.py +++ b/sidemantic/sql/generator.py @@ -242,10 +242,14 @@ def metric_needs_window(m): except Exception: pass + # Extract columns needed for metric-level filters (before building CTEs) + metric_filter_cols_by_model = self._extract_metric_filter_columns(metrics) + # Build CTEs for all models with pushed-down filters cte_sqls = [] for model_name in all_models: model_filters = pushdown_filters.get(model_name, []) + metric_filter_cols = metric_filter_cols_by_model.get(model_name) cte_sql = self._build_model_cte( model_name, parsed_dims, @@ -253,6 +257,7 @@ def metric_needs_window(m): model_filters if model_filters else None, order_by=order_by, all_models=all_models, + metric_filter_columns=metric_filter_cols, ) cte_sqls.append(cte_sql) @@ -472,12 +477,70 @@ def _classify_filters_for_pushdown( return pushdown_filters, main_query_filters + def _extract_metric_filter_columns(self, metrics: list[str]) -> dict[str, set[str]]: + """Extract columns referenced in metric-level filters. + + Args: + metrics: List of metric references (e.g., ["orders.revenue", "bookings.gross_value"]) + + Returns: + Dict mapping model_name -> set of column names needed for metric filters + """ + columns_by_model: dict[str, set[str]] = {} + + for metric_ref in metrics: + if "." in metric_ref: + # model.measure format + model_name, measure_name = metric_ref.split(".") + model = self.graph.get_model(model_name) + if model: + measure = model.get_metric(measure_name) + if measure and measure.filters: + if model_name not in columns_by_model: + columns_by_model[model_name] = set() + for f in measure.filters: + # Replace {model} placeholder for parsing + aliased_filter = f.replace("{model}", f"{model_name}_cte") + try: + parsed = sqlglot.parse_one(aliased_filter, dialect=self.dialect) + for col in parsed.find_all(exp.Column): + if col.table and col.table.replace("_cte", "") == model_name: + columns_by_model[model_name].add(col.name) + except Exception: + pass + else: + # Just metric name - try graph-level metric + try: + metric = self.graph.get_metric(metric_ref) + if metric and metric.filters: + deps = metric.get_dependencies(self.graph) + for dep in deps: + if "." in dep: + dep_model_name = dep.split(".")[0] + if dep_model_name not in columns_by_model: + columns_by_model[dep_model_name] = set() + for f in metric.filters: + aliased_filter = f.replace("{model}", f"{dep_model_name}_cte") + try: + parsed = sqlglot.parse_one(aliased_filter, dialect=self.dialect) + for col in parsed.find_all(exp.Column): + if col.table and col.table.replace("_cte", "") == dep_model_name: + columns_by_model[dep_model_name].add(col.name) + except Exception: + pass + break # Only use first dependency's model + except KeyError: + pass + + return columns_by_model + def _find_needed_dimensions( self, model_name: str, dimensions: list[tuple[str, str | None]], filters: list[str] | None, order_by: list[str] | None, + metric_filter_columns: set[str] | None = None, ) -> set[str]: """Find which dimensions from this model are actually needed. @@ -486,6 +549,7 @@ def _find_needed_dimensions( dimensions: Parsed dimension references from query filters: Filter expressions order_by: Order by fields + metric_filter_columns: Columns needed for metric-level filters Returns: Set of dimension names needed for this model @@ -519,6 +583,10 @@ def _find_needed_dimensions( if model_part == model_name: needed.add(dim_part) + # Columns needed for metric-level filters + if metric_filter_columns: + needed.update(metric_filter_columns) + return needed def _build_model_cte( @@ -529,6 +597,7 @@ def _build_model_cte( filters: list[str] | None = None, order_by: list[str] | None = None, all_models: set[str] | None = None, + metric_filter_columns: set[str] | None = None, ) -> str: """Build CTE SQL for a model with optional filter pushdown. @@ -539,6 +608,7 @@ def _build_model_cte( filters: Filters to push down into this CTE (optional) order_by: Order by fields (for determining needed dimensions) all_models: All models in query (for determining if joins needed) + metric_filter_columns: Columns needed for metric-level filters Returns: CTE SQL string @@ -548,7 +618,9 @@ def _build_model_cte( needs_joins = len(all_models) > 1 # Find which dimensions are actually needed - needed_dimensions = self._find_needed_dimensions(model_name, dimensions, filters, order_by) + needed_dimensions = self._find_needed_dimensions( + model_name, dimensions, filters, order_by, metric_filter_columns + ) # Build SELECT columns select_cols = [] diff --git a/tests/metrics/test_filters.py b/tests/metrics/test_filters.py index f501d484..b0af1d05 100644 --- a/tests/metrics/test_filters.py +++ b/tests/metrics/test_filters.py @@ -159,3 +159,87 @@ def test_metric_filter_with_time_dimension(layer): # Should contain both filters assert "orders_cte.status = 'completed'" in sql assert "CURRENT_DATE - 30" in sql or "CURRENT_DATE-30" in sql # SQLGlot might format differently + + +def test_metric_filter_column_not_in_query_dimensions(layer): + """Test that metric filter columns are included in CTE even when not in query dimensions. + + This is a regression test for the bug where filters like: + filters: ["state IN ('confirmed', 'completed')"] + would fail because the 'state' column wasn't added to the CTE SELECT list + when 'state' wasn't explicitly requested as a dimension in the query. + """ + bookings = Model( + name="bookings", + table="wide_bookings", + primary_key="booking_id", + dimensions=[ + Dimension(name="state", type="categorical"), + Dimension(name="region", type="categorical"), + ], + metrics=[ + Metric( + name="gross_booking_value", + agg="sum", + sql="gross_booking_value", + filters=["{model}.state IN ('confirmed', 'completed', 'cancelled')"], + ), + ], + ) + + layer.add_model(bookings) + + # Query the metric WITHOUT including 'state' in dimensions + # This should still work because the filter needs 'state' in the CTE + sql = layer.compile(metrics=["bookings.gross_booking_value"], dimensions=["bookings.region"]) + + print("Generated SQL:") + print(sql) + + # The 'state' column must be in the CTE for the filter to work + # Check that state appears in the CTE SELECT (before FROM) + cte_match = sql.split("FROM")[0] # Get the CTE SELECT part + assert "state" in cte_match, f"'state' column should be in CTE SELECT for filter to work. CTE: {cte_match}" + + # The filter should be in the WHERE clause + assert ( + "state IN ('confirmed', 'completed', 'cancelled')" in sql + or "state IN ('cancelled', 'completed', 'confirmed')" in sql + ) # Order might vary + + +def test_metric_filter_multiple_columns_not_in_dimensions(layer): + """Test multiple filter columns are included in CTE when not in query dimensions.""" + orders = Model( + name="orders", + table="orders_table", + primary_key="order_id", + dimensions=[ + Dimension(name="status", type="categorical"), + Dimension(name="payment_method", type="categorical"), + Dimension(name="region", type="categorical"), + ], + metrics=[ + Metric( + name="card_completed_revenue", + agg="sum", + sql="amount", + filters=[ + "{model}.status = 'completed'", + "{model}.payment_method IN ('visa', 'mastercard')", + ], + ), + ], + ) + + layer.add_model(orders) + + # Query with only 'region' dimension - both status and payment_method need to be in CTE + sql = layer.compile(metrics=["orders.card_completed_revenue"], dimensions=["orders.region"]) + + print("Generated SQL:") + print(sql) + + cte_match = sql.split("FROM")[0] + assert "status" in cte_match, f"'status' should be in CTE SELECT. CTE: {cte_match}" + assert "payment_method" in cte_match, f"'payment_method' should be in CTE SELECT. CTE: {cte_match}"