Skip to content

Commit

Permalink
fix(pat-5034): python - include order-by when computing group-by
Browse files Browse the repository at this point in the history
  • Loading branch information
Ajith Mascarenhas committed Dec 19, 2023
1 parent c06f6a3 commit c0b4f4b
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 67 deletions.
129 changes: 67 additions & 62 deletions static/python/src/backends/dpm_agent/dpm_agent_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import asyncio
import atexit
import base64
import json
import logging
from typing import Dict, List, Union
Expand Down Expand Up @@ -193,6 +190,71 @@ def make_dpm_order_by_expression(ordering) -> DpmAgentQuery.OrderByExpression:
)


def make_dpm_agent_query(query) -> DpmAgentQuery:
"""
Makes a query message from the table expression to send to dpm-agent.
Args:
query: Table expression
Returns:
Query RPC message to send to dpm-agent.
"""
dpm_agent_query = DpmAgentQuery()
dpm_agent_query.id.packageId = query.package_id
dpm_agent_query.clientVersion.CopyFrom(
ClientVersion(
client=ClientVersion.PYTHON,
codeVersion=CODE_VERSION,
datasetVersion=query.dataset_version,
)
)
dpm_agent_query.selectFrom = query.name

filter_expr, selection, ordering, limit_to = (
query.filter_expr,
query.selection,
query.ordering,
query.limit_to,
)

dpm_select_exprs = (
list(map(make_dpm_select_expression, selection)) if selection else None
)
if dpm_select_exprs:
dpm_agent_query.select.extend(dpm_select_exprs)

if filter_expr:
dpm_agent_query.filter.CopyFrom(make_dpm_boolean_expression(filter_expr))

selection = selection or []
ordering = ordering or []
# NB: we cannot use the selection expressions themselves as elements in a
# set as they are not hashable, and their __eq__ magic method has been
# overridden to support their usage in query filter expressions. We use the
# set of field expression names to determine if an order-by expression is
# already present in the selection set.
selection_set = set([x.name for x in selection])
expanded_selection = selection[:]
expanded_selection.extend([x for x, _ in ordering if x.name not in selection_set])
if expanded_selection and any(
isinstance(x, AggregateFieldExpr) for x in expanded_selection
):
grouping = [
x for x in expanded_selection if not isinstance(x, AggregateFieldExpr)
]
if grouping:
dpm_agent_query.groupBy.extend(map(make_dpm_group_by_expression, grouping))

if ordering and len(ordering) > 0:
dpm_agent_query.orderBy.extend(map(make_dpm_order_by_expression, ordering))

if limit_to > 0:
dpm_agent_query.limit = limit_to

return dpm_agent_query


