Skip to content

Commit

Permalink
New module .codegen.ffunctions (fortran functions)
Browse files Browse the repository at this point in the history
  • Loading branch information
bjodah committed Mar 21, 2017
1 parent 723f6c1 commit ceb9c4d
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 18 deletions.
8 changes: 8 additions & 0 deletions doc/src/modules/codegen.rst
Expand Up @@ -551,8 +551,16 @@ available with ``autowrap``.
There are other facilities available with Sympy to do efficient numeric
computation. See :ref:`this<numeric_computation>` page for a comparison among them.


Special (finite precision arithmetic) math functions
----------------------------------------------------

.. automodule:: sympy.codegen.cfunctions
:members:


Fortran specific functions
--------------------------

.. automodule:: sympy.codegen.ffunctions
:members:
71 changes: 71 additions & 0 deletions sympy/codegen/ffunctions.py
@@ -0,0 +1,71 @@
"""
Functions with corresponding implementations in Fortran.
The functions defined in this module allows the user to express functions such as ``dsign``
as a SymPy function for symbolic manipulation.
"""
from sympy.core.function import Function
from sympy.core.numbers import Float

class FFunction(Function):
_required_standard = 77

def _fcode(self, printer):
name = self.__class__.__name__
if printer._settings['standard'] < self._required_standard:
raise NotImplementedError("%s requires Fortran %d or newer" %
(name, self._required_standard))
return '{0}({1})'.format(name, ', '.join(map(printer._print, self.args)))

class F95Function(FFunction):
_required_standard = 95


class isign(FFunction):
""" Fortran sign intrinsic with for integer arguments. """
nargs = 2


class dsign(FFunction):
""" Fortran sign intrinsic with for double precision arguments. """
nargs = 2


class cmplx(FFunction):
""" Fortran complex conversion function. """
nargs = 2 # may be extended to (2, 3) at a later point


class kind(FFunction):
""" Fortran kind function. """
nargs = 1


class merge(F95Function):
""" Fortran merge function """
nargs = 3


class _literal(Float):
_token = None
_decimals = None

def _fcode(self, printer):
mantissa, sgnd_ex = ('%.{0}e'.format(self._decimals) % self).split('e')
mantissa = mantissa.strip('0').rstrip('.')
ex_sgn, ex_num = sgnd_ex[0], sgnd_ex[1:].lstrip('0')
ex_sgn = '' if ex_sgn == '+' else ex_sgn
return (mantissa or '0') + self._token + ex_sgn + (ex_num or '0')


class literal_sp(_literal):
""" Fortran single precision real literal """
_token = 'e'
_decimals = 9


class literal_dp(_literal):
""" Fortran double precision real literal """
_token = 'd'
_decimals = 17
29 changes: 29 additions & 0 deletions sympy/codegen/tests/test_ffunctions.py
@@ -0,0 +1,29 @@
from sympy import Symbol
from sympy.codegen.ffunctions import isign, dsign, cmplx, kind, literal_dp
from sympy.printing.fcode import fcode


def test_isign():
x = Symbol('x', integer=True)
assert isign(1, x) == isign(1, x)
assert fcode(isign(1, x), standard=95, source_format='free') == 'isign(1, x)'


def test_dsign():
x = Symbol('x')
assert dsign(1, x) == dsign(1, x)
assert fcode(dsign(literal_dp(1), x), standard=95, source_format='free') == 'dsign(1d0, x)'


def test_cmplx():
x = Symbol('x')
assert cmplx(1, x) == cmplx(1, x)


def test_kind():
x = Symbol('x')
assert kind(x) == kind(x)


def test_literal_dp():
assert fcode(literal_dp(0), source_format='free') == '0d0'
57 changes: 57 additions & 0 deletions sympy/core/tests/test_args.py
Expand Up @@ -3761,30 +3761,37 @@ def test_sympy__codegen__ast__Assignment():
from sympy.codegen.ast import Assignment
assert _test_args(Assignment(x, y))


def test_sympy__codegen__cfunctions__expm1():
from sympy.codegen.cfunctions import expm1
assert _test_args(expm1(x))


