In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from nmf_son.base import nmf_son
from nmf_son.utils import calculate_gscore, normalized_similarity, load_results
from sklearn.decomposition import NMF

np.random.seed(42)
np.set_printoptions(precision=3)

In [None]:
def plot_scores(fscores, gscores, lambda_vals, reg_val):
    def plot_ax(ax, f, g, total):
        ax.plot(total, color='black', linewidth=3, label='total')
        ax.plot(f, label='f')
        ax.plot(g, label='g')
        ax.legend()

    total_score = fscores + lambda_vals * gscores
    fig, axs = plt.subplots(1, 2, figsize=(20, 6))
    fig.suptitle(f'lambda = {reg_val}', fontsize=25)

    axs[0].set_yscale('log')
    plot_ax(axs[0], fscores, gscores, total_score)
    axs[0].set_title('log scale', fontsize=16)

    axs[1].set_yscale('log')
    plot_ax(axs[1], fscores - fscores[-1], gscores - gscores[-1], total_score - total_score[-1])
    axs[1].set_title('log(score - score*)', fontsize=16)

def plot_matrices(W, H, img_size, comparison_idxs, share_y=False):
    fig, axs = plt.subplots(2, len(comparison_idxs), figsize=(20, 10), sharey='row' if share_y else False)

    for i, idx in enumerate(comparison_idxs):
        axs[0, i].plot(W[:, idx] / np.dot(W[:, idx], W[:, idx]))
        axs[0, i].set_title(f'W({idx+1})')

        h_idx_3d = H[idx, :].reshape(img_size, order='F')
        axs[1, i].plot(h_idx_3d)
        axs[1, i].set_title(f'H({idx+1})')

def plot_images(H, img_size, comparison_idxs):
    fig, axs = plt.subplots(1, len(comparison_idxs), figsize=(20, 10))

    for i, idx in enumerate(comparison_idxs):
        h_idx_3d = H[idx, :].reshape(img_size, order='F')

        axs[i].imshow(h_idx_3d)
        img = axs[i].imshow(h_idx_3d, cmap='gray')
        divider = make_axes_locatable(axs[i])

        cax = divider.append_axes('right', size='5%', pad=0.1)
        fig.colorbar(img, cax=cax, orientation='vertical')
        axs[i].set_title(f'H({idx+1})')

In [None]:
X = np.load('../datasets/urban_small.npz')['X']

In [None]:
data = np.load('../saved_models/urban/urban_small_r6_ini.npz')
ini_W = data['W']
ini_H = data['H']

In [None]:
rank = 6
save_filepath = '../saved_models/urban/lambda_tuning/r{}_l{}_it{}.npz'

In [None]:
# reg_vals = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.25, 0.5, 0.75, 1, 2, 5, 10, 25, 50, 100, 200, 500, 1000]
# reg_vals = [1.5, 1.8, 1.85, 1.9, 1.95, 1.97, 2, 2.03, 2.05, 2.1, 2.15, 2.2, 2.5, 3, 4]
reg_vals = [1e-8, 2e-8, 4e-8, 6e-7, 1e-7, 2e-7, 4e-7, 7e-7, 1e-6, 2e-6, 5e-6, 1e-5, 1.5e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 8e-4, 1e-3]
it_ckpts = [200, 500, 1000, 2000, 3000, 4000, 5000]

In [None]:
reg_file_dict = {x: None for x in reg_vals}

In [None]:
for k in range(len(it_ckpts)):
    for reg_val in reg_vals:
        if reg_file_dict[reg_val]:
            Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = load_results(reg_file_dict[reg_val])

            Wb, Hb, Wl, Hl, new_fscores, new_gscores, new_lambda_vals = nmf_son(X, Wl, Hl, _lambda=reg_val, itermax=it_ckpts[k] - it_ckpts[k-1])

            fscores = np.append(fscores[:-1], new_fscores)
            gscores = np.append(gscores[:-1], new_gscores)
            lambda_vals = np.append(lambda_vals[:-1], new_lambda_vals)
        else:
            Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = nmf_son(X, ini_W.copy(), ini_H.copy(), _lambda=reg_val, itermax=it_ckpts[k])
        tmp_filepath = save_filepath.format(rank, str(reg_val).replace('.', '_'), it_ckpts[k])
        with open(tmp_filepath, 'wb') as fout:
            np.savez_compressed(fout, Wb=Wb, Hb=Hb, Wl=Wl, Hl=Hl, fscores=fscores, gscores=gscores, lambda_vals=lambda_vals)
        reg_file_dict[reg_val] = tmp_filepath
        print(reg_val, it_ckpts[k], 'complete')

In [None]:
# comp_idxs = range(rank)
# img_size = (20, 10)
# iters = 1000
#
# for reg_val in reg_vals:
#     Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = load_results(save_filepath.format(rank, str(reg_val).replace('.', '_'), iters))
#     plot_scores(fscores, gscores, lambda_vals, reg_val)
#     plot_matrices(Wl, Hl, img_size, comp_idxs, share_y=True)
#     plot_images(Hl, img_size, comp_idxs)