In [1]:
from sympy.matrices import Matrix
from sympy import MatrixSymbol
from sympy import Symbol
import sympy as sp
from sympy.printing.fortran import fcode
import sympy.printing as printing

In [2]:
def apply_cse_and_return_fcode(expr,assign_to,source_format='free',standard=95, skip_cse=False):
    """Apply common subexpression simplification and return Fortran code
    
    Input:
        expr - a list of expressions we would like to evaluate
        assign_to - the symbols we would like to assign them to (either a tuple of strings or sp.symbols)
    
    Output:
        String containing subroutine or function body
    """
    if skip_cse:
        rexpr = sp.simplify(expr)
        ret = ''
    else:
        rvar, rexpr = sp.cse(sp.simplify(expr),
            order='none',
            list=False)

        ret = '  real(dp) :: ' + ','.join(str(var) for (var, _) in rvar) + '\n\n'

        for var, var_expr in rvar:
            ret += '  ' + fcode(var_expr,assign_to=var,source_format=source_format,standard=standard) + '\n'
        
        ret += '\n'

    for var, var_expr in zip(assign_to,rexpr):
        ret += '  ' + fcode(var_expr,assign_to=var,source_format=source_format,standard=standard) + '\n'

    return ret

In [3]:
def binary_overload(fun, is_hyper_dual=True, skip_cse=False, v_type='dual'):
    """Generate Fortran code for hyper-dual-number overloads of binary functions

    Parameters
    ------------
        fun: Function
            The binary function to overload
        is_hyper_dual: Logical
            If True, the function is overloaded for hyper-dual numbers instead of dual numbers
        skip_cse: Logical
            Skip common subexpression elimination if True
    Returns
    ------------
        code: String
            The generated Fortran code
        expr: Tuple
            The sympy expressions that were generated

    """
    u = Symbol('u_x')
    v = Symbol('v_x')

    f = fun(u, v)

    # n = Symbol('n')

    du = Symbol('u_dx')
    if v_type == 'dual':
        dv = Symbol('v_dx')
    else:
        dv = 0

    if is_hyper_dual:
        du2 = Symbol('u_dx2')
        if v_type == 'dual':
            dv2 = Symbol('v_dx2')
        else:
            dv2 = 0

        ddu = Symbol('u_ddx')
        if v_type == 'dual':
            ddv = Symbol('v_ddx')
        else:
            ddv = 0

    x = Matrix([u, v])
    dx = Matrix([du, dv])
    if is_hyper_dual:
        dx2 = Matrix([du2, dv2])
        ddx = Matrix([ddu, ddv])

    df = (sp.diff(f, x).T*dx)[0]
    if is_hyper_dual:
        df2 = (sp.diff(f, x).T*dx2)[0]
        ddf = (dx.T*sp.hessian(f, x)*dx2 + sp.diff(f, x).T*ddx)[0]

    if is_hyper_dual:
        expr = (f, df, df2, ddf)
        assign_to = ('res%x', 'res%dx', 'res%dx2', 'res%ddx')
    else:
        expr = (f, df)
        assign_to = ('res%x', 'res%dx')

    code = apply_cse_and_return_fcode(expr, assign_to, source_format='free',standard=95, skip_cse=skip_cse)

    code = code.replace('u_x', f'u%x')
    code = code.replace('v_x', f'v%x')
    code = code.replace('u_dx', f'u%dx')
    code = code.replace('v_dx', f'v%dx')
    if is_hyper_dual:
        code = code.replace('u_dx2', f'u%dx2')
        code = code.replace('v_dx2', f'v%dx2')
        code = code.replace('u_ddx', f'u%ddx')
        code = code.replace('v_ddx', f'v%ddx')

    return code, expr

    # print(ret)

# Div (/) operator

### Dual

In [6]:

def div(u, v):
    return u/v

ret, exprs = binary_overload(div, is_hyper_dual=True, skip_cse=True)

for expr in exprs:
    display(expr)

print(ret)


u_x/v_x

u_dx/v_x - u_x*v_dx/v_x**2

u_dx2/v_x - u_x*v_dx2/v_x**2

u_ddx/v_x - u_dx2*v_dx/v_x**2 - u_x*v_ddx/v_x**2 + v_dx2*(-u_dx/v_x**2 + 2*u_x*v_dx/v_x**3)

  res%x = u%x/v%x
  res%dx = u%dx/v%x - u%x*v%dx/v%x**2
  res%dx2 = u%dx2/v%x - u%x*v%dx2/v%x**2
  res%ddx = u%ddx/v%x - u%dx2*v%dx/v%x**2 - u%x*v%ddx/v%x**2 + v%dx2*( &
      -u%dx/v%x**2 + 2*u%x*v%dx/v%x**3)



In [7]:
ret, exprs = binary_overload(div, is_hyper_dual=True, skip_cse=True, v_type='integer')

for expr in exprs:
    display(expr)

print(ret)

u_x/v_x

u_dx/v_x

u_dx2/v_x

u_ddx/v_x

  res%x = u%x/v%x
  res%dx = u%dx/v%x
  res%dx2 = u%dx2/v%x
  res%ddx = u%ddx/v%x



In [8]:
n = Symbol('n')
X = MatrixSymbol('X', n, 1)

NameError: name 'MatrixSymbol' is not defined