In [1]:
from fortran_gen import unary_overload, binary_overload
import sympy as sp
from jinja2 import Template

In [None]:
# Jinja2 template for a fypp macro for a binary function for dual numbers
binary_macro_d = '''
${elemental_purity}$ function {{interface_name}}_{{usn}}_{{vsn}}(u, v) result(res)
    {{u_type}}, intent(in) :: u
    {{v_type}}, intent(in) :: v
    {{res_type}} :: res
    {{x_body}}
    {{dx_body}}
  #:if test_coverage == True
    {{interface_name}}_{{usn}}_{{vsn}}_counter = {{interface_name}}_{{usn}}_{{vsn}}_counter + 1
  #:endif
end function
'''
# Jinja2 template for a fypp macro for a binary function for hyper-dual numbers
binary_macro_hd = '''
${elemental_purity}$ function {{interface_name}}_{{usn}}_{{vsn}}(u, v) result(res)
    {{u_type}}, intent(in) :: u
    {{v_type}}, intent(in) :: v
    {{res_type}} :: res
    integer :: j
    {{x_body}}
    {{dx_body}}
    {{ddx_body}}
  #:if test_coverage == True
    {{interface_name}}_{{usn}}_{{vsn}}_counter = {{interface_name}}_{{usn}}_{{vsn}}_counter + 1
  #:endif
end function
'''
binary_macro_d_tm = Template(binary_macro_d)
binary_macro_hd_tm = Template(binary_macro_hd)



In [2]:

def unary_add(u):
    return u
def unary_minus(u):
    return -u
def add(u, v):
    return u + v
def minus(u, v):
    return u - v
def mult(u, v):
    return u*v
def div(u, v):
    return u/v
def pow(u, v):
    return u**v

is_hyper_dual_v = [False, True]
dual_type_short_name_v = ['d', 'hd']
skip_cse = True

funs = add, minus, mult, div

for fun in funs:
    code = []
    for is_hyper_dual, sn in zip(is_hyper_dual_v, dual_type_short_name_v):
        code.append(f'#:def {fun.__name__}_{sn}(dual_type, real_kind, is_pure=True, test_coverage=False)')
        code.append('    #!set elemental_purity = dnad.elemental_purity(is_pure, test_coverage)')
        if fun.__name__ == 'add':
            code, exprs = unary_overload(unary_add, is_hyper_dual=is_hyper_dual, skip_cse=skip_cse, u_type='dual', code=code)
        elif fun.__name__ == 'minus':
            code, exprs = unary_overload(unary_minus, is_hyper_dual=is_hyper_dual, skip_cse=skip_cse, u_type='dual', code=code)
        code, exprs = binary_overload(fun, is_hyper_dual=is_hyper_dual, skip_cse=skip_cse, num_types=('dual','dual'), code=code)
        code, exprs = binary_overload(fun, is_hyper_dual=is_hyper_dual, skip_cse=skip_cse, num_types=('dual','real'), code=code)
        code, exprs = binary_overload(fun, is_hyper_dual=is_hyper_dual, skip_cse=skip_cse, num_types=('real','dual'), code=code)
        code, exprs = binary_overload(fun, is_hyper_dual=is_hyper_dual, skip_cse=skip_cse, num_types=('dual','integer'), code=code)
        code, exprs = binary_overload(fun, is_hyper_dual=is_hyper_dual, skip_cse=skip_cse, num_types=('integer','dual'), code=code)
        code.append('#:enddef')

    fid = open(f'macros/implementations/{fun.__name__}_gen.fypp', 'w')
    for line in code:
        fid.write(line + '\n')
    fid.close()

fun = pow
code = []
for is_hyper_dual, sn in zip(is_hyper_dual_v, dual_type_short_name_v):
    code.append(f'#:def {fun.__name__}_{sn}(dual_type, real_kind, is_pure=True, test_coverage=False)')
    code.append('    #!set elemental_purity = dnad.elemental_purity(is_pure, test_coverage)')
    code, exprs = binary_overload(fun, is_hyper_dual=is_hyper_dual, skip_cse=skip_cse, num_types=('dual','integer'), code=code)
    code, exprs = binary_overload(fun, is_hyper_dual=is_hyper_dual, skip_cse=skip_cse, num_types=('dual','real'), code=code)
    code, exprs = binary_overload(fun, is_hyper_dual=is_hyper_dual, skip_cse=skip_cse, num_types=('dual','dual'), code=code)
    code.append('#:enddef')
fid = open(f'macros/implementations/{fun.__name__}_gen.fypp', 'w')
for line in code:
    fid.write(line + '\n')
fid.close()


fun = sp.sqrt
code = []
for is_hyper_dual, sn in zip(is_hyper_dual_v, dual_type_short_name_v):
    code.append(f'#:def {fun.__name__}_{sn}(dual_type, real_kind)')
    code, exprs = unary_overload(fun, is_hyper_dual=is_hyper_dual, skip_cse=skip_cse, u_type='dual', code=code)
    code.append('#:enddef')
for il, line in enumerate(code):
    line = line.replace('(1.0d0/2.0d0)', '0.5_${real_kind}$')
    line = line.replace('1.0d0/4.0d0', '0.25_${real_kind}$')
    line = line.replace('(3.0d0/2.0d0)', '1.5_${real_kind}$')
    code[il] = line
    #     res%dx = (1.0d0/2.0d0)*u%dx/sqrt(u%x)
    #     do j = 1, ${size_dx}$
    #         res%ddx(:, j) = (1.0d0/2.0d0)*u%ddx(:, j)/sqrt(u%x) - 1.0d0/4.0d0*u%dx*u%dx(j)/ &
    #   u%x**(3.0d0/2.0d0)
    #     end do


fid = open(f'macros/implementations/{fun.__name__}_gen.fypp', 'w')
for line in code:
    fid.write(line + '\n')
fid.close()


fun = sp.log
code = []
for is_hyper_dual, sn in zip(is_hyper_dual_v, dual_type_short_name_v):
    code.append(f'#:def {fun.__name__}_{sn}(dual_type, real_kind)')
    code, exprs = unary_overload(fun, is_hyper_dual=is_hyper_dual, skip_cse=skip_cse, u_type='dual', code=code)
    code.append('#:enddef')
fid = open(f'macros/implementations/{fun.__name__}_gen.fypp', 'w')
for line in code:
    fid.write(line + '\n')
fid.close()
