Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiating a function with a matrix parameter fails with ShapeError #24229

Open
pbsds opened this issue Nov 7, 2022 · 4 comments
Open

Comments

@pbsds
Copy link

pbsds commented Nov 7, 2022

Reproduction:
image

Jupyterlab traceback:
---------------------------------------------------------------------------
ShapeError                                Traceback (most recent call last)
Input In [6], in <cell line: 4>()
      2 display( f )
      3 display( Derivative(fx, x0) )
----> 4 display( Derivative(fx, x0).doit() )

File ~/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages/sympy/core/function.py:1612, in Derivative.doit(self, **hints)
   1610     expr = expr.doit(**hints)
   1611 hints['evaluate'] = True
-> 1612 rv = self.func(expr, *self.variable_count, **hints)
   1613 if rv!= self and rv.has(Derivative):
   1614     rv =  rv.doit(**hints)

File ~/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages/sympy/core/function.py:1442, in Derivative.__new__(cls, expr, *variables, **kwargs)
   1435     if not old_v.is_scalar and not hasattr(
   1436             old_v, '_eval_derivative'):
   1437         # special hack providing evaluation for classes
   1438         # that have defined is_scalar=True but have no
   1439         # _eval_derivative defined
   1440         expr *= old_v.diff(old_v)
-> 1442 obj = cls._dispatch_eval_derivative_n_times(expr, v, count)
   1443 if obj is not None and obj.is_zero:
   1444     return obj

File ~/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages/sympy/core/function.py:1903, in Derivative._dispatch_eval_derivative_n_times(cls, expr, v, count)
   1897 @classmethod
   1898 def _dispatch_eval_derivative_n_times(cls, expr, v, count):
   1899     # Evaluate the derivative `n` times.  If
   1900     # `_eval_derivative_n_times` is not overridden by the current
   1901     # object, the default in `Basic` will call a loop over
   1902     # `_eval_derivative`:
-> 1903     return expr._eval_derivative_n_times(v, count)

File ~/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages/sympy/core/basic.py:1785, in Basic._eval_derivative_n_times(self, s, n)
   1783 obj = self
   1784 for i in range(n):
-> 1785     obj2 = obj._eval_derivative(s)
   1786     if obj == obj2 or obj2 is None:
   1787         break

File ~/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages/sympy/core/function.py:610, in Function._eval_derivative(self, s)
    608     except ArgumentIndexError:
    609         df = Function.fdiff(self, i)
--> 610     l.append(df * da)
    611 return Add(*l)

File ~/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages/sympy/core/decorators.py:236, in _SympifyWrapper.make_wrapped.<locals>._func(self, other)
    234 if not isinstance(other, expectedcls):
    235     return retval
--> 236 return func(self, other)

File ~/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages/sympy/core/decorators.py:105, in call_highest_priority.<locals>.priority_decorator.<locals>.binary_op_wrapper(self, other)
    103         f = getattr(other, method_name, None)
    104         if f is not None:
--> 105             return f(self)
    106 return func(self, other)

File ~/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages/sympy/core/decorators.py:106, in call_highest_priority.<locals>.priority_decorator.<locals>.binary_op_wrapper(self, other)
    104         if f is not None:
    105             return f(self)
--> 106 return func(self, other)

File ~/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages/sympy/matrices/common.py:2958, in MatrixArithmetic.__rmul__(self, other)
   2956 @call_highest_priority('__mul__')
   2957 def __rmul__(self, other):
-> 2958     return self.rmultiply(other)

File ~/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages/sympy/matrices/common.py:2979, in MatrixArithmetic.rmultiply(self, other, dotprodsimp)
   2975 if (hasattr(other, 'shape') and len(other.shape) == 2 and
   2976     (getattr(other, 'is_Matrix', True) or
   2977      getattr(other, 'is_MatrixLike', True))):
   2978     if self.shape[0] != other.shape[1]:
-> 2979         raise ShapeError("Matrix size mismatch.")
   2981 # honest SymPy matrices defer to their class's routine
   2982 if getattr(other, 'is_Matrix', False):