def test_sympy__codegen__cfunctions__log1p():
from sympy.codegen.cfunctions import log1p
assert _test_args(log1p(x))


def test_sympy__codegen__cfunctions__exp2():
from sympy.codegen.cfunctions import exp2
assert _test_args(exp2(x))


def test_sympy__codegen__cfunctions__log2():
from sympy.codegen.cfunctions import log2
assert _test_args(log2(x))


def test_sympy__codegen__cfunctions__fma():
from sympy.codegen.cfunctions import fma
assert _test_args(fma(x, y, z))


def test_sympy__codegen__cfunctions__log10():
from sympy.codegen.cfunctions import log10
assert _test_args(log10(x))


def test_sympy__codegen__cfunctions__Sqrt():
from sympy.codegen.cfunctions import Sqrt
assert _test_args(Sqrt(x))
Expand All @@ -3798,6 +3805,56 @@ def test_sympy__codegen__cfunctions__hypot():
assert _test_args(hypot(x, y))


def test_sympy__codegen__ffunctions__FFunction():
from sympy.codegen.ffunctions import FFunction
assert _test_args(FFunction('f'))


def test_sympy__codegen__ffunctions__F95Function():
from sympy.codegen.ffunctions import F95Function
assert _test_args(F95Function('f'))


def test_sympy__codegen__ffunctions__isign():
from sympy.codegen.ffunctions import isign
assert _test_args(isign(1, x))


def test_sympy__codegen__ffunctions__dsign():
from sympy.codegen.ffunctions import dsign
assert _test_args(dsign(1, x))


def test_sympy__codegen__ffunctions__cmplx():
from sympy.codegen.ffunctions import cmplx
assert _test_args(cmplx(x, y))


def test_sympy__codegen__ffunctions__kind():
from sympy.codegen.ffunctions import kind
assert _test_args(kind(x))


def test_sympy__codegen__ffunctions__merge():
from sympy.codegen.ffunctions import merge
assert _test_args(merge(1, 2, Eq(x, 0)))


def test_sympy__codegen__ffunctions___literal():
from sympy.codegen.ffunctions import _literal
assert _test_args(_literal(1))


def test_sympy__codegen__ffunctions__literal_sp():
from sympy.codegen.ffunctions import literal_sp
assert _test_args(literal_sp(1))


def test_sympy__codegen__ffunctions__literal_dp():
from sympy.codegen.ffunctions import literal_dp
assert _test_args(literal_dp(1))


def test_sympy__vector__coordsysrect__CoordSysCartesian():
from sympy.vector.coordsysrect import CoordSysCartesian
assert _test_args(CoordSysCartesian('C'))
Expand Down
34 changes: 24 additions & 10 deletions sympy/printing/fcode.py
Expand Up @@ -24,8 +24,10 @@
from sympy.core import S, Add, N
from sympy.core.compatibility import string_types, range
from sympy.core.function import Function
from sympy.core.relational import Eq
from sympy.sets import Range
from sympy.codegen.ast import Assignment
from sympy.codegen.ffunctions import isign, dsign, cmplx, merge, literal_dp
from sympy.printing.codeprinter import CodePrinter
from sympy.printing.precedence import precedence

Expand All @@ -43,11 +45,11 @@
"log": "log",
"exp": "exp",
"erf": "erf",
"Abs": "Abs",
"sign": "sign",
"Abs": "abs",
"conjugate": "conjg"
}


class FCodePrinter(CodePrinter):
"""A printer to convert sympy expressions to strings of Fortran code"""
printmethod = "_fcode"
Expand Down Expand Up @@ -106,14 +108,6 @@ def _get_statement(self, codestring):

def _get_comment(self, text):
return "! {0}".format(text)
#issue 12267
def _print_sign(self,func):
if func.args[0].is_integer:
return "merge(0, isign(1, {0}), {0} == 0)".format(self._print(func.args[0]))
elif func.args[0].is_complex:
return "merge(cmplx(0d0, 0d0), {0}/abs({0}), abs({0}) == 0d0)".format(self._print(func.args[0]))
else:
return "merge(0d0, dsign(1d0, {0}), {0} == 0d0)".format(self._print(func.args[0]))

