Skip to content

Commit

Permalink
PR Review
Browse files Browse the repository at this point in the history
  • Loading branch information
caglayantuna committed Feb 25, 2021
1 parent d921527 commit dcd3ffc
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 194 deletions.
33 changes: 12 additions & 21 deletions tensorly/decomposition/_nn_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ._base_decomposition import DecompositionMixin
from ..random import random_cp
from ..base import unfold
from ..tenalg.proximal import soft_thresholding,hals_nnls_approx,hals_nnls_exact
from ..tenalg.proximal import soft_thresholding,hals_nnls
from ..cp_tensor import (cp_to_tensor, CPTensor,
unfolding_dot_khatri_rao, cp_norm,
cp_normalize, validate_cp_rank)
Expand Down Expand Up @@ -292,7 +292,7 @@ def non_negative_parafac(tensor, rank, n_iter_max=100, init='svd', svd='numpy_sv


def non_negative_parafac_hals(tensor, rank, n_iter_max=100, init="svd", svd='numpy_svd', tol=1e-7,
sparsity_coefficients=[], fixed_modes=[],hals='approx',
sparsity_coefficients=None, fixed_modes=None,exact=False,
verbose=False, return_errors=False):
"""
Non-negative CP decomposition via HALS
Expand Down Expand Up @@ -339,9 +339,6 @@ def non_negative_parafac_hals(tensor, rank, n_iter_max=100, init="svd", svd='num
toc: list
A list with accumulated time at each iterations
fixed_modes = [], normalize = [False, False, False],
verbose = True, return_errors = False)
References
----------
[1]: N. Gillis and F. Glineur, Accelerated Multiplicative Updates and
Expand All @@ -356,9 +353,9 @@ def non_negative_parafac_hals(tensor, rank, n_iter_max=100, init="svd", svd='num
norm_tensor = tl.norm(tensor, 2)

n_modes = tl.ndim(tensor)
if sparsity_coefficients is None or isinstance(sparsity_coefficients, float):
sparsity_coefficients = [sparsity_coefficients]*n_modes
if sparsity_coefficients == None or len(sparsity_coefficients) != n_modes:
sparsity_coefficients = [None for i in range(n_modes)]

if fixed_modes is None:
fixed_modes = []

Expand Down Expand Up @@ -389,16 +386,10 @@ def non_negative_parafac_hals(tensor, rank, n_iter_max=100, init="svd", svd='num
else:
mttkrp = unfolding_dot_khatri_rao(tensor, (None, factors), mode)


# Call the hals resolution with nnls, optimizing the current mode
if hals=='approx':
factors[mode] = tl.transpose(
hals_nnls_approx(tl.transpose(mttkrp), pseudo_inverse, tl.transpose(factors[mode]),
n_iter_max=100,sparsity_coefficient=sparsity_coefficients[mode])[0])
elif hals=='exact':
factors[mode] = tl.transpose(
hals_nnls_exact(tl.transpose(mttkrp), pseudo_inverse, tl.transpose(factors[mode]),
n_iter_max=5000)[0])
factors[mode] = tl.transpose(
hals_nnls(tl.transpose(mttkrp), pseudo_inverse, tl.transpose(factors[mode]),
n_iter_max=100,sparsity_coefficient=sparsity_coefficients[mode],exact=exact)[0])

if tol:
factors_norm = cp_norm((weights, factors))
Expand Down Expand Up @@ -557,7 +548,7 @@ def __repr__(self):



class CPNN_Hals(DecompositionMixin):
class CPNN_HALS(DecompositionMixin):
"""Non-Negative Candecomp-Parafac decomposition via Alternating-Least Square
Computes a rank-`rank` decomposition of `tensor` [1]_ such that,
Expand Down Expand Up @@ -635,12 +626,12 @@ def __init__(self, rank, n_iter_max=100, tol=1e-08,
fixed_modes=[],
normalize_factors=False,
sparsity=None,
hals='approx',
exact=False,
mask=None, svd_mask_repeats=5,
cvg_criterion='abs_rec_error',
random_state=None,
verbose=0):
self.hals=hals
self.exact=exact
self.rank = rank
self.n_iter_max = n_iter_max
self.tol = tol
Expand Down Expand Up @@ -673,7 +664,7 @@ def fit_transform(self, tensor):
tol=self.tol,
init=self.init,
svd=self.svd,
hals=self.hals,
exact=self.exact,
verbose=self.verbose,
return_errors=True)

