-
Notifications
You must be signed in to change notification settings - Fork 62
/
atoms.py
162 lines (127 loc) · 4.87 KB
/
atoms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
""" Sympy primitives for representing atoms of ga expressions """
from typing import Union
from sympy import Symbol, AtomicExpr, S, Basic, sympify, MatrixExpr
from sympy import Determinant as _Determinant
from sympy.core import numbers
from sympy.core.function import AppliedUndef, UndefinedFunction
from sympy.printing.pretty.stringpict import prettyForm, stringPict
from sympy.printing.pretty.pretty_symbology import U
__all__ = [
'BasisVectorSymbol',
'BasisBladeSymbol',
'BasisBladeNoWedgeSymbol',
'BasisBaseSymbol',
'DotProductSymbol',
]
def _all_same(items):
return all(x == items[0] for x in items)
class BasisVectorSymbol(Symbol):
""" A symbol representing a basis vector """
is_commutative = False
def _latex(self, print_obj):
try:
return print_obj._print_Symbol(self, style="bold")
except TypeError:
# too old a sympy version for `style=`
return r"\mathbf{{{}}}".format(print_obj._print_Symbol(self))
class _GradedSymbol(AtomicExpr):
""" Base class for all graded symbols
Constructing this from a single symbol returns that symbol itself.
Constructing from no symbols returns the scalar `S.One`.
This may change in future.
"""
# the scalar isn't commutative, but __new__ ensures we do not ever create
# this type of objects for scalars
is_commutative = False
def __new__(cls, *args: BasisVectorSymbol) -> Union[
numbers.One,
BasisVectorSymbol,
"_GradedSymbol"
]:
if len(args) == 0:
return S.One
elif len(args) == 1:
return args[0]
else:
return super().__new__(cls, *args)
class _JoinedPrinterMixin(Basic):
""" Helper class to print `Basic.args` joined by symbol.
Subclasses must populate `_op_sym` and `_op_sym_latex`
"""
def _sympystr(self, printer):
return self._op_sympystr.join(
printer._print(v)
for v in self.args
)
def _pretty(self, printer):
ret = []
for i, v in enumerate(self.args):
if i != 0:
ret.append(self._op_pretty)
ret.append(printer._print(v))
return prettyForm(*stringPict.next(*ret))
def _latex(self, printer):
return self._op_latex.join(
printer._print(v)
for v in self.args
)
class BasisBaseSymbol(_GradedSymbol, _JoinedPrinterMixin):
r""" A basis base in a non-orthogonal algebra, such as :math:`e_1 e_2` """
_op_sympystr = '*'
_op_pretty = prettyForm('*')
_op_latex = ''
class BasisBladeSymbol(_GradedSymbol, _JoinedPrinterMixin):
r""" A basis blade such as :math:`e_1 \wedge e_2` """
_op_sympystr = '^'
_op_pretty = prettyForm('^')
_op_latex = r'\wedge '
class BasisBladeNoWedgeSymbol(BasisBladeSymbol):
r""" A basis blade with shortened rendering such as :math:`e_{12}` """
def _split_name(self):
sub_str = []
root_str = []
for basis_vec in self.args:
split_lst = basis_vec.name.split('_')
if len(split_lst) != 2:
raise ValueError('Incompatible basis vector {} for wedgeless printing'.format(basis_vec))
else:
sub_str.append(split_lst[1])
root_str.append(split_lst[0])
if _all_same(root_str):
return root_str[0], ''.join(sub_str)
else:
raise ValueError('No unique root symbol to use for wedgeless printing')
def __common_printer(self, printer):
# print as if we were a basis vector
root, sub = self._split_name()
return printer._print(BasisVectorSymbol("{}_{}".format(root, sub)))
_sympystr = _pretty = _latex = __common_printer
class DotProductSymbol(AtomicExpr):
""" A symbol used to represent a dot product, like :class:`sympy.DotProduct` """
is_real = True
def _sympystr(self, printer):
a, b = self.args
return "({}.{})".format(printer._print(a), printer._print(b))
def _latex(self, printer):
a, b = self.args
return r"\left ({}\cdot {}\right ) ".format(printer._print(a), printer._print(b))
def _pretty(self, printer):
a, b = self.args
pform = prettyForm(*stringPict.next(
printer._print(a),
printer._print(U('DOT OPERATOR')),
printer._print(b),
))
return prettyForm(*pform.parens())
class MatrixFunction(UndefinedFunction):
""" Like a MatrixSymbol, but for functions. """
def __new__(mcl, name, m, n, **kwargs):
cls = super().__new__(mcl, name, (AppliedUndef, MatrixExpr), {}, **kwargs)
cls.shape = sympify(n, strict=True), sympify(n, strict=True)
return cls
# workaround until sympy/sympy#19354 is merged
if _Determinant.is_commutative is not True:
class Determinant(_Determinant):
is_commutative = True
else:
Determinant = _Determinant