def _declare_number_const(self, name, value):
return "parameter ({0} = {1})".format(name, value)
Expand All @@ -136,6 +130,18 @@ def _get_loop_opening_ending(self, indices):
close_lines.append("end do")
return open_lines, close_lines

def _print_sign(self, expr):
from sympy import Abs
arg, = expr.args
if arg.is_integer:
new_expr = merge(0, isign(1, arg), Eq(arg, 0))
elif arg.is_complex:
new_expr = merge(cmplx(literal_dp(0), literal_dp(0)), arg/Abs(arg), Eq(Abs(arg), literal_dp(0)))
else:
new_expr = merge(literal_dp(0), dsign(literal_dp(1), arg), Eq(arg, literal_dp(0)))
return self._print(new_expr)


def _print_Piecewise(self, expr):
if expr.args[-1].cond != True:
# We need the last conditional to be a True, otherwise the resulting
Expand Down Expand Up @@ -292,6 +298,14 @@ def _print_For(self, expr):
'end do').format(target=target, start=start, stop=stop,
step=step, body=body)

def _print_Equality(self, expr):
lhs, rhs = expr.args
return ' == '.join(map(self._print, (lhs, rhs)))

def _print_Unequality(self, expr):
lhs, rhs = expr.args
return ' /= '.join(map(self._print, (lhs, rhs)))

def _pad_leading_columns(self, lines):
result = []
for line in lines:
Expand Down
16 changes: 9 additions & 7 deletions sympy/printing/tests/test_fcode.py
Expand Up @@ -20,14 +20,16 @@ class nint(Function):
def _fcode(self, printer):
return "nint(%s)" % printer._print(self.args[0])
assert fcode(nint(x)) == " nint(x)"
#issue 12267
def test_fcode_sign():


def test_fcode_sign(): #issue 12267
x=symbols('x')
y=symbols('y', integer=True)
z=symbols('z', complex=True)
assert fcode(sign(x), source_format='free') == "merge(0d0, dsign(1d0, x), x == 0d0)"
assert fcode(sign(y), source_format='free') == "merge(0, isign(1, y), y == 0)"
assert fcode(sign(z), source_format='free') == "merge(cmplx(0d0, 0d0), z/abs(z), abs(z) == 0d0)"
assert fcode(sign(x), standard=95, source_format='free') == "merge(0d0, dsign(1d0, x), x == 0d0)"
assert fcode(sign(y), standard=95, source_format='free') == "merge(0, isign(1, y), y == 0)"
assert fcode(sign(z), standard=95, source_format='free') == "merge(cmplx(0d0, 0d0), z/abs(z), abs(z) == 0d0)"
raises(NotImplementedError, lambda: fcode(sign(x)))


def test_fcode_Pow():
Expand Down Expand Up @@ -351,8 +353,8 @@ def test_fcode_Xlogical():

def test_fcode_Relational():
x, y = symbols("x y")
assert fcode(Relational(x, y, "=="), source_format="free") == "Eq(x, y)"
assert fcode(Relational(x, y, "!="), source_format="free") == "Ne(x, y)"
assert fcode(Relational(x, y, "=="), source_format="free") == "x == y"
assert fcode(Relational(x, y, "!="), source_format="free") == "x /= y"
assert fcode(Relational(x, y, ">="), source_format="free") == "x >= y"
assert fcode(Relational(x, y, "<="), source_format="free") == "x <= y"
assert fcode(Relational(x, y, ">"), source_format="free") == "x > y"
Expand Down
2 changes: 1 addition & 1 deletion sympy/utilities/tests/test_codegen.py
Expand Up @@ -751,7 +751,7 @@ def test_intrinsic_math_codegen():
'REAL*8 function test_abs(x)\n'
'implicit none\n'
'REAL*8, intent(in) :: x\n'
'test_abs = Abs(x)\n'
'test_abs = abs(x)\n'
'end function\n'
'REAL*8 function test_acos(x)\n'
'implicit none\n'
Expand Down

0 comments on commit ceb9c4d

Please sign in to comment.