Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make all_close more robust #26708

Merged
merged 5 commits into from
Jun 21, 2024
Merged

make all_close more robust #26708

merged 5 commits into from
Jun 21, 2024

Conversation

smichr
Copy link
Member

@smichr smichr commented Jun 15, 2024

Brief description of what is fixed or changed

all_close is more efficient in testing whether two expressions differ trivially in terms of literal numbers (not symbolic values).

Other comments

all_close is not a publically imported routine, but it is not underscore-named in numbers. It attempts to free the user from tedious analysis of two expressions to see if they are the same when, in fact, they might only differ trivially in numerical values used. It is currently only used in tests.

One thing that would make it fail in master is the inability to recognize that x and 1.0*x are trivially the same since they do not have the expected identical structure; one is a Symbol and the other a Mul.

It is also recognized in master that searching for matching terms between two Add is very inefficient since a search for a given term is made factor by factor:

args2 = list(expr2.args)
for arg1 in expr1.args:
    for i, arg2 in enumerate(args2):
        if _all_close(arg1, arg2, rtol, atol):
            args2.pop(i)

Here are a few simple tests:

from random import shuffle
from sympy.core.numbers import *
from time import time
args = [Mul(*[Dummy() for i in range(10)]) for i in range(50)]
a=Add(*args)
shuffle(args)
b=Add(*args,evaluate=False)
>>> a==b
False
>>> t=time();all_close(a,b);time()-t
0.078  # 0.03 in this PR

args = [Mul(*[Dummy() for i in range(10)]) for i in range(500)]
a=Add(*args)
shuffle(args)
b=Add(*args,evaluate=False)
>>> a==b
False
>>> t=time();all_close(a,b);time()-t
6.78  # 0.73 in this PR

It might make more sense to have this somewhere other than in the core.numbers since it really has more to do with testing (and testing expressions) than it does with testing numbers.

Release Notes

NO ENTRY

Instead of doing an O(n^2) search for a matching term,
an optimization is made to handle the exact matching
terms using coefficients_dict and then a fallback to
searching is made for any more complicated terms.
@sympy-bot
Copy link

sympy-bot commented Jun 15, 2024

Hi, I am the SymPy bot. I'm here to help you write a release notes entry. Please read the guide on how to write release notes.

  • No release notes entry will be added for this pull request.
Click here to see the pull request description that was parsed.
#### Brief description of what is fixed or changed

`all_close` is more efficient in testing whether two expressions differ trivially in terms of literal numbers (not symbolic values).

#### Other comments

`all_close` is not a publically imported routine, but it is not underscore-named in `numbers`. It attempts to free the user from tedious analysis of two expressions to see if they are the same when, in fact, they might only differ trivially in numerical values used. It is currently only used in tests.

One thing that would make it fail in master is the inability to recognize that `x` and `1.0*x` are trivially the same since they do not have the expected identical structure; one is a Symbol and the other a Mul.

It is also recognized in master that searching for matching terms between two Add is very inefficient since a search for a given term is made factor by factor:
```python
args2 = list(expr2.args)
for arg1 in expr1.args:
    for i, arg2 in enumerate(args2):
        if _all_close(arg1, arg2, rtol, atol):
            args2.pop(i)
```

Here are a few simple tests:
```python
from random import shuffle
from sympy.core.numbers import *
from time import time
args = [Mul(*[Dummy() for i in range(10)]) for i in range(50)]
a=Add(*args)
shuffle(args)
b=Add(*args,evaluate=False)
>>> a==b
False
>>> t=time();all_close(a,b);time()-t
0.078  # 0.03 in this PR

args = [Mul(*[Dummy() for i in range(10)]) for i in range(500)]
a=Add(*args)
shuffle(args)
b=Add(*args,evaluate=False)
>>> a==b
False
>>> t=time();all_close(a,b);time()-t
6.78  # 0.73 in this PR
```

It might make more sense to have this somewhere other than in the `core.numbers` since it really has more to do with testing (and testing expressions) than it does with testing numbers.

#### Release Notes

<!-- Write the release notes for this release below between the BEGIN and END
statements. The basic format is a bulleted list with the name of the subpackage
and the release note for this PR. For example:

* solvers
  * Added a new solver for logarithmic equations.

* functions
  * Fixed a bug with log of integers. Formerly, `log(-x)` incorrectly gave `-log(x)`.

* physics.units
  * Corrected a semantical error in the conversion between volt and statvolt which
    reported the volt as being larger than the statvolt.

or if no release note(s) should be included use:

NO ENTRY

See https://github.com/sympy/sympy/wiki/Writing-Release-Notes for more
information on how to write release notes. The bot will check your release
notes automatically to see if they are formatted correctly. -->

<!-- BEGIN RELEASE NOTES -->
NO ENTRY
<!-- END RELEASE NOTES -->

Copy link

Benchmark results from GitHub Actions

Lower numbers are good, higher numbers are bad. A ratio less than 1
means a speed up and greater than 1 means a slowdown. Green lines
beginning with + are slowdowns (the PR is slower then master or
master is slower than the previous release). Red lines beginning
with - are speedups.

Significantly changed benchmark results (PR vs master)

Significantly changed benchmark results (master vs previous release)

Full benchmark results can be found as artifacts in GitHub Actions
(click on checks at the top of the PR).

@sylee957
Copy link
Member

sylee957 commented Jun 17, 2024

Why are $x$ and $1.0.00000001x$ approximately equal, but not $0$ and $0.0.00000001x$ approximately equal?

from sympy import *
from sympy.core.numbers import all_close

eps = 0.00000001
x = Symbol('x')
print(all_close(1*x, (1 + eps)*x))
print(all_close(0*x, (0 + eps)*x))

I don't exactly know if this is intended by the design,
however, it's difficult to understand the difference between them by the caulus.

Even between $1.0001x$ and $1x$, the difference eventually grows large if $x$ takes values of billions or trillions,
so in that context, it's hard to understand why the difference of 1*x and (1 + eps)*x is different than the difference between 0*x and epx*x.

If the difference between 1.00001x and 1x can be reduced to the difference between 0.00001x and 0, I may just come up with different idea like:
we can subtract two symbolic expressions at first, and check the difference of coefficients
could be much more easier.

@smichr
Copy link
Member Author

smichr commented Jun 17, 2024

Why are ... [1e-10*x and 0 different]?

They are different because they have different structure, one is a Mul and the other a constant, one depends on x and the other does not. The purpose of the function is to test whether coefficients of like terms is small according to the rtol and atol parameters.

sympy/core/numbers.py Outdated Show resolved Hide resolved
sympy/core/numbers.py Outdated Show resolved Hide resolved
@smichr
Copy link
Member Author

smichr commented Jun 18, 2024

we can subtract two symbolic expressions at first, and check the difference of coefficients
could be much more easier.

There are two parts to the numerical test: abs(num1 - num2) <= atol + rtol*abs(num2). If num2 is 0 then rtol doesn't matter and the answer should be based on atol. So if x**2 + 1e-10*x + 1 and x**2 + 1.0 are being compared, I think you would suggest that we compare 1e-10 to 0 and, if atol is greater than 1e-10 then report that the two expressions are nearly the same. But there is a 3rd part to the advertised function: the structure must be the same. So in this case, even though numerics would have passed, structure fails.

The only exception I am trying to provide in this PR is to recognize the multiplicative and additive identities. The multiplicative is dropped when it is Rational instead of Float (so 1*x -> x); the additive identity is always dropped because Float(0) == 0 (so x + 0.0 -> x).

Click to see code that would improve performance while requiring strict structural agreement

