Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): Implement Predicate Pushdown for Calls #1712

Merged
merged 10 commits into from
Jun 3, 2024
182 changes: 132 additions & 50 deletions weave/trace_server/clickhouse_trace_server_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,31 +269,37 @@ def calls_query_stats(self, req: tsi.CallsQueryStatsReq) -> tsi.CallsQueryStatsR
"""Returns a stats object for the given query. This is useful for counts or other
aggregate statistics that are not directly queryable from the calls themselves.
"""
conditions = []
having_conditions = []
start_event_conditions = []
end_event_conditions = []
param_builder = ParamBuilder()
raw_fields_used = set()

# First, apply the application filter
if req.filter:
filter_conds, fields_used = _process_calls_filter_to_conditions(
filter_to_conditions = _process_calls_filter_to_conditions(
req.filter, param_builder
)
raw_fields_used.update(fields_used)
conditions.extend(filter_conds)
raw_fields_used.update(filter_to_conditions.fields_used)
having_conditions.extend(filter_to_conditions.having_conditions)
start_event_conditions.extend(filter_to_conditions.start_event_conditions)
end_event_conditions.extend(filter_to_conditions.end_event_conditions)

# Next, apply the query filter
if req.query:
query_conds, fields_used = _process_query_to_conditions(
having_query_conds, fields_used = _process_query_to_conditions(
req.query, all_call_select_columns, all_call_json_columns, param_builder
)
raw_fields_used.update(fields_used)
conditions.extend(query_conds)
having_conditions.extend(having_query_conds)

# Perform the query against the database
stats = self._calls_query_stats_raw(
req.project_id,
columns=list(raw_fields_used),
conditions=conditions,
start_event_conditions=start_event_conditions,
end_event_conditions=end_event_conditions,
having_conditions=having_conditions,
parameters=param_builder.get_params(),
)

Expand All @@ -304,27 +310,33 @@ def calls_query_stream(
self, req: tsi.CallsQueryReq
) -> typing.Iterator[tsi.CallSchema]:
"""Returns a stream of calls that match the given query."""
conditions = []
having_conditions = []
start_event_conditions = []
end_event_conditions = []
param_builder = ParamBuilder()

# First, apply the application filter
if req.filter:
(filter_conds, _) = _process_calls_filter_to_conditions(
filter_to_conditions = _process_calls_filter_to_conditions(
req.filter, param_builder
)
conditions.extend(filter_conds)
having_conditions.extend(filter_to_conditions.having_conditions)
start_event_conditions.extend(filter_to_conditions.start_event_conditions)
end_event_conditions.extend(filter_to_conditions.end_event_conditions)

# Next, apply the query filter
if req.query:
query_conds, _ = _process_query_to_conditions(
having_query_conds, _ = _process_query_to_conditions(
req.query, all_call_select_columns, all_call_json_columns, param_builder
)
conditions.extend(query_conds)
having_conditions.extend(having_query_conds)

# Perform the query against the database
ch_call_dicts = self._select_calls_query_raw(
req.project_id,
conditions=conditions,
start_event_conditions=start_event_conditions,
end_event_conditions=end_event_conditions,
having_conditions=having_conditions,
parameters=param_builder.get_params(),
limit=req.limit,
offset=req.offset,
Expand All @@ -350,20 +362,27 @@ def calls_delete(self, req: tsi.CallsDeleteReq) -> tsi.CallsDeleteRes:
f"Cannot delete more than {MAX_DELETE_CALLS_COUNT} calls at once"
)

proj_cond = "project_id = {project_id: String}"
# Note: i think this project condition is redundant
proj_cond = "calls_merged.project_id = {project_id: String}"
proj_params = {"project_id": req.project_id}

# get all parents
parents = self._select_calls_query(
req.project_id,
conditions=[proj_cond, "id IN {ids: Array(String)}"],
start_event_conditions=[
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes should make deletion a bit faster

proj_cond,
"calls_merged.id IN {ids: Array(String)}",
],
parameters=proj_params | {"ids": req.call_ids},
)

# get all calls with trace_ids matching parents
all_calls = self._select_calls_query(
req.project_id,
conditions=[proj_cond, "trace_id IN {trace_ids: Array(String)}"],
start_event_conditions=[
proj_cond,
"calls_merged.trace_id IN {trace_ids: Array(String)}",
],
parameters=proj_params | {"trace_ids": [p.trace_id for p in parents]},
)

Expand Down Expand Up @@ -921,7 +940,7 @@ def _call_read(self, req: tsi.CallReadReq) -> SelectableCHCallSchema:
# Generate and run the query to get the call from the database
ch_calls = self._select_calls_query(
req.project_id,
conditions=["id = {id: String}"],
start_event_conditions=["calls_merged.id = {id: String}"],
limit=1,
parameters={"id": req.id},
)
Expand All @@ -936,7 +955,9 @@ def _select_calls_query(
self,
project_id: str,
columns: typing.Optional[typing.List[str]] = None,
conditions: typing.Optional[typing.List[str]] = None,
start_event_conditions: typing.Optional[typing.List[str]] = None,
end_event_conditions: typing.Optional[typing.List[str]] = None,
having_conditions: typing.Optional[typing.List[str]] = None,
order_by: typing.Optional[typing.List[typing.Tuple[str, str]]] = None,
offset: typing.Optional[int] = None,
limit: typing.Optional[int] = None,
Expand All @@ -945,7 +966,9 @@ def _select_calls_query(
dicts = self._select_calls_query_raw(
project_id,
columns=columns,
conditions=conditions,
start_event_conditions=start_event_conditions,
end_event_conditions=end_event_conditions,
having_conditions=having_conditions,
order_by=order_by,
offset=offset,
limit=limit,
Expand All @@ -960,7 +983,9 @@ def _select_calls_query_raw(
self,
project_id: str,
columns: typing.Optional[typing.List[str]] = None,
conditions: typing.Optional[typing.List[str]] = None,
start_event_conditions: typing.Optional[typing.List[str]] = None,
end_event_conditions: typing.Optional[typing.List[str]] = None,
having_conditions: typing.Optional[typing.List[str]] = None,
order_by: typing.Optional[typing.List[typing.Tuple[str, str]]] = None,
offset: typing.Optional[int] = None,
limit: typing.Optional[int] = None,
Expand Down Expand Up @@ -992,10 +1017,15 @@ def _select_calls_query_raw(
merged_cols.append(f"any({col}) AS {col}")
select_columns_part = ", ".join(merged_cols)

if not conditions:
conditions = ["1 = 1"]
if not having_conditions:
having_conditions = ["1 = 1"]

conditions_part = _combine_conditions(conditions, "AND")
having_conditions_part = _combine_conditions(having_conditions, "AND")

where_conditions_part = _make_calls_where_condition_from_event_conditions(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is one of the new clauses

start_event_conditions=start_event_conditions,
end_event_conditions=end_event_conditions,
)

order_by_part = "ORDER BY started_at ASC"
if order_by is not None:
Expand Down Expand Up @@ -1053,9 +1083,10 @@ def _select_calls_query_raw(
SELECT {select_columns_part}
FROM calls_merged
WHERE project_id = {{project_id: String}}
AND {where_conditions_part}
GROUP BY project_id, id
HAVING deleted_at IS NULL AND
{conditions_part}
{having_conditions_part}
{order_by_part}
{limit_part}
{offset_part}
Expand All @@ -1072,7 +1103,9 @@ def _calls_query_stats_raw(
self,
project_id: str,
columns: typing.Optional[typing.List[str]] = None,
conditions: typing.Optional[typing.List[str]] = None,
start_event_conditions: typing.Optional[typing.List[str]] = None,
end_event_conditions: typing.Optional[typing.List[str]] = None,
having_conditions: typing.Optional[typing.List[str]] = None,
parameters: typing.Optional[typing.Dict[str, typing.Any]] = None,
) -> typing.Dict:
"""Generates and executes a query to get stats for a calls query."""
Expand All @@ -1082,10 +1115,15 @@ def _calls_query_stats_raw(

parameters["project_id"] = project_id

if not conditions:
conditions = ["1 = 1"]
if not having_conditions:
having_conditions = ["1 = 1"]

conditions_part = _combine_conditions(conditions, "AND")
having_conditions_part = _combine_conditions(having_conditions, "AND")

where_conditions_part = _make_calls_where_condition_from_event_conditions(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another new clause

start_event_conditions=start_event_conditions,
end_event_conditions=end_event_conditions,
)

if columns == None:
columns = ["id"]
Expand All @@ -1112,8 +1150,9 @@ def _calls_query_stats_raw(
SELECT {select_columns_part}
FROM calls_merged
WHERE project_id = {{project_id: String}}
AND {where_conditions_part}
GROUP BY project_id, id
HAVING {conditions_part}
HAVING {having_conditions_part}
)
"""

Expand Down Expand Up @@ -1706,13 +1745,22 @@ def _transform_external_field_to_internal_field(
return field, param_builder, raw_fields_used


class FilterToConditions(BaseModel):
having_conditions: list[str]
start_event_conditions: list[str]
end_event_conditions: list[str]
fields_used: set[str]


def _process_calls_filter_to_conditions(
filter: tsi._CallsFilter,
param_builder: typing.Optional[ParamBuilder] = None,
) -> tuple[list[str], set[str]]:
) -> FilterToConditions:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method has been refactored to return specific conditions where possible, using correct namespacing

"""Converts a CallsFilter to a list of conditions for a clickhouse query."""
param_builder = param_builder or ParamBuilder()
conditions = []
having_conditions: list[str] = []
start_event_conditions: list[str] = []
end_event_conditions: list[str] = []
raw_fields_used = set()

if filter.op_names:
Expand All @@ -1732,69 +1780,75 @@ def _process_calls_filter_to_conditions(

if non_wildcarded_names:
or_conditions.append(
f"op_name IN {_param_slot(param_builder.add_param(non_wildcarded_names), 'Array(String)')}"
f"calls_merged.op_name IN {_param_slot(param_builder.add_param(non_wildcarded_names), 'Array(String)')}"
)
raw_fields_used.add("op_name")

for name in wildcarded_names:
like_name = name[: -len(WILDCARD_ARTIFACT_VERSION_AND_PATH)] + ":%"
or_conditions.append(
f"op_name LIKE {_param_slot(param_builder.add_param(like_name), 'String')}"
f"calls_merged.op_name LIKE {_param_slot(param_builder.add_param(like_name), 'String')}"
)
raw_fields_used.add("op_name")

if or_conditions:
conditions.append(_combine_conditions(or_conditions, "OR"))
start_event_conditions.append(_combine_conditions(or_conditions, "OR"))

if filter.input_refs:
conditions.append(
f"hasAny(input_refs, {_param_slot(param_builder.add_param(filter.input_refs), 'Array(String)')})"
start_event_conditions.append(
f"hasAny(calls_merged.input_refs, {_param_slot(param_builder.add_param(filter.input_refs), 'Array(String)')})"
)
raw_fields_used.add("input_refs")

if filter.output_refs:
conditions.append(
f"hasAny(output_refs, {_param_slot(param_builder.add_param(filter.output_refs), 'Array(String)')})"
end_event_conditions.append(
f"hasAny(calls_merged.output_refs, {_param_slot(param_builder.add_param(filter.output_refs), 'Array(String)')})"
)
raw_fields_used.add("output_refs")

if filter.parent_ids:
conditions.append(
f"parent_id IN {_param_slot(param_builder.add_param(filter.parent_ids), 'Array(String)')}"
start_event_conditions.append(
f"calls_merged.parent_id IN {_param_slot(param_builder.add_param(filter.parent_ids), 'Array(String)')}"
)
raw_fields_used.add("parent_id")

if filter.trace_ids:
conditions.append(
f"trace_id IN {_param_slot(param_builder.add_param(filter.trace_ids), 'Array(String)')}"
start_event_conditions.append(
f"calls_merged.trace_id IN {_param_slot(param_builder.add_param(filter.trace_ids), 'Array(String)')}"
)
raw_fields_used.add("trace_id")

if filter.call_ids:
conditions.append(
f"id IN {_param_slot(param_builder.add_param(filter.call_ids), 'Array(String)')}"
start_event_conditions.append(
f"calls_merged.id IN {_param_slot(param_builder.add_param(filter.call_ids), 'Array(String)')}"
)
raw_fields_used.add("id")

if filter.trace_roots_only:
conditions.append("parent_id IS NULL")
start_event_conditions.append("calls_merged.parent_id IS NULL")
raw_fields_used.add("parent_id")

if filter.wb_user_ids:
conditions.append(
f"wb_user_id IN {_param_slot(param_builder.add_param(filter.wb_user_ids), 'Array(String)')})"
start_event_conditions.append(
f"calls_merged.wb_user_id IN {_param_slot(param_builder.add_param(filter.wb_user_ids), 'Array(String)')})"
)
raw_fields_used.add("wb_user_id")

if filter.wb_run_ids:
conditions.append(
f"wb_run_id IN {_param_slot(param_builder.add_param(filter.wb_run_ids), 'Array(String)')})"
start_event_conditions.append(
f"calls_merged.wb_run_id IN {_param_slot(param_builder.add_param(filter.wb_run_ids), 'Array(String)')})"
)
raw_fields_used.add("wb_run_id")

return conditions, raw_fields_used
return FilterToConditions(
having_conditions=having_conditions,
start_event_conditions=start_event_conditions,
end_event_conditions=end_event_conditions,
fields_used=raw_fields_used,
)


# TODO: Implement predicate pushdown just like `_process_calls_filter_to_conditions`
def _process_query_to_conditions(
query: tsi.Query,
all_columns: typing.Sequence[str],
Expand Down Expand Up @@ -1898,3 +1952,31 @@ def process_operand(operand: tsi_query.Operand) -> str:
conditions.append(filter_cond)

return conditions, raw_fields_used


def _make_calls_where_condition_from_event_conditions(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create an aggregate filter

start_event_conditions: typing.Optional[typing.List[str]] = None,
end_event_conditions: typing.Optional[typing.List[str]] = None,
) -> str:
event_conds = []
if start_event_conditions is not None and len(start_event_conditions) > 0:
conds = _combine_conditions(
["isNotNull(started_at)", *start_event_conditions], "AND"
)
event_conds.append(
f"calls_merged.id IN (SELECT id FROM calls_merged WHERE {conds})"
)

if end_event_conditions is not None and len(end_event_conditions) > 0:
conds = _combine_conditions(
["isNotNull(ended_at)", *end_event_conditions], "AND"
)
event_conds.append(
f"calls_merged.id IN (SELECT id FROM calls_merged WHERE {conds})"
)

where_conditions_part = "(1 = 1)"
if event_conds:
where_conditions_part = _combine_conditions(event_conds, "AND")

return where_conditions_part
Loading