class DpmAgentClient:
"""DpmAgentClient uses a gRPC client to compile and execute queries by using
the `dpm-agent` which routes the queries to the specific source specified in
Expand All @@ -211,63 +273,6 @@ def __init__(
# ValueError: metadata was invalid
self.metadata = [(b"dpm-auth-token", bytes(self.dpm_auth_token, "utf-8"))]

async def make_dpm_agent_query(self, query) -> DpmAgentQuery:
"""
Makes a query message from the table expression to send to dpm-agent.
Args:
query: Table expression
Returns:
Query RPC message to send to dpm-agent.
"""
dpm_agent_query = DpmAgentQuery()
dpm_agent_query.id.packageId = query.package_id
dpm_agent_query.clientVersion.CopyFrom(
ClientVersion(
client=ClientVersion.PYTHON,
codeVersion=CODE_VERSION,
datasetVersion=query.dataset_version,
)
)
dpm_agent_query.selectFrom = query.name

filter_expr, selection, ordering, limit_to = (
query.filter_expr,
query.selection,
query.ordering,
query.limit_to,
)

selections = (
list(map(make_dpm_select_expression, selection)) if selection else None
)
if selections:
dpm_agent_query.select.extend(selections)

if filter_expr:
dpm_agent_query.filter.CopyFrom(make_dpm_boolean_expression(filter_expr))

if selection and any(
isinstance(field_expr, AggregateFieldExpr) for field_expr in selection
):
grouping = filter(
lambda field_expr: not isinstance(field_expr, AggregateFieldExpr),
selection,
)
if grouping:
dpm_agent_query.groupBy.extend(
map(make_dpm_group_by_expression, grouping)
)

if ordering and len(ordering) > 0:
dpm_agent_query.orderBy.extend(map(make_dpm_order_by_expression, ordering))

if limit_to > 0:
dpm_agent_query.limit = limit_to

return dpm_agent_query

async def compile(self, query) -> str:
"""
Compiles table expression using dpm-agent.
Expand All @@ -278,7 +283,7 @@ async def compile(self, query) -> str:
Returns:
Resolves to the compiled query string obtained from dpm-agent, or rejects on error.
"""
dpm_agent_query = await self.make_dpm_agent_query(query)
dpm_agent_query = make_dpm_agent_query(query)
dpm_agent_query.dryRun = True
response = self.client.ExecuteQuery(dpm_agent_query, metadata=self.metadata)
return response.queryString
Expand All @@ -293,7 +298,7 @@ async def execute(self, query) -> List[Dict]:
Returns:
Resolves to the executed query results obtained from dpm-agent, or rejects on error.
"""
dpm_agent_query = await self.make_dpm_agent_query(query)
dpm_agent_query = make_dpm_agent_query(query)
response = self.client.ExecuteQuery(dpm_agent_query, metadata=self.metadata)

try:
Expand Down
4 changes: 1 addition & 3 deletions static/python/src/backends/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import logging
import os
import platform
from enum import Enum
from typing import Optional
from urllib.parse import urlparse

from ..backends.dpm_agent.dpm_agent_client import DpmAgentClient, make_client
from ..backends.dpm_agent.dpm_agent_client import make_client
from .env import get_env
from .interface import Backend

Expand Down
4 changes: 2 additions & 2 deletions static/python/src/field_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
A binary boolean expression. Can be combined with other boolean expressions
using `and`, `or` methods.
"""
super().__init__(field, alias)
super().__init__(f"({op}({field.name}, {other.name}))", alias)
self.field = field
self.op = op
self.other = other
Expand All @@ -105,7 +105,7 @@ def __init__(self, field: FieldExpr, op: UnaryOperator) -> None:
field: The field expression to perform the unary operation on.
op: The unary operator to apply to the field expression.
"""
super().__init__(("(" + op + "(" + field.name + "))"))
super().__init__(f"({op}({field.name}))")
self.field = field
self.op = op

Expand Down
196 changes: 196 additions & 0 deletions static/python/src/test/dpm_query_builder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
from datetime import date
import pytest

from ..table import Table
from ..field import DateField, Field, StringField
from ..version import CODE_VERSION
from ..backends.dpm_agent.dpm_agent_client import make_dpm_agent_query
from ..backends.dpm_agent.dpm_agent_pb2 import ClientVersion, Query as DpmAgentQuery


@pytest.fixture
def id():
return StringField("id")


@pytest.fixture
def name():
return StringField("name")


@pytest.fixture
def price():
return Field("price")


@pytest.fixture
def created_on():
return DateField("created_on")


@pytest.fixture
def table(id, name, price, created_on):
backend = {}
return Table(
backend=backend,
package_id="pkg-123",
dataset_name="ds=456",
dataset_version="0.1.0",
source="test",
name="testTable",
fields=[id, name, price, created_on],
)


def test_returns_expected_query_message_for_query_with_selections(table):
query = table.select("id", "name").limit(10)
dpm_agent_query = make_dpm_agent_query(query)
assert dpm_agent_query