def all_close(expr1, expr2, rtol=1e-5, atol=1e-8):
    """Return True if expr1 and expr2 are numerically close.

    The expressions must have the same structure, but any Rational, Integer, or
    Float numbers they contain are compared approximately using rtol and atol.
    Any other parts of expressions are compared exactly.

    Relative tolerance is measured with respect to expr2 so when used in
    testing expr2 should be the expected correct answer.

    Examples
    ========

    >>> from sympy import exp
    >>> from sympy.abc import x, y
    >>> from sympy.core.numbers import all_close
    >>> expr1 = 0.1*exp(x - y)
    >>> expr2 = exp(x - y)/10
    >>> expr1
    0.1*exp(x - y)
    >>> expr2
    exp(x - y)/10
    >>> expr1 == expr2
    False
    >>> all_close(expr1, expr2)
    True
    """
    NUM_TYPES = (Rational, Float)

    def _all_close(obj1, obj2):
        if type(obj1) == type(obj2) and isinstance(obj1, (list, tuple)):
            if len(obj1) != len(obj2):
                return False
            return all(_all_close(e1, e2) for e1, e2 in zip(obj1, obj2))
        num1 = isinstance(obj1, NUM_TYPES)
        num2 = isinstance(obj2, NUM_TYPES)
        if num1 != num2:
            return False
        if num1:
            return bool(abs(obj1 - obj2) <= atol + rtol*abs(obj2))
        return _all_close_expr(_sympify(obj1), _sympify(obj2))

    def _all_close_expr(expr1, expr2):
        if expr1.is_Atom:
            return expr1 == expr2
        if expr1.func != expr2.func or len(expr1.args) != len(expr2.args):
            return False
        if expr1.is_Add or expr1.is_Mul:
            return _all_close_ac(expr1, expr2)
		args = zip(expr1.args, expr2.args)
		return all(_all_close_expr(a1, a2) for a1, a2 in args)

    def _all_close_ac(expr1, expr2):
        # compare expressions with associative commutative operators for
        # approximate equality by seeing that all terms have equivalent
        # coefficients (which are always Rational or Float)
        if expr1.is_Mul:
            # as_coeff_mul automatically will supply coeff of 1
            c1, e1 = expr1.as_coeff_mul(rational=False)
            c2, e2 = expr2.as_coeff_mul(rational=False)
            if not _close_num(c1, c2):
                return False
            s1 = set(e1)
            s2 = set(e2)
            common = s1 & s2
            s1 -= common
            s2 -= common
            if not s1:
                return True
            if not any(i.has(Float) for j in (s1, s2) for i in j):
                return False
            # factors might not be matching, e.g.
            # x != x**1.0, exp(x) != exp(1.0*x), etc...
            s1 = [i.as_base_exp() for i in ordered(s1)]
            s2 = [i.as_base_exp() for i in ordered(s2)]
            unmatched = list(range(len(s1)))
            for be1 in s1:
                for i in unmatched:
                    be2 = s2[i]
                    if _all_close(be1, be2):
                        unmatched.remove(i)
                        break
                else:
                    return False
            return not(unmatched)
        assert expr1.is_Add
        cd1 = expr1.as_coefficients_dict()
        cd2 = expr2.as_coefficients_dict()
        if 1 in cd1 and 1 in cd2 and _close_num(cd1.pop(1), cd2.pop(2)):
			# there are now no keys that are Numbers
			pass
		else:
			return False
        for k in list(cd1):
            if k in cd2:
                if not _close_num(cd1.pop(k), cd2.pop(k)):
                    return False
            # k (or a close version in cd2) might have
            # Floats in a factor of the term which will
            # be handled below
        else:
            if not cd1:
                return True
        for k1 in cd1:
            for k2 in cd2:
                if _all_close_expr(k1, k2):
                    # found a matching key
                    # XXX there could be a corner case where
                    # more than 1 might match and the numbers are
                    # such that one is better than the other
                    # that is not being considered here
                    if not _close_num(cd1[k1], cd2[k2]):
                        return False
                    break
            else:
                # no key matched
                return False
        return True

    return _all_close(expr1, expr2)

@smichr smichr merged commit 6d6e89c into sympy:master Jun 21, 2024
48 checks passed
@smichr smichr deleted the allclose branch June 21, 2024 14:27
Comment on lines +4284 to +4289
>>> all_close(x, x + 1e-10)
True
>>> all_close(x, 1.0*x)
True
>>> all_close(x, 1.0*x + 1e-10)
True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely not what I wanted when creating this function or whenever I have used it. Its purpose was to be used in tests where these sorts of differences should generally show up as test failures.

That now means that I want a different function for use in tests because this one does not do what is actually needed there any more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants