Skip to content

Commit

Permalink
Merge 3d5fefa into 1776726
Browse files Browse the repository at this point in the history
  • Loading branch information
bmyerz committed Mar 4, 2016
2 parents 1776726 + 3d5fefa commit b4887f6
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 7 deletions.
5 changes: 3 additions & 2 deletions raco/backends/myria/myria.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from raco.algebra import Shuffle
from raco.algebra import convertcondition
from raco.backends import Language, Algebra
from raco.backends.sql.catalog import SQLCatalog
from raco.backends.sql.catalog import SQLCatalog, PostgresSQLFunctionProvider
from raco.catalog import Catalog
from raco.datastructure.UnionFind import UnionFind
from raco.expression import UnnamedAttributeRef
Expand Down Expand Up @@ -1478,7 +1478,8 @@ def __init__(self, dialect=None, push_grouping=False):
def fire(self, expr):
if isinstance(expr, (algebra.Scan, algebra.ScanTemp)):
return expr
cat = SQLCatalog(push_grouping=self.push_grouping)
cat = SQLCatalog(provider=PostgresSQLFunctionProvider(),
push_grouping=self.push_grouping)
try:
sql_plan = cat.get_sql(expr)
sql_string = sql_plan.compile(dialect=self.dialect)
Expand Down
46 changes: 41 additions & 5 deletions raco/backends/sql/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import raco.scheme as scheme
import raco.types as types
from raco.representation import RepresentationProperties
import abc


type_to_raco = {Integer: types.LONG_TYPE,
Expand All @@ -30,10 +31,44 @@
types.DATETIME_TYPE: DateTime}


class SQLFunctionProvider(object):
"""Interface for translating function names. For Raco functions
not understood by SQLAlchemy, like stdev, we cannot rely
on SQLAlchemy's compiler to translate function
names to the given dialect.
For functions not understood by SQLAlchemy, the SQLAlchemy compiler
just emits them verbatim."""

@abc.abstractmethod
def convert_unary_expr(self, expr, input):
pass


class _DefaultSQLFunctionProvider(SQLFunctionProvider):
def convert_unary_expr(self, expr, input):
# just use the function name without complaining
fname = expr.__class__.__name__.lower()
return getattr(func, fname)(input)


class PostgresSQLFunctionProvider(SQLFunctionProvider):
def convert_unary_expr(self, expr, input):
fname = expr.__class__.__name__.lower()

# replacements
if fname == "stdev":
return func.stddev_samp(input)

# Warning: may create some functions not available in Postgres
return getattr(func, fname)(input)


class SQLCatalog(Catalog):
def __init__(self, engine=None, push_grouping=False):
def __init__(self, engine=None, push_grouping=False,
provider=_DefaultSQLFunctionProvider()):
self.engine = engine
self.push_grouping = push_grouping
self.provider = provider
self.metadata = MetaData()

@staticmethod
Expand Down Expand Up @@ -107,10 +142,11 @@ def _convert_zeroary_expr(self, cols, expr, input_scheme):

def _convert_unary_expr(self, cols, expr, input_scheme):
input = self._convert_expr(cols, expr.input, input_scheme)
if isinstance(expr, expression.MAX):
return func.max(input)
if isinstance(expr, expression.MIN):
return func.min(input)

c = self.provider.convert_unary_expr(expr, input)
if c is not None:
return c

raise NotImplementedError("expression {} to sql".format(type(expr)))

def _convert_binary_expr(self, cols, expr, input_scheme):
Expand Down
49 changes: 49 additions & 0 deletions raco/myrial/optimizer_tests.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import collections
import random
import sys
import re

from raco.algebra import *
from raco.expression import NamedAttributeRef as AttRef
from raco.expression import UnnamedAttributeRef as AttIndex
from raco.expression import StateVar
from raco.expression import aggregate

from raco.backends.myria import (
MyriaShuffleConsumer, MyriaShuffleProducer, MyriaHyperShuffleProducer,
Expand Down Expand Up @@ -1125,3 +1127,50 @@ def test_push_half_groupby_into_sql(self):
expected = dict(((k, v), 1) for k, v in temp.items())

self.assertEquals(result, expected)

def _check_aggregate_functions_pushed(
self,
func,
expected,
override=False):
if override:
agg = func
else:
agg = "{func}(r.i)".format(func=func)

query = """
r = scan({part});
t = select r.h, {agg} from r;
store(t, OUTPUT);""".format(part=self.part_key, agg=agg)

lp = self.get_logical_plan(query)
pp = self.logical_to_physical(lp, push_sql=True,
push_sql_grouping=True)

self.assertEquals(self.get_count(pp, MyriaQueryScan), 1)

for op in pp.walk():
if isinstance(op, MyriaQueryScan):
self.assertTrue(re.search(expected, op.sql))

def test_aggregate_AVG_pushed(self):
"""AVG is translated properly for postgresql. This is
a function not in SQLAlchemy"""
self._check_aggregate_functions_pushed(
aggregate.AVG.__name__, 'avg')

def test_aggregate_STDDEV_pushed(self):
"""STDEV is translated properly for postgresql. This is
a function that is named differently in Raco and postgresql"""
self._check_aggregate_functions_pushed(
aggregate.STDEV.__name__, 'stddev_samp')

def test_aggregate_COUNTALL_pushed(self):
"""COUNTALL is translated properly for postgresql. This is
a function that is expressed differently in Raco and postgresql"""

# MyriaL parses count(*) to Raco COUNTALL. And COUNTALL
# should currently (under the no nulls semantics of Raco/Myria)
# translate to COUNT(something)
self._check_aggregate_functions_pushed(
'count(*)', r'count[(][a-zA-Z.]+[)]', True)

0 comments on commit b4887f6

Please sign in to comment.