Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convexity check for multivariate functions #26450

Merged
merged 17 commits into from
May 15, 2024
1 change: 1 addition & 0 deletions .mailmap
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ Elias Basler <e.e.basler@protonmail.com>
Elisha Hollander <just4now666666@gmail.com> donno2048 <just4now666666@gmail.com>
Elliot Marshall <Marshall2389@gmail.com> <marshall2389@gmail.com>
Elrond der Elbenfuerst <elrond+sympy.org@samba-tng.org>
Emile Fourcini <emile.fourcin1@gmail.com> Emile <emile.fourcin1@gmail.com>
Emma Hogan <ehogan@gemini.edu>
Enric Florit <efz1005@gmail.com>
Eric Demer <demer@mailbox.org>
Expand Down
2 changes: 1 addition & 1 deletion AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -1266,4 +1266,4 @@ Augusto Borges <borges.augustoar@gmail.com>
Han Wei Ang <ang.h.w@u.nus.edu>
Pablo <48098178+PabloRuizCuevas@users.noreply.github.com>
Congxu Yang <u7189828@anu.edu.au>
Saicharan <62512681+saicharan2804@users.noreply.github.com>
Saicharan <62512681+saicharan2804@users.noreply.github.com>
51 changes: 18 additions & 33 deletions sympy/calculus/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from sympy.core.numbers import (E, I, Rational, oo, pi)
from sympy.core.relational import Eq
from sympy.core.singleton import S
from sympy.core.symbol import (Dummy, Symbol, symbols)
from sympy.core.symbol import (Dummy, Symbol)
from sympy.functions.elementary.complexes import (Abs, re)
from sympy.functions.elementary.exponential import (exp, log)
from sympy.functions.elementary.integers import frac
Expand All @@ -23,12 +23,11 @@
from sympy.sets.fancysets import ImageSet
from sympy.sets.conditionset import ConditionSet
from sympy.testing.pytest import XFAIL, raises, _both_exp_pow, slow
from sympy.abc import x
from sympy.abc import x, y

a = Symbol('a', real=True)

def test_function_range():
x, y, a, b = symbols('x y a b')
assert function_range(sin(x), x, Interval(-pi/2, pi/2)
) == Interval(-1, 1)
assert function_range(sin(x), x, Interval(0, pi)
Expand Down Expand Up @@ -73,7 +72,6 @@ def test_function_range1():


def test_continuous_domain():
x = Symbol('x')
assert continuous_domain(sin(x), x, Interval(0, 2*pi)) == Interval(0, 2*pi)
assert continuous_domain(tan(x), x, Interval(0, 2*pi)) == \
Union(Interval(0, pi/2, False, True), Interval(pi/2, pi*Rational(3, 2), True, True),
Expand Down Expand Up @@ -186,10 +184,6 @@ def test_not_empty_in():

@_both_exp_pow
def test_periodicity():
x = Symbol('x')
y = Symbol('y')
z = Symbol('z', real=True)

assert periodicity(sin(2*x), x) == pi
assert periodicity((-2)*tan(4*x), x) == pi/4
assert periodicity(sin(x)**2, x) == 2*pi
Expand Down Expand Up @@ -221,14 +215,14 @@ def test_periodicity():

assert periodicity(exp(x), x) is None
assert periodicity(exp(I*x), x) == 2*pi
assert periodicity(exp(I*z), z) == 2*pi
assert periodicity(exp(z), z) is None
assert periodicity(exp(log(sin(z) + I*cos(2*z)), evaluate=False), z) == 2*pi
assert periodicity(exp(log(sin(2*z) + I*cos(z)), evaluate=False), z) == 2*pi
assert periodicity(exp(sin(z)), z) == 2*pi
assert periodicity(exp(2*I*z), z) == pi
assert periodicity(exp(z + I*sin(z)), z) is None
assert periodicity(exp(cos(z/2) + sin(z)), z) == 4*pi
assert periodicity(exp(I*a), a) == 2*pi
assert periodicity(exp(a), a) is None
assert periodicity(exp(log(sin(a) + I*cos(2*a)), evaluate=False), a) == 2*pi
assert periodicity(exp(log(sin(2*a) + I*cos(a)), evaluate=False), a) == 2*pi
assert periodicity(exp(sin(a)), a) == 2*pi
assert periodicity(exp(2*I*a), a) == pi
assert periodicity(exp(a + I*sin(a)), a) is None
assert periodicity(exp(cos(a/2) + sin(a)), a) == 4*pi
assert periodicity(log(x), x) is None
assert periodicity(exp(x)**sin(x), x) is None
assert periodicity(sin(x)**y, y) is None
Expand Down Expand Up @@ -261,9 +255,6 @@ def test_periodicity():


def test_periodicity_check():
x = Symbol('x')
y = Symbol('y')

assert periodicity(tan(x), x, check=True) == pi
assert periodicity(sin(x) + cos(x), x, check=True) == 2*pi
assert periodicity(sec(x), x) == 2*pi
Expand All @@ -285,13 +276,13 @@ def test_is_convex():
assert is_convex(x**2, x, domain=Interval(0, oo)) == True
assert is_convex(1/x**3, x, domain=Interval.Lopen(0, oo)) == True
assert is_convex(-1/x**3, x, domain=Interval.Ropen(-oo, 0)) == True
assert is_convex(log(x), x) == False
raises(NotImplementedError, lambda: is_convex(log(x), x, a))
assert is_convex(log(x) ,x) == False
assert is_convex(x**2+y**2, x, y) == True
assert is_convex(cos(x) + cos(y), x) == False
assert is_convex(8*x**2 - 2*y**2, x, y) == False


def test_stationary_points():
x, y = symbols('x y')

assert stationary_points(sin(x), x, Interval(-pi/2, pi/2)
) == {-pi/2, pi/2}
assert stationary_points(sin(x), x, Interval.Ropen(0, pi/4)
Expand Down Expand Up @@ -324,7 +315,6 @@ def test_stationary_points():


def test_maximum():
x, y = symbols('x y')
assert maximum(sin(x), x) is S.One
assert maximum(sin(x), x, Interval(0, 1)) == sin(1)
assert maximum(tan(x), x) is oo
Expand Down Expand Up @@ -357,8 +347,6 @@ def test_maximum():


def test_minimum():
x, y = symbols('x y')

assert minimum(sin(x), x) is S.NegativeOne
assert minimum(sin(x), x, Interval(1, 4)) == sin(4)
assert minimum(tan(x), x) is -oo
Expand Down Expand Up @@ -386,22 +374,19 @@ def test_minimum():


def test_issue_19869():
t = symbols('t')
assert (maximum(sqrt(3)*(t - 1)/(3*sqrt(t**2 + 1)), t)
assert (maximum(sqrt(3)*(x - 1)/(3*sqrt(x**2 + 1)), x)
) == sqrt(3)/3


def test_issue_16469():
x = Symbol("x", real=True)
f = abs(x)
assert function_range(f, x, S.Reals) == Interval(0, oo, False, True)
f = abs(a)
assert function_range(f, a, S.Reals) == Interval(0, oo, False, True)


@_both_exp_pow
def test_issue_18747():
assert periodicity(exp(pi*I*(x/4+S.Half/2)), x) == 8
assert periodicity(exp(pi*I*(x/4 + S.Half/2)), x) == 8


def test_issue_25942():
x = Symbol("x")
assert (acos(x) > pi/3).as_set() == Interval.Ropen(-1, S(1)/2)
10 changes: 3 additions & 7 deletions sympy/calculus/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sympy.sets.conditionset import ConditionSet
from sympy.utilities import filldedent
from sympy.utilities.iterables import iterable
from sympy.matrices.dense import hessian


def continuous_domain(f, symbol, domain):
Expand Down Expand Up @@ -745,18 +746,13 @@ def is_convex(f, *syms, domain=S.Reals):
.. [5] https://en.wikipedia.org/wiki/Concave_function

"""

if len(syms) > 1:
raise NotImplementedError(
"The check for the convexity of multivariate functions is not implemented yet.")

if len(syms) > 1 :
return hessian(f, syms).is_positive_semidefinite
from sympy.solvers.inequalities import solve_univariate_inequality

f = _sympify(f)
var = syms[0]
if any(s in domain for s in singularities(f, var)):
return False

condition = f.diff(var, 2) < 0
if solve_univariate_inequality(condition, var, False, domain):
return False
Expand Down
Loading