Skip to content

Commit

Permalink
Merge ff6de46 into e74b530
Browse files Browse the repository at this point in the history
  • Loading branch information
terrorfisch committed Jul 13, 2018
2 parents e74b530 + ff6de46 commit 47a4410
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 16 deletions.
9 changes: 4 additions & 5 deletions qctoolkit/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import warnings
import functools
import array
import itertools

import sympy
import numpy

from qctoolkit.serialization import AnonymousSerializable
from qctoolkit.utils.sympy import sympify, to_numpy, recursive_substitution, evaluate_lambdified,\
get_most_simple_representation
get_most_simple_representation, get_variables

__all__ = ["Expression", "ExpressionVariableMissingException", "ExpressionScalar", "ExpressionVector"]

Expand Down Expand Up @@ -125,9 +126,7 @@ def __init__(self, expression_vector: Sequence):
if isinstance(expression_vector, sympy.NDimArray):
expression_vector = to_numpy(expression_vector)
self._expression_vector = self.sympify_vector(expression_vector)
variables = {str(x)
for expr in self._expression_vector.ravel()
for x in expr.free_symbols}
variables = set(itertools.chain.from_iterable(map(get_variables, self._expression_vector.flat)))
self._variables = tuple(variables)

@property
Expand Down Expand Up @@ -220,7 +219,7 @@ def __init__(self, ex: Union[str, Number, sympy.Expr]) -> None:
self._original_expression = ex
self._sympified_expression = sympify(ex)

self._variables = tuple(str(var) for var in self._sympified_expression.free_symbols)
self._variables = get_variables(self._sympified_expression)

@property
def underlying_expression(self) -> sympy.Expr:
Expand Down
19 changes: 15 additions & 4 deletions qctoolkit/utils/sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import numpy


__all__ = ["sympify", "substitute_with_eval", "to_numpy"]
__all__ = ["sympify", "substitute_with_eval", "to_numpy", "get_variables", "get_free_symbols", "recursive_substitution",
"evaluate_lambdified", "get_most_simple_representation"]


Sympifyable = Union[str, Number, sympy.Expr, numpy.str_]
Expand Down Expand Up @@ -119,13 +120,23 @@ def get_most_simple_representation(expression: sympy.Expr) -> Union[str, int, fl
return str(expression)


def get_free_symbols(expression: sympy.Expr) -> Sequence[sympy.Symbol]:
return tuple(symbol
for symbol in expression.free_symbols
if not isinstance(symbol, sympy.Indexed))


def get_variables(expression: sympy.Expr) -> Sequence[str]:
return tuple(map(str, get_free_symbols(expression)))


def substitute_with_eval(expression: sympy.Expr,
substitutions: Dict[str, Union[sympy.Expr, numpy.ndarray, str]]) -> sympy.Expr:
"""Substitutes only sympy.Symbols. Workaround for numpy like array behaviour. ~Factor 3 slower compared to subs"""
substitutions = {k: v if isinstance(v, sympy.Expr) else sympify(v)
for k, v in substitutions.items()}

for symbol in expression.free_symbols:
for symbol in get_free_symbols(expression):
symbol_name = str(symbol)
if symbol_name not in substitutions:
substitutions[symbol_name] = symbol
Expand All @@ -146,14 +157,14 @@ def _recursive_substitution(expression: sympy.Expr,
func = numpy_compatible_mul
else:
func = expression.func
substitutions = {s: substitutions.get(s, s) for s in expression.free_symbols}
substitutions = {s: substitutions.get(s, s) for s in get_free_symbols(expression)}
return func(*(_recursive_substitution(arg, substitutions) for arg in expression.args))


def recursive_substitution(expression: sympy.Expr,
substitutions: Dict[str, Union[sympy.Expr, numpy.ndarray, str]]) -> sympy.Expr:
substitutions = {sympy.Symbol(k): sympify(v) for k, v in substitutions.items()}
for s in expression.free_symbols:
for s in get_free_symbols(expression):
substitutions.setdefault(s, s)
return _recursive_substitution(expression, substitutions)

Expand Down
6 changes: 6 additions & 0 deletions tests/expression_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ def test_variables(self) -> None:
received = sorted(e.variables)
self.assertEqual(expected, received)

def test_variables_indexed(self):
e = ExpressionScalar('a[i] * c')
expected = sorted(['a', 'i', 'c'])
received = sorted(e.variables)
self.assertEqual(expected, received)

def test_evaluate_variable_missing(self) -> None:
e = ExpressionScalar('a * b + c')
params = {
Expand Down
34 changes: 27 additions & 7 deletions tests/utils/sympy_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
a_ = IndexedBase(a)
b_ = IndexedBase(b)

from qctoolkit.utils.sympy import sympify as qc_sympify, substitute_with_eval, recursive_substitution, Len, evaluate_lambdified, evaluate_compiled, get_most_simple_representation
from qctoolkit.utils.sympy import sympify as qc_sympify, substitute_with_eval, recursive_substitution, Len,\
evaluate_lambdified, evaluate_compiled, get_most_simple_representation, get_variables, get_free_symbols


################################################### SUBSTITUTION #######################################################
Expand Down Expand Up @@ -214,15 +215,34 @@ def substitute(self, expression: sympy.Expr, substitutions: dict):
return recursive_substitution(expression, substitutions).doit()


class GetFreeSymbolsTests(TestCase):
def assert_symbol_sets_equal(self, expected, actual):
self.assertEqual(len(expected), len(actual))
self.assertEqual(set(expected), set(actual))

def test_get_free_symbols(self):
expr = a * b / 5
self.assert_symbol_sets_equal([a, b], get_free_symbols(expr))

def test_get_free_symbols_indexed(self):
expr = a_[i] * IndexedBase(a*b)[j]
self.assert_symbol_sets_equal({a, b, i, j}, set(get_free_symbols(expr)))

def test_get_variables(self):
expr = a * b / 5
self.assertEqual({'a', 'b'}, set(get_variables(expr)))

def test_get_variables_indexed(self):
expr = a_[i] * IndexedBase(a*b)[j]
self.assertEqual({'a', 'b', 'i', 'j'}, set(get_variables(expr)))


class EvaluationTests(TestCase):
def evaluate(self, expression: Union[sympy.Expr, np.ndarray], parameters):
def get_variables(expr: sympy.Expr):
return {str(s) for s in expr.free_symbols}
if isinstance(expression, np.ndarray):
vectorized = np.vectorize(get_variables)
get_variables = lambda expr: set.union(*vectorized(expr))

variables = get_variables(expression)
variables = set.union(*map(set, map(get_variables, expression.flat)))
else:
variables = get_variables(expression)
return evaluate_lambdified(expression, variables=list(variables), parameters=parameters, lambdified=None)[0]

def test_eval_simple(self):
Expand Down

0 comments on commit 47a4410

Please sign in to comment.