Expand Down
197 changes: 24 additions & 173 deletions tensorly/tenalg/proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def procrustes(matrix):
"""
U, _, V = tl.partial_svd(matrix, n_eigenvecs=min(matrix.shape))
return tl.dot(U, V)
def hals_nnls_approx(UtM, UtU, V, n_iter_max=500,delta=10e-8,
sparsity_coefficient=None, normalize = False,nonzero_rows=False):
def hals_nnls(UtM, UtU, V=None, n_iter_max=500,delta=10e-8,
sparsity_coefficient=None, normalize = False,nonzero_rows=False,exact=False):

"""
Non Negative Least Squares (NNLS)
Expand Down Expand Up @@ -140,7 +140,7 @@ def hals_nnls_approx(UtM, UtU, V, n_iter_max=500,delta=10e-8,
Pre-computed product of the transposed of U and M, used in the update rule
UtU: r-by-r array
Pre-computed product of the transposed of U and U, used in the update rule
in_V: r-by-n initialization matrix (mutable)
V: r-by-n initialization matrix (mutable)
Initialized V array
By default, is initialized with one non-zero entry per column
corresponding to the closest column of U of the corresponding column of M.
Expand All @@ -156,7 +156,7 @@ def hals_nnls_approx(UtM, UtU, V, n_iter_max=500,delta=10e-8,
The coefficient controling the sparisty level in the objective function.
If set to None, the problem is solved unconstrained.
Default: None
nonzero_row: boolean
nonzero_rows: boolean
True if the lines of the V matrix can't be zero,
False if they can be zero
Default: False
Expand Down Expand Up @@ -188,60 +188,40 @@ def hals_nnls_approx(UtM, UtU, V, n_iter_max=500,delta=10e-8,
if V is None: # checks if V is empty
V = tl.solve(UtU, UtM)

V[V < 0] = 0
V=tl.clip(V,a_min=0,a_max=None)
# Scaling
scale = tl.sum(UtM * V) / tl.sum(
UtU * tl.dot(V, tl.transpose(V)))
V = tl.dot(scale, V)
else:
V = in_V
if exact:
n_iter_max=5000
delta = 10e-12

rho = 1
eps0 = 0
cnt = 1
eps = 1


while eps >= delta * eps0 and cnt <= 1 + 0.5* rho and cnt <= maxiter:
for iteration in range(n_iter_max):
nodelta = 0
for k in range(rank):

if UtU[k, k]:
if sparsity_coefficient is not None: # Modifying the objective function for sparsification
if tl.get_backend() == 'pytorch':
import torch
deltaV = torch.maximum((UtM[k, :] - UtU[k, :] @ V - sparsity_coefficient * tl.ones(n_col_M)) / UtU[k, k],
-V[k, :])
else:
deltaV = tl.max([(UtM[k, :] - UtU[k, :] @ V - sparsity_coefficient * tl.ones(n_col_M)) / UtU[k, k],

deltaV = tl.max([(UtM[k, :] - UtU[k, :] @ V - sparsity_coefficient * tl.ones(n_col_M)) / UtU[k, k],
-V[k, :]],axis=0)
if tl.get_backend()=='tensorflow':
import tensorflow as tf
V=tf.Variable(V,dtype='float')
V[k, :].assign(V[k, :] + deltaV)
else:
V[k, :] = V[k, :] + deltaV
tl.index_update(V, tl.index[k, :], V[k, :] + deltaV)

else: # without sparsity

if tl.get_backend() == 'pytorch':
import torch
deltaV = torch.maximum((UtM[k, :] - tl.dot(UtU[k, :], V)) / UtU[k, k],
-V[k, :])
else:
deltaV = tl.max([(UtM[k, :] - tl.dot(UtU[k, :], V)) / UtU[k, k],
deltaV = tl.max([(UtM[k, :] - tl.dot(UtU[k, :], V)) / UtU[k, k],
-V[k, :]], axis=0)
if tl.get_backend()=='tensorflow':
import tensorflow as tf
V=tf.Variable(V,dtype='float')
V[k, :].assign(V[k, :] + deltaV)
else:
V[k, :] = V[k, :] + deltaV
tl.index_update(V,tl.index[k, :],V[k, :] + deltaV)

nodelta = nodelta + tl.dot(deltaV, tl.transpose(deltaV))

# Safety procedure, if columns aren't allow to be zero
if nonzero_row and (V[k, :] == 0).all():
if nonzero_rows and tl.all(V[k, :] == 0):
V[k, :] = 1e-16 * tl.max(V)

elif nonzero_rows:
Expand All @@ -254,147 +234,18 @@ def hals_nnls_approx(UtM, UtU, V, n_iter_max=500,delta=10e-8,
else:
sqrt_n = 1/n_col_M ** (1/2)
V[k,:] = [sqrt_n for i in range(n_col_M)]
if cnt == 1:
if iteration == 1:
eps0 = nodelta

rho_up=tl.shape(V)[0]*tl.shape(V)[1]+tl.shape(V)[1]*rank
rho_down=tl.shape(V)[0]*rank+tl.shape(V)[0]
rho=1+(rho_up/rho_down)
eps = nodelta
cnt += 1

return V, eps, cnt, rho
def hals_nnls_exact(UtM, UtU, in_V, maxiter,delta=10e-12,sparsity_coefficient=None):
"""
Non Negative Least Squares (NNLS)
Computes an exact solution of a nonnegative least
squares problem (NNLS) with an exact block-coordinate descent scheme.
M is m by n, U is m by r, V is r by n.
All matrices are nonnegative componentwise.
The NNLS unconstrained problem, as defined in [1], solve the following problem:
min_{V >= 0} ||M-UV||_F^2
The matrix V is updated linewise.
The update rule of the k-th line of V (V[k,:]) for this resolution is::
V[k,:]_(j+1) = V[k,:]_(j) + (UtM[k,:] - UtU[k,:] V_(j))/UtU[k,k]
with j the update iteration.
This function is made for being used repetively inside an
outer-loop alternating algorithm, for instance for computing nonnegative
matrix Factorization or tensor factorization.
Parameters
----------
UtM: r-by-n array
Pre-computed product of the transposed of U and M, used in the update rule
UtU: r-by-r array
Pre-computed product of the transposed of U and U, used in the update rule
in_V: r-by-n initialization matrix (mutable)
Initialized V array
By default, is initialized with one non-zero entry per column
corresponding to the closest column of U of the corresponding column of M.
maxiter: Postivie integer
Upper bound on the number of iterations
Default: 500
delta : float in [0,1]
early stop criterion, while err_k > delta*err_0. Set small for
almost exact nnls solution, or larger (e.g. 1e-2) for inner loops
of a PARAFAC computation.
Default: 10e-12
sparsity_coefficient: float or None
The coefficient controling the sparisty level in the objective function.
If set to None, the problem is solved unconstrained.
Default: None
Returns
-------
V: array
a r-by-n nonnegative matrix \approx argmin_{V >= 0} ||M-UV||_F^2
eps: float
number of loops authorized by the error stop criterion
cnt: integer
final number of update iteration performed
References
----------
[1]: N. Gillis and F. Glineur, Accelerated Multiplicative Updates and
Hierarchical ALS Algorithms for Nonnegative Matrix Factorization,
Neural Computation 24 (4): 1085-1105, 2012.
[2] J. Eggert, and E. Korner. "Sparse coding and NMF."
2004 IEEE International Joint Conference on Neural Networks
(IEEE Cat. No. 04CH37541). Vol. 4. IEEE, 2004.
"""

r, n = tl.shape(UtM)
if not in_V.size: # checks if V is empty
V = tl.solve(UtU, UtM)

V[V < 0] = 0
# Scaling
scale = tl.sum(UtM * V) / tl.sum(
UtU * tl.dot(V, tl.transpose(V)))
V = tl.dot(scale, V)
else:
V = in_V.copy()

eps0 = 0
cnt = 1
eps = 1

while eps >= delta * eps0 and cnt <= maxiter:
nodelta = 0
for k in range(r):

if UtU[k, k] != 0:
if sparsity_coefficient != None: # Modifying the objective function for sparsification
if tl.get_backend() == 'pytorch':
import torch
deltaV = torch.maximum(
(UtM[k, :] - UtU[k, :] @ V - sparsity_coefficient * tl.ones(n)) / UtU[k, k],
-V[k, :])
else:
deltaV = tl.max([(UtM[k, :] - UtU[k, :] @ V - sparsity_coefficient * tl.ones(n)) / UtU[k, k],
-V[k, :]], axis=0)
if tl.get_backend() == 'tensorflow':
import tensorflow as tf
V = tf.Variable(V, dtype='float')
V[k, :].assign(V[k, :] + deltaV)
else:
V[k, :] = V[k, :] + deltaV

else: # without sparsity

if tl.get_backend() == 'pytorch':
import torch
deltaV = torch.maximum((UtM[k, :] - tl.dot(UtU[k, :], V)) / UtU[k, k],
-V[k, :])
else:
deltaV = tl.max([(UtM[k, :] - tl.dot(UtU[k, :], V)) / UtU[k, k],
-V[k, :]], axis=0)
if tl.get_backend() == 'tensorflow':
import tensorflow as tf
V = tf.Variable(V, dtype='float')
V[k, :].assign(V[k, :] + deltaV)
else:
V[k, :] = V[k, :] + deltaV

nodelta = nodelta + tl.dot(deltaV, tl.transpose(deltaV))


if cnt == 1:
eps0 = nodelta

eps = nodelta
cnt += 1

return V, eps, cnt
if exact:
if eps < delta * eps0:
break
else:
if eps < delta * eps0 or iteration > 1 + 0.5 * rho:
break

return V, eps, iteration, rho

0 comments on commit dcd3ffc

Please sign in to comment.