diff --git a/raco/myrial/parser.py b/raco/myrial/parser.py index 47f97166..8ccc0349 100644 --- a/raco/myrial/parser.py +++ b/raco/myrial/parser.py @@ -234,7 +234,26 @@ def check_name(name): Parser.decomposable_aggs[logical] = da @staticmethod - def add_udf(p, name, args, body_expr): + def add_nary_udf(p, name, args, emitters): + """Add an n-ary user-defined function to the global function table. + + :param p: The parser context + :param name: The name of the function + :type name: string + :param args: A list of function arguments + :type args: list of strings + :param emitter: The output expression(s) + :type body_expr: A list of NaryEmitArg instances + """ + if not all(isinstance(e, emitarg.NaryEmitArg) for e in emitters): + raise IllegalWildcardException(name, p.lineno(0)) + if sum(len(x.sexprs) for x in emitters) != len(emitters): + raise NestedTupleExpressionException(p.lineno(0)) + emit_exprs = [e.sexprs[0] for e in emitters] + Parser.add_udf(p, name, args, emit_exprs) + + @staticmethod + def add_udf(p, name, args, body_exprs): """Add a user-defined function to the global function table. :param p: The parser context @@ -242,8 +261,8 @@ def add_udf(p, name, args, body_expr): :type name: string :param args: A list of function arguments :type args: list of strings - :param body_expr: A scalar expression containing the function body - :type body_expr: raco.expression.Expression + :param body_exprs: A list of scalar expressions containing the body + :type body_exprs: list of raco.expression.Expression """ if name in Parser.udf_functions: raise DuplicateFunctionDefinitionException(name, p.lineno(0)) @@ -251,9 +270,15 @@ def add_udf(p, name, args, body_expr): if len(args) != len(set(args)): raise DuplicateVariableException(name, p.lineno(0)) - Parser.check_for_undefined(p, name, body_expr, args) + if len(body_exprs) == 1: + emit_op = body_exprs[0] + else: + emit_op = TupleExpression(body_exprs) + + Parser.check_for_undefined(p, name, emit_op, args) - Parser.udf_functions[name] = Function(args, body_expr) + Parser.udf_functions[name] = Function(args, emit_op) + return emit_op @staticmethod def mangle(name): @@ -355,13 +380,19 @@ def p_unreserved_id_list(p): @staticmethod def p_udf(p): """udf : DEF unreserved_id LPAREN optional_arg_list RPAREN COLON sexpr SEMI""" # noqa - Parser.add_udf(p, p[2], p[4], p[7]) + Parser.add_udf(p, p[2], p[4], [p[7]]) + p[0] = None + + @staticmethod + def p_nary_udf(p): + """udf : DEF unreserved_id LPAREN optional_arg_list RPAREN COLON table_literal SEMI""" # noqa + Parser.add_nary_udf(p, p[2], p[4], p[7]) p[0] = None @staticmethod def p_constant(p): """constant : CONST unreserved_id COLON sexpr SEMI""" - Parser.add_udf(p, p[2], [], p[4]) + Parser.add_udf(p, p[2], [], [p[4]]) p[0] = None @staticmethod diff --git a/raco/myrial/query_tests.py b/raco/myrial/query_tests.py index b4cc371a..18f5efbb 100644 --- a/raco/myrial/query_tests.py +++ b/raco/myrial/query_tests.py @@ -1245,6 +1245,52 @@ def test_duplicate_variable_udf(self): with self.assertRaises(DuplicateVariableException): self.check_result(query, collections.Counter()) + def test_nary_udf(self): + query = """ + DEF Foo(a,b): [a + b, a - b]; + + out = [FROM SCAN(%s) AS X EMIT id, Foo(salary, dept_id) as [x, y]]; + STORE(out, OUTPUT); + """ % self.emp_key + + expected = collections.Counter([(t[0], t[1] + t[3], t[3] - t[1]) + for t in self.emp_table]) + self.check_result(query, expected) + + def test_nary_udf_name_count(self): + query = """ + DEF Foo(a,b): [a + b, a - b]; + + out = [FROM SCAN(%s) AS X EMIT id, Foo(salary, dept_id) as [x, y, z]]; + STORE(out, OUTPUT); + """ % self.emp_key + + with self.assertRaises(IllegalColumnNamesException): + self.check_result(query, None) + + def test_nary_udf_illegal_nesting(self): + query = """ + DEF Foo(x): [x + 3, x - 3]; + DEF Bar(a,b): [Foo(x), Foo(b)]; + + out = [FROM SCAN(%s) AS X EMIT id, Bar(salary, dept_id) as [x, y]]; + STORE(out, OUTPUT); + """ % self.emp_key + + with self.assertRaises(NestedTupleExpressionException): + self.check_result(query, None) + + def test_nary_udf_illegal_wildcard(self): + query = """ + DEF Foo(x): [x + 3, *]; + + out = [FROM SCAN(%s) AS X EMIT id, Foo(salary, dept_id) as [x, y]]; + STORE(out, OUTPUT); + """ % self.emp_key + + with self.assertRaises(IllegalWildcardException): + self.check_result(query, None) + def test_triangle_udf(self): query = """ DEF Triangle(a,b): (a*b)//2;