From c9ccc93b882bff7ad868e8b952687718ee806dfa Mon Sep 17 00:00:00 2001 From: Richard Otis Date: Sun, 4 Aug 2019 14:59:42 -0700 Subject: [PATCH] WIP: build_functions: Add backend toggle based on count_ops() --- pycalphad/codegen/sympydiff_utils.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/pycalphad/codegen/sympydiff_utils.py b/pycalphad/codegen/sympydiff_utils.py index 84b6e03d9..3f1158a11 100644 --- a/pycalphad/codegen/sympydiff_utils.py +++ b/pycalphad/codegen/sympydiff_utils.py @@ -3,12 +3,14 @@ """ from pycalphad.core.cache import cacheit from pycalphad.core.utils import wrap_symbol_symengine -from symengine import sympify, lambdify +from symengine import sympify, lambdify, count_ops from collections import namedtuple BuildFunctionsResult = namedtuple('BuildFunctionsResult', ['func', 'grad', 'hess']) +BACKEND_OPS_THRESHOLD = 50000 + @cacheit def build_functions(sympy_graph, variables, parameters=None, wrt=None, include_obj=True, include_grad=False, include_hess=False, cse=True): @@ -23,14 +25,24 @@ def build_functions(sympy_graph, variables, parameters=None, wrt=None, include_o func, grad, hess = None, None, None inp = sympify(variables + parameters) graph = sympify(sympy_graph) + if count_ops(graph) > BACKEND_OPS_THRESHOLD: + backend = 'lambda' + else: + backend = 'llvm' # TODO: did not replace zoo with oo if include_obj: - func = lambdify(inp, [graph], backend='lambda', cse=cse) + func = lambdify(inp, [graph], backend=backend, cse=cse) if include_grad or include_hess: grad_graphs = list(graph.diff(w) for w in wrt) + grad_ops = sum(count_ops(x) for x in grad_graphs) + if grad_ops > BACKEND_OPS_THRESHOLD: + grad_backend = 'lambda' + else: + grad_backend = 'llvm' if include_grad: - grad = lambdify(inp, grad_graphs, backend='lambda', cse=cse) + grad = lambdify(inp, grad_graphs, backend=grad_backend, cse=cse) if include_hess: hess_graphs = list(list(g.diff(w) for w in wrt) for g in grad_graphs) + # Hessians are hard-coded to always use the lambda backend, for performance hess = lambdify(inp, hess_graphs, backend='lambda', cse=cse) return BuildFunctionsResult(func=func, grad=grad, hess=hess)