About implementing special functions

Leonid Kovalev edited this page Dec 28, 2017 · 19 revisions

About implementing special functions

Intro

On this page I'd like to describe the process of implementing new special functions. There should be a list of all points one usually has to take care of.

Warning: This information is still work in progress. Some arguments may be incomplete or wrong.

Getting started

The class Function is the base class for all functions. It is located in sympy.core. First we start a new file with the name of the new function. Then we import the base class with

from sympy.core.function import Function

Now we can subclass from it. The example is based on the Gaussian error function erf(z). We start a new python class with:

class erf(Function):
    """
    The Gauss error function.
    """

I added a first small docstring. (See below for information how a good docstring should look like.)

TODO:

  • How to export the new function (add in __init__)
  • Better to append in files where similar functions are
  • Common baseclass for families of functions (trigonometric, bessel, ...)

Basic properties

The new function erf(z) takes only one (complex) argument. We tell SymPy this with the following variable assignment.

    nargs = 1

This has to be done on the top level in the new class. (On the same level where you usually put new def ...: statements.)

At this point we can go into a Python shell and call the new function.

>>> z = Symbol("z")
>>> erf(z)
erf(z)

However there is not much more we can do with it yet.

The arguments (in this case z) are stored in the member variable self.args which is always a tuple, even if there is only one element in it.

We can always retrieve the argument z by self.args[0] If the function would take two arguments f(x,y) then we would use self.args[0] and self.args[1] for getting back x and y.

TODO:

  • Are the args sorted? (Reminds me of a atan2 fix to finish) // Anurag - Arguments are not sorted in my opinion. args[0] stores the first argument passed and so on.

Function evaluation

To allow for evaluation of our function for given arguments (numbers or symbolic expressions) we have to add the following function to the class:

    @classmethod
    def eval(cls, arg):
        ...

Excerpt from the docstring:

The eval() method is called when the class cls is about to be instantiated and it should return either some simplified instance (possible of some other class) [...]

Now we can put all the knowledge about erf(z) inside this function. For example we know the value erf(0) to be 0, hence we write:

        # Value at zero
        if z is S.Zero:
            return S.Zero

The code in the eval method will usually take the form of a big nested if tree.

Special values

We can now add further well known function values for:

  • Values at single points in C
  • Values at (real) infinities oo and -oo

In our example of erf we know:

        elif arg is S.Infinity:
            return S.One
        elif arg is S.NegativeInfinity:
            return S.NegativeOne

this works well for finitely many special points z \in C. If we have a function with a more complicated structure then the code will become more involved. For example the evaluation of cos(k*2*pi) with k an integer should always return 0.

Symmetries

  • Even/odd functions
  • Complex mirror symmetry

Complex mirror symmetry

The definition of what is called "mirror symmetry" is given by the following transformation rule: Assume z to be a complex number, then it holds that conjugate(f(z)) = f(conjugate(z)).

We can implement this rule (read from left to right) with the code snippet:

    def _eval_conjugate(self):
        return self.func(self.args[0].conjugate())

It's better to implement the symmetry rule this way than the other way around (i.e., in .eval()) because it can then work its way down recursively. For example, f(2 + I).conjugate() would fully evaluate to f(2 - I).

evalf to a floating point representation

If you named your function after the mpmath counterpart, then evaluation should work automatically. This should be the case for the vast majority of functions implemented in sympy. If your function is essentially equivalent to an mpmath function but has a different name (hopefully for a good reason!), then you can add a translation in sympy/utilities/lambdify.py. This is the case for example the trigonometric integrals like Shi(x). Finally, if your function has no mpmath counterpart, you can implement an _eval_evalf(self, prec) method. Here prec is binary precision (that is, the requested number of significant binary digits). This is done for example for the RootOf class in sympy/polys/rootoftools.py.

Differentiation

Method fdiff implements the derivative of a function. For example, the derivative of sine is cosine:

def fdiff(self, argindex=1):
    if argindex == 1:
        return cos(self.args[0])
    else:
        raise ArgumentIndexError(self, argindex)

Multiple arguments

The parameter argindex is the (1-based) index of the argument in which the derivative is taken. So the derivative of a function with two variables would be implemented as

def fdiff(self, argindex):
    x, y = self.args
    if argindex == 1:
        return # derivative in x
    elif argindex == 2:
        return # derivative in y
    else:
        raise ArgumentIndexError(self, argindex)  # wrong index

Series expansions

  • Which methods to implement?

Taylor series expansions

If we know a closed form expression for the n-th term in the Taylor series expansion we can implement the method taylor_term(n, x, ...). The example shows the implementation of from a Fresnel function:

    @staticmethod
    def taylor_term(n, x, *previous_terms):
        if n < 0:
            return S.Zero
        else:
            x = sympify(x)                                                                               
            return x*(-x**4)**n*(S(2)**(-2*n)*pi**(2*n))/((4*n+1)*C.factorial(2*n))

This is then helpful for expressing Taylor expansions like:

z = Symbol("z")
n = Symbol("n", integer=True)
Sum(fresnelc(z).taylor_term(n, z), (n,0,oo))

Integration

Limit calculation

  • Simple cases

  • Extensions to allow a function to be tractable by the Gruntz algorithm

Printing

  • The str printer should print the function in a way that recreates it in Python. The pretty printers should print it using mathematical notation. The most important printers are the Unicode and ASCII pretty printers, and the LaTeX printer. Support in the other printers is nice too. See http://docs.sympy.org/dev/modules/printing.html. Note that printing for built in SymPy functions should go in the printer classes (sympy/printing), not on the objects themselves. TODO: Add some docs on how to update the printers, especially the Unicode pretty printer.

  • Code generators. If there is a nice way to implement the code generation for a given language, it can be added to the codegen printers.

Tests

  • Everything above should be tested (obviously). Follow the examples of tests for other functions. Random numerical tests can be useful for testing mathematical equivalences (see existing tests for examples).

TODO

  • f(x) is an instance of AppliedUndef
  • f is an instance of UndefinedFunction
  • sin(x) is an instance of Function
  • sin is an instance of FunctionClass
Clone this wiki locally
You can’t perform that action at this time.
You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.
Press h to open a hovercard with more details.