Reference from https://github.com/jasoncoding13/nmf/blob/master/nmf/nmfs.py

In [1]:
import numpy as np
from common_alg import or_x_set, or_y_set, guas_set, salt_set, block_set

In [2]:
def random_init(n_components, X):
    n_features, n_samples = X.shape
    avg = np.sqrt(X.mean() / n_components)
    rng = np.random.RandomState(13)
    D = avg * rng.randn(n_features, n_components)
    R = avg * rng.randn(n_components, n_samples)
    np.abs(D, out=D)
    np.abs(R, out=R)
    return D, R

In [3]:
class BaseNMF():
    
    def __init__(self,n_components,init='random',tol=1e-4,max_iter=200,skip_iter=10):
        self.n_components = n_components
        self.init = init
        self.tol = tol
        self.max_iter = max_iter
        self.skip_iter = skip_iter
    
    def _init(self, X):
        D, R = random_init(self.n_components, X)
        return D,R
  
    def fit(self, X):
        D, R = self._init(X)
        losses = [self._compute_loss(X, D, R)]
        for iter_ in range(self.max_iter):
            D, R = self._update(X, D, R)
            # check converagence
            if iter_ % self.skip_iter == 0:
                losses.append(self._compute_loss(X, D, R))
                criterion = abs(losses[-1] - losses[-2]) / losses[-2]
                if criterion < 1e-3:
                    break
        return D, R
    def _compute_loss(self, X, D, R):
        return None

    def _update_weight(self, X, D, R):
        return None

    def _update(self, X, D, R):
        # update W
        W = self._update_weight(X, D, R)
        # update D
        denominator_D = (W * D.dot(R)).dot(R.T)
        denominator_D[denominator_D == 0] = np.finfo(np.float32).eps
        D = D * ((W * X).dot(R.T)) / denominator_D
        # update R
        denominator_R = D.T.dot(W * D.dot(R))
        denominator_R[denominator_R == 0] = np.finfo(np.float32).eps
        R = R * (D.T.dot(W * X)) / denominator_R
        return D, R

In [4]:
class L1NMF(BaseNMF):
    """L1-NMF
    """
    def _compute_loss(self, X, D, R):
        return np.sum(np.abs(X - D.dot(R)))
        

    def _update_weight(self, X, D, R):
        eps = X.var() / D.shape[1]
        return 1 / (np.sqrt(np.square(X - D.dot(R))) + eps ** 2)


In [5]:
def RRE(clean,R,D,name):
    rre = []
    for i in range(5):
        R_n = R[i]
        D_n = D[i]
        error = np.linalg.norm(clean-D_n.dot(R_n))/np.linalg.norm(clean)
        rre.append(error)
    print(f' the rre of {name} in L1-Norm Base NMF is {sum(rre)/len(rre)}')

In [6]:
def get_result(clean_set,data_set,name):
    mdoel = L1NMF(n_components = 40)
    d_s=[]
    r_s=[]
    for i in range(5):
        D,R = mdoel.fit(data_set[i])
        d_s.append(D)
        r_s.append(R)
    RRE(clean_set,r_s,d_s,name)

In [7]:
get_result(or_x_set,salt_set,'salt_set')
get_result(or_x_set,guas_set,'guas_set')
get_result(or_x_set,block_set,'block_set')

 the rre of salt_set in L1-Norm Base NMF is 0.5081227713547113
 the rre of guas_set in L1-Norm Base NMF is 0.5170950112939028
 the rre of block_set in L1-Norm Base NMF is 0.5151392202148571


In [8]:
class L21NMF(BaseNMF):
    def _compute_loss(self, X, D, R):
        return np.sum(np.sqrt(np.sum(np.square(X - D.dot(R)), axis=0)))

    def _update_weight(self, X, D, R):
        return 1 / np.sqrt(np.sum(np.square(X - D.dot(R)), axis=0))


In [9]:
def RRE_21(clean,R,D,name):
    rre = []
    for i in range(5):
        R_n = R[i]
        D_n = D[i]
        error = np.linalg.norm(clean-D_n.dot(R_n))/np.linalg.norm(clean)
        rre.append(error)
    print(f' the rre of {name} in L21-Norm Base NMF is {sum(rre)/len(rre)}')

In [10]:
def get_result_21(clean_set,data_set,name):
    mdoel = L21NMF(n_components = 40)
    d_s=[]
    r_s=[]
    for i in range(5):
        D,R = mdoel.fit(data_set[i])
        d_s.append(D)
        r_s.append(R)
    RRE_21(clean_set,r_s,d_s,name)

In [11]:
get_result_21(or_x_set,salt_set,'salt_set')
get_result_21(or_x_set,guas_set,'guas_set')
get_result_21(or_x_set,block_set,'block_set')

 the rre of salt_set in L21-Norm Base NMF is 0.5077859696839705
 the rre of guas_set in L21-Norm Base NMF is 0.516946246856734
 the rre of block_set in L21-Norm Base NMF is 0.5148012158799901
