Skip to content

Commit

Permalink
Merge pull request #22056 from Upabjojr/split_multiple_contractions_a…
Browse files Browse the repository at this point in the history
…dd_onearray

Convert array to matrix: preserving dimension fix
  • Loading branch information
Upabjojr committed Sep 9, 2021
2 parents 5771e11 + 0542b65 commit 9cb75ab
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 27 deletions.
70 changes: 57 additions & 13 deletions sympy/tensor/array/expressions/array_expressions.py
Expand Up @@ -4,6 +4,8 @@
from itertools import accumulate
from typing import Optional, List, Dict

import typing

from sympy import Expr, ImmutableDenseNDimArray, S, Symbol, ZeroMatrix, Basic, tensorproduct, Add, permutedims, \
Tuple, tensordiagonal, Lambda, Dummy, Function, MatrixExpr, NDimArray, Indexed, IndexedBase, default_sort_key, \
tensorcontraction, diagonalize_vector, Mul
Expand Down Expand Up @@ -923,20 +925,30 @@ def split_multiple_contractions(self):
editor = _EditArrayContraction(self)

contraction_indices = self.contraction_indices
if isinstance(self.expr, ArrayTensorProduct):
args = list(self.expr.args)
else:
args = [self.expr]
# TODO: unify API, best location in ArrayTensorProduct
subranks = [get_rank(i) for i in args]
# TODO: unify API
mapping = _get_mapping_from_subranks(subranks)
reverse_mapping = {v: k for k, v in mapping.items()}

onearray_insert = []

for indl, links in enumerate(contraction_indices):
if len(links) <= 2:
continue

# Check multiple contractions:
#
# Examples:
#
# * `A_ij b_j0 C_jk` ===> `A*DiagMatrix(b)*C \otimes OneArray(1)` with permutation (1 2)
#
# Care for:
# - matrix being diagonalized (i.e. `A_ii`)
# - vectors being diagonalized (i.e. `a_i0`)

# Multiple contractions can be split into matrix multiplications if
# not more than three arguments are non-diagonals or non-vectors.
#
# Vectors get diagonalized while diagonal matrices remain diagonal.
# The non-diagonal matrices can be at the beginning or at the end
# of the final matrix multiplication line.

positions = editor.get_mapping_for_index(indl)

# Also consider the case of diagonal matrices being contracted:
Expand All @@ -945,11 +957,12 @@ def split_multiple_contractions(self):
not_vectors: Tuple[_ArgE, int] = []
vectors: Tuple[_ArgE, int] = []
for arg_ind, rel_ind in positions:
mat = args[arg_ind]
other_arg_pos = 1-rel_ind
other_arg_abs = reverse_mapping[arg_ind, other_arg_pos]
arg = editor.args_with_ind[arg_ind]
if (((1 not in mat.shape)) or
mat = arg.element
abs_arg_start, abs_arg_end = editor.get_absolute_range(arg)
other_arg_pos = 1-rel_ind
other_arg_abs = abs_arg_start + other_arg_pos
if ((1 not in mat.shape) or
((current_dimension == 1) is True and mat.shape != (1, 1)) or
any(other_arg_abs in l for li, l in enumerate(contraction_indices) if li != indl)
):
Expand All @@ -976,9 +989,14 @@ def split_multiple_contractions(self):
new_index = editor.get_new_contraction_index()
assert v.indices.index(None) == 1 - rel_ind
v.indices[v.indices.index(None)] = new_index
onearray_insert.append(v)

last_vec, rel_ind = vectors_to_loop[-1]
last_vec.indices[rel_ind] = new_index

for v in onearray_insert:
editor.insert_after(v, _ArgE(OneArray(1), [None]))

return editor.to_array_contraction()

def flatten_contraction_of_diagonal(self):
Expand Down Expand Up @@ -1465,6 +1483,32 @@ def track_permutation_merge(self, destination: _ArgE, from_element: _ArgE):
self._track_permutation[index_destination].extend(self._track_permutation[index_element])
self._track_permutation.pop(index_element)

def get_absolute_free_range(self, arg: _ArgE) -> typing.Tuple[int, int]:
"""
Return the range of the free indices of the arg as absolute positions
among all free indices.
"""
counter = 0
for arg_with_ind in self.args_with_ind:
number_free_indices = len([i for i in arg_with_ind.indices if i is None])
if arg_with_ind == arg:
return counter, counter + number_free_indices
counter += number_free_indices
raise IndexError("argument not found")

def get_absolute_range(self, arg: _ArgE) -> typing.Tuple[int, int]:
"""
Return the absolute range of indices for arg, disregarding dummy
indices.
"""
counter = 0
for arg_with_ind in self.args_with_ind:
number_indices = len(arg_with_ind.indices)
if arg_with_ind == arg:
return counter, counter + number_indices
counter += number_indices
raise IndexError("argument not found")


def get_rank(expr):
if isinstance(expr, (MatrixExpr, MatrixElement)):
Expand Down
Expand Up @@ -273,8 +273,8 @@ def test_arrayexpr_split_multiple_contractions():
X = MatrixSymbol("X", k, k)

cg = ArrayContraction(ArrayTensorProduct(A.T, a, b, b.T, (A*X*b).applyfunc(cos)), (1, 2, 8), (5, 6, 9))
assert cg.split_multiple_contractions().dummy_eq(ArrayContraction(ArrayTensorProduct(DiagMatrix(a), (A*X*b).applyfunc(cos), A.T, b, b.T), (0, 2), (1, 5), (3, 7, 8)))
# assert recognize_matrix_expression(cg)
expected = ArrayContraction(ArrayTensorProduct(A.T, DiagMatrix(a), OneArray(1), b, b.T, (A*X*b).applyfunc(cos)), (1, 3), (2, 9), (6, 7, 10))
assert cg.split_multiple_contractions().dummy_eq(expected)

# Check no overlap of lines:

Expand Down
Expand Up @@ -170,8 +170,8 @@ def test_arrayexpr_convert_array_to_diagonalized_vector():

cg = ArrayDiagonal(ArrayTensorProduct(I, x, A, B), (1, 2), (5, 6))
assert _array_diag2contr_diagmatrix(cg) == ArrayDiagonal(ArrayContraction(ArrayTensorProduct(I, OneArray(1), A, B, DiagMatrix(x)), (1, 7)), (5, 6))
# TODO: not yet working
# assert convert_array_to_matrix(cg)
# TODO: this is returning a wrong result:
# convert_array_to_matrix(cg)

cg = ArrayDiagonal(ArrayTensorProduct(x, I1), (1, 2))
assert isinstance(cg, ArrayDiagonal)
Expand Down Expand Up @@ -201,33 +201,35 @@ def test_arrayexpr_convert_array_to_diagonalized_vector():
assert convert_array_to_matrix(cg) == A * a

cg = ArrayContraction(ArrayTensorProduct(A, a, B), (1, 2, 4))
assert cg.split_multiple_contractions() == ArrayContraction(ArrayTensorProduct(A, DiagMatrix(a), B), (1, 2), (3, 4))
assert cg.split_multiple_contractions() == ArrayContraction(ArrayTensorProduct(A, DiagMatrix(a), OneArray(1), B), (1, 2), (3, 5))
assert convert_array_to_matrix(cg) == A * DiagMatrix(a) * B

cg = ArrayContraction(ArrayTensorProduct(A, a, B), (0, 2, 4))
assert cg.split_multiple_contractions() == ArrayContraction(ArrayTensorProduct(A, DiagMatrix(a), B), (0, 2), (3, 4))
assert cg.split_multiple_contractions() == ArrayContraction(ArrayTensorProduct(A, DiagMatrix(a), OneArray(1), B), (0, 2), (3, 5))
assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) * B

