Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ cdef class Basic(object):
if (len(f) != 1):
raise RuntimeError("Variable w.r.t should be given")
return self._diff(f.pop())
return diff(self, *args)
return _diff(self, *args)

def subs_dict(Basic self not None, *args):
warnings.warn("subs_dict() is deprecated. Use subs() instead", DeprecationWarning)
Expand Down Expand Up @@ -3687,7 +3687,7 @@ cdef class DenseMatrixBase(MatrixBase):
return R

def diff(self, *args):
return diff(self, *args)
return _diff(self, *args)

#TODO: implement this in C++
def subs(self, *args):
Expand Down Expand Up @@ -4063,15 +4063,23 @@ def module_cleanup():
import atexit
atexit.register(module_cleanup)


def diff(expr, *args):
cdef Basic ex = sympify(expr)
if isinstance(expr, MatrixBase):
# Don't sympify matrices so that mutable matrices
# return mutable matrices
return _diff(expr, *args)
return _diff(sympify(expr), *args)


def _diff(expr, *args):
cdef Basic prev
cdef Basic b
cdef size_t i
cdef size_t length = len(args)

if not length:
return ex
return expr

cdef size_t l = 0
cdef Basic cur_arg, next_arg
Expand All @@ -4083,20 +4091,20 @@ def diff(expr, *args):

if l + 1 == length:
# No next argument, differentiate with no integer argument
return ex._diff(cur_arg)
return expr._diff(cur_arg)

next_arg = sympify(args[l + 1])
# Check if the next arg was derivative order
if isinstance(next_arg, Integer):
i = int(next_arg)
for _ in range(i):
ex = ex._diff(cur_arg)
expr = expr._diff(cur_arg)
l += 2
if l == length:
return ex
return expr
cur_arg = sympify(args[l])
else:
ex = ex._diff(cur_arg)
expr = expr._diff(cur_arg)
l += 1
cur_arg = next_arg

Expand Down
8 changes: 8 additions & 0 deletions symengine/tests/test_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,14 @@ def test_cross():
DenseMatrix(1, 2, [1, 1]).cross(DenseMatrix(1, 2, [1, 1])))


def test_diff():
x = symbols("x")
M = DenseMatrix(1, 2, [x**2, x])
result = M.diff(x)
assert isinstance(result, DenseMatrix)
assert result == DenseMatrix(1, 2, [2*x, 1])


def test_immutablematrix():
A = ImmutableMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
assert A.shape == (3, 3)
Expand Down