-
Notifications
You must be signed in to change notification settings - Fork 281
/
mps_tensor.py
117 lines (90 loc) · 3.9 KB
/
mps_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Core operations on tensors in Matrix Product State (MPS) format, also known as Tensor-Train (TT)
"""
import tensorly as tl
def _validate_mps_tensor(mps_tensor):
factors = mps_tensor
n_factors = len(factors)
if n_factors < 2:
raise ValueError('A Matrix-Product-State (ttrain) tensor should be composed of at least two factors and a core.'
'However, {} factor was given.'.format(n_factors))
rank = []
shape = []
for index, factor in enumerate(factors):
current_rank, current_shape, next_rank = tl.shape(factor)
# Check that factors are third order tensors
if not tl.ndim(factor)==3:
raise ValueError('MPS expresses a tensor as third order factors (tt-cores).\n'
'However, tl.ndim(factors[{}]) = {}'.format(
index, tl.ndim(factor)))
# Consecutive factors should have matching ranks
if index and tl.shape(factors[index - 1])[2] != current_rank:
raise ValueError('Consecutive factors should have matching ranks\n'
' -- e.g. tl.shape(factors[0])[2]) == tl.shape(factors[1])[0])\n'
'However, tl.shape(factor[{}])[2] == {} but'
' tl.shape(factor[{}])[0] == {} '.format(
index - 1, tl.shape(factors[index - 1])[2], index, current_rank))
# Check for boundary conditions
if (index == 0) and current_rank != 1:
raise ValueError('Boundary conditions dictate factor[0].shape[0] == 1.'
'However, got factor[0].shape[0] = {}.'.format(
current_rank))
if (index == n_factors - 1) and next_rank != 1:
raise ValueError('Boundary conditions dictate factor[-1].shape[2] == 1.'
'However, got factor[{}].shape[2] = {}.'.format(
n_factors, next_rank))
shape.append(current_shape)
rank.append(current_rank)
# Add last rank (boundary condition)
rank.append(next_rank)
return tuple(shape), tuple(rank)
def mps_to_tensor(factors):
"""Returns the full tensor whose MPS decomposition is given by 'factors'
Re-assembles 'factors', which represent a tensor in MPS/TT format
into the corresponding full tensor
Parameters
----------
factors: list of 3D-arrays
MPS factors (known as core in TT terminology)
Returns
-------
output_tensor: ndarray
tensor whose MPS/TT decomposition was given by 'factors'
"""
full_shape = [f.shape[1] for f in factors]
full_tensor = tl.reshape(factors[0], (full_shape[0], -1))
for factor in factors[1:]:
rank_prev, _, rank_next = factor.shape
factor = tl.reshape(factor, (rank_prev, -1))
full_tensor = tl.dot(full_tensor, factor)
full_tensor = tl.reshape(full_tensor, (-1, rank_next))
return tl.reshape(full_tensor, full_shape)
def mps_to_unfolded(factors, mode):
"""Returns the unfolding matrix of a tensor given in MPS (or Tensor-Train) format
Reassembles a full tensor from 'factors' and returns its unfolding matrix
with mode given by 'mode'
Parameters
----------
factors: list of 3D-arrays
MPS factors
mode: int
unfolding matrix to be computed along this mode
Returns
-------
2-D array
unfolding matrix at mode given by 'mode'
"""
return tl.unfold(mps_to_tensor(factors), mode)
def mps_to_vec(factors):
"""Returns the tensor defined by its MPS format ('factors') into
its vectorized format
Parameters
----------
factors: list of 3D-arrays
MPS factors
Returns
-------
1-D array
vectorized format of tensor defined by 'factors'
"""
return tl.tensor_to_vec(mps_to_tensor(factors))