Skip to content

Commit

Permalink
Merge branch 'master' into myria-unique-aggs
Browse files Browse the repository at this point in the history
  • Loading branch information
dhalperi committed Sep 22, 2014
2 parents 78e7158 + c457ef2 commit ea226a8
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 7 deletions.
45 changes: 38 additions & 7 deletions raco/myrial/parser.py
Expand Up @@ -234,26 +234,51 @@ def check_name(name):
Parser.decomposable_aggs[logical] = da

@staticmethod
def add_udf(p, name, args, body_expr):
def add_nary_udf(p, name, args, emitters):
"""Add an n-ary user-defined function to the global function table.
:param p: The parser context
:param name: The name of the function
:type name: string
:param args: A list of function arguments
:type args: list of strings
:param emitter: The output expression(s)
:type body_expr: A list of NaryEmitArg instances
"""
if not all(isinstance(e, emitarg.NaryEmitArg) for e in emitters):
raise IllegalWildcardException(name, p.lineno(0))
if sum(len(x.sexprs) for x in emitters) != len(emitters):
raise NestedTupleExpressionException(p.lineno(0))
emit_exprs = [e.sexprs[0] for e in emitters]
Parser.add_udf(p, name, args, emit_exprs)

@staticmethod
def add_udf(p, name, args, body_exprs):
"""Add a user-defined function to the global function table.
:param p: The parser context
:param name: The name of the function
:type name: string
:param args: A list of function arguments
:type args: list of strings
:param body_expr: A scalar expression containing the function body
:type body_expr: raco.expression.Expression
:param body_exprs: A list of scalar expressions containing the body
:type body_exprs: list of raco.expression.Expression
"""
if name in Parser.udf_functions:
raise DuplicateFunctionDefinitionException(name, p.lineno(0))

if len(args) != len(set(args)):
raise DuplicateVariableException(name, p.lineno(0))

Parser.check_for_undefined(p, name, body_expr, args)
if len(body_exprs) == 1:
emit_op = body_exprs[0]
else:
emit_op = TupleExpression(body_exprs)

Parser.check_for_undefined(p, name, emit_op, args)

Parser.udf_functions[name] = Function(args, body_expr)
Parser.udf_functions[name] = Function(args, emit_op)
return emit_op

@staticmethod
def mangle(name):
Expand Down Expand Up @@ -355,13 +380,19 @@ def p_unreserved_id_list(p):
@staticmethod
def p_udf(p):
"""udf : DEF unreserved_id LPAREN optional_arg_list RPAREN COLON sexpr SEMI""" # noqa
Parser.add_udf(p, p[2], p[4], p[7])
Parser.add_udf(p, p[2], p[4], [p[7]])
p[0] = None

@staticmethod
def p_nary_udf(p):
"""udf : DEF unreserved_id LPAREN optional_arg_list RPAREN COLON table_literal SEMI""" # noqa
Parser.add_nary_udf(p, p[2], p[4], p[7])
p[0] = None

@staticmethod
def p_constant(p):
"""constant : CONST unreserved_id COLON sexpr SEMI"""
Parser.add_udf(p, p[2], [], p[4])
Parser.add_udf(p, p[2], [], [p[4]])
p[0] = None

@staticmethod
Expand Down
46 changes: 46 additions & 0 deletions raco/myrial/query_tests.py
Expand Up @@ -1245,6 +1245,52 @@ def test_duplicate_variable_udf(self):
with self.assertRaises(DuplicateVariableException):
self.check_result(query, collections.Counter())

def test_nary_udf(self):
query = """
DEF Foo(a,b): [a + b, a - b];
out = [FROM SCAN(%s) AS X EMIT id, Foo(salary, dept_id) as [x, y]];
STORE(out, OUTPUT);
""" % self.emp_key

expected = collections.Counter([(t[0], t[1] + t[3], t[3] - t[1])
for t in self.emp_table])
self.check_result(query, expected)

def test_nary_udf_name_count(self):
query = """
DEF Foo(a,b): [a + b, a - b];
out = [FROM SCAN(%s) AS X EMIT id, Foo(salary, dept_id) as [x, y, z]];
STORE(out, OUTPUT);
""" % self.emp_key

with self.assertRaises(IllegalColumnNamesException):
self.check_result(query, None)

def test_nary_udf_illegal_nesting(self):
query = """
DEF Foo(x): [x + 3, x - 3];
DEF Bar(a,b): [Foo(x), Foo(b)];
out = [FROM SCAN(%s) AS X EMIT id, Bar(salary, dept_id) as [x, y]];
STORE(out, OUTPUT);
""" % self.emp_key

with self.assertRaises(NestedTupleExpressionException):
self.check_result(query, None)

def test_nary_udf_illegal_wildcard(self):
query = """
DEF Foo(x): [x + 3, *];
out = [FROM SCAN(%s) AS X EMIT id, Foo(salary, dept_id) as [x, y]];
STORE(out, OUTPUT);
""" % self.emp_key

with self.assertRaises(IllegalWildcardException):
self.check_result(query, None)

def test_triangle_udf(self):
query = """
DEF Triangle(a,b): (a*b)//2;
Expand Down

0 comments on commit ea226a8

Please sign in to comment.