In [1]:
from pycalphad import Database, Model
from pycalphad.core.utils import make_callable
import pycalphad.variables as v
import numba
import numpy as np

In [2]:
dbf = Database('Al-Ni/Al-Ni-Dupin-2001.tdb')
mod = Model(dbf, ['AL', 'NI', 'VA'], 'FCC_L12')

In [3]:
from sympy import lambdify
from pycalphad.core.utils import NumPyPrinter
@numba.jit('float64(boolean, float64, float64)', nopython=True, nogil=True)
def where(condition, x, y):
    if condition:
        return x
    else:
        return y

In [4]:

x1 = np.random.rand(5, 1, 1)*2000
x2 = np.random.rand(5, 5, 1)
x3 = np.random.rand(5, 5, 1)
x4 = np.random.rand(5, 5, 6000)
x5 = np.random.rand(5, 5, 6000)
x6 = np.random.rand(5, 5, 6000)
#np.testing.assert_allclose(vec_func(x1, x2, x3, x4, x5, x6), pure_func(x1, x2, x3, x4, x5, x6))
#%timeit vec_func(x1, x2, x3, x4, x5, x6)

In [5]:
from sympy import lambdify, count_ops
import numba
def make_gradient_from_graph(graph, wrt):
    wrt = tuple(wrt)
    grads = np.empty((len(wrt)), dtype=object)
    hess_indices = []
    namespace = {}
    for i in range(len(wrt)):
        grads[i] = graph.diff(wrt[i])
        #print('------------------------------------')
        #print(grads[i])
        for j in range(i, len(wrt)):
            hess_elem = grads[i].diff(wrt[j])
            hess_elem_func = lambdify(tuple(wrt), hess_elem, dummify=True,\
                                      modules=[{'where': where}, 'numpy'], printer=NumPyPrinter)
            namespace['hess_{0}{1}'.format(i, j)] = numba.jit('float64({})'.format(','.join(['float64'] * len(wrt))),
                                                              nopython=True, nogil=True)(hess_elem_func)
            hess_indices.append((i, j))
        grad_elem_func = lambdify(tuple(wrt), grads[i], dummify=True,\
                                  modules=[{'where': where}, 'numpy'], printer=NumPyPrinter)
        namespace['grad_{0}'.format(i)] = numba.jit('float64({})'.format(','.join(['float64'] * len(wrt))),
                                                    nopython=True, nogil=True)(grad_elem_func)
    # Build the gradient and Hessian using compile() and exec
    # We do this because Numba needs "static" information about the arguments and functions
    call_args = ','.join(['_x{0}'.format(i) for i in range(len(wrt))])
    call_passed_args = ','.join(['_x{0}[0]'.format(i) for i in range(len(wrt))])

    grad_code = 'def grad_func({0}, lengthfix, result):'.format(call_args)
    grad_list = ['    result[..., {0}] = grad_{0}({1})'.format(i, call_passed_args) for i in range(len(wrt))]
    grad_code = grad_code + '\n' + '\n'.join(grad_list)
    #print(grad_code)
    
    grad_code = compile(grad_code, '<string>', 'exec')
    try:
        exec grad_code in namespace
    except SyntaxError:
        exec(grad_code, namespace)

    # Now construct the Hessian
    hess_code = 'def hess_func({0}, lengthfix, result):'.format(call_args)
    hess_list = ['    result[{0},{1}] = result[{1}, {0}] = hess_{0}{1}({2})'.format(i, j, call_passed_args) \
                 for i, j in hess_indices]
    hess_code = hess_code + '\n' + '\n'.join(hess_list)
    #print(hess_code)
    
    hess_code = compile(hess_code, '<string>', 'exec')
    try:
        exec hess_code in namespace
    except SyntaxError:
        exec(hess_code, namespace)

    grad_func = numba.guvectorize([','.join(['float64[:]'] * (len(wrt)+2))],
                                   ','.join(['()'] * len(wrt)) + ',(n)->(n)', nopython=True)(namespace['grad_func'])
    hess_func = None
    hess_func = numba.guvectorize([','.join(['float64[:]'] * (len(wrt)+1)) + ',float64[:,:]'],
                                   ','.join(['()'] * len(wrt)) + ',(n)->(n,n)', nopython=True)(namespace['hess_func'])
    return grad_func, hess_func

In [6]:
%time grad_func, hess_func = make_gradient_from_graph(mod.ast, mod.variables)

CPU times: user 2min 29s, sys: 634 ms, total: 2min 29s
Wall time: 2min 29s


In [7]:
%time grad_func(x1, x2, x3, x4, x5, x6, np.empty(6), np.zeros(x6.shape + (6,)))

CPU times: user 339 ms, sys: 1 ms, total: 340 ms
Wall time: 341 ms