want = DpmAgentQuery(
id={"packageId": "pkg-123"},
clientVersion={
"client": ClientVersion.Client.PYTHON,
"datasetVersion": "0.1.0",
"codeVersion": "0.1.0",
},
selectFrom="testTable",
select=[
{"argument": {"field": {"fieldName": "id"}}},
{"argument": {"field": {"fieldName": "name"}}},
],
limit=10,
)
assert dpm_agent_query == want


def test_returns_expected_query_message_for_query_with_selections_and_filter(
table, name, created_on
):
query = (
table.select("id", "name")
.filter(name.like("%bah%") & (created_on < date(2023, 1, 1)))
.limit(10)
)
dpm_agent_query = make_dpm_agent_query(query)
assert dpm_agent_query

want = DpmAgentQuery(
id={"packageId": "pkg-123"},
clientVersion={
"client": ClientVersion.Client.PYTHON,
"datasetVersion": "0.1.0",
"codeVersion": "0.1.0",
},
selectFrom="testTable",
select=[
{"argument": {"field": {"fieldName": "id"}}},
{"argument": {"field": {"fieldName": "name"}}},
],
filter={
"op": DpmAgentQuery.BooleanExpression.AND,
"arguments": [
{
"condition": {
"op": DpmAgentQuery.BooleanExpression.LIKE,
"arguments": [
{"field": {"fieldName": "name"}},
{"literal": {"string": "%bah%"}},
],
}
},
{
"condition": {
"op": DpmAgentQuery.BooleanExpression.LT,
"arguments": [
{"field": {"fieldName": "created_on"}},
{"literal": {"string": "2023-01-01"}},
],
}
},
],
},
limit=10,
)
assert dpm_agent_query == want


def test_returns_expected_query_message_for_query_with_selections_filter_aggs_order(
table, name, created_on, price
):
query = (
table.select("id", "name", price.avg().with_alias("avg_price"))
.filter(name.like("%bah%") & (created_on < date(2023, 1, 1)))
.order_by(["avg_price", "DESC"], [created_on, "ASC"])
.limit(10)
)
dpm_agent_query = make_dpm_agent_query(query)
assert dpm_agent_query

want = DpmAgentQuery(
id={"packageId": "pkg-123"},
clientVersion={
"client": ClientVersion.Client.PYTHON,
"datasetVersion": "0.1.0",
"codeVersion": "0.1.0",
},
selectFrom="testTable",
select=[
{"argument": {"field": {"fieldName": "id"}}},
{"argument": {"field": {"fieldName": "name"}}},
{
"argument": {
"aggregate": {
"op": DpmAgentQuery.AggregateExpression.MEAN,
"argument": {"field": {"fieldName": "price"}},
}
},
"alias": "avg_price",
},
],
filter={
"op": DpmAgentQuery.BooleanExpression.AND,
"arguments": [
{
"condition": {
"op": DpmAgentQuery.BooleanExpression.LIKE,
"arguments": [
{"field": {"fieldName": "name"}},
{"literal": {"string": "%bah%"}},
],
}
},
{
"condition": {
"op": DpmAgentQuery.BooleanExpression.LT,
"arguments": [
{"field": {"fieldName": "created_on"}},
{"literal": {"string": "2023-01-01"}},
],
}
},
],
},
groupBy=[
{"field": {"fieldName": "id"}},
{"field": {"fieldName": "name"}},
{"field": {"fieldName": "created_on"}},
],
orderBy=[
{
"argument": {
"aggregate": {
"op": DpmAgentQuery.AggregateExpression.MEAN,
"argument": {"field": {"fieldName": "price"}},
},
},
"direction": DpmAgentQuery.OrderByExpression.DESC,
},
{
"argument": {"field": {"fieldName": "created_on"}},
"direction": DpmAgentQuery.OrderByExpression.ASC,
},
],
limit=10,
)
assert dpm_agent_query == want

0 comments on commit c0b4f4b

Please sign in to comment.