In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sonnmf.main import sonnmf

In [None]:
def save_results(filepath, W, H, fscores, gscores, hscores, total_scores):
    with open(filepath, 'wb') as fout:
        np.savez_compressed(fout, W=W, H=H, fscores=fscores, gscores=gscores, hscores=hscores, total_scores=total_scores)

def load_results(filepath):
    data = np.load(filepath)
    return data['W'], data['H'], data['fscores'], data['gscores'], data['hscores'], data['total_scores']

def plot_3d(X, Wt, W, filepath):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(X[0, :], X[1, :], X[2, :], c='lightblue', marker='o')
    ax.scatter(Wt[0, :], Wt[1, :], Wt[2, :], c='red', marker='o', s=144)
    ax.scatter(W[0, :], W[1, :], W[2, :], c='black', marker='x', s=144)
    ax.set_xlabel('X1')
    ax.set_ylabel('X2')
    ax.set_zlabel('X3')
    ax.legend(['Data points', 'True W', 'Estimated W'])
    ax.grid(True)
    if filepath:
        plt.savefig(filepath)
    else:
        plt.show()

In [None]:
data_filepath = '../datasets/jasper_small_2.npz'
ini_filepath = '../saved_models/jasper_small_2/r{}_ini.npz'
save_filepath = '../saved_models/jasper_small_2/r{}_l{}_g{}_it{}.npz'
image_filepath = '../images/jasper_small_2/r{}_l{}_g{}_it{}.jpg'

In [None]:
# data = np.load(data_filepath)
# M = data['M']
# Wt = data['W_true']
# Ht = data['H_true']

M = np.load(data_filepath)['X']

m, n = M.shape

In [None]:
max_iters = 1000
r = n

In [None]:
if os.path.exists(ini_filepath.format(r)):
    data = np.load(ini_filepath.format(r))
    ini_W = data['ini_W']
    ini_H = data['ini_H']
else:
    ini_W = np.random.rand(m, r)
    ini_H = np.random.rand(r, n)
    with open(ini_filepath.format(r), 'wb') as fout:
        np.savez_compressed(fout, ini_W=ini_W, ini_H=ini_H)

In [None]:
lams = [0.001, 1000, 0.1, 1, 10]
gammas = [0.001, 1000, 0.1, 1, 10]
it_checkpoints = [1000, 2000, 5000, 10000]

for i, iters in enumerate(it_checkpoints):
    for g in gammas:
        for l in lams:
            if i == 0:
                W, H, fscores, gscores, hscores, total_scores = sonnmf(M, ini_W.copy(), ini_H.copy(), lam=l, gamma=g, itermax=iters, W_update_iters=10, early_stop=True, verbose=False)
                save_results(save_filepath.format(r, l , g, iters), W, H, fscores, gscores, hscores, total_scores)

            else:
                old_W, old_H, old_fscores, old_gscores, old_hscores, old_total_scores = load_results(save_filepath.format(r, l , g, it_checkpoints[i-1]))
                W, H, fscores, gscores, hscores, total_scores = sonnmf(M, old_W.copy(), old_H.copy(), lam=l, gamma=g, itermax=iters - it_checkpoints[i-1], W_update_iters=10, early_stop=True, verbose=False)
                fscores = np.concatenate((old_fscores[:-1], fscores))
                gscores = np.concatenate((old_gscores[:-1], gscores))
                hscores = np.concatenate((old_hscores[:-1], hscores))
                total_scores = np.concatenate((old_total_scores[:-1], total_scores))
                save_results(save_filepath.format(r, l , g, iters), W, H, fscores, gscores, hscores, total_scores)
            print(iters, g, l)

In [None]:
# lam = 0.01
# gamma = 0.5

In [None]:
# import time
# import csv
#
# W_update_iters_list = [1, 5, 10, 50, 100]
# maxit = 1000
#
# with open('results.csv', 'w', newline='') as csvfile:
#         fieldnames = ['W_update_iters', 'time', 'lam', 'gamma', 'itermax', 'H_update_iters']
#         writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
#         writer.writeheader()

In [None]:
# tscores = list()
# for wits in W_update_iters_list:
#     start_time = time.time()
#     W, H, fscores, gscores, hscores, total_scores = sonnmf(M, ini_W.copy(), ini_H.copy(), lam, gamma, itermax=maxit, H_update_iters=1, W_update_iters=wits, early_stop=True, verbose=False)
#     elapsed_time = time.time() - start_time
#
#     plot_3d(M, Wt, W, f'W{wits}H1IT{maxit}.jpg')
#
#     # Save the current results to a CSV file
#     with open('results.csv', 'a', newline='') as csvfile:
#         writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
#         writer.writerow({'W_update_iters': wits, 'time': elapsed_time, 'lam': lam, 'gamma': gamma, 'itermax': maxit, 'H_update_iters': 1})
#
#     tscores.append(total_scores)

In [None]:
# for i, wits in enumerate(W_update_iters_list):
#     plt.plot(tscores[i], label=f'W_its={wits}')
# plt.yscale('log')
# plt.legend()

In [None]:
# for i, wits in enumerate(W_update_iters_list):
#     plt.plot(tscores[i][-199:] - tscores[i][-200:-1], label=f'W_its={wits}')
# # plt.yscale('log')
# plt.title('$F_{shifted} - F$')
# plt.xlabel('iterations')
# plt.legend()
# # plt.savefig('scores_for_w_its.jpg')