cg = ArrayContraction(ArrayTensorProduct(A, a, b, a.T, B), (0, 2, 4, 7, 9))
assert cg.split_multiple_contractions() == ArrayContraction(ArrayTensorProduct(A, DiagMatrix(a), DiagMatrix(b),
DiagMatrix(a), B),
(0, 2), (3, 4), (5, 7), (6, 9))
assert convert_array_to_matrix(cg).doit() == A.T * DiagMatrix(a) * DiagMatrix(b) * DiagMatrix(a) * B.T
assert cg.split_multiple_contractions() == ArrayContraction(ArrayTensorProduct(A, DiagMatrix(a), OneArray(1),
DiagMatrix(b), OneArray(1), DiagMatrix(a), OneArray(1), B),
(0, 2), (3, 5), (6, 9), (8, 12))
assert convert_array_to_matrix(cg) == A.T * DiagMatrix(a) * DiagMatrix(b) * DiagMatrix(a) * B.T

cg = ArrayContraction(ArrayTensorProduct(I1, I1, I1), (1, 2, 4))
assert cg.split_multiple_contractions() == ArrayContraction(ArrayTensorProduct(I1, I1, I1), (1, 2), (3, 4))
assert convert_array_to_matrix(cg).doit() == Identity(1)
assert cg.split_multiple_contractions() == ArrayContraction(ArrayTensorProduct(I1, I1, OneArray(1), I1), (1, 2), (3, 5))
assert convert_array_to_matrix(cg) == 1

cg = ArrayContraction(ArrayTensorProduct(I, I, I, I, A), (1, 2, 8), (5, 6, 9))
assert convert_array_to_matrix(cg.split_multiple_contractions()).doit() == A

cg = ArrayContraction(ArrayTensorProduct(A, a, C, a, B), (1, 2, 4), (5, 6, 8))
expected = ArrayContraction(ArrayTensorProduct(DiagMatrix(a), DiagMatrix(a), C, A, B), (0, 4), (1, 7), (2, 5), (3, 8))
expected = ArrayContraction(ArrayTensorProduct(A, DiagMatrix(a), OneArray(1), C, DiagMatrix(a), OneArray(1), B), (1, 3), (2, 5), (6, 7), (8, 10))
assert cg.split_multiple_contractions() == expected
assert convert_array_to_matrix(cg) == A * DiagMatrix(a) * C * DiagMatrix(a) * B

cg = ArrayContraction(ArrayTensorProduct(a, I1, b, I1, (a.T*b).applyfunc(cos)), (1, 2, 8), (5, 6, 9))
assert cg.split_multiple_contractions().dummy_eq(ArrayContraction(ArrayTensorProduct((a.T * b).applyfunc(cos), I1, I1, a, b), (0, 2), (1, 4), (3, 7), (5, 9)))
expected = ArrayContraction(ArrayTensorProduct(a, I1, OneArray(1), b, I1, OneArray(1), (a.T*b).applyfunc(cos)),
(1, 3), (2, 10), (6, 8), (7, 11))
assert cg.split_multiple_contractions().dummy_eq(expected)
assert convert_array_to_matrix(cg).doit().dummy_eq(MatMul(a, (a.T * b).applyfunc(cos), b.T))


Expand Down

0 comments on commit 9cb75ab

Please sign in to comment.