array([[[[  1.42907095e+03,  -4.99213737e+04,  -7.70312405e+04,
           -2.78291675e+05,  -2.88806916e+05,  -7.14257185e+04],
         [  1.22674823e+03,  -5.85084903e+04,  -5.62875954e+04,
           -2.38297377e+05,  -2.46886197e+05,  -6.59177112e+04],
         [  9.46422288e+02,  -2.05268579e+04,  -2.31703195e+04,
           -1.79064904e+05,  -1.81666961e+05,  -3.24274073e+04],
         ..., 
         [  1.31827788e+03,  -4.35055436e+04,  -6.70149392e+04,
           -2.54631661e+05,  -2.63945208e+05,  -6.26772339e+04],
         [  8.60420625e+02,  -1.07536356e+04,  -1.33131716e+04,
           -1.61217941e+05,  -1.62292850e+05,  -1.96229186e+04],
         [  8.54265977e+02,  -1.28082732e+04,  -1.46817503e+04,
           -1.59289389e+05,  -1.60636815e+05,  -2.76918114e+04]],

        [[  1.39427932e+03,  -5.92039278e+04,  -3.13226247e+04,
           -2.85437147e+05,  -2.76707605e+05,  -4.02089132e+04],
         [  1.38234626e+03,  -3.89737289e+04,  -6.26200478e+04,
           -2.73

In [8]:
result = np.zeros(x6.shape + (6,6))
%time hess_func(x1, x2, x3, x4, x5, x6, np.empty(6), result)
print(result)

CPU times: user 1.02 s, sys: 10 ms, total: 1.03 s
Wall time: 1.03 s
[[[[[  1.45361401e+03  -2.29986420e+01  -1.73589195e+01  -1.02005939e+04
      -1.01999411e+04   2.74441554e+03]
    [ -2.29986420e+01  -4.74392218e+04  -6.93471628e+04  -1.77374376e+05
      -4.11677105e+05  -9.88653984e+04]
    [ -1.73589195e+01  -6.93471628e+04  -5.72915605e+04  -3.96437720e+05
      -2.04358468e+05  -1.53607366e+05]
    [ -1.02005939e+04  -1.77374376e+05  -3.96437720e+05  -2.46217342e+05
      -5.19359773e+05  -5.35529146e+05]
    [ -1.01999411e+04  -4.11677105e+05  -2.04358468e+05  -5.19359773e+05
      -2.68361172e+05  -5.56156449e+05]
    [  2.74441554e+03  -9.88653984e+04  -1.53607366e+05  -5.35529146e+05
      -5.56156449e+05  -1.31204415e+05]]

   [[  1.24397101e+03  -1.36274063e+01  -1.09969754e+01  -8.72771452e+03
      -8.72540495e+03   2.18926887e+03]
    [ -1.36274063e+01  -3.68949488e+04  -7.54660758e+04  -1.48501843e+05
      -3.56051190e+05  -1.21523230e+05]
    [ -1.09969754e+01  -7.

In [9]:
from pycalphad.core.autograd_utils import build_functions

%time obj, ag_grad_func, ag_hess_func = build_functions(mod.ast, tuple(mod.variables))

%time ag_grad_func(x1, x2, x3, x4, x5, x6)

CPU times: user 284 ms, sys: 7 ms, total: 291 ms
Wall time: 269 ms
CPU times: user 977 ms, sys: 335 ms, total: 1.31 s
Wall time: 1.31 s


array([[[[ -4.14694549e+01,  -3.61785980e+04,  -6.32884649e+04,
           -9.55165965e+03,  -2.00669003e+04,  -6.68322263e+04],
         [ -3.26658020e+01,  -4.67477798e+04,  -4.45268849e+04,
           -8.31665009e+03,  -1.69054706e+04,  -6.19867199e+04],
         [ -1.74967683e+01,  -1.15145430e+04,  -1.41580047e+04,
           -2.82906340e+03,  -5.43112025e+03,  -2.94150612e+04],
         ..., 
         [ -3.63379166e+01,  -3.08448636e+04,  -5.43542592e+04,
           -7.05203000e+03,  -1.63655766e+04,  -5.84454298e+04],
         [ -1.28675825e+01,  -2.58563846e+03,  -5.14517446e+03,
           -1.49273430e+03,  -2.56764327e+03,  -1.68927837e+04],
         [ -1.29359535e+01,  -4.69783075e+03,  -6.57130781e+03,
           -6.89665189e+02,  -2.03709073e+03,  -2.49809141e+04]],

        [[ -3.44059708e+01,  -4.58522834e+04,  -1.79709803e+04,
           -2.43457044e+04,  -1.56161621e+04,  -3.57461558e+04],
         [ -3.65948937e+01,  -2.57143586e+04,  -4.93606774e+04,
           -1.39