Skip to content

Commit

Permalink
keep the same memory for np.dot(W, H)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDLT committed Nov 15, 2017
1 parent 1145083 commit 7ccc2be
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions sklearn/decomposition/nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def _beta_divergence(X, W, H, beta, square_root=False):
return res


def _special_dot_X(W, H, X):
def _special_dot_X(W, H, X, out=None):
"""Computes np.dot(W, H) in a special way:
- If X is sparse, np.dot(W, H) is computed only where X is non zero,
Expand All @@ -247,11 +247,11 @@ def _special_dot_X(W, H, X):
WH = sp.coo_matrix((dot_vals, (ii, jj)), shape=X.shape)
return WH.tocsr()
elif isinstance(X, np.ma.masked_array):
WH = np.ma.masked_array(np.dot(W, H), mask=X.mask)
WH.unshare_mask()
WH = np.ma.masked_array(np.dot(W, H, out=out), mask=X.mask)
WH._sharedmask = False
return WH
else:
return np.dot(W, H)
return np.dot(W, H, out=out)


def _safe_dot(X, Ht):
Expand Down Expand Up @@ -608,7 +608,8 @@ def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0,


def _multiplicative_update_w(X, W, H, beta_loss, l1_reg_W, l2_reg_W, gamma,
H_sum=None, HHt=None, XHt=None, update_H=True):
WH, H_sum=None, HHt=None, XHt=None,
update_H=True):
"""update W in Multiplicative Update NMF"""
X_mask = X.mask if isinstance(X, np.ma.masked_array) else False

Expand All @@ -629,13 +630,13 @@ def _multiplicative_update_w(X, W, H, beta_loss, l1_reg_W, l2_reg_W, gamma,
HHt = np.dot(H, H.T)
denominator = np.dot(W, HHt)
else:
WH = _special_dot_X(W, H, X)
WH = _special_dot_X(W, H, X, out=WH)
denominator = _safe_dot(WH, H.T)

else:
# Numerator
# if X is sparse, compute WH only where X is non zero
WH_safe_X = _special_dot_X(W, H, X)
WH_safe_X = _special_dot_X(W, H, X, out=WH)
if sp.issparse(X):
WH_safe_X_data = WH_safe_X.data
X_data = X.data
Expand Down Expand Up @@ -718,7 +719,8 @@ def _multiplicative_update_w(X, W, H, beta_loss, l1_reg_W, l2_reg_W, gamma,
return delta_W, H_sum, HHt, XHt


def _multiplicative_update_h(X, W, H, beta_loss, l1_reg_H, l2_reg_H, gamma):
def _multiplicative_update_h(X, W, H, beta_loss, l1_reg_H, l2_reg_H, gamma,
WH):
"""update H in Multiplicative Update NMF"""
X_mask = X.mask if isinstance(X, np.ma.masked_array) else False

Expand All @@ -728,12 +730,12 @@ def _multiplicative_update_h(X, W, H, beta_loss, l1_reg_H, l2_reg_H, gamma):
denominator = np.dot(np.dot(W.T, W), H)
else:
numerator = _safe_dot(W.T, X)
WH = _special_dot_X(W, H, X)
WH = _special_dot_X(W, H, X, out=WH)
denominator = _safe_dot(W.T, WH)

else:
# Numerator
WH_safe_X = _special_dot_X(W, H, X)
WH_safe_X = _special_dot_X(W, H, X, out=WH)
if sp.issparse(X):
WH_safe_X_data = WH_safe_X.data
X_data = X.data
Expand Down Expand Up @@ -895,6 +897,9 @@ def _fit_multiplicative_update(X, W, H, beta_loss='frobenius',
else:
gamma = 1.

# allocate memory for the product np.dot(W, H)
WH = np.empty(X.shape) if not sp.issparse(X) else None

# transform in a numpy masked array if X contains missing (NaN) values
if not sp.issparse(X):
X_mask = np.isnan(X)
Expand All @@ -910,7 +915,7 @@ def _fit_multiplicative_update(X, W, H, beta_loss='frobenius',
# update W
# H_sum, HHt and XHt are saved and reused if not update_H
delta_W, H_sum, HHt, XHt = _multiplicative_update_w(
X, W, H, beta_loss, l1_reg_W, l2_reg_W, gamma,
X, W, H, beta_loss, l1_reg_W, l2_reg_W, gamma, WH,
H_sum, HHt, XHt, update_H)
W *= delta_W

Expand All @@ -921,7 +926,7 @@ def _fit_multiplicative_update(X, W, H, beta_loss='frobenius',
# update H
if update_H:
delta_H = _multiplicative_update_h(X, W, H, beta_loss, l1_reg_H,
l2_reg_H, gamma)
l2_reg_H, gamma, WH)
H *= delta_H

# These values will be recomputed since H changed
Expand Down

0 comments on commit 7ccc2be

Please sign in to comment.