Skip to content

Commit

Permalink
Class API
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanKossaifi committed Oct 2, 2020
1 parent 1fe9902 commit 2fd1d1d
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 15 deletions.
11 changes: 6 additions & 5 deletions tensorly/decomposition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
tensor decomposition such as CANDECOMP-PARAFAC and Tucker.
"""

from .candecomp_parafac import parafac, non_negative_parafac, randomised_parafac, sample_khatri_rao
from ._tucker import tucker, partial_tucker, non_negative_tucker
from .candecomp_parafac import (parafac, non_negative_parafac, CPALS,
randomised_parafac, sample_khatri_rao)
from ._tucker import tucker, partial_tucker, non_negative_tucker, Tucker
from .robust_decomposition import robust_pca
from .mps_decomposition import matrix_product_state
from .parafac2 import parafac2
from .mps_decomposition import matrix_product_state, TensorTrain
from .parafac2 import parafac2, Parafac2ALS
from .symmetric_parafac import symmetric_parafac_power_iteration, symmetric_power_iteration

from ._cp_power import parafac_power_iteration, power_iteration, CPPower
8 changes: 4 additions & 4 deletions tensorly/decomposition/_tucker.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,16 +380,16 @@ def fit_transform(self, tensor):
random_state=self.random_state,
mask=self.mask,
verbose=self.verbose)
self.tucker_tensor_ = tucker_tensor
self.decomposition_ = tucker_tensor
return tucker_tensor[0]

def transform(self, tensor):
_, factors = self.tucker_tensor_
_, factors = self.decomposition_
return tlg.multi_mode_dot(tensor, factors, transpose=True)

def inverse_transform(self, tensor):
_, factors = self.tucker_tensor_
_, factors = self.decomposition_
return tlg.multi_mode_dot(tensor, factors)

def __repr__(self):
return f'Rank-{self.rank} Tucker decomposition.'
return f'Rank-{self.rank} Tucker decomposition via HOOI.'
7 changes: 4 additions & 3 deletions tensorly/decomposition/candecomp_parafac.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def randomised_parafac(tensor, rank, n_samples, n_iter_max=100, init='random', s
return KruskalTensor((weights, factors))


class CP:
class CPALS:
def __init__(self, rank, n_iter_max=100, tol=1e-08,
init='svd', svd='numpy_svd',
l2_reg=0,
Expand Down Expand Up @@ -829,7 +829,7 @@ def fit_transform(self, tensor):
KruskalTensor
decomposed tensor
"""
kruskal_tensor = parafac(tensor, rank=self.rank,
kruskal_tensor, errors = parafac(tensor, rank=self.rank,
n_iter_max=self.n_iter_max,
tol=self.tol,
init=self.init,
Expand All @@ -841,7 +841,8 @@ def fit_transform(self, tensor):
random_state=self.random_state,
verbose=self.verbose,
return_errors=True)
self.kruskal_tensor_ = kruskal_tensor
self.decomposition_ = kruskal_tensor
self.errors_ = errors
return kruskal_tensor

def fit(self, tensor):
Expand Down
4 changes: 2 additions & 2 deletions tensorly/decomposition/mps_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def __init__(self, rank, verbose=False):
self.verbose = verbose

def fit_transform(self, tensor):
self.tensor_train_ = matrix_product_state(tensor, rank=self.rank, verbose=self.verbose)
return self.tensor_train_
self.decomposition_ = matrix_product_state(tensor, rank=self.rank, verbose=self.verbose)
return self.decomposition_

def fit(self, tensor):
self.fit_transform(tensor)
Expand Down
129 changes: 129 additions & 0 deletions tensorly/decomposition/parafac2.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,132 @@ def parafac2(tensor_slices, rank, n_iter_max=100, init='random', svd='numpy_svd'
return parafac2_tensor, rec_errors
else:
return parafac2_tensor

class Parafac2ALS:

def __init__(self, rank, n_iter_max=100, init='random', svd='numpy_svd', normalize_factors=False,
tol=1e-8, random_state=None, verbose=False, return_errors=False, n_iter_parafac=5):
r"""PARAFAC2 decomposition [1]_ of a third order tensor via alternating least squares (ALS)
Computes a rank-`rank` PARAFAC2 decomposition of the third-order tensor defined by
`tensor_slices`. The decomposition is on the form :math:`(A [B_i] C)` such that the
i-th frontal slice, :math:`X_i`, of :math:`X` is given by
.. math::
X_i = B_i diag(a_i) C^T,
where :math:`diag(a_i)` is the diagonal matrix whose nonzero entries are equal to
the :math:`i`-th row of the :math:`I \times R` factor matrix :math:`A`, :math:`B_i`
is a :math:`J_i \times R` factor matrix such that the cross product matrix :math:`B_{i_1}^T B_{i_1}`
is constant for all :math:`i`, and :math:`C` is a :math:`K \times R` factor matrix.
To compute this decomposition, we reformulate the expression for :math:`B_i` such that
.. math::
B_i = P_i B,
where :math:`P_i` is a :math:`J_i \times R` orthogonal matrix and :math:`B` is a
:math:`R \times R` matrix.
An alternative formulation of the PARAFAC2 decomposition is that the tensor element
:math:`X_{ijk}` is given by
.. math::
X_{ijk} = \sum_{r=1}^R A_{ir} B_{ijr} C_{kr},
with the same constraints hold for :math:`B_i` as above.
Parameters
----------
tensor_slices : ndarray or list of ndarrays
Either a third order tensor or a list of second order tensors that may have different number of rows.
Note that the second mode factor matrices are allowed to change over the first mode, not the
third mode as some other implementations use (see note below).
rank : int
Number of components.
n_iter_max : int
Maximum number of iteration
init : {'svd', 'random', KruskalTensor, Parafac2Tensor}
Type of factor matrix initialization. See `initialize_factors`.
svd : str, default is 'numpy_svd'
function to use to compute the SVD, acceptable values in tensorly.SVD_FUNS
normalize_factors : bool (optional)
If True, aggregate the weights of each factor in a 1D-tensor
of shape (rank, ), which will contain the norms of the factors. Note that
there may be some inaccuracies in the component weights.
tol : float, optional
(Default: 1e-8) Relative reconstruction error tolerance. The
algorithm is considered to have found the global minimum when the
reconstruction error is less than `tol`.
random_state : {None, int, np.random.RandomState}
verbose : int, optional
Level of verbosity
return_errors : bool, optional
Activate return of iteration errors
n_iter_parafac: int, optional
Number of PARAFAC iterations to perform for each PARAFAC2 iteration
Returns
-------
Parafac2Tensor : (weight, factors, projection_matrices)
* weights : 1D array of shape (rank, )
all ones if normalize_factors is False (default),
weights of the (normalized) factors otherwise
* factors : List of factors of the CP decomposition element `i` is of shape
(tensor.shape[i], rank)
* projection_matrices : List of projection matrices used to create evolving
factors.
errors : list
A list of reconstruction errors at each iteration of the algorithms.
References
----------
.. [1] Kiers, H.A.L., ten Berge, J.M.F. and Bro, R. (1999),
PARAFAC2—Part I. A direct fitting algorithm for the PARAFAC2 model.
J. Chemometrics, 13: 275-294.
Notes
-----
This formulation of the PARAFAC2 decomposition is slightly different from the one in [1]_.
The difference lies in that here, the second mode changes over the first mode, whereas in
[1]_, the second mode changes over the third mode. We made this change since that means
that the function accept both lists of matrices and a single nd-array as input without
any reordering of the modes.
"""
self.rank = rank
self.n_iter_max=n_iter_max
self.init=init
self.svd=svd
self.normalize_factors=normalize_factors
self.tol=tol
self.random_state=random_state
self.verbose=verbose
self.return_errors=return_errors
self.n_iter_parafac = n_iter_parafac

def fit_transform(self, tensor):
self.decomposition_ = parafac2(tensor, rank = self.rank,
n_iter_max=self.n_iter_max,
init=self.init,
svd=self.svd,
normalize_factors=self.normalize_factors,
tol=self.tol,
random_state=self.random_state,
verbose=self.verbose,
return_errors=self.return_errors,
n_iter_parafac = self.n_iter_parafac)
return self.decomposition_

def fit(self, tensor):
self.fit_transform(tensor)
return self

def __repr__(self):
return f'Rank-{self.rank} PARAFAC2 decomposition.'



2 changes: 1 addition & 1 deletion tensorly/decomposition/symmetric_parafac.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def symmetric_power_iteration(tensor, n_repeat=10, n_iteration=10, verbose=False
modes = list(range(1, order))

for _ in range(n_repeat):
factor = tl.tensor(np.random.random_sample(size))
factor = tl.tensor(np.random.random_sample(size), **tl.context(tensor))

for _ in range(n_iteration):
for _ in range(order):
Expand Down

0 comments on commit 2fd1d1d

Please sign in to comment.