Skip to content

Commit

Permalink
xreplace complex infinity for all expressions in build_functions
Browse files Browse the repository at this point in the history
  • Loading branch information
bocklund committed Mar 28, 2020
1 parent 8c3dc51 commit faf16e9
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions pycalphad/codegen/sympydiff_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"""
from pycalphad.core.cache import cacheit
from pycalphad.core.utils import wrap_symbol_symengine
from symengine import sympify, lambdify
from symengine import sympify, lambdify, zoo, oo
from collections import namedtuple

BuildFunctionsResult = namedtuple('BuildFunctionsResult', ['func', 'grad', 'hess'])
Expand Down Expand Up @@ -107,16 +107,20 @@ def build_functions(sympy_graph, variables, parameters=None, wrt=None,
variables = tuple(variables)
parameters = tuple(parameters)
func, grad, hess = None, None, None
inp = sympify(variables + parameters)
graph = sympify(sympy_graph)
# TODO: did not replace zoo with oo
# Replace complex infinity (zoo) with real infinity because SymEngine
# cannot lambdify complex infinity. We also replace in the derivatives in
# case some differentiation would produce a complex infinity. The
# replacement is assumed to be cheap enough that it's safer to replace the
# complex values and pay the minor time penalty.
inp = sympify([f.xreplace({zoo: oo}) for f in variables + parameters])
graph = sympify(sympy_graph).xreplace({zoo: oo})
func = lambdify(inp, [graph], **_get_lambidfy_options(func_options))
if include_grad or include_hess:
grad_graphs = list(graph.diff(w) for w in wrt)
grad_graphs = list(graph.diff(w).xreplace({zoo: oo}) for w in wrt)
if include_grad:
grad = lambdify(inp, grad_graphs, **_get_lambidfy_options(grad_options))
if include_hess:
hess_graphs = list(list(g.diff(w) for w in wrt) for g in grad_graphs)
hess_graphs = list(list(g.diff(w).xreplace({zoo: oo}) for w in wrt) for g in grad_graphs)
hess = lambdify(inp, hess_graphs, **_get_lambidfy_options(hess_options))
return BuildFunctionsResult(func=func, grad=grad, hess=hess)

Expand Down Expand Up @@ -159,13 +163,18 @@ def build_constraint_functions(variables, constraints, parameters=None, func_opt
wrt = variables
parameters = tuple(parameters)
constraint_func, jacobian_func, hessian_func = None, None, None
inp = sympify(variables + parameters)
graph = sympify(constraints)
# Replace complex infinity (zoo) with real infinity because SymEngine
# cannot lambdify complex infinity. We also replace in the derivatives in
# case some differentiation would produce a complex infinity. The
# replacement is assumed to be cheap enough that it's safer to replace the
# complex values and pay the minor time penalty.
inp = sympify([f.xreplace({zoo: oo}) for f in variables + parameters])
graph = sympify([f.xreplace({zoo: oo}) for f in constraints])
constraint_func = lambdify(inp, [graph], **_get_lambidfy_options(func_options))

grad_graphs = list(list(c.diff(w) for w in wrt) for c in graph)
grad_graphs = list(list(c.diff(w).xreplace({zoo: oo}) for w in wrt) for c in graph)
jacobian_func = lambdify(inp, grad_graphs, **_get_lambidfy_options(jac_options))

hess_graphs = list(list(list(g.diff(w) for w in wrt) for g in c) for c in grad_graphs)
hess_graphs = list(list(list(g.diff(w).xreplace({zoo: oo}) for w in wrt) for g in c) for c in grad_graphs)
hessian_func = lambdify(inp, hess_graphs, **_get_lambidfy_options(hess_options))
return ConstraintFunctions(cons_func=constraint_func, cons_jac=jacobian_func, cons_hess=hessian_func)

0 comments on commit faf16e9

Please sign in to comment.