/
tensor.py
320 lines (273 loc) · 15.8 KB
/
tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
"""
Provides functions for doing general tensor operations, notably tensor contraction.
In particular, vorpy.tensor.contract works like numpy.einsum but works on dtype=object
(but is probably MUCH slower).
"""
import functools
import itertools
import numpy as np
import operator
import typing
import vorpy
class VorpyTensorException(vorpy.VorpyException):
"""
Base class for all exceptions generated by the vorpy.tensor module (this doesn't
necessarily include other exceptions generated by functions called within the
vorpy.tensor module).
"""
pass
class VorpyTensorProgrammerError(vorpy.VorpyProgrammerError):
"""Base class for all internal programmer error exceptions generated by the vorpy.tensor module."""
pass
def is_shape (tup:typing.Tuple[int,...]) -> bool:
"""
A shape is a tuple (which implies finite-length) having nonnegative integer elements. In particular,
the empty tuple is a shape, which is the shape of a scalar value (i.e. a 0-tensor).
"""
return all(x >= 0 for x in tup)
def validate_shape_or_raise (tup:typing.Tuple[int,...], tup_name:str) -> None:
"""
Does nothing if tup is a valid shape. Otherwise will raise TypeError with an indicative message.
"""
if not is_shape(tup):
raise TypeError(f'{tup_name} was expected to be a shape (tuple of nonnegative ints; could be the empty tuple), but it was actually {tup}')
def shape (tn):
"""
Returns the shape of the tensor tn. If tn is a scalar, then the shape is tuple().
Note that multiindex_iterator(shape(tn)) can be used to iterate over the components
of tn. See also component. This function just returns numpy.shape(tn).
"""
return np.shape(tn)
def order_of_shape (sh):
"""
Returns the tensor order of a tensor type having shape sh, which in particular is len(sh).
The order of a tensor is equivalent to the number of indices necessary to address one of its
components. In particular, the tensor order of the shape () is 0.
"""
return len(sh)
def dimension_of_shape (sh):
"""
Returns the dimension of the vector space defined by a tensor type having shape sh, which in
particular is numpy.prod(sh). The dimension of a tensor type can also be thought of as the
number of independent components of an instance of that tensor type.Note that if
order_of_shape(sh) == 0, then dimension_of_shape(sh) == 1.
All components of sh must be nonnegative (otherwise they don't specify the dimension of a
tensor factor).
"""
assert all(s >= 0 for s in sh), 'sh must have all nonnegative components'
return np.prod(sh, dtype=int)
def order (tn):
"""
Returns the tensor order of the tensor type of tn, which is order_of_shape(shape(tn)).
"""
return order_of_shape(shape(tn))
def dimension (tn):
"""
Returns the dimension of the tensor type of tn, which is np.size(tn), and is equivalent to
dimension_of_shape(shape(tn)). Note that the dimension of a scalar type is 1.
"""
return np.size(tn)
def component (tn, multiindex):
"""
Returns the component of tn specified by the given multiindex. In particular, if tn has tensor
order 0 (i.e. is a scalar), then this just returns tn itself.
"""
return tn[multiindex] if hasattr(tn,'shape') else tn
# NOTE: This is present (in possibly better form) in vorpy.symbolic. TODO: Consolidate.
def multiindex_iterator (multiindex_shape):
"""
Returns an iterator for a multiindex for the given shape. For example,
multiindex_iterator(shape(T)) can be used to iterate over the components of T.
"""
return itertools.product(*tuple(range(dim) for dim in multiindex_shape))
def identity_tensor (operand_shape:typing.Tuple[int,...], *, dtype:typing.Any) -> np.ndarray:
"""
Returns the identity tensor operating on the tensor space having given shape. The shape of
the returned value will be operand_shape+operand_shape.
For example, if operand_shape is (2,3,4) and I denotes identity_tensor(operand_shape), then I
will have shape (2,3,4,2,3,4) and if T has shape (2,3,4), then np.einsum('ijkpqr,pqr', I, T) == T.
More generally for a tensor T,
np.dot(identity_tensor(T.shape).reshape(T.size,T.size), T.reshape(T.size)).reshape(T.shape) == T
i.e. where you "flatten" the operand space into a 1-tensor and apply identity_tensor to it.
"""
return np.eye(dimension_of_shape(operand_shape), dtype=dtype).reshape(operand_shape+operand_shape)
def diagonal_tensor (diagonal_t:np.ndarray) -> np.ndarray:
"""
Returns a tensor which operates on a space having shape diagonal_t.shape, scaling each component by the corresponding
element of diagonal_t. The return value has shape s+s, where s denotes diagonal_t.shape.
"""
shape = np.shape(diagonal_t)
return np.diagflat(diagonal_t).reshape(shape+shape)
def operand_shape_of (T:np.ndarray) -> typing.Tuple[int,...]:
"""
Returns the shape of the vector space that the tensor T operates on, assuming T is a linear operator (i.e. T
has shape s+s for some shape s, where s is the operand shape). Raises Exception if T is not a linear operator.
"""
o = order(T)
if o % 2 != 0:
raise Exception(f'expected T to be a tensor having even order, but its order was {o}')
k = o//2
operand_shape = T.shape[:k]
expected_shape = operand_shape + operand_shape
if T.shape != expected_shape:
raise Exception(f'expected T to have shape {expected_shape} (i.e. be a linear operator, operating on shape {operand_shape}) but it had shape {T.shape}')
return operand_shape
def as_linear_operator (T:np.ndarray) -> np.ndarray:
"""
Reshapes tensor T to be a square matrix. Requires T to have shape operand_shape+operand_shape, where
operand_shape is the shape of the operand space.
"""
operand_shape = operand_shape_of(T)
operand_dim = dimension_of_shape(operand_shape)
return T.reshape(operand_dim,operand_dim)
def _positions_of_all_occurrences_of_char (s, c):
for pos,ch in enumerate(s):
if ch == c:
yield pos
def contract (contraction_string, *tensors, **kwargs):
"""
This is meant to do the same thing as numpy.einsum, except that it can handle dtype=object
(but is probably MUCH slower).
"""
if '->' in contraction_string:
raise VorpyTensorException('The -> syntax supported in numpy.einsum is not supported in vorpy.tensor.contract; use the `output` kwarg instead')
output_index_string = kwargs.get('output', None)
if 'dtype' not in kwargs:
raise VorpyTensorException('Must specify the \'dtype\' keyword argument (e.g. dtype=float, dtype=object, etc).')
dtype = kwargs['dtype']
error_messages = []
#
# Starting here is just checking that the contraction is well-defined, such as checking
# the summation semantics of the contracted and free indices, checking that the contracted
# slots' dimensions match, etc.
#
# Verify that the indices in the contraction string match the orders of the tensor arguments.
index_strings = contraction_string.split(',')
if len(index_strings) != len(tensors):
raise VorpyTensorException('There must be the same number of comma-delimited index strings (which in this case is {0}) as tensor arguments (which in this case is {1}).'.format(len(index_strings), len(tensors)))
all_index_counts_matched = True
for i,(index_string,tensor) in enumerate(zip(index_strings,tensors)):
if len(index_string) != order(tensor):
error_messages.append('the number of indices in {0}th index string \'{1}\' (which in this case is {2}) did not match the order of the corresponding tensor argument (which in this case is {3})'.format(i, index_string, len(index_string), order(tensor)))
all_index_counts_matched = False
if not all_index_counts_matched:
raise VorpyTensorException('At least one index string had a number of indices that did not match the order of its corresponding tensor argument. In particular, {0}.'.format(', '.join(error_messages)))
# Determine which indices are to be contracted (defined as any indices occurring more than once)
# and determine the free indices (defined as any indices occurring exactly once).
indices = frozenset(c for c in contraction_string if c != ',')
contraction_indices = frozenset(c for c in indices if contraction_string.count(c) > 1)
free_indices = indices - contraction_indices # Set subtraction
# If the 'output' keyword argument wasn't specified, use the alphabetization of free_indices
# as the output indices.
if output_index_string == None:
output_indices = free_indices
output_index_string = ''.join(sorted(list(free_indices)))
# Otherwise, perform some verification on output_index_string.
else:
# If the 'output' keyword argument was specified (stored in output_index_string),
# then verify that it's well-defined, in that that output_index_string contains
# unique characters.
output_indices = frozenset(output_index_string)
output_indices_are_unique = True
for index in output_indices:
if output_index_string.count(index) > 1:
error_messages.append('index \'{0}\' occurs more than once'.format(index))
output_indices_are_unique = False
if not output_indices_are_unique:
raise VorpyTensorException('The characters of the output keyword argument (which in this case is \'{0}\') must be unique. In particular, {1}.'.format(output_index_string, ', '.join(error_messages)))
# Verify that free_indices and output_index_string contain exactly the same characters.
if output_indices != free_indices:
raise VorpyTensorException('The output indices (which in this case are \'{0}\') must be precisely the free indices (which in this case are \'{1}\').'.format(''.join(sorted(output_indices)), ''.join(sorted(free_indices))))
# Verify that the dimensions of each of contraction_indices match, while constructing
# an indexed list of the dimensions of the contracted slots.
contraction_index_string = ''.join(sorted(list(contraction_indices)))
contracted_indices_dimensions_match = True
for contraction_index in contraction_index_string:
indexed_slots_and_dims = []
for arg_index,(index_string,tensor) in enumerate(zip(index_strings,tensors)):
for slot_index in _positions_of_all_occurrences_of_char(index_string,contraction_index):
indexed_slots_and_dims.append((arg_index,slot_index,tensor.shape[slot_index]))
distinct_dims = frozenset(dim for arg_index,slot_index,dim in indexed_slots_and_dims)
if len(distinct_dims) > 1:
slot_indices = ','.join('{0}th'.format(slot_index) for _,slot_index,_ in indexed_slots_and_dims)
arg_indices = ','.join('{0}th'.format(arg_index) for arg_index,_,_ in indexed_slots_and_dims)
dims = ','.join('{0}'.format(dim) for _,_,dim in indexed_slots_and_dims)
error_messages.append('index \'{0}\' is used to contract the {1} slots respectively of the {2} tensor arguments whose respective slots have non-matching dimensions {3}'.format(contraction_index, slot_indices, arg_indices, dims))
contracted_indices_dimensions_match = False
if not contracted_indices_dimensions_match:
raise VorpyTensorException('The dimensions of at least one set of contracted tensor slots did not match. In particular, {0}.'.format(', '.join(error_messages)))
def dims_of_index_string (index_string):
def tensor_and_slot_in_which_index_occurs (index):
for index_string,tensor in zip(index_strings,tensors):
slot = index_string.find(index)
if slot >= 0:
return tensor,slot
raise VorpyTensorProgrammerError('This should never happen.')
lookup = tuple(tensor_and_slot_in_which_index_occurs(index) for index in index_string)
return tuple(tensor.shape[slot] for tensor,slot in lookup)
contraction_dims = dims_of_index_string(contraction_index_string)
output_dims = dims_of_index_string(output_index_string)
#
# Starting here is the actual contraction computation
#
def component_indices_function (index_string):
is_contraction_index = tuple(index in contraction_index_string for index in index_string)
lookups = tuple((0 if is_contraction_index[i] else 1, contraction_index_string.index(index) if is_contraction_index[i] else output_index_string.index(index)) for i,index in enumerate(index_string))
index_string_pair = (contraction_index_string, output_index_string)
for i,lookup in enumerate(lookups):
if index_string[i] != index_string_pair[lookup[0]][lookup[1]]:
raise VorpyTensorProgrammerError('This should not happen')
def component_indices_of (contracted_and_output_indices_tuple):
if len(lookups) != len(index_string):
raise VorpyTensorProgrammerError('This should not happen')
if len(contracted_and_output_indices_tuple) != 2:
raise VorpyTensorProgrammerError('This should not happen')
if len(contracted_and_output_indices_tuple[0]) != len(contraction_index_string):
raise VorpyTensorProgrammerError('This should not happen')
if len(contracted_and_output_indices_tuple[1]) != len(output_index_string):
raise VorpyTensorProgrammerError('This should not happen')
retval = tuple(contracted_and_output_indices_tuple[lookup[0]][lookup[1]] for lookup in lookups)
return retval
test_output = ''.join(component_indices_of((contraction_index_string, output_index_string)))
if test_output != index_string:
raise VorpyTensorProgrammerError('This should not happen')
return component_indices_of
component_indices_functions = tuple(component_indices_function(index_string) for index_string in index_strings)
def product_of_components_of_tensors (contracted_and_output_indices_tuple):
return functools.reduce(
operator.mul,
tuple(
component(tensor,component_indices_function(contracted_and_output_indices_tuple))
for tensor,component_indices_function in zip(tensors,component_indices_functions)
),
1,
)
def computed_component (output_component_indices):
return sum(product_of_components_of_tensors((contraction_component_indices, output_component_indices)) for contraction_component_indices in multiindex_iterator(contraction_dims))
retval = np.ndarray(output_dims, dtype=dtype, buffer=np.array([computed_component(output_component_indices) for output_component_indices in multiindex_iterator(output_dims)]))
# If the result is a 0-tensor, then coerce it to the scalar type.
if retval.shape == tuple():
retval = retval[tuple()]
return retval
def tensor_power_of_vector (V, p):
"""
Returns the pth tensor power of vector V. This should be a tensor having order p,
which looks like V \otimes ... \otimes V (with p factors). If p is zero, then this
returns 1.
TODO: Implement this for tensors of arbitrary order (especially including 0-tensors).
"""
V_order = vorpy.tensor.order(V)
if V_order != 1:
raise FancyException(f'Expected V to be a vector (i.e. a 1-tensor), but it was actually a {V_order}-tensor')
if p < 0:
raise FancyException(f'Expected p to be a nonnegative integer, but it was actually {p}')
if p == 0:
return np.array(1) # TODO: Should this be an actual scalar?
elif p == 1:
return V
else:
assert len(V.shape) == 1 # This should be equivalent to V_order == 1.
V_dim = V.shape[0]
V_to_the_p_minus_1 = tensor_power_of_vector(V, p-1)
retval_shape = (V_dim,)*p
return np.outer(V, V_to_the_p_minus_1.reshape(-1)).reshape(*retval_shape)