diff --git a/symengine/utilities/matchpycpp/CMakeLists.txt b/symengine/utilities/matchpycpp/CMakeLists.txt index 62cc8efde1..4c041ffb96 100644 --- a/symengine/utilities/matchpycpp/CMakeLists.txt +++ b/symengine/utilities/matchpycpp/CMakeLists.txt @@ -7,3 +7,5 @@ include_directories(BEFORE ${teuchos_BINARY_DIR}) #add_executable(output_test output_test.cpp) #target_link_libraries(output_test symengine) + +add_subdirectory(tests) diff --git a/symengine/utilities/matchpycpp/common.h b/symengine/utilities/matchpycpp/common.h index 07957a1adb..29f0ce55ba 100644 --- a/symengine/utilities/matchpycpp/common.h +++ b/symengine/utilities/matchpycpp/common.h @@ -2,10 +2,41 @@ #define COMMON_H #include +#include +#include +#include #include #include +#include -typedef std::map> - Substitution2; +using namespace std; +using namespace SymEngine; + +typedef map> Substitution; +typedef deque> Deque; + +int try_add_variable(Substitution &subst, string variable_name, + RCP &replacement) +{ + if (subst.find(variable_name) == subst.end()) { + subst[variable_name] = replacement; + } else { + } + return 0; +} + +Deque get_deque(RCP expr) +{ + Deque d; + for (RCP i : expr->get_args()) { + d.push_back(i); + } + return d; +} + +RCP x = symbol("x"); +RCP y = symbol("y"); +RCP z = symbol("z"); +//RCP w = symbol("w"); #endif diff --git a/symengine/utilities/matchpycpp/cpp_code_generation.py b/symengine/utilities/matchpycpp/cpp_code_generation.py index 981a170112..4d19a1e84d 100644 --- a/symengine/utilities/matchpycpp/cpp_code_generation.py +++ b/symengine/utilities/matchpycpp/cpp_code_generation.py @@ -9,6 +9,8 @@ from matchpy.matching.many_to_one import _EPS from matchpy.utils import get_short_lambda_source +from symengine_printer import symengine_print + COLLAPSE_IF_RE = re.compile( r'\n(?P\s*)if (?P[^\n]+):\n+\1(?P\s+)' r'(?P(?:\#[^\n]*\n+\1\3)*)' @@ -50,6 +52,9 @@ def get_var_name(self, prefix): self._var_number += 1 return prefix + str(self._var_number) + def prepend_code(self): + code = """ +""" def generate_code(self, func_name='match_root', add_imports=True): self._imports.add('#include ') self._imports.add('#include ') @@ -62,41 +67,8 @@ def generate_code(self, func_name='match_root', add_imports=True): self._imports.add('#include ') self._imports.add('#include ') - self._imports.add('#include "generator_trick.h"') - - self.add_line('') - self.add_line('using namespace std;') - self.add_line('using namespace SymEngine;') - self.add_line('typedef map> Substitution;') - self.add_line('typedef deque> Deque;') - self.add_line('') - - self._code += """ -int try_add_variable(Substitution &subst, string variable_name, - RCP &replacement) -{ - if (subst.find(variable_name) == subst.end()) { - subst[variable_name] = replacement; - } else { - } - return 0; -} - -Deque get_deque(RCP expr) -{ - Deque d; - for (RCP i : expr->get_args()) { - d.push_back(i); - } - return d; -} - -RCP x = symbol("x"); -RCP y = symbol("y"); -RCP z = symbol("z"); -RCP w = symbol("w"); + self._imports.add('#include "common.h"') -""" self.add_line('tuple {}(RCP subject)'.format(func_name)) self.indent() self.add_line('Deque {};'.format(self._subjects[-1])) @@ -121,7 +93,7 @@ def generate_self(self, state): self._imports.add('#include "bipartite.h"') self._imports.add('#include "common.h"') generator = type(self)(state.matcher.automaton) - generator.indent() + generator.indent(bracket=False) global_code, code = generator.generate_code(func_name='get_match_iter', add_imports=False) self._global_code.append(global_code) patterns = self.commutative_patterns(state.matcher.patterns) @@ -136,50 +108,62 @@ class CommutativeMatcher{0} : public CommutativeMatcher {{ public: {8}static CommutativeMatcher{0} *_instance = NULL; -{8}patterns = {1}; -{8}Deque subjects = {2}; -{8}subjects_by_id = {7}; -{8}BipartiteGraph bipartite; -{8}associative = {3}; -{8}int max_optional_count = {4}; -{8}{5} anonymous_patterns; - -{8}CommutativeMatcher{0} +{8}static patterns; +{8}static Deque subjects; +{8}static subjects_by_id; +{8}static BipartiteGraph bipartite; +{8}static associative; +{8}static int max_optional_count; +{8}static anonymous_patterns; + +{8}CommutativeMatcher{0}() {8}{{ {8}{8}add_subject(NULL); {8}}} -{8}@staticmethod -{8}CommutativeMatcher{0} get() +{8}static CommutativeMatcher{0} *get() {8}{{ {8}{8}if (CommutativeMatcher{0}._instance == NULL) -{8}{8}{8}CommutativeMatcher{0}._instance = new CommutativeMatcher{0}(); -{8}{8}return CommutativeMatcher{0}._instance; +{8}{8}{8}this->_instance = new CommutativeMatcher{0}(); +{8}{8}return this->_instance; {8}}} -{8}static tuple -{6}'''.strip().format( +{6} +}}; + + +static CommutativeMatcher{0}::patterns = {1}; +static CommutativeMatcher{0}::Deque subjects = {2}; +static CommutativeMatcher{0}::subjects_by_id = {7}; +static CommutativeMatcher{0}::BipartiteGraph bipartite; +static CommutativeMatcher{0}::associative = {3}; +static CommutativeMatcher{0}::int max_optional_count = {4}; +static CommutativeMatcher{0}::anonymous_patterns = {5}; + +'''.strip().format( state.number, patterns, subjects, associative, max_optional_count, anonymous_patterns, code, subjects_by_id, self._indentation ) ) - self.add_line('matcher = CommutativeMatcher{}.get()'.format(state.number)) + self.add_line('CommutativeMatcher{0} *matcher = CommutativeMatcher{0}::get();'.format(state.number)) tmp = self.get_var_name('tmp') - self.add_line('{} = {}'.format(tmp, self._subjects[-1])) - self.add_line('{} = []'.format(self._subjects[-1])) - self.add_line('for s in {}:'.format(tmp)) - self.indent() - self.add_line('matcher.add_subject(s)') + self.add_line('RCP {} = {};'.format(tmp, self._subjects[-1])) + self.add_line('{} = {{}};'.format(self._subjects[-1])) + self.add_line('for (auto &s : {}->get_args()) {{'.format(tmp)) + self.indent(bracket=False) + self.add_line('matcher->add_subject(s);') subjects = self._subjects.pop() self.dedent() self.add_line( - 'for pattern_index, subst{} in matcher.match({}, subst{}):'.format(self._substs + 1, tmp, self._substs) + 'for (auto &p : matcher->match({}, subst{})) {{'.format(tmp, self._substs) ) self._substs += 1 - self.indent() + self.indent(bracket=False) + self.add_line("pattern_index = p.first;") + self.add_line("Substitution subst{} = p.second;".format(self._substs)) for pattern_index, transitions in state.transitions.items(): - self.add_line('if pattern_index == {}:'.format(pattern_index)) - self.indent() + self.add_line('if (pattern_index == {}) {{'.format(pattern_index)) + self.indent(bracket=False) patterns, variables = next((p, v) for i, p, v in state.matcher.patterns.values() if i == pattern_index) variables = set(v[0][0] for v in variables) pvars = iter(get_variables(state.matcher.automaton.patterns[i][0].expression) for i in patterns) @@ -192,6 +176,7 @@ class CommutativeMatcher{0} : public CommutativeMatcher self.generate_constraints(constraints, transitions) self.dedent() self.dedent() + self.add_line("delete matcher;") self._substs -= 1 self._subjects.append(subjects) else: @@ -275,17 +260,13 @@ def generate_transition_code(self, transition): self.exit_variable_assignment() exit_func(value) - def get_args(self, operation, operation_type): - return 'op_iter({})'.format(operation) - def push_subjects(self, value, operation): self._subjects.append(self.get_var_name('subjects')) self.add_line('Deque {} = get_deque({});'.format(self._subjects[-1], value)) - #self.get_args(value, operation))) def push_subst(self): new_subst = self.get_var_name('subst') - self.add_line('subst{} = Substitution(subst{})'.format(self._substs + 1, self._substs)) + self.add_line('Substitution subst{} = Substitution(subst{});'.format(self._substs + 1, self._substs)) self._substs += 1 def enter_eps(self, _): @@ -345,7 +326,7 @@ def exit_symbol_wildcard(self, value): self.dedent() def enter_fixed_wildcard(self, wildcard): - self.add_line('if (len({}) >= 1) {{'.format(self._subjects[-1])) + self.add_line('if ({}.size() >= 1) {{'.format(self._subjects[-1])) self.indent(bracket=False) tmp = self.get_var_name('tmp') self.add_line('RCP {} = {}.front();'.format(tmp, self._subjects[-1])) @@ -398,10 +379,7 @@ def enter_symbol(self, symbol): return tmp def symbol_repr(self, symbol): - # TODO: transform this into a SymEnginePrinter module in SymPy - if isinstance(symbol, sympy.Pow): - return 'pow({0}, {1})'.format(*map(self.symbol_repr, symbol.args)) - return repr(symbol) + return symengine_print(symbol) def exit_symbol(self, value): self.add_line('{}.push_front({});'.format(self._subjects[-1], value)) diff --git a/symengine/utilities/matchpycpp/generate_tests.py b/symengine/utilities/matchpycpp/generate_tests.py new file mode 100644 index 0000000000..91f89703fc --- /dev/null +++ b/symengine/utilities/matchpycpp/generate_tests.py @@ -0,0 +1,82 @@ +import os +from sympy.integrals.rubi.symbol import WC +import matchpy +from cpp_code_generation import CppCodeGenerator +from symengine_printer import symengine_print +from sympy import * + +x, y, z = symbols("x y z") +a, b, c = symbols("a b c") +f = Function("f") +w = WC("w") + + +collection_of_expressions = [ + ([x], {x: True, y: False}), + ([x**y], {x**y: True, x**z: False}), + ([x**y, w], {x**y: True, x: True, x+y: True}), + #([x+y], {}), +] + +def generate_matchpy_matcher(pattern_list): + matcher = matchpy.ManyToOneMatcher() + for pattern in pattern_list: + matcher.add(matchpy.Pattern(pattern)) + return matcher + + +def generate_cpp_code(matcher): + cg = CppCodeGenerator(matcher) + a, b = cg.generate_code() + return a, b + + +def export_code_to_file(filename, a, b): + fout = open(os.path.join("tests", filename), "w") + fout.write(a) + fout.write("\n\n") + fout.write(b) + return fout + + +def add_main_with_tests(fout, test_cases): + fout.write(""" +int main() { + tuple ret; + +""") + for test_case, result in test_cases.items(): + fout.write(" ret = match_root({0});\n".format(symengine_print(test_case))) + if result: + fout.write(" assert(get<0>(ret) >= 0);\n") + else: + fout.write(" assert(get<0>(ret) == -1);\n") + fout.write("}\n") + + +def generate_tests(): + fout = open(os.path.join("tests", "CMakeLists.txt"), "w") + fout.write("""\ +project(matchpycpp_tests) + +include_directories(BEFORE ${symengine_SOURCE_DIR}) +include_directories(BEFORE ${symengine_BINARY_DIR}) +include_directories(BEFORE ${teuchos_SOURCE_DIR}) +include_directories(BEFORE ${teuchos_BINARY_DIR}) +""") + for i, (pattern_list, test_cases) in enumerate(collection_of_expressions): + matcher = generate_matchpy_matcher(pattern_list) + a, b = generate_cpp_code(matcher) + filename = "test_case{0:03}".format(i) + filenamecpp = "{}.cpp".format(filename) + fwrite = export_code_to_file(filenamecpp, a, b) + add_main_with_tests(fwrite, test_cases) + fwrite.close() + fout.write("\n") + fout.write("add_executable({0} {1})\n".format(filename, filenamecpp)) + fout.write("target_link_libraries({0} symengine)\n".format(filename)) + fout.close() + + +if __name__ == "__main__": + generate_tests() diff --git a/symengine/utilities/matchpycpp/generator_trick.h b/symengine/utilities/matchpycpp/generator_trick.h new file mode 100644 index 0000000000..ab260676e8 --- /dev/null +++ b/symengine/utilities/matchpycpp/generator_trick.h @@ -0,0 +1,52 @@ +#ifndef GENERATOR_TRICK_H +#define GENERATOR_TRICK_H + +#include +#include +#include + +template +class GeneratorTrick +{ +public: + GeneratorTrick() + { + generator_stop = false; + } + virtual ~GeneratorTrick(){}; + + std::shared_ptr next() + { + if (current == nullptr) { + start(); + } + while (true) { + if (generator_stop) { + break; + } + current(); + if (yield_queue.size() > 0) { + std::shared_ptr front + = std::make_shared(yield_queue.front()); + yield_queue.pop_front(); + return front; + } + } + return nullptr; + } + + void yield(T value) + { + yield_queue.push_back(value); + } + +protected: + bool generator_stop; + std::function current; + +private: + std::deque yield_queue; + virtual void start() = 0; +}; + +#endif diff --git a/symengine/utilities/matchpycpp/symengine_printer.py b/symengine/utilities/matchpycpp/symengine_printer.py new file mode 100644 index 0000000000..2a3367d21d --- /dev/null +++ b/symengine/utilities/matchpycpp/symengine_printer.py @@ -0,0 +1,20 @@ +from sympy.printing.pycode import PythonCodePrinter + + +class SymEnginePrinter(PythonCodePrinter): + + def _print_Pow(self, expr): + return "pow({0}, {1})".format(self._print(expr.base), self._print(expr.exp)) + + def _print_Add(self, expr): + if len(expr.args) != 2: + raise NotImplementedError + return "add({}, {})".format( + self._print(expr.args[0]), + self._print(expr.args[1]), + ) + + +def symengine_print(expr): + printer = SymEnginePrinter() + return printer.doprint(expr) diff --git a/symengine/utilities/matchpycpp/tests/.gitkeep b/symengine/utilities/matchpycpp/tests/.gitkeep new file mode 100644 index 0000000000..e69de29bb2