Skip to content

Commit

Permalink
Trac #8450: intermediate complex expression in real functions make ma…
Browse files Browse the repository at this point in the history
…ny plot functions fail

All of the following plots fail
{{{
x, y = SR.var('x y')
contour_plot(abs(x+i*y), (x,-1,1), (y,-1,1))
density_plot(abs(x+i*y), (x,-1,1), (y,-1,1))
plot3d(abs(x+i*y), (x,-1,1),(y,-1,1))
streamline_plot(abs(x+i*y), (x,-1,1),(y,-1,1))
}}}
with
{{{
TypeError: unable to coerce to a real number
}}}
The culprit is the call to `setup_for_eval_on_grid` (from
`sage/plot/misc.py`) that tries to compile the symbolic expression with
`fast_float`. But since the expression involves an intermediate complex
number the compilation fails. This can be tested with any of the two
{{{
fast_float(abs(x + i*y), x, y)
fast_callable(abs(x + i*y), vars=[x,y])
}}}
The function compilation succeeds if we ask for a complex function
instead
{{{
fast_callable(abs(x + i*y), vars=[x,y], domain=complex)
}}}

See also [https://ask.sagemath.org/question/46275/typeerror-unable-to-
coerce-to-a-real-number/ this question on ask.sagemath.org].

URL: https://trac.sagemath.org/8450
Reported by: jason
Ticket author(s): Michael Orlitzky
Reviewer(s): Dima Pasechnik
  • Loading branch information
Release Manager committed Jan 30, 2022
2 parents 439907f + 6531609 commit 0222526
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 73 deletions.
20 changes: 19 additions & 1 deletion src/sage/functions/other.py
Expand Up @@ -859,10 +859,28 @@ def _evalf_(self, base, exp, parent=None):
sage: real_nth_root(Reals(100)(2), 2)
1.4142135623730950488016887242
"""
if hasattr(exp, 'real_part'):
# To allow complex "noise" while plotting, the fast_callable()
# interpreters used in plots will convert all intermediate
# expressions to CDF, returning only the final answer as a
# real number. However, for a symbolic function such as this,
# the "exp" argument is in fact an intermediate expression.
# Thus we are forced to deal with exponents of the form
# (n + 0*I), which a priori will throw a TypeError at the "%"
# below. Here we special-case only CDF and CC, leaving the
# python "complex" type unhandled: you have to try very hard
# to pass a python "complex" in as an exponent, and the extra
# effort/slowdown doesn't seem worth it.
if exp.imag_part().is_zero():
exp = exp.real_part()
else:
raise ValueError("exponent cannot be complex")
exp = ZZ(exp)

negative = base < 0

if negative:
if exp % 2 == 0:
if exp.mod(2) == 0:
raise ValueError('no real nth root of negative real number with even n')
base = -base

Expand Down
237 changes: 215 additions & 22 deletions src/sage/plot/misc.py
Expand Up @@ -13,12 +13,12 @@
# http://www.gnu.org/licenses/
#*****************************************************************************

from sage.ext.fast_eval import fast_float

from sage.structure.element import is_Vector, Expression

def setup_for_eval_on_grid(funcs, ranges, plot_points=None, return_vars=False):
"""
def setup_for_eval_on_grid(funcs,
ranges,
plot_points=None,
return_vars=False,
imaginary_tolerance=1e-8):
r"""
Calculate the necessary parameters to construct a list of points,
and make the functions fast_callable.
Expand All @@ -37,6 +37,11 @@ def setup_for_eval_on_grid(funcs, ranges, plot_points=None, return_vars=False):
- ``return_vars`` -- (default ``False``) If ``True``, return the variables,
in order.
- ``imaginary_tolerance`` -- (default: ``1e-8``); if an imaginary
number arises (due, for example, to numerical issues), this
tolerance specifies how large it has to be in magnitude before
we raise an error. In other words, imaginary parts smaller than
this are ignored in your plot points.
OUTPUT:
Expand All @@ -58,17 +63,17 @@ def setup_for_eval_on_grid(funcs, ranges, plot_points=None, return_vars=False):
sage: g(x,y)=x+y
sage: h(y)=-y
sage: sage.plot.misc.setup_for_eval_on_grid(f, [(0, 2),(1,3),(-4,1)], plot_points=5)
(<sage.ext...>, [(0.0, 2.0, 0.5), (1.0, 3.0, 0.5), (-4.0, 1.0, 1.25)])
(<sage...>, [(0.0, 2.0, 0.5), (1.0, 3.0, 0.5), (-4.0, 1.0, 1.25)])
sage: sage.plot.misc.setup_for_eval_on_grid([g,h], [(0, 2),(-1,1)], plot_points=5)
((<sage.ext...>, <sage.ext...>), [(0.0, 2.0, 0.5), (-1.0, 1.0, 0.5)])
((<sage...>, <sage...>), [(0.0, 2.0, 0.5), (-1.0, 1.0, 0.5)])
sage: sage.plot.misc.setup_for_eval_on_grid([sin,cos], [(-1,1)], plot_points=9)
((<sage.ext...>, <sage.ext...>), [(-1.0, 1.0, 0.25)])
((<sage...>, <sage...>), [(-1.0, 1.0, 0.25)])
sage: sage.plot.misc.setup_for_eval_on_grid([lambda x: x^2,cos], [(-1,1)], plot_points=9)
((<function <lambda> ...>, <sage.ext...>), [(-1.0, 1.0, 0.25)])
((<function <lambda> ...>, <sage...>), [(-1.0, 1.0, 0.25)])
sage: sage.plot.misc.setup_for_eval_on_grid([x+y], [(x,-1,1),(y,-2,2)])
((<sage.ext...>,), [(-1.0, 1.0, 2.0), (-2.0, 2.0, 4.0)])
((<sage...>,), [(-1.0, 1.0, 2.0), (-2.0, 2.0, 4.0)])
sage: sage.plot.misc.setup_for_eval_on_grid(x+y, [(x,-1,1),(y,-1,1)], plot_points=[4,9])
(<sage.ext...>, [(-1.0, 1.0, 0.6666666666666666), (-1.0, 1.0, 0.25)])
(<sage...>, [(-1.0, 1.0, 0.6666666666666666), (-1.0, 1.0, 0.25)])
sage: sage.plot.misc.setup_for_eval_on_grid(x+y, [(x,-1,1),(y,-1,1)], plot_points=[4,9,10])
Traceback (most recent call last):
...
Expand All @@ -86,7 +91,7 @@ def setup_for_eval_on_grid(funcs, ranges, plot_points=None, return_vars=False):
ValueError: At least one variable range has more than 3 entries: each should either have 2 or 3 entries, with one of the forms (xmin, xmax) or (x, xmin, xmax)
sage: sage.plot.misc.setup_for_eval_on_grid(x+y, [(y,1,-1),(x,-1,1)], plot_points=5)
(<sage.ext...>, [(1.0, -1.0, 0.5), (-1.0, 1.0, 0.5)])
(<sage...>, [(1.0, -1.0, 0.5), (-1.0, 1.0, 0.5)])
sage: sage.plot.misc.setup_for_eval_on_grid(x+y, [(x,1,-1),(x,-1,1)], plot_points=5)
Traceback (most recent call last):
...
Expand All @@ -96,9 +101,25 @@ def setup_for_eval_on_grid(funcs, ranges, plot_points=None, return_vars=False):
...
ValueError: plot start point and end point must be different
sage: sage.plot.misc.setup_for_eval_on_grid(x+y, [(x,1,-1),(y,-1,1)], return_vars=True)
(<sage.ext...>, [(1.0, -1.0, 2.0), (-1.0, 1.0, 2.0)], [x, y])
(<sage...>, [(1.0, -1.0, 2.0), (-1.0, 1.0, 2.0)], [x, y])
sage: sage.plot.misc.setup_for_eval_on_grid(x+y, [(y,1,-1),(x,-1,1)], return_vars=True)
(<sage.ext...>, [(1.0, -1.0, 2.0), (-1.0, 1.0, 2.0)], [y, x])
(<sage...>, [(1.0, -1.0, 2.0), (-1.0, 1.0, 2.0)], [y, x])
TESTS:
Ensure that we can plot expressions with intermediate complex
terms as in :trac:`8450`::
sage: x, y = SR.var('x y')
sage: contour_plot(abs(x+i*y), (x,-1,1), (y,-1,1))
Graphics object consisting of 1 graphics primitive
sage: density_plot(abs(x+i*y), (x,-1,1), (y,-1,1))
Graphics object consisting of 1 graphics primitive
sage: plot3d(abs(x+i*y), (x,-1,1),(y,-1,1))
Graphics3d Object
sage: streamline_plot(abs(x+i*y), (x,-1,1),(y,-1,1))
Graphics object consisting of 1 graphics primitive
"""
if max(map(len, ranges)) > 3:
raise ValueError("At least one variable range has more than 3 entries: each should either have 2 or 3 entries, with one of the forms (xmin, xmax) or (x, xmin, xmax)")
Expand Down Expand Up @@ -133,22 +154,58 @@ def setup_for_eval_on_grid(funcs, ranges, plot_points=None, return_vars=False):
if min(range_steps) == float(0):
raise ValueError("plot start point and end point must be different")

