In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from nmf_son.base import nmf_son
from nmf_son.utils import load_results
from sklearn.decomposition import NMF

np.random.seed(42)
np.set_printoptions(precision=3)

In [None]:
def plot_scores(fscores, gscores, lambda_vals, reg_val):
    def plot_ax(ax, f, g, total):
        ax.plot(total, color='black', linewidth=3, label='total')
        ax.plot(f, color='cyan', linewidth=1.5, label='f')
        ax.plot(g, color='yellow', linewidth=1.5, label='g')
        ax.legend()

    total_score = fscores + lambda_vals * gscores
    fig, axs = plt.subplots(1, 2, figsize=(20, 6))
    fig.suptitle(f'lambda = {reg_val}', fontsize=25)

    axs[0].set_yscale('log')
    plot_ax(axs[0], fscores, gscores, total_score)
    axs[0].set_title('log scale', fontsize=16)

    axs[1].set_yscale('log')
    plot_ax(axs[1], fscores - fscores[-1], gscores - gscores[-1], total_score - total_score[-1])
    axs[1].set_title('log(score - score*)', fontsize=16)

def plot_matrices(W, H, img_size, comparison_idxs):
    fig, axs = plt.subplots(2, len(comparison_idxs), figsize=(20, 10), sharey='row')

    for i, idx in enumerate(comparison_idxs):
        axs[0, i].plot(W[:, idx] / np.dot(W[:, idx], W[:, idx]))
        axs[0, i].set_title(f'W({idx+1})')

        h_idx_3d = H[idx, :].reshape(img_size, order='F')
        axs[1, i].plot(h_idx_3d)
        axs[1, i].set_title(f'H({idx+1})')

def plot_images(H, img_size, comparison_idxs):
    fig, axs = plt.subplots(1, len(comparison_idxs), figsize=(20, 10))

    for i, idx in enumerate(comparison_idxs):
        h_idx_3d = H[idx, :].reshape(img_size, order='F')

        axs[i].imshow(h_idx_3d)
        img = axs[i].imshow(h_idx_3d, cmap='gray')
        divider = make_axes_locatable(axs[i])

        cax = divider.append_axes('right', size='5%', pad=0.1)
        fig.colorbar(img, cax=cax, orientation='vertical')
        axs[i].set_title(f'H({idx+1})')

In [None]:
X = np.load('../datasets/jasper_full.npz')['X']
rank = 8

data = np.load(f'../saved_models/jasper/jasper_full_r{rank}_ini.npz')
ini_W = data['W']
ini_H = data['H']

# reg_vals = [1e-6, 1e-5, 1e-4, 0.001, 0.001, 0.01, 0.1, 0.5, 1, 2, 5, 10]
# reg_vals = [1e-8, 1e-7, 8, 10, 15, 20, 50, 80, 100, 1000, 1e4]
reg_vals = [1e-12, 1e-10, 25, 30, 35, 40, 45]
save_filepath = '../saved_models/jasper/lambda_tuning/r{}_l{}.npz'

In [None]:
# X = np.load('../datasets/urban_full.npz')['X']
# rank = 10
#
# data = np.load(f'../saved_models/urban/urban_full_r{rank}_ini.npz')
# ini_W = data['W']
# ini_H = data['H']
#
# # reg_vals = [0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.25, 0.5, 0.75, 1, 2, 5, 10, 25, 50, 100, 200, 500, 1000]
# # reg_vals = [1e-8, 2e-8, 4e-8, 6e-7, 1e-7, 2e-7, 4e-7, 7e-7, 1e-6, 2e-6, 5e-6, 1e-5, 1.5e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 8e-4, 1e-3]
# # reg_vals = [5e-5, 5.5e-5, 6e-5, 6.5e-5, 7e-5, 7.5e-5, 8e-5, 8.5e-5, 9e-5, 9.5e-5, 1e-4]
# # reg_vals = [1, 1.25, 1.5, 1.75, 2, 2.25, 2.5, 2.75, 3, 3.25, 3.5, 3.75, 4, 4.25, 4.5, 4.75, 5]
# # reg_vals = [0.01, 0.1, 0.5, 1, 2, 5, 8, 10, 100]
# reg_vals = [100, 10, 8, 5]
# save_filepath = '../saved_models/urban/lambda_tuning/r{}_l{}.npz'

In [None]:
max_iter = 3000

In [None]:
# for reg_val in reg_vals:
#     Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = nmf_son(X, ini_W.copy(), ini_H.copy(), _lambda=reg_val, itermax=max_iter)
#     with open(save_filepath.format(rank, reg_val), 'wb') as fout:
#         np.savez_compressed(fout, Wb=Wb, Hb=Hb, Wl=Wl, Hl=Hl, fscores=fscores, gscores=gscores, lambda_vals=lambda_vals)

In [None]:
comp_idxs = range(rank)
img_size = (100, 100)

In [None]:
# model = NMF(n_components=rank, init='random', random_state=42, max_iter=max_iter)
# vanillaW = model.fit_transform(X)
# vanillaH = model.components_
#
# plot_matrices(smaller_vanillaW, smaller_vanillaH, img_size, range(rank))

In [None]:
for reg_val in reg_vals:
    Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = load_results(f'../saved_models/jasper/lambda_tuning/r8_l{reg_val}.npz')
    plot_scores(fscores, gscores, lambda_vals, reg_val)
    plot_matrices(Wl, Hl, img_size, comp_idxs)
    plot_images(Hl, img_size, comp_idxs)

In [None]:
# import os
#
# for k in range(len(it_ckpts)):
#     os.rename(save_filepath.format(rank, str(reg_vals[5]).replace('.', '_'), it_ckpts[k]),
#               save_filepath.format(rank, '7.5e-05'.replace('.', '_'), it_ckpts[k]))

In [None]:
# total_scores = fscores + np.r_[np.NaN, lambda_vals[1:]] * gscores
# sum((abs(total_scores[1:] - total_scores[:-1]) / total_scores[:-1]) >= 1e-5)
# abs(total_scores[1:] - total_scores[:-1]) / total_scores[:-1]