In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sonnmf.main import sonnmf

In [None]:
def save_results(filepath, W, H, fscores, gscores, hscores, total_scores):
    with open(filepath, 'wb') as fout:
        np.savez_compressed(fout, W=W, H=H, fscores=fscores, gscores=gscores, hscores=hscores, total_scores=total_scores)

def load_results(filepath):
    data = np.load(filepath)
    return data['W'], data['H'], data['fscores'], data['gscores'], data['hscores'], data['total_scores']

def plot_3d(X, Wt, W, filepath):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(X[0, :], X[1, :], X[2, :], c='lightblue', marker='o')
    ax.scatter(Wt[0, :], Wt[1, :], Wt[2, :], c='red', marker='o', s=144)
    ax.scatter(W[0, :], W[1, :], W[2, :], c='black', marker='x', s=144)
    ax.set_xlabel('X1')
    ax.set_ylabel('X2')
    ax.set_zlabel('X3')
    ax.legend(['Data points', 'True W', 'Estimated W'])
    ax.grid(True)
    if filepath:
        plt.savefig(filepath)
    else:
        plt.show()

In [None]:
data_filepath = '../datasets/jasper_small_2.npz'
ini_filepath = '../saved_models/jasper_small_2/r{}_ini.npz'
save_filepath = '../saved_models/jasper_small_2/r{}_l{}_g{}.npz'
image_filepath = '../images/jasper_small_2/r{}_l{}_g{}.jpg'

In [None]:
M = np.load(data_filepath)['X']

m, n = M.shape

In [None]:
max_iters = 10000
r = n

In [None]:
if os.path.exists(ini_filepath.format(r)):
    data = np.load(ini_filepath.format(r))
    ini_W = data['ini_W']
    ini_H = data['ini_H']
else:
    ini_W = np.random.rand(m, r)
    ini_H = np.random.rand(r, n)
    with open(ini_filepath, 'wb') as fout:
        np.savez_compressed(fout, ini_W=ini_W, ini_H=ini_H)

In [None]:
lams = [0.001, 1000, 0.1, 1, 10]
gammas = [0.001, 1000, 0.1, 1, 10]

for g in gammas:
    for l in lams:
        W, H, fscores, gscores, hscores, total_scores = sonnmf(M, ini_W.copy(), ini_H.copy(), lam=l, gamma=g, itermax=max_iters, early_stop=True, verbose=False)
        save_results(save_filepath.format(r, l , g), W, H, fscores, gscores, hscores, total_scores)
        print(g, l)