Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace(oldtype, newtype) does unexpected and superfluous constructor calls to oldtype #24460

Closed
Costor opened this issue Jan 1, 2023 · 5 comments
Labels

Comments

@Costor
Copy link
Contributor

Costor commented Jan 1, 2023

This refers to bullet point 1.1, replacement of types, in documentation for sympy.core.basic.replace().

Situation

When replacing oldtype by newtype in an expression expr, expr.replace(oldtype, newtype) may call the constructor for oldtype and create further oldtype objects. Since oldtype is to be replaced by newtype this objects will then immediately be replaced by newtype objects, so this calls are superfluous at least. Moreover they are unexpected and may have unwanted performance or other side effects.

Example

from sympy import Expr, Add, symbols
class MathObj(Expr): # create a new mathematical object in SymPy from Expr, just for demo
    def __new__(cls, a1, a2):
        obj = Expr.__new__(cls, a1, a2) # sets up self.func, self.args
        print(f"__new__(MathObj, {a1}, {a2})") # show that __new__ is called
        return obj
# create some term using MathObj
n, m = symbols("n m")
p = MathObj(n, MathObj(m,m)) + MathObj(MathObj(n,n), m) #prints 4 times __new__(MathObj,..)
print("doing p.replace(MathObj, Add):")

q = p.replace(MathObj, Add) # calls and prints 2 times __new__(MathObj,...) !!

print(q) # prints 3*m + 3*n

Analysis
This behaviour occurs if oldtype is nested within expr (as in the example above). In this case the statement rv = rv.func(*newargs) in function walk() in sympy.core.basic.replace() invokes the oldtype constructor. If rv.func == oldtype it should invoke newtype(*newargs) instead.
However as replace() has various usages that are all handled by this code I'm unable to verify this.

@oscarbenjamin
Copy link
Contributor

may have unwanted performance or other side effects.

If calling obj.func(obj.args) has side effects then that's a bug. You're right that this is a performance issue though.

@oscarbenjamin
Copy link
Contributor

I guess you're referring to this:

sympy/sympy/core/basic.py

Lines 1606 to 1608 in 40b1af7

newargs = tuple([walk(a, F) for a in args])
if args != newargs:
rv = rv.func(*newargs)

I think it's not necessarily superfluous. Applying replace to the args might mean that rv.func(*newargs) evaluates to something different affecting whether or not replace(newtype, oldtype) would succeed e.g.:

>>> exp(I*pi*exp(pi)).replace(exp, cos)
-1
>>> cos(I*pi*cos(pi))
cosh(π)
>>> exp(I*pi*cos(pi))
-1

@Costor
Copy link
Contributor Author

Costor commented Jan 5, 2023

I see your point: This is an "incremental" replace, i.e. doing replace from bottom up step by step, evaluating every intermediate expression (= at level of arguments), re-evaluate the expression and continue the replace in the result.
The documentation imho is not clear about this "Traverses an expression tree and performs replacement of matching subexpressions from the bottom to the top of the tree. The default approach is to do the replacement in a simultaneous fashion so changes made are targeted only once."
I probably have been misguided by the word "simultaneous" which led me to the idea that all occurrences of oldtype would be replaced by newtype effectively "simultaneously". What is implemented here is a "no repeat", i.e. a replacement will not be repeated in a changed term.
(Also if I take the existing replace as an "incremental" replace I would then ask why the expression head is called only when all its arguments have been replaced, and why not each time when an argument has been done. The behaviour is difficult to predict.)

Is there a replacement function in SymPy to get a usual "simultaneous replacement" that would give the result -1 in this example (I have tried xreplace and subst to no avail):

from sympy import symbols, Pow, Add, srepr
a = symbols("a")
p = Pow(a, -a)
print("srepr(p) = ", srepr(p))
print("p.replace(Pow, Add) = ", p.replace(Pow, Add))
p2 = Pow(Pow(a, -a), Pow(a, -a) - 1)
print("p2=", p2)
print("p2.replace(Pow, Add) expected: Add(Add(a, -a), Add(a, -a) - 1) = ",
                                      Add(Add(a, -a), Add(a, -a) - 1))
print("but get: ", p2.replace(Pow,Add)) # Pow(0,-1)
# because Pow(Add(a,-a), Add(a,-a)-1) is called by .replace()

@oscarbenjamin
Copy link
Contributor

Maybe there should be an evaluate=False flag for replace:

In [64]: with evaluate(False):
    ...:     p3 = p2.replace(Pow, Add)
    ...: 

In [65]: p3
Out[65]: -a + a + -a + a - 1

In [66]: p3.doit()
Out[66]: -1

@Costor
Copy link
Contributor Author

Costor commented Jan 5, 2023

That is really cool, and should be added to the documentation of replace how to get a really "static" replacement, in contrast to the dynamic incremental replacement.

(The "static" replacement option coded in replace directly would perhaps be slightly more efficient, but thats most probably not worth the effort and risk fiddling in core code.)

@Costor Costor closed this as completed Jan 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants