-
Notifications
You must be signed in to change notification settings - Fork 62
/
test_lt.py
147 lines (114 loc) · 4.46 KB
/
test_lt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import unittest
import pytest
from sympy import symbols, S, Matrix
from galgebra.ga import Ga
from galgebra.lt import Mlt
class TestLt(unittest.TestCase):
# reproduce gh-105
def test_lt_matrix(self):
base = Ga('a b', g=[1, 1], coords=symbols('x, y', real=True))
a, b = base.mv()
A = base.lt([a+b, 2*a-b])
assert str(A) == 'Lt(a) = a + b\nLt(b) = 2*a - b'
assert str(A.matrix()) == 'Matrix([[1, 2], [1, -1]])'
def test_lt_function(self):
""" Test construction from a function """
base = Ga('a b', g=[1, 1], coords=symbols('x, y', real=True))
a, b = base.mv()
def not_linear(x):
return x * x
with pytest.raises(ValueError, match='linear'):
base.lt(not_linear)
def not_vector(x):
return x + S.One
with pytest.raises(ValueError, match='vector'):
base.lt(not_vector)
def ok(x):
return (x | b) * a + 2*x
f = base.lt(ok)
x = base.mv('x', 'vector')
y = base.mv('y', 'vector')
assert f(x) == ok(x)
assert f(x^y) == ok(x)^ok(y)
assert f(1 + 2*(x^y)) == 1 + 2*(ok(x)^ok(y))
def test_deprecations(self):
base = Ga('a b', g=[1, 1], coords=symbols('x, y', real=True))
l = base.lt([[1, 2], [3, 4]])
with pytest.warns(DeprecationWarning):
assert l.X == l.Ga.coord_vec
with pytest.warns(DeprecationWarning):
assert l.coords == l.Ga.coords
l = base.lt('L', mode='a')
with pytest.warns(DeprecationWarning):
assert l.mode == 'a'
with pytest.warns(DeprecationWarning):
assert not l.fct_flg
l = base.lt('L', mode='s', f=True)
with pytest.warns(DeprecationWarning):
assert l.mode == 's'
with pytest.warns(DeprecationWarning):
assert l.fct_flg
class TestMlt(unittest.TestCase):
def test_basic(self):
# from TensorDef.py
coords = symbols('t x y z', real=True)
st4d, g0, g1, g2, g3 = Ga.build('gamma*t|x|y|z', g=[1, -1, -1, -1],
coords=coords)
A = st4d.mv('T', 'bivector')
def TA(a1, a2):
return A | (a1 ^ a2)
T = Mlt(TA, st4d)
# tests begin
a1 = st4d.mv('a1', 'vector')
a2 = st4d.mv('a2', 'vector')
a3 = st4d.mv('a3', 'vector')
a4 = st4d.mv('a4', 'vector')
# calling the Mlt is like calling the function
assert T(a1, a2) == TA(a1, a2)
# for addition, argument slots are reused
assert (T + T)(a1, a2) == T(a1, a2) + T(a1, a2)
assert (T - T)(a1, a2) == T(a1, a2) - T(a1, a2)
# for multiplication, argument slots are chained
assert (T * T)(a1, a2, a3, a4) == TA(a1, a2) * T(a3, a4)
assert (T ^ T)(a1, a2, a3, a4) == TA(a1, a2) ^ T(a3, a4)
assert (T | T)(a1, a2, a3, a4) == TA(a1, a2) | T(a3, a4)
# Test linearity properties. Note that this behavior is implied by our
# test that T and TA are equivalent above, but it does exercise
# `Mlt.__call__` with compound expressions as arguments.
alpha = st4d.mv('alpha', 'scalar')
b = st4d.mv('b', 'vector')
assert T(alpha * a1, a2) == alpha * T(a1, a2)
assert T(a1, alpha * a2) == alpha * T(a1, a2)
assert T(a1 + b, a2) == T(a1, a2) + T(b, a2)
assert T(a1, a2 + b) == T(a1, a2) + T(a1, b)
def test_from_str(self):
coords = symbols('x y', real=True)
g, e1, e2 = Ga.build('e*1|2', coords=coords, g=[1, 1])
a1 = g.mv('a1', 'vector')
a2 = g.mv('a2', 'vector')
a1x, a1y = a1.get_coefs(1)
a2x, a2y = a2.get_coefs(1)
# one-d
T = Mlt('T', g, nargs=1)
v = T(a1)
# Two new symbols created
Tx, Ty = sorted(v.free_symbols - {a1x, a1y}, key=lambda x: x.sort_key())
assert v == (
Tx * a1x +
Ty * a1y
)
# two-d
T = Mlt('T', g, nargs=2)
v = T(a1, a2)
# four new symbols created
Txx, Txy, Tyx, Tyy = sorted(v.free_symbols - {a1x, a1y, a2x, a2y}, key=lambda x: x.sort_key())
assert v == (
Txx * a1x * a2x +
Txy * a1x * a2y +
Tyx * a1y * a2x +
Tyy * a1y * a2y
)
def test_deprecations(self):
g = Ga('e*a|b', g=[1, 1])
with pytest.warns(DeprecationWarning):
assert Mlt.extact_basis_indexes(g) == ['a', 'b']