ShapeError: Matrix size mismatch.
`rich` traceback:
╭──────────────────────────── Traceback (most recent call last) ────────────────────────────╮
│ /run/user/1000/ipykernel_724116/3085794013.py:4 in <cell line: 4>                         │
│                                                                                           │
│ [Errno 2] No such file or directory: '/run/user/1000/ipykernel_724116/3085794013.py'      │
│                                                                                           │
│ /home/pbsds/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages │
│ /sympy/core/function.py:1612 in doit                                                      │
│                                                                                           │
│   1609 │   │   if hints.get('deep', True):                                                │
│   1610 │   │   │   expr = expr.doit(**hints)                                              │
│   1611 │   │   hints['evaluate'] = True                                                   │
│ ❱ 1612 │   │   rv = self.func(expr, *self.variable_count, **hints)                        │
│   1613 │   │   if rv!= self and rv.has(Derivative):                                       │
│   1614 │   │   │   rv =  rv.doit(**hints)                                                 │
│   1615 │   │   return rv                                                                  │
│                                                                                           │
│ /home/pbsds/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages │
│ /sympy/core/function.py:1442 in __new__                                                   │
│                                                                                           │
│   1439 │   │   │   │   │   # _eval_derivative defined                                     │
│   1440 │   │   │   │   │   expr *= old_v.diff(old_v)                                      │
│   1441 │   │   │                                                                          │
│ ❱ 1442 │   │   │   obj = cls._dispatch_eval_derivative_n_times(expr, v, count)            │
│   1443 │   │   │   if obj is not None and obj.is_zero:                                    │
│   1444 │   │   │   │   return obj                                                         │
│   1445                                                                                    │
│                                                                                           │
│ /home/pbsds/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages │
│ /sympy/core/function.py:1903 in _dispatch_eval_derivative_n_times                         │
│                                                                                           │
│   1900 │   │   # `_eval_derivative_n_times` is not overridden by the current              │
│   1901 │   │   # object, the default in `Basic` will call a loop over                     │
│   1902 │   │   # `_eval_derivative`:                                                      │
│ ❱ 1903 │   │   return expr._eval_derivative_n_times(v, count)                             │
│   1904                                                                                    │
│   1905                                                                                    │
│   1906 def _derivative_dispatch(expr, *variables, **kwargs):                              │
│                                                                                           │
│ /home/pbsds/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages │
│ /sympy/core/basic.py:1785 in _eval_derivative_n_times                                     │
│                                                                                           │
│   1782 │   │   if isinstance(n, (int, Integer)):                                          │
│   1783 │   │   │   obj = self                                                             │
│   1784 │   │   │   for i in range(n):                                                     │
│ ❱ 1785 │   │   │   │   obj2 = obj._eval_derivative(s)                                     │
│   1786 │   │   │   │   if obj == obj2 or obj2 is None:                                    │
│   1787 │   │   │   │   │   break                                                          │
│   1788 │   │   │   │   obj = obj2                                                         │
│                                                                                           │
│ /home/pbsds/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages │
│ /sympy/core/function.py:610 in _eval_derivative                                           │
│                                                                                           │
│    607 │   │   │   │   df = self.fdiff(i)                                                 │
│    608 │   │   │   except ArgumentIndexError:                                             │
│    609 │   │   │   │   df = Function.fdiff(self, i)                                       │
│ ❱  610 │   │   │   l.append(df * da)                                                      │
│    611 │   │   return Add(*l)                                                             │
│    612 │                                                                                  │
│    613 │   def _eval_is_commutative(self):                                                │
│                                                                                           │
│ /home/pbsds/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages │
│ /sympy/core/decorators.py:236 in _func                                                    │
│                                                                                           │
│   233 │   │   │   │   │   return retval                                                   │
│   234 │   │   │   if not isinstance(other, expectedcls):                                  │
│   235 │   │   │   │   return retval                                                       │
│ ❱ 236 │   │   │   return func(self, other)                                                │
│   237 │   │                                                                               │
│   238 │   │   return _func                                                                │
│   239                                                                                     │
│                                                                                           │
│ /home/pbsds/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages │
│ /sympy/core/decorators.py:105 in binary_op_wrapper                                        │
│                                                                                           │
│   102 │   │   │   │   if other._op_priority > self._op_priority:                          │
│   103 │   │   │   │   │   f = getattr(other, method_name, None)                           │
│   104 │   │   │   │   │   if f is not None:                                               │
│ ❱ 105 │   │   │   │   │   │   return f(self)                                              │
│   106 │   │   │   return func(self, other)                                                │
│   107 │   │   return binary_op_wrapper                                                    │
│   108 │   return priority_decorator                                                       │
│                                                                                           │
│ /home/pbsds/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages │
│ /sympy/core/decorators.py:106 in binary_op_wrapper                                        │
│                                                                                           │
│   103 │   │   │   │   │   f = getattr(other, method_name, None)                           │
│   104 │   │   │   │   │   if f is not None:                                               │
│   105 │   │   │   │   │   │   return f(self)                                              │
│ ❱ 106 │   │   │   return func(self, other)                                                │
│   107 │   │   return binary_op_wrapper                                                    │
│   108 │   return priority_decorator                                                       │
│   109                                                                                     │
│                                                                                           │
│ /home/pbsds/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages │
│ /sympy/matrices/common.py:2958 in __rmul__                                                │
│                                                                                           │
│   2955 │                                                                                  │
│   2956 │   @call_highest_priority('__mul__')                                              │
│   2957 │   def __rmul__(self, other):                                                     │
│ ❱ 2958 │   │   return self.rmultiply(other)                                               │
│   2959 │                                                                                  │
│   2960 │   def rmultiply(self, other, dotprodsimp=None):                                  │
│   2961 │   │   """Same as __rmul__() but with optional simplification.                    │
│                                                                                           │
│ /home/pbsds/.cache/pypoetry/virtualenvs/ifield-P6Ko3Gy1-py3.9/lib/python3.9/site-packages │
│ /sympy/matrices/common.py:2979 in rmultiply                                               │
│                                                                                           │
│   2976 │   │   │   (getattr(other, 'is_Matrix', True) or                                  │
│   2977 │   │   │    getattr(other, 'is_MatrixLike', True))):                              │
│   2978 │   │   │   if self.shape[0] != other.shape[1]:                                    │
│ ❱ 2979 │   │   │   │   raise ShapeError("Matrix size mismatch.")                          │
│   2980 │   │                                                                              │
│   2981 │   │   # honest SymPy matrices defer to their class's routine                     │
│   2982 │   │   if getattr(other, 'is_Matrix', False):                                     │
╰───────────────────────────────────────────────────────────────────────────────────────────╯
ShapeError: Matrix size mismatch.

My current workaround is to use the first example and a lot of substitutions, but i'd rather be able to use the latter case.

@oscarbenjamin
Copy link
Contributor

Please include code inline in github rather than attaching a screenshot.

@ThePauliPrinciple
Copy link
Contributor

If you mean to apply this function element-wise, you can do something along the lines of:
afbeelding

import sympy
from sympy import symbols, Function, Derivative
from sympy.physics.vector import *
f = Function('f')
N=ReferenceFrame('N')
x0, x1, x2 = symbols('x:3')
X= x0*N.x + x1*N.y + x2*N.z
Derivative(X.to_matrix(N), x0)
Derivative(X.to_matrix(N), x0).doit()
X.to_matrix(N).applyfunc(f)
Derivative(X.to_matrix(N).applyfunc(f),x0)
Derivative(X.to_matrix(N).applyfunc(f),x0).doit()

@pbsds
Copy link
Author

pbsds commented Nov 11, 2022

Please include code inline in github rather than attaching a screenshot.

from sympy import symbols, Function, Derivative
from sympy.physics.vector import *
N = ReferenceFrame('N')
x0, x1, x2 = symbols("x:3")
X = x0*N.x + x1*N.y + x2*N.z
fx = Function("f")(X.to_matrix(N))
print f )
print( Derivative(fx, x0) )
print( Derivative(fx, x0).doit() )

If you mean to apply this function element-wise, you can do something along the lines of...

Neat! However my use case is not elementwise application of a function, it is a multivariate function.

@ThePauliPrinciple
Copy link
Contributor

ThePauliPrinciple commented Nov 12, 2022

I think the problem here is actually that it uses the chain rule and then tries to evaluate this derivative:

Derivative(f(X.to_matrix(N)),X.to_matrix(N)).doit()

and that is not handled in SymPy

afbeelding

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants