Skip to content

Commit

Permalink
test generator added
Browse files Browse the repository at this point in the history
  • Loading branch information
Upabjojr committed Jan 8, 2019
1 parent cb6fd17 commit 0e76468
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 72 deletions.
2 changes: 2 additions & 0 deletions symengine/utilities/matchpycpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ include_directories(BEFORE ${teuchos_BINARY_DIR})

#add_executable(output_test output_test.cpp)
#target_link_libraries(output_test symengine)

add_subdirectory(tests)
35 changes: 33 additions & 2 deletions symengine/utilities/matchpycpp/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,41 @@
#define COMMON_H

#include <symengine/basic.h>
#include <symengine/add.h>
#include <symengine/pow.h>
#include <symengine/mul.h>
#include <map>
#include <string>
#include <queue>

typedef std::map<std::string, SymEngine::RCP<const SymEngine::Basic>>
Substitution2;
using namespace std;
using namespace SymEngine;

typedef map<string, RCP<const Basic>> Substitution;
typedef deque<RCP<const Basic>> Deque;

int try_add_variable(Substitution &subst, string variable_name,
RCP<const Basic> &replacement)
{
if (subst.find(variable_name) == subst.end()) {
subst[variable_name] = replacement;
} else {
}
return 0;
}

Deque get_deque(RCP<const Basic> expr)
{
Deque d;
for (RCP<const Basic> i : expr->get_args()) {
d.push_back(i);
}
return d;
}

RCP<const Basic> x = symbol("x");
RCP<const Basic> y = symbol("y");
RCP<const Basic> z = symbol("z");
//RCP<const Basic> w = symbol("w");

#endif
118 changes: 48 additions & 70 deletions symengine/utilities/matchpycpp/cpp_code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from matchpy.matching.many_to_one import _EPS
from matchpy.utils import get_short_lambda_source

from symengine_printer import symengine_print

COLLAPSE_IF_RE = re.compile(
r'\n(?P<indent1>\s*)if (?P<cond1>[^\n]+):\n+\1(?P<indent2>\s+)'
r'(?P<comment>(?:\#[^\n]*\n+\1\3)*)'
Expand Down Expand Up @@ -50,6 +52,9 @@ def get_var_name(self, prefix):
self._var_number += 1
return prefix + str(self._var_number)

def prepend_code(self):
code = """
"""
def generate_code(self, func_name='match_root', add_imports=True):
self._imports.add('#include <deque>')
self._imports.add('#include <iostream>')
Expand All @@ -62,41 +67,8 @@ def generate_code(self, func_name='match_root', add_imports=True):
self._imports.add('#include <symengine/basic.h>')
self._imports.add('#include <symengine/pow.h>')

self._imports.add('#include "generator_trick.h"')

self.add_line('')
self.add_line('using namespace std;')
self.add_line('using namespace SymEngine;')
self.add_line('typedef map<string, RCP<const Basic>> Substitution;')
self.add_line('typedef deque<RCP<const Basic>> Deque;')
self.add_line('')

self._code += """
int try_add_variable(Substitution &subst, string variable_name,
RCP<const Basic> &replacement)
{
if (subst.find(variable_name) == subst.end()) {
subst[variable_name] = replacement;
} else {
}
return 0;
}
Deque get_deque(RCP<const Basic> expr)
{
Deque d;
for (RCP<const Basic> i : expr->get_args()) {
d.push_back(i);
}
return d;
}
RCP<const Basic> x = symbol("x");
RCP<const Basic> y = symbol("y");
RCP<const Basic> z = symbol("z");
RCP<const Basic> w = symbol("w");
self._imports.add('#include "common.h"')

"""
self.add_line('tuple<int, Substitution> {}(RCP<const Basic> subject)'.format(func_name))
self.indent()
self.add_line('Deque {};'.format(self._subjects[-1]))
Expand All @@ -121,7 +93,7 @@ def generate_self(self, state):
self._imports.add('#include "bipartite.h"')
self._imports.add('#include "common.h"')
generator = type(self)(state.matcher.automaton)
generator.indent()
generator.indent(bracket=False)
global_code, code = generator.generate_code(func_name='get_match_iter', add_imports=False)
self._global_code.append(global_code)
patterns = self.commutative_patterns(state.matcher.patterns)
Expand All @@ -136,50 +108,62 @@ class CommutativeMatcher{0} : public CommutativeMatcher
{{
public:
{8}static CommutativeMatcher{0} *_instance = NULL;
{8}patterns = {1};
{8}Deque subjects = {2};
{8}subjects_by_id = {7};
{8}BipartiteGraph bipartite;
{8}associative = {3};
{8}int max_optional_count = {4};
{8}{5} anonymous_patterns;
{8}CommutativeMatcher{0}
{8}static patterns;
{8}static Deque subjects;
{8}static subjects_by_id;
{8}static BipartiteGraph bipartite;
{8}static associative;
{8}static int max_optional_count;
{8}static anonymous_patterns;
{8}CommutativeMatcher{0}()
{8}{{
{8}{8}add_subject(NULL);
{8}}}
{8}@staticmethod
{8}CommutativeMatcher{0} get()
{8}static CommutativeMatcher{0} *get()
{8}{{
{8}{8}if (CommutativeMatcher{0}._instance == NULL)
{8}{8}{8}CommutativeMatcher{0}._instance = new CommutativeMatcher{0}();
{8}{8}return CommutativeMatcher{0}._instance;
{8}{8}{8}this->_instance = new CommutativeMatcher{0}();
{8}{8}return this->_instance;
{8}}}
{8}static tuple<int, Substitution>
{6}'''.strip().format(
{6}
}};
static CommutativeMatcher{0}::patterns = {1};
static CommutativeMatcher{0}::Deque subjects = {2};
static CommutativeMatcher{0}::subjects_by_id = {7};
static CommutativeMatcher{0}::BipartiteGraph bipartite;
static CommutativeMatcher{0}::associative = {3};
static CommutativeMatcher{0}::int max_optional_count = {4};
static CommutativeMatcher{0}::anonymous_patterns = {5};
'''.strip().format(
state.number, patterns, subjects, associative, max_optional_count, anonymous_patterns, code,
subjects_by_id, self._indentation
)
)
self.add_line('matcher = CommutativeMatcher{}.get()'.format(state.number))
self.add_line('CommutativeMatcher{0} *matcher = CommutativeMatcher{0}::get();'.format(state.number))
tmp = self.get_var_name('tmp')
self.add_line('{} = {}'.format(tmp, self._subjects[-1]))
self.add_line('{} = []'.format(self._subjects[-1]))
self.add_line('for s in {}:'.format(tmp))
self.indent()
self.add_line('matcher.add_subject(s)')
self.add_line('RCP<const Basic> {} = {};'.format(tmp, self._subjects[-1]))
self.add_line('{} = {{}};'.format(self._subjects[-1]))
self.add_line('for (auto &s : {}->get_args()) {{'.format(tmp))
self.indent(bracket=False)
self.add_line('matcher->add_subject(s);')
subjects = self._subjects.pop()
self.dedent()
self.add_line(
'for pattern_index, subst{} in matcher.match({}, subst{}):'.format(self._substs + 1, tmp, self._substs)
'for (auto &p : matcher->match({}, subst{})) {{'.format(tmp, self._substs)
)
self._substs += 1
self.indent()
self.indent(bracket=False)
self.add_line("pattern_index = p.first;")
self.add_line("Substitution subst{} = p.second;".format(self._substs))
for pattern_index, transitions in state.transitions.items():
self.add_line('if pattern_index == {}:'.format(pattern_index))
self.indent()
self.add_line('if (pattern_index == {}) {{'.format(pattern_index))
self.indent(bracket=False)
patterns, variables = next((p, v) for i, p, v in state.matcher.patterns.values() if i == pattern_index)
variables = set(v[0][0] for v in variables)
pvars = iter(get_variables(state.matcher.automaton.patterns[i][0].expression) for i in patterns)
Expand All @@ -192,6 +176,7 @@ class CommutativeMatcher{0} : public CommutativeMatcher
self.generate_constraints(constraints, transitions)
self.dedent()
self.dedent()
self.add_line("delete matcher;")
self._substs -= 1
self._subjects.append(subjects)
else:
Expand Down Expand Up @@ -275,17 +260,13 @@ def generate_transition_code(self, transition):
self.exit_variable_assignment()
exit_func(value)

def get_args(self, operation, operation_type):
return 'op_iter({})'.format(operation)

def push_subjects(self, value, operation):
self._subjects.append(self.get_var_name('subjects'))
self.add_line('Deque {} = get_deque({});'.format(self._subjects[-1], value))
#self.get_args(value, operation)))

def push_subst(self):
new_subst = self.get_var_name('subst')
self.add_line('subst{} = Substitution(subst{})'.format(self._substs + 1, self._substs))
self.add_line('Substitution subst{} = Substitution(subst{});'.format(self._substs + 1, self._substs))
self._substs += 1

def enter_eps(self, _):
Expand Down Expand Up @@ -345,7 +326,7 @@ def exit_symbol_wildcard(self, value):
self.dedent()

def enter_fixed_wildcard(self, wildcard):
self.add_line('if (len({}) >= 1) {{'.format(self._subjects[-1]))
self.add_line('if ({}.size() >= 1) {{'.format(self._subjects[-1]))
self.indent(bracket=False)
tmp = self.get_var_name('tmp')
self.add_line('RCP<const Basic> {} = {}.front();'.format(tmp, self._subjects[-1]))
Expand Down Expand Up @@ -398,10 +379,7 @@ def enter_symbol(self, symbol):
return tmp

def symbol_repr(self, symbol):
# TODO: transform this into a SymEnginePrinter module in SymPy
if isinstance(symbol, sympy.Pow):
return 'pow({0}, {1})'.format(*map(self.symbol_repr, symbol.args))
return repr(symbol)
return symengine_print(symbol)

def exit_symbol(self, value):
self.add_line('{}.push_front({});'.format(self._subjects[-1], value))
Expand Down
82 changes: 82 additions & 0 deletions symengine/utilities/matchpycpp/generate_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
from sympy.integrals.rubi.symbol import WC
import matchpy
from cpp_code_generation import CppCodeGenerator
from symengine_printer import symengine_print
from sympy import *

x, y, z = symbols("x y z")
a, b, c = symbols("a b c")
f = Function("f")
w = WC("w")


collection_of_expressions = [
([x], {x: True, y: False}),
([x**y], {x**y: True, x**z: False}),
([x**y, w], {x**y: True, x: True, x+y: True}),
#([x+y], {}),
]

def generate_matchpy_matcher(pattern_list):
matcher = matchpy.ManyToOneMatcher()
for pattern in pattern_list:
matcher.add(matchpy.Pattern(pattern))
return matcher


def generate_cpp_code(matcher):
cg = CppCodeGenerator(matcher)
a, b = cg.generate_code()
return a, b


def export_code_to_file(filename, a, b):
fout = open(os.path.join("tests", filename), "w")
fout.write(a)
fout.write("\n\n")
fout.write(b)
return fout


def add_main_with_tests(fout, test_cases):
fout.write("""
int main() {
tuple<int, Substitution> ret;
""")
for test_case, result in test_cases.items():
fout.write(" ret = match_root({0});\n".format(symengine_print(test_case)))
if result:
fout.write(" assert(get<0>(ret) >= 0);\n")
else:
fout.write(" assert(get<0>(ret) == -1);\n")
fout.write("}\n")


def generate_tests():
fout = open(os.path.join("tests", "CMakeLists.txt"), "w")
fout.write("""\
project(matchpycpp_tests)
include_directories(BEFORE ${symengine_SOURCE_DIR})
include_directories(BEFORE ${symengine_BINARY_DIR})
include_directories(BEFORE ${teuchos_SOURCE_DIR})
include_directories(BEFORE ${teuchos_BINARY_DIR})
""")
for i, (pattern_list, test_cases) in enumerate(collection_of_expressions):
matcher = generate_matchpy_matcher(pattern_list)
a, b = generate_cpp_code(matcher)
filename = "test_case{0:03}".format(i)
filenamecpp = "{}.cpp".format(filename)
fwrite = export_code_to_file(filenamecpp, a, b)
add_main_with_tests(fwrite, test_cases)
fwrite.close()
fout.write("\n")
fout.write("add_executable({0} {1})\n".format(filename, filenamecpp))
fout.write("target_link_libraries({0} symengine)\n".format(filename))
fout.close()


if __name__ == "__main__":
generate_tests()
52 changes: 52 additions & 0 deletions symengine/utilities/matchpycpp/generator_trick.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#ifndef GENERATOR_TRICK_H
#define GENERATOR_TRICK_H

#include <deque>
#include <functional>
#include <memory>

template <typename T>
class GeneratorTrick
{
public:
GeneratorTrick()
{
generator_stop = false;
}
virtual ~GeneratorTrick(){};

std::shared_ptr<T> next()
{
if (current == nullptr) {
start();
}
while (true) {
if (generator_stop) {
break;
}
current();
if (yield_queue.size() > 0) {
std::shared_ptr<T> front
= std::make_shared<T>(yield_queue.front());
yield_queue.pop_front();
return front;
}
}
return nullptr;
}

void yield(T value)
{
yield_queue.push_back(value);
}

protected:
bool generator_stop;
std::function<void()> current;

private:
std::deque<T> yield_queue;
virtual void start() = 0;
};

#endif

0 comments on commit 0e76468

Please sign in to comment.