In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from nmf_son.base import nmf_son
from nmf_son.utils import load_results

np.random.seed(42)
np.set_printoptions(precision=3)

In [None]:
def plot_scores(fscores, gscores, lambda_vals, reg_val):
    def plot_ax(ax, f, g, total):
        ax.plot(total, color='black', linewidth=3, label='total')
        ax.plot(f, color='cyan', linewidth=1.5, label='f')
        ax.plot(g, color='yellow', linewidth=1.5, label='g')
        ax.legend()

    total_score = fscores + lambda_vals * gscores
    fig, axs = plt.subplots(1, 2, figsize=(20, 6))
    fig.suptitle(f'lambda = {reg_val}', fontsize=25)

    axs[0].set_yscale('log')
    plot_ax(axs[0], fscores, gscores, total_score)
    axs[0].set_title('log scale', fontsize=16)

    axs[1].set_yscale('log')
    plot_ax(axs[1], fscores - fscores[-1], gscores - gscores[-1], total_score - total_score[-1])
    axs[1].set_title('log(score - score*)', fontsize=16)

def plot_matrices(W, H, img_size, comparison_idxs):
    fig, axs = plt.subplots(2, len(comparison_idxs), figsize=(20, 10), sharey='row')

    for i, idx in enumerate(comparison_idxs):
        axs[0, i].plot(W[:, idx] / np.dot(W[:, idx], W[:, idx]))
        axs[0, i].set_title(f'W({idx+1})')

        h_idx_3d = H[idx, :].reshape(img_size, order='F')
        axs[1, i].plot(h_idx_3d)
        axs[1, i].set_title(f'H({idx+1})')

def plot_images(H, img_size, comparison_idxs):
    fig, axs = plt.subplots(1, len(comparison_idxs), figsize=(20, 10))

    for i, idx in enumerate(comparison_idxs):
        h_idx_3d = H[idx, :].reshape(img_size, order='F')

        axs[i].imshow(h_idx_3d)
        img = axs[i].imshow(h_idx_3d, cmap='gray')
        divider = make_axes_locatable(axs[i])

        cax = divider.append_axes('right', size='5%', pad=0.1)
        fig.colorbar(img, cax=cax, orientation='vertical')
        axs[i].set_title(f'H({idx+1})')

In [None]:
X = np.load('../datasets/urban_small.npz')['X']

In [None]:
data = np.load('../saved_models/urban/urban_small_r6_ini.npz')
ini_W = data['W']
ini_H = data['H']

In [None]:
rank = 6
save_filepath = '../saved_models/urban/tol_testing/r{}_l{}_tol{}.npz'

In [None]:
max_iters = 5
reg_vals = [9e-5, 2]
tols = [1e-3, 1e-4, 1e-5, 1e-6]

In [None]:
for tol in tols:
    for reg_val in reg_vals:
        Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = nmf_son(X, ini_W.copy(), ini_H.copy(), _lambda=reg_val, itermax=5000, early_stop=tol)
        tmp_filepath = save_filepath.format(rank, str(reg_val).replace('.', '_'), str(tol).replace('.', '_'))
        with open(tmp_filepath, 'wb') as fout:
            np.savez_compressed(fout, Wb=Wb, Hb=Hb, Wl=Wl, Hl=Hl, fscores=fscores, gscores=gscores, lambda_vals=lambda_vals)

        print(reg_val, tol, 'complete')

In [None]:
# comp_idxs = range(rank)
# img_size = (20, 10)
#
# for reg_val in reg_vals:
#     for tol in tols:
#         Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = load_results(save_filepath.format(rank, str(reg_val).replace('.', '_'), str(tol).replace('.', '_')))
#         plot_scores(fscores, gscores, lambda_vals, reg_val)
#         plot_matrices(Wl, Hl, img_size, comp_idxs)
#         plot_images(Hl, img_size, comp_idxs)