Skip to content

Commit

Permalink
Class API for decompositions
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanKossaifi committed Oct 3, 2020
1 parent 3402998 commit 3d7fa48
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 76 deletions.
6 changes: 3 additions & 3 deletions tensorly/decomposition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
tensor decomposition such as CANDECOMP-PARAFAC and Tucker.
"""

from .candecomp_parafac import (parafac, non_negative_parafac, CPALS,
from .candecomp_parafac import (parafac, non_negative_parafac, CP, RandomisedCP,
randomised_parafac, sample_khatri_rao)
from ._tucker import tucker, partial_tucker, non_negative_tucker, Tucker
from .robust_decomposition import robust_pca
from .mps_decomposition import matrix_product_state, TensorTrain
from .parafac2 import parafac2, Parafac2ALS
from .symmetric_parafac import symmetric_parafac_power_iteration, symmetric_power_iteration
from .parafac2 import parafac2, Parafac2
from .symmetric_parafac import symmetric_parafac_power_iteration, symmetric_power_iteration, SymmetricCP
from ._cp_power import parafac_power_iteration, power_iteration, CPPower
7 changes: 2 additions & 5 deletions tensorly/decomposition/_cp_power.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorly as tl
from .._base_classes import DecompositionMixin
from tensorly.tenalg import outer
from tensorly.metrics.regression import standard_deviation
import numpy as np
Expand Down Expand Up @@ -113,7 +114,7 @@ def parafac_power_iteration(tensor, rank, n_repeat=10, n_iteration=10, verbose=0



class CPPower:
class CPPower(DecompositionMixin):
def __init__(self, rank, n_repeat=10, n_iteration=10, verbose=0):
"""CP Decomposition via Robust Tensor Power Iteration
Expand Down Expand Up @@ -165,9 +166,5 @@ def fit_transform(self, tensor):
self.decomposition_ = kruskal_tensor
return kruskal_tensor

def fit(self, tensor):
self.fit_transform(tensor)
return self

def __repr__(self):
return f'Rank-{self.rank} CP decomposition via Robust Tensor Power Iteration.'
60 changes: 38 additions & 22 deletions tensorly/decomposition/_tucker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorly as tl
from .._base_classes import DecompositionMixin
from ..base import unfold
from ..tenalg import multi_mode_dot, mode_dot
from ..tucker_tensor import tucker_to_tensor
Expand Down Expand Up @@ -310,26 +311,34 @@ def non_negative_tucker(tensor, rank, n_iter_max=10, init='svd', tol=10e-5,
return nn_core, nn_factors


class Tucker:
def __init__(self, rank=None, n_iter_max=100,
class Tucker(DecompositionMixin):
def __init__(self, rank=None, n_iter_max=100, non_negative=False,
init='svd', svd='numpy_svd', tol=10e-5,
random_state=None, mask=None, verbose=False):
"""Tucker decomposition via Higher Order Orthogonal Iteration (HOI)
"""Tucker decomposition
Decomposes `tensor` into a Tucker decomposition:
``tensor = [| core; factors[0], ...factors[-1] |]`` [1]_
Uses Higher Order Orthogonal Iteration (HOI) if non_negative-False
and iterative multiplicative update otherwise if non_negative=True.
Parameters
----------
tensor : ndarray
ranks : None or int list
size of the core tensor, ``(len(ranks) == tensor.ndim)``
rank : None or int
number of components
non_negative : bool, default is False
if True, uses a non-negative Tucker via iterative multiplicative updates
otherwise, uses a Higher-Order Orthogonal Iteration.
n_iter_max : int
maximum number of iteration
init : {'svd', 'random'}, optional
svd : str, default is 'numpy_svd'
ignore if non_negative is True
function to use to compute the SVD,
acceptable values in tensorly.SVD_FUNS
tol : float, optional
Expand All @@ -353,6 +362,7 @@ def __init__(self, rank=None, n_iter_max=100,
SIAM REVIEW, vol. 51, n. 3, pp. 455-500, 2009.
"""
self.rank = rank
self.non_negative = non_negative
self.n_iter_max = n_iter_max
self.init = init
self.svd = svd
Expand All @@ -361,29 +371,35 @@ def __init__(self, rank=None, n_iter_max=100,
self.mask = mask
self.verbose = verbose

def fit(self, tensor):
self.fit_transform(tensor)
return self

def fit_transform(self, tensor):
tucker_tensor = tucker(tensor, rank=self.rank,
n_iter_max=self.n_iter_max,
init=self.init,
svd=self.svd,
tol=self.tol,
random_state=self.random_state,
mask=self.mask,
verbose=self.verbose)
if self.non_negative:
if self.mask is not None:
raise ValueError('mask is currently not suppoorted for non-negative Tucker.')
tucker_tensor = non_negative_tucker(tensor, rank=self.rank,
n_iter_max=self.n_iter_max,
init=self.init,
tol=self.tol,
random_state=self.random_state,
verbose=self.verbose)
else:
tucker_tensor = tucker(tensor, rank=self.rank,
n_iter_max=self.n_iter_max,
init=self.init,
svd=self.svd,
tol=self.tol,
random_state=self.random_state,
mask=self.mask,
verbose=self.verbose)
self.decomposition_ = tucker_tensor
return tucker_tensor[0]
return tucker_tensor

def transform(self, tensor):
_, factors = self.decomposition_
return tlg.multi_mode_dot(tensor, factors, transpose=True)
# def transform(self, tensor):
# _, factors = self.decomposition_
# return tlg.multi_mode_dot(tensor, factors, transpose=True)

def inverse_transform(self, tensor):
_, factors = self.decomposition_
return tlg.multi_mode_dot(tensor, factors)
# def inverse_transform(self, tensor):
# _, factors = self.decomposition_
# return tlg.multi_mode_dot(tensor, factors)

def __repr__(self):
return f'Rank-{self.rank} Tucker decomposition via HOOI.'
118 changes: 97 additions & 21 deletions tensorly/decomposition/candecomp_parafac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings

import tensorly as tl
from .._base_classes import DecompositionMixin
from ..random import check_random_state, random_kruskal
from ..base import unfold
from ..kruskal_tensor import (kruskal_to_tensor, KruskalTensor,
Expand Down Expand Up @@ -629,7 +630,7 @@ def sample_khatri_rao(matrices, n_samples, skip_matrix=None,


def randomised_parafac(tensor, rank, n_samples, n_iter_max=100, init='random', svd='numpy_svd',
tol=10e-9, max_stagnation=20, random_state=None, verbose=1):
tol=10e-9, max_stagnation=20, return_errors=False, random_state=None, verbose=1):
"""Randomised CP decomposition via sampled ALS
Parameters
Expand All @@ -651,6 +652,8 @@ def randomised_parafac(tensor, rank, n_samples, n_iter_max=100, init='random', s
if not zero, the maximum allowed number
of iterations with no decrease in fit
random_state : {None, int, np.random.RandomState}, default is None
return_errors : bool, default is False
if True, return a list of all errors
verbose : int, optional
level of verbosity
Expand Down Expand Up @@ -711,10 +714,13 @@ def randomised_parafac(tensor, rank, n_samples, n_iter_max=100, init='random', s
print('converged in {} iterations.'.format(iteration))
break

return KruskalTensor((weights, factors))
if return_errors:
return KruskalTensor((weights, factors)), rec_errors
else:
return KruskalTensor((weights, factors))


class CPALS:
class CP(DecompositionMixin):
def __init__(self, rank, n_iter_max=100, tol=1e-08,
init='svd', svd='numpy_svd',
l2_reg=0,
Expand All @@ -728,7 +734,7 @@ def __init__(self, rank, n_iter_max=100, tol=1e-08,
cvg_criterion='abs_rec_error',
random_state=None,
verbose=0):
"""Candecomp-Parafac decomposition
"""Candecomp-Parafac decomposition via Alternating-Least Square
Computes a rank-`rank` decomposition of `tensor` [1]_ such that,
Expand All @@ -743,6 +749,7 @@ def __init__(self, rank, n_iter_max=100, tol=1e-08,
Maximum number of iteration
init : {'svd', 'random'}, optional
Type of factor matrix initialization. See `initialize_factors`.
non_negative : bool, default is False
svd : str, default is 'numpy_svd'
function to use to compute the SVD, acceptable values in tensorly.SVD_FUNS
normalize_factors : if True, aggregate the weights of each factor in a 1D-tensor
Expand Down Expand Up @@ -806,6 +813,7 @@ def __init__(self, rank, n_iter_max=100, tol=1e-08,
self.l2_reg = l2_reg
self.init = init
self.linesearch = linesearch
self.non_negative = non_negative
self.svd = svd
self.normalize_factors = normalize_factors
self.orthogonalise = orthogonalise
Expand All @@ -829,25 +837,93 @@ def fit_transform(self, tensor):
KruskalTensor
decomposed tensor
"""
kruskal_tensor, errors = parafac(tensor, rank=self.rank,
n_iter_max=self.n_iter_max,
tol=self.tol,
init=self.init,
svd=self.svd,
normalize_factors=self.normalize_factors,
orthogonalise=self.orthogonalise,
mask=self.mask,
cvg_criterion=self.cvg_criterion,
random_state=self.random_state,
verbose=self.verbose,
return_errors=True)
if self.non_negative:
kruskal_tensor, errors = non_negative_parafac(tensor, rank=self.rank,
n_iter_max=self.n_iter_max,
tol=self.tol,
init=self.init,
svd=self.svd,
normalize_factors=self.normalize_factors,
orthogonalise=self.orthogonalise,
mask=self.mask,
cvg_criterion=self.cvg_criterion,
random_state=self.random_state,
verbose=self.verbose,
return_errors=True)
else:
kruskal_tensor, errors = parafac(tensor, rank=self.rank,
n_iter_max=self.n_iter_max,
tol=self.tol,
init=self.init,
svd=self.svd,
normalize_factors=self.normalize_factors,
orthogonalise=self.orthogonalise,
mask=self.mask,
linesearch = self.linesearch,
cvg_criterion=self.cvg_criterion,
random_state=self.random_state,
verbose=self.verbose,
return_errors=True)
self.decomposition_ = kruskal_tensor
self.errors_ = errors
return kruskal_tensor

def fit(self, tensor):
self.fit_transform(tensor)
return self
return self.decomposition_

def __repr__(self):
return f'Rank-{self.rank} CP decomposition.'



class RandomisedCP(DecompositionMixin):

def __init__(self, rank, n_samples, n_iter_max=100, init='random', svd='numpy_svd',
tol=10e-9, max_stagnation=20, random_state=None, verbose=1):
"""Randomised CP decomposition via sampled ALS
Parameters
----------
tensor : ndarray
rank : int
number of components
n_samples : int
number of samples per ALS step
n_iter_max : int
maximum number of iteration
init : {'svd', 'random'}, optional
svd : str, default is 'numpy_svd'
function to use to compute the SVD, acceptable values in tensorly.SVD_FUNS
tol : float, optional
tolerance: the algorithm stops when the variation in
the reconstruction error is less than the tolerance
max_stagnation: int, optional, default is 0
if not zero, the maximum allowed number
of iterations with no decrease in fit
random_state : {None, int, np.random.RandomState}, default is None
verbose : int, optional
level of verbosity
Returns
-------
factors : ndarray list
list of positive factors of the CP decomposition
element `i` is of shape ``(tensor.shape[i], rank)``
References
----------
.. [3] Casey Battaglino, Grey Ballard and Tamara G. Kolda,
"A Practical Randomized CP Tensor Decomposition",
"""
self.rank=rank
self.n_samples=n_samples
self.n_iter_max=n_iter_max
self.init=init
self.svd=svd
self.tol=tol
self.max_stagnation=max_stagnation
self.random_state=random_state
self.verbose=verbose

def fit_transform(self, tensor):
self.decomposition_, self.errors_ = randomised_parafac(tensor, rank=self.rank, n_samples=self.n_samples,
n_iter_max=self.n_iter_max, init=self.init, svd=self.svd, tol=self.tol, return_errors=True,
max_stagnation=self.max_stagnation, random_state=self.random_state, verbose=self.verbose)
return self.decomposition_
13 changes: 2 additions & 11 deletions tensorly/decomposition/mps_decomposition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorly as tl
from .._base_classes import DecompositionMixin

def matrix_product_state(input_tensor, rank, verbose=False):
"""MPS decomposition via recursive SVD
Expand Down Expand Up @@ -83,7 +84,7 @@ def matrix_product_state(input_tensor, rank, verbose=False):
return factors


class TensorTrain:
class TensorTrain(DecompositionMixin):
def __init__(self, rank, verbose=False):
"""MPS decomposition via recursive SVD
Expand Down Expand Up @@ -115,13 +116,3 @@ def __init__(self, rank, verbose=False):
def fit_transform(self, tensor):
self.decomposition_ = matrix_product_state(tensor, rank=self.rank, verbose=self.verbose)
return self.decomposition_

def fit(self, tensor):
self.fit_transform(tensor)
return self

def __repr__(self):
return f'Rank-{self.rank} Tensor-Train decomposition.'



17 changes: 3 additions & 14 deletions tensorly/decomposition/parafac2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorly as tl
from .._base_classes import DecompositionMixin
from tensorly.random import random_parafac2
from tensorly import backend as T
from . import parafac
Expand Down Expand Up @@ -279,7 +280,8 @@ def parafac2(tensor_slices, rank, n_iter_max=100, init='random', svd='numpy_svd'
else:
return parafac2_tensor

class Parafac2ALS:

class Parafac2(DecompositionMixin):

def __init__(self, rank, n_iter_max=100, init='random', svd='numpy_svd', normalize_factors=False,
tol=1e-8, random_state=None, verbose=False, n_iter_parafac=5):
Expand Down Expand Up @@ -355,9 +357,6 @@ def __init__(self, rank, n_iter_max=100, init='random', svd='numpy_svd', normali
* projection_matrices : List of projection matrices used to create evolving
factors.
errors : list
A list of reconstruction errors at each iteration of the algorithms.
References
----------
.. [1] Kiers, H.A.L., ten Berge, J.M.F. and Bro, R. (1999),
Expand Down Expand Up @@ -394,13 +393,3 @@ def fit_transform(self, tensor):
return_errors=True,
n_iter_parafac = self.n_iter_parafac)
return self.decomposition_

def fit(self, tensor):
self.fit_transform(tensor)
return self

def __repr__(self):
return f'Rank-{self.rank} PARAFAC2 decomposition.'



0 comments on commit 3d7fa48

Please sign in to comment.