diff --git a/sympy/simplify/fu.py b/sympy/simplify/fu.py index 8cd518b3d4e1..bd66c2c16193 100644 --- a/sympy/simplify/fu.py +++ b/sympy/simplify/fu.py @@ -246,6 +246,13 @@ def f(rv): rv = fmap[type(rv)](S.Pi/2 - rv.args[0]) return rv + # touch numbers iside of trig functions to let them automatically update + rv = rv.replace( + lambda x: isinstance(x, TrigonometricFunction), + lambda x: x.replace( + lambda n: n.is_number and n.is_Mul, + lambda n: n.func(*n.args))) + return bottom_up(rv, f) @@ -273,7 +280,12 @@ def TR4(rv): 0 1 zoo 0 """ # special values at 0, pi/6, pi/4, pi/3, pi/2 already handled - return rv + return rv.replace( + lambda x: + isinstance(x, TrigonometricFunction) and + (r:=x.args[0]/pi).is_Rational and r.q in (1, 2, 3, 4, 6), + lambda x: + x.func(x.args[0].func(*x.args[0].args))) def _TR56(rv, f, g, h, max, pow): diff --git a/sympy/simplify/tests/test_fu.py b/sympy/simplify/tests/test_fu.py index 29f614d7ad6f..0b0a6fbc491b 100644 --- a/sympy/simplify/tests/test_fu.py +++ b/sympy/simplify/tests/test_fu.py @@ -1,15 +1,17 @@ from sympy.core.add import Add from sympy.core.mul import Mul from sympy.core.numbers import (I, Rational, pi) +from sympy.core.parameters import evaluate from sympy.core.singleton import S from sympy.core.symbol import (Dummy, Symbol, symbols) from sympy.functions.elementary.hyperbolic import (cosh, coth, csch, sech, sinh, tanh) from sympy.functions.elementary.miscellaneous import (root, sqrt) -from sympy.functions.elementary.trigonometric import (cos, cot, csc, sec, sin, tan) +from sympy.functions.elementary.trigonometric import (cos, cot, csc, sec, sin, tan, + TrigonometricFunction) from sympy.simplify.powsimp import powsimp from sympy.simplify.fu import ( L, TR1, TR10, TR10i, TR11, _TR11, TR12, TR12i, TR13, TR14, TR15, TR16, - TR111, TR2, TR2i, TR3, TR5, TR6, TR7, TR8, TR9, TRmorrie, _TR56 as T, + TR111, TR2, TR2i, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TRmorrie, _TR56 as T, TRpower, hyper_as_trig, fu, process_common_addends, trig_split, as_f_sign_1) from sympy.core.random import verify_numerically @@ -72,6 +74,17 @@ def test_TR3(): j = TR3(i) assert verify_numerically(i, j) and i.func != j.func + with evaluate(False): + eq = cos(9*pi/22) + assert eq.has(9*pi) and TR3(eq) == sin(pi/11) + + +def test_TR4(): + for i in [0, pi/6, pi/4, pi/3, pi/2]: + with evaluate(False): + eq = cos(i) + assert isinstance(eq, cos) and not isinstance(TR4(eq), cos) + def test__TR56(): h = lambda x: 1 - x