Skip to content

Commit

Permalink
WIP: build_functions: Add backend toggle based on count_ops()
Browse files Browse the repository at this point in the history
  • Loading branch information
richardotis committed Aug 4, 2019
1 parent 34dc2c5 commit c9ccc93
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions pycalphad/codegen/sympydiff_utils.py
Expand Up @@ -3,12 +3,14 @@
"""
from pycalphad.core.cache import cacheit
from pycalphad.core.utils import wrap_symbol_symengine
from symengine import sympify, lambdify
from symengine import sympify, lambdify, count_ops
from collections import namedtuple


BuildFunctionsResult = namedtuple('BuildFunctionsResult', ['func', 'grad', 'hess'])

BACKEND_OPS_THRESHOLD = 50000


@cacheit
def build_functions(sympy_graph, variables, parameters=None, wrt=None, include_obj=True, include_grad=False, include_hess=False, cse=True):
Expand All @@ -23,14 +25,24 @@ def build_functions(sympy_graph, variables, parameters=None, wrt=None, include_o
func, grad, hess = None, None, None
inp = sympify(variables + parameters)
graph = sympify(sympy_graph)
if count_ops(graph) > BACKEND_OPS_THRESHOLD:
backend = 'lambda'
else:
backend = 'llvm'
# TODO: did not replace zoo with oo
if include_obj:
func = lambdify(inp, [graph], backend='lambda', cse=cse)
func = lambdify(inp, [graph], backend=backend, cse=cse)
if include_grad or include_hess:
grad_graphs = list(graph.diff(w) for w in wrt)
grad_ops = sum(count_ops(x) for x in grad_graphs)
if grad_ops > BACKEND_OPS_THRESHOLD:
grad_backend = 'lambda'
else:
grad_backend = 'llvm'
if include_grad:
grad = lambdify(inp, grad_graphs, backend='lambda', cse=cse)
grad = lambdify(inp, grad_graphs, backend=grad_backend, cse=cse)
if include_hess:
hess_graphs = list(list(g.diff(w) for w in wrt) for g in grad_graphs)
# Hessians are hard-coded to always use the lambda backend, for performance
hess = lambdify(inp, hess_graphs, backend='lambda', cse=cse)
return BuildFunctionsResult(func=func, grad=grad, hess=hess)

0 comments on commit c9ccc93

Please sign in to comment.