options = {}
eov = False # eov = "expect one value"
if nargs == 1:
options['expect_one_var'] = True

if is_Vector(funcs):
funcs = list(funcs)
eov = True

from sage.ext.fast_callable import fast_callable
def try_make_fast(f):
# If "f" supports fast_callable(), use it. We can't guarantee
# that our arguments will actually support fast_callable()
# because, for example, the user may already have done it
# himself, and the result of fast_callable() can't be
# fast-callabled again.
from sage.rings.complex_double import CDF
from sage.ext.interpreters.wrapper_cdf import Wrapper_cdf

if hasattr(f, '_fast_callable_'):
ff = fast_callable(f, vars=vars, expect_one_var=eov, domain=CDF)
return FastCallablePlotWrapper(ff, imag_tol=imaginary_tolerance)
elif isinstance(f, Wrapper_cdf):
# Already a fast-callable, just wrap it. This can happen
# if, for example, a symolic expression is passed to a
# higher-level plot() function that converts it to a
# fast-callable with expr._plot_fast_callable() before
# we ever see it.
return FastCallablePlotWrapper(f, imag_tol=imaginary_tolerance)
elif hasattr(f, '__call__'):
# This will catch python functions, among other things. We don't
# wrap these yet because we don't know what type they'll return.
return f
else:
# Convert things like ZZ(0) into constant functions.
from sage.symbolic.ring import SR
ff = fast_callable(SR(f),
vars=vars,
expect_one_var=eov,
domain=CDF)
return FastCallablePlotWrapper(ff, imag_tol=imaginary_tolerance)

# Handle vectors, lists, tuples, etc.
if hasattr(funcs, "__iter__"):
funcs = tuple( try_make_fast(f) for f in funcs )
else:
funcs = try_make_fast(funcs)

#TODO: raise an error if there is a function/method in funcs that takes more values than we have ranges

if return_vars:
return (fast_float(funcs, *vars, **options),
return (funcs,
[tuple(_range + [range_step])
for _range, range_step in zip(ranges, range_steps)],
vars)
else:
return (fast_float(funcs, *vars, **options),
return (funcs,
[tuple(_range + [range_step])
for _range, range_step in zip(ranges, range_steps)])

Expand Down Expand Up @@ -192,6 +249,7 @@ def unify_arguments(funcs):
if not isinstance(funcs, (list, tuple)):
funcs = [funcs]

from sage.structure.element import Expression
for f in funcs:
if isinstance(f, Expression) and f.is_callable():
f_args = set(f.arguments())
Expand Down Expand Up @@ -384,3 +442,138 @@ def get_matplotlib_linestyle(linestyle, return_type):
"'dashed', 'dotted', dashdot', 'None'}, "
"respectively {'-', '--', ':', '-.', ''}"%
(linestyle))


class FastCallablePlotWrapper:
r"""
A class to alter the return types of the fast-callable functions
used during plotting.
When plotting symbolic expressions and functions, we generally
first convert them to a faster form with :func:`fast_callable`.
That function takes a ``domain`` parameter that forces the end
(and all intermediate) results of evaluation to a specific type.
Though we always want the end result to be of type ``float``,
correctly choosing the ``domain`` presents some problems:
* ``float`` is a bad choice because it's common for real
functions to have complex terms in them. Moreover, when one is
generating plots programmatically, precision issues can
produce terms like ``1.0 + 1e-12*I`` that are hard to avoid if
calling ``real()`` on everything is infeasible.
* ``complex`` has essentially the same problem as ``float``.
There are several symbolic functions like :func:`min_symbolic`,
:func:`max_symbolic`, and :func:`floor` that are unable to
operate on complex numbers.
* ``None`` leaves the types of the inputs/outputs alone, but due
to the lack of a specialized interpreter, slows down plotting
by an unacceptable amount.
* ``CDF`` has none of the other issues, because ``CDF`` has its
own specialized interpreter, a lexicographic ordering (for
min/max), and supports :func:`floor`. However, none of the
plotting functions can handle complex numbers, so using
``CDF`` would require us to wrap every evaluation in a
``CDF``-to-``float`` conversion routine within the plotting
infrastructure. This slows things down less than a domain of
``None`` does, but is unattractive mainly because of how
invasive it would be to "fix" the output everywhere.
Creating a new fast-callable interpreter that has different input
and output types solves most of the problems with a ``CDF``
domain, but :func:`fast_callable` and the interpreter classes in
:mod:`sage.ext.interpreters` are not really written with that in
mind. The ``domain`` parameter to :func:`fast_callable`, for
example, is expecting a single Sage ring that corresponds to one
interpreter. You can make it accept, for example, a string like
"CDF-to-float", but the hacks required to make that work feel
wrong.
Thus we arrive at this solution: a class to wrap the result of
:func:`fast_callable`. Whenever we need to support intermediate
complex terms in a plot function, we can set ``domain=CDF`` while
creating its fast-callable incarnation, and then wrap the result in
this class. The ``__call__`` method of this class then ensures
that the ``CDF`` output is converted to a ``float``. Since
plotting tries to ignore unplottable points, this job is easier
than it would be in a more general context: we simply return
``nan`` whenever the result has a nontrivial imaginary part.
EXAMPLES:
The ``float`` incarnation of "not a number" is returned instead
of an error being thrown if the answer is complex::
sage: from sage.plot.misc import FastCallablePlotWrapper
sage: f = sqrt(x)
sage: ff = fast_callable(f, vars=[x], domain=CDF)
sage: fff = FastCallablePlotWrapper(ff, imag_tol=1e-8)
sage: fff(1)
1.0
sage: fff(-1)
nan
"""
def __init__(self, ff, imag_tol):
r"""
Construct a ``FastCallablePlotWrapper``.
INPUT:
- ``ff`` -- a fast-callable wrapper over ``CDF``; an instance of
:class:`sage.ext.interpreters.Wrapper_cdf`, usually constructed
with :func:`fast_callable`.
- ``imag_tol`` -- float; how big of an imaginary part we're willing
to ignore before returning ``nan``.
OUTPUT:
An instance of ``FastCallablePlotWrapper`` that can be called
just like ``ff``, but that always returns a ``float``, even if
it is ``nan``.
EXAMPLES:
The wrapper will ignore an imaginary part smaller in magnitude
than ``imag_tol``::
sage: from sage.plot.misc import FastCallablePlotWrapper
sage: f = x
sage: ff = fast_callable(f, vars=[x], domain=CDF)
sage: fff = FastCallablePlotWrapper(ff, imag_tol=1e-8)
sage: fff(I*1e-9)
0.0
sage: fff = FastCallablePlotWrapper(ff, imag_tol=1e-12)
sage: fff(I*1e-9)
nan
"""
self._ff = ff
self._imag_tol = imag_tol

def __call__(self, *args):
r"""
Evaluate the underlying fast-callable and convert the result to
``float``.
TESTS:
Evaluation never fails and always returns a ``float``::
sage: from sage.plot.misc import FastCallablePlotWrapper
sage: f = x
sage: ff = fast_callable(f, vars=[x], domain=CDF)
sage: fff = FastCallablePlotWrapper(ff, imag_tol=1e-8)
sage: type(fff(CDF.random_element())) is float
True
"""
z = self._ff(*args)

if abs(z.imag()) < self._imag_tol:
return float(z.real())
else:
return float("nan")

0 comments on commit 0222526

Please sign in to comment.