Skip to content

Commit

Permalink
Create a C++ PySymbol if Symbol is subclassed.
Browse files Browse the repository at this point in the history
  • Loading branch information
isuruf committed Jun 2, 2017
1 parent 5de6845 commit 83bd7ae
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 100 deletions.
6 changes: 2 additions & 4 deletions symengine/__init__.py
Expand Up @@ -3,14 +3,12 @@
have_mpfr, have_mpc, have_flint, have_piranha, have_llvm,
Integer, Rational, Float, Number, RealNumber, RealDouble,
ComplexDouble, Max, Min, DenseMatrix, Matrix, ImmutableMatrix,
ImmutableDenseMatrix, MutableDenseMatrix, MutableMatrix,
sin, cos, tan, cot, csc, sec, asin, acos, atan, acot, acsc, asec,
sinh, cosh, tanh, coth, sech, csch, asinh, acosh, atanh, acoth,
asech, acsch, atan2, exp, log, gamma, sqrt,
ImmutableDenseMatrix, MutableDenseMatrix, Basic,
Lambdify, LambdifyCSE, Lambdify as lambdify, DictBasic, symarray,
series, diff, zeros, eye, diag,ones, zeros,
add, expand, has_symbol, UndefFunction)
from .utilities import var, symbols
from .functions import *

if have_mpfr:
from .lib.symengine_wrapper import RealMPFR
Expand Down
22 changes: 5 additions & 17 deletions symengine/functions.py
@@ -1,17 +1,5 @@
import symengine
from types import ModuleType
import sys

functions = ModuleType(__name__ + ".functions")
sys.modules[functions.__name__] = functions

functions.sqrt = sqrt
functions.exp = exp

for name in ("""sin cos tan cot csc sec
asin acos atan acot acsc asec
sinh cosh tanh coth sech csch
asinh acosh atanh acoth asech acsch
gamma log atan2""").split():
setattr(functions, name, getattr(symengine, name))

from .lib.symengine_wrapper import (sin, cos, tan, cot, csc, sec,
asin, acos, atan, acot, acsc, asec,
sinh, cosh, tanh, coth, sech, csch,
asinh, acosh, atanh, acoth, asech, acsch,
gamma, log, atan2, sqrt, exp)
41 changes: 15 additions & 26 deletions symengine/lib/symengine_wrapper.pyx
Expand Up @@ -37,7 +37,7 @@ cdef c2py(RCP[const symengine.Basic] o):
elif (symengine.is_a_Symbol(deref(o))):
if (symengine.is_a_PySymbol(deref(o))):
return <object>(deref(symengine.rcp_static_cast_PySymbol(o)).get_py_object())
r = Symbol.__new__(Symbol)
r = Basic.__new__(Symbol)
elif (symengine.is_a_Constant(deref(o))):
r = Constant.__new__(Constant)
elif (symengine.is_a_PyFunction(deref(o))):
Expand Down Expand Up @@ -666,7 +666,8 @@ cdef class Basic(object):
return d

def coeff(self, x, n=1):
cdef Symbol _x = _sympify(x)
cdef Basic _x = _sympify(x)
require(_x, Symbol)
cdef Basic _n = _sympify(n)
return c2py(symengine.coeff(deref(self.thisptr), deref(_x.thisptr), deref(_n.thisptr)))

Expand Down Expand Up @@ -696,11 +697,12 @@ def series(ex, x=None, x0=0, n=6, as_deg_coef_pair=False):
if not syms:
return _ex

cdef Symbol _x
cdef Basic _x
if x is None:
_x = list(syms)[0]
else:
_x = _sympify(x)
require(_x, Symbol)
if not _x in syms:
return _ex

Expand Down Expand Up @@ -731,31 +733,25 @@ def series(ex, x=None, x0=0, n=6, as_deg_coef_pair=False):
return add(*l)


cdef class Symbol(Basic):
class Symbol(Basic):

"""
Symbol is a class to store a symbolic variable with a given name.
Note: Subclassing `Symbol` will not work properly. Use `PySymbol`
which is a subclass of `Symbol` for subclassing.
"""
def __cinit__(self, name = None):
if name is None:
return
self.thisptr = symengine.make_rcp_Symbol(name.encode("utf-8"))

def __init__(self, name = None):
return
def __init__(Basic self, name, *args, **kwargs):
if type(self) == Symbol:
self.thisptr = symengine.make_rcp_Symbol(name.encode("utf-8"))
else:
self.thisptr = symengine.make_rcp_PySymbol(name.encode("utf-8"), <PyObject*>self)

def _sympy_(self):
cdef RCP[const symengine.Symbol] X = symengine.rcp_static_cast_Symbol(self.thisptr)
import sympy
return sympy.Symbol(str(deref(X).get_name().decode("utf-8")))
return sympy.Symbol(str(self))

def _sage_(self):
cdef RCP[const symengine.Symbol] X = symengine.rcp_static_cast_Symbol(self.thisptr)
import sage.all as sage
return sage.SR.symbol(str(deref(X).get_name().decode("utf-8")))
return sage.SR.symbol(str(self))

@property
def name(self):
Expand All @@ -770,14 +766,6 @@ cdef class Symbol(Basic):
return True


cdef class PySymbol(Symbol):
def __init__(self, name, *args, **kwargs):
super(PySymbol, self).__init__(name)
if name is None:
return
self.thisptr = symengine.make_rcp_PySymbol(name.encode("utf-8"), <PyObject*>self)


def symarray(prefix, shape, **kwargs):
""" Creates an nd-array of symbols
Expand Down Expand Up @@ -3296,7 +3284,8 @@ def LambdifyCSE(args, exprs, real=True, cse=None, concatenate=None):

def has_symbol(obj, symbol=None):
cdef Basic b = _sympify(obj)
cdef Symbol s = _sympify(symbol)
cdef Basic s = _sympify(symbol)
require(s, Symbol)
if (not symbol):
return not b.free_symbols.empty()
else:
Expand Down
44 changes: 1 addition & 43 deletions symengine/sympy_compat.py
@@ -1,44 +1,2 @@
from .lib import symengine_wrapper as symengine
from .utilities import var, symbols
from .compatibility import with_metaclass
from .lib.symengine_wrapper import (sympify, sympify as S,
SympifyError, sqrt, I, E, pi, MutableDenseMatrix,
ImmutableDenseMatrix, DenseMatrix, Matrix, Derivative, exp,
nextprime, mod_inverse, primitive_root, Lambdify as lambdify,
symarray, diff, eye, diag, ones, zeros, expand, Subs,
FunctionSymbol as AppliedUndef, Max, Min, Integer, Rational,
Float, Number, Add, Mul, Pow, sin, cos, tan, cot, csc, sec,
asin, acos, atan, acot, acsc, asec, sinh, cosh, tanh, coth, sech, csch,
asinh, acosh, atanh, acoth, asech, acsch, gamma, log, atan2)
from types import ModuleType
import sys


class BasicMeta(type):
def __instancecheck__(self, instance):
return isinstance(instance, self._classes)


class Basic(with_metaclass(BasicMeta, object)):
_classes = (symengine.Basic,)
pass


class Symbol(symengine.PySymbol, Basic):
_classes = (symengine.Symbol,)
pass


functions = ModuleType(__name__ + ".functions")
sys.modules[functions.__name__] = functions

functions.sqrt = sqrt
functions.exp = exp

for name in ("""sin cos tan cot csc sec
asin acos atan acot acsc asec
sinh cosh tanh coth sech csch
asinh acosh atanh acoth asech acsch
gamma log atan2""").split():
setattr(functions, name, getattr(symengine, name))
from symengine import *

20 changes: 10 additions & 10 deletions symengine/tests/test_sympy_compat.py
@@ -1,6 +1,6 @@
from symengine.sympy_compat import (Integer, Rational, S, Basic, Add, Mul,
Pow, symbols, Symbol, log, sin, cos, sech, csch, zeros, atan2, Number, Float,
symengine, Min, Max)
Min, Max, RealDouble, have_mpfr)
from symengine.utilities import raises


Expand Down Expand Up @@ -34,28 +34,28 @@ def test_Float():
assert isinstance(A, Float)
assert isinstance(B, Float)
assert isinstance(C, Float)
assert isinstance(A, symengine.RealDouble)
assert isinstance(B, symengine.RealDouble)
assert isinstance(C, symengine.RealDouble)
assert isinstance(A, RealDouble)
assert isinstance(B, RealDouble)
assert isinstance(C, RealDouble)
raises(ValueError, lambda: Float("1.23", dps = 3, precision = 10))
raises(ValueError, lambda: Float(A, dps = 3, precision = 16))
if symengine.have_mpfr:
if have_mpfr:
A = Float("1.23", dps = 16)
B = Float("1.23", precision = 56)
assert A == B
assert isinstance(A, Float)
assert isinstance(B, Float)
assert isinstance(A, symengine.RealMPFR)
assert isinstance(B, symengine.RealMPFR)
assert isinstance(A, RealMPFR)
assert isinstance(B, RealMPFR)
A = Float(C, dps = 16)
assert A == B
assert isinstance(A, Float)
assert isinstance(A, symengine.RealMPFR)
assert isinstance(A, RealMPFR)
A = Float(A, precision = 53)
assert A == C
assert isinstance(A, Float)
assert isinstance(A, symengine.RealDouble)
if not symengine.have_mpfr:
assert isinstance(A, RealDouble)
if not have_mpfr:
raises(ValueError, lambda: Float("1.23", precision = 58))


Expand Down

0 comments on commit 83bd7ae

Please sign in to comment.