diff --git a/symengine/parser.cpp b/symengine/parser.cpp index 42e9139d32..df0af45679 100644 --- a/symengine/parser.cpp +++ b/symengine/parser.cpp @@ -9,7 +9,10 @@ #include #include #include -#include +#include +#include +#include +#include namespace SymEngine { @@ -21,21 +24,23 @@ class expressionParser {'*', 3}, {'/', 4}, {'^', 5} }; std::vector operatorClose; + std::string S; + unsigned int Slen; public: - RCP parse_string(std::string &s, uint l, uint h) + RCP parse_string(unsigned int l, unsigned int h) { RCP result; bool result_set = false; bool expr_is_symbol = false; std::string expr; - for (uint iter = l; iter < h; ++iter) + for (unsigned int iter = l; iter < h; ++iter) { - if (is_operator(s, iter)) + if (is_operator(iter)) { - if (s[iter] != '(') + if (S[iter] != '(') { if (!result_set) { @@ -46,43 +51,44 @@ class expressionParser } } - switch(s[iter]) + switch(S[iter]) { case '+': - result = add(result, parse_string(s, iter+1, operatorClose[iter])); - iter = operatorClose[iter]-1; - break; - case '(': - result = parse_string(s, iter+1, operatorClose[iter]); + result = add(result, parse_string(iter+1, operatorClose[iter])); iter = operatorClose[iter]-1; break; case '*': - result = mul(result, parse_string(s, iter+1, operatorClose[iter])); + result = mul(result, parse_string(iter+1, operatorClose[iter])); iter = operatorClose[iter]-1; break; case '-': - result = sub(result, parse_string(s, iter+1, operatorClose[iter])); + result = sub(result, parse_string(iter+1, operatorClose[iter])); iter = operatorClose[iter]-1; break; case '/': - result = div(result, parse_string(s, iter+1, operatorClose[iter])); + result = div(result, parse_string(iter+1, operatorClose[iter])); iter = operatorClose[iter]-1; break; case '^': - result = pow(result, parse_string(s, iter+1, operatorClose[iter])); + result = pow(result, parse_string(iter+1, operatorClose[iter])); + iter = operatorClose[iter]-1; + break; + case '(': + result = functionify(iter, expr); iter = operatorClose[iter]-1; break; case ')': continue; } + result_set = true; expr_is_symbol = false; } else { - expr += s[iter]; + expr += S[iter]; - int ascii = s[iter] - '0'; + int ascii = S[iter] - '0'; if (ascii < 0 or ascii > 9) expr_is_symbol = true; @@ -101,27 +107,27 @@ class expressionParser RCP parse(std::string &s) { - std::string copy; - std::stack rBracket; - std::stack > opStack; + std::stack rBracket; + std::stack > opStack; + S = ""; - for (uint i = 0; i < s.length(); ++i) + for (unsigned int i = 0; i < s.length(); ++i) { if (s[i] == ' ') continue; - copy += s[i]; + S += s[i]; } - uint newLength = copy.length(); + Slen = S.length(); operatorClose.clear(); - operatorClose.resize(newLength); - opStack.push(std::make_pair(-1, newLength)); + operatorClose.resize(Slen); + opStack.push(std::make_pair(-1, Slen)); - for (int i = newLength-1; i >= 0; i--) + for (int i = Slen-1; i >= 0; i--) { - if (is_operator(copy, i)) + if (is_operator(i)) { - char x = copy[i]; + char x = S[i]; if(x == '(') { while(opStack.top().second != rBracket.top()) @@ -145,16 +151,26 @@ class expressionParser } } } - return parse_string(copy, 0, newLength); + return parse_string(0, Slen); } - bool is_operator(std::string& s, int iter) + bool is_operator(int iter) { - if (iter >= 0 and iter < (int)s.length()) - if (OPERATORS.find(s[iter]) != OPERATORS.end()) + if (iter >= 0 and iter < (int)Slen) + if (OPERATORS.find(S[iter]) != OPERATORS.end()) return true; return false; } + + RCP functionify(unsigned int iter, std::string expr) + { + RCP inner = parse_string(iter+1, operatorClose[iter]); + + if(expr == "") return inner; + if(expr == "sin") return sin(inner); + + throw std::runtime_error("Unknown function " + expr); + } }; } // SymEngine diff --git a/symengine/tests/basic/test_parser.cpp b/symengine/tests/basic/test_parser.cpp index cc468cca48..24d67d78f3 100644 --- a/symengine/tests/basic/test_parser.cpp +++ b/symengine/tests/basic/test_parser.cpp @@ -122,3 +122,16 @@ TEST_CASE("Parsing: symbols", "[parser]") res = p.parse(s); REQUIRE(eq(*res, *div(mul(y, integer(3)), add(x, integer(1))))); } + +TEST_CASE("Parsing: functions", "[parser]") +{ + std::string s; + expressionParser p; + RCP res; + RCP x = symbol("x"); + RCP y = symbol("y"); + + s = "sin(x)"; + res = p.parse(s); + REQUIRE(eq(*res, *sin(x))); +}