Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added .xreplace() to Vector and Dyadic #20446

Merged
merged 5 commits into from Jan 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 46 additions & 0 deletions sympy/physics/vector/dyadic.py
Expand Up @@ -539,6 +539,52 @@ def _eval_evalf(self, prec):
new_args.append(tuple(new_inlist))
return Dyadic(new_args)

def xreplace(self, rule):
"""
Replace occurrences of objects within the measure numbers of the Dyadic.

Parameters
==========

rule : dict-like
Expresses a replacement rule.

Returns
=======

Dyadic
Result of the replacement.

Examples
========

>>> from sympy import symbols, pi
>>> from sympy.physics.vector import ReferenceFrame, outer
>>> N = ReferenceFrame('N')
>>> D = outer(N.x, N.x)
>>> x, y, z = symbols('x y z')
>>> ((1 + x*y) * D).xreplace({x: pi})
(pi*y + 1)*(N.x|N.x)
>>> ((1 + x*y) * D).xreplace({x: pi, y: 2})
(1 + 2*pi)*(N.x|N.x)
sidhu1012 marked this conversation as resolved.
Show resolved Hide resolved

Replacements occur only if an entire node in the expression tree is
matched:

>>> ((x*y + z) * D).xreplace({x*y: pi})
(z + pi)*(N.x|N.x)
>>> ((x*y*z) * D).xreplace({x*y: pi})
x*y*z*(N.x|N.x)

"""

new_args = []
for inlist in self.args:
new_inlist = list(inlist)
new_inlist[0] = new_inlist[0].xreplace(rule)
new_args.append(tuple(new_inlist))
return Dyadic(new_args)

def _check_dyadic(other):
if not isinstance(other, Dyadic):
raise TypeError('A Dyadic must be supplied')
Expand Down
15 changes: 14 additions & 1 deletion sympy/physics/vector/tests/test_dyadic.py
@@ -1,5 +1,5 @@
from sympy import sin, cos, symbols, pi, Float, ImmutableMatrix as Matrix
from sympy.physics.vector import ReferenceFrame, Vector, dynamicsymbols
from sympy.physics.vector import ReferenceFrame, Vector, dynamicsymbols, outer
from sympy.physics.vector.dyadic import _check_dyadic
from sympy.testing.pytest import raises

Expand Down Expand Up @@ -104,3 +104,16 @@ def test_dyadic_evalf():
s = symbols('s')
a = 5 * s * pi* (N.x | N.x)
assert a.evalf(2) == Float('5', 2) * Float('3.1416', 2) * s * (N.x | N.x)

def test_dyadic_xreplace():
x, y, z = symbols('x y z')
N = ReferenceFrame('N')
D = outer(N.x, N.x)
v = x*y * D
assert v.xreplace({x : cos(x)}) == cos(x)*y * D
assert v.xreplace({x*y : pi}) == pi * D
v = (x*y)**z * D
assert v.xreplace({(x*y)**z : 1}) == D
assert v.xreplace({x:1, z:0}) == D
raises(TypeError, lambda: v.xreplace())
raises(TypeError, lambda: v.xreplace([x, y]))
10 changes: 10 additions & 0 deletions sympy/physics/vector/tests/test_vector.py
Expand Up @@ -177,3 +177,13 @@ def test_vector_evalf():
assert v.evalf(2) == Float('3.1416', 2) * A.x
v = pi * A.x + 5 * a * A.y - b * A.z
assert v.evalf(3) == Float('3.1416', 3) * A.x + Float('5', 3) * a * A.y - b * A.z

def test_vector_xreplace():
x, y, z = symbols('x y z')
v = x**2 * A.x + x*y * A.y + x*y*z * A.z
assert v.xreplace({x : cos(x)}) == cos(x)**2 * A.x + y*cos(x) * A.y + y*z*cos(x) * A.z
assert v.xreplace({x*y : pi}) == x**2 * A.x + pi * A.y + x*y*z * A.z
assert v.xreplace({x*y*z : 1}) == x**2*A.x + x*y*A.y + A.z
assert v.xreplace({x:1, z:0}) == A.x + y * A.y
raises(TypeError, lambda: v.xreplace())
raises(TypeError, lambda: v.xreplace([x, y]))
43 changes: 43 additions & 0 deletions sympy/physics/vector/vector.py
Expand Up @@ -719,6 +719,49 @@ def _eval_evalf(self, prec):
new_args.append([mat.evalf(n=prec_to_dps(prec)), frame])
return Vector(new_args)

def xreplace(self, rule):
"""
Replace occurrences of objects within the measure numbers of the vector.

Parameters
==========

rule : dict-like
Expresses a replacement rule.

Returns
=======

Vector
Result of the replacement.

Examples
========

>>> from sympy import symbols, pi
>>> from sympy.physics.vector import ReferenceFrame
>>> A = ReferenceFrame('A')
>>> x, y, z = symbols('x y z')
>>> ((1 + x*y) * A.x).xreplace({x: pi})
(pi*y + 1)*A.x
>>> ((1 + x*y) * A.x).xreplace({x: pi, y: 2})
(1 + 2*pi)*A.x

Replacements occur only if an entire node in the expression tree is
matched:

>>> ((x*y + z) * A.x).xreplace({x*y: pi})
(z + pi)*A.x
>>> ((x*y*z) * A.x).xreplace({x*y: pi})
x*y*z*A.x

"""

new_args = []
for mat, frame in self.args:
mat = mat.xreplace(rule)
new_args.append([mat, frame])
return Vector(new_args)

class VectorTypeError(TypeError):

Expand Down