In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from nmf_son.base import nmf_son
from nmf_son.utils import load_results
from sklearn.decomposition import NMF

np.random.seed(42)
np.set_printoptions(precision=3)

In [None]:
# def plot_matrices(W, H, img_size, comparison_idxs):
#     fig, axs = plt.subplots(2, len(comparison_idxs), figsize=(20, 10), sharey='row')
#
#     for i, idx in enumerate(comparison_idxs):
#         axs[0, i].plot(W[:, idx] / np.dot(W[:, idx], W[:, idx]))
#         axs[0, i].set_title(f'W({idx+1})')
#
#         h_idx_3d = H[idx, :].reshape(img_size, order='F')
#         axs[1, i].plot(h_idx_3d)
#         axs[1, i].set_title(f'H({idx+1})')
#
# def plot_images(H, img_size, comparison_idxs):
#     fig, axs = plt.subplots(1, len(comparison_idxs), figsize=(20, 10))
#
#     for i, idx in enumerate(comparison_idxs):
#         h_idx_3d = H[idx, :].reshape(img_size, order='F')
#
#         axs[i].imshow(h_idx_3d)
#         img = axs[i].imshow(h_idx_3d, cmap='gray')
#         divider = make_axes_locatable(axs[i])
#
#         cax = divider.append_axes('right', size='5%', pad=0.1)
#         fig.colorbar(img, cax=cax, orientation='vertical')
#         axs[i].set_title(f'H({idx+1})')


def plot_scores(fscores, gscores, lambda_vals, plot_title=None, filename=None):
    fscores = fscores[1:]
    gscores = gscores[1:]
    lambda_vals = lambda_vals[1:]
    total_score = fscores + lambda_vals * gscores
    fig, axs = plt.subplots(1, 2, figsize=(20, 6))
    if plot_title:
        fig.suptitle(plot_title, fontsize=25)

    axs[0].set_yscale('log')
    axs[0].plot(total_score, color='black', linewidth=3, label='$F(W, H)$')
    axs[0].plot(fscores, color='cyan', linewidth=1.5, label='$f(W, H)$')
    axs[0].plot(gscores, color='yellow', linewidth=1.5, label='$g(W)$')
    axs[0].set_xlabel('Iterations')
    axs[0].legend()

    fscores -= min(fscores)
    gscores -= min(gscores)
    total_score -= min(total_score)

    axs[1].set_yscale('log')
    axs[1].plot(total_score, color='black', linewidth=3, label='$F(W, H) - min(F(W, H))$')
    axs[1].plot(fscores, color='cyan', linewidth=1.5, label='$f(W, H) - min(f(W, H))$')
    axs[1].plot(gscores, color='yellow', linewidth=1.5, label='$g(W) - min($g(W)$)')
    axs[1].set_xlabel('Iterations')
    axs[1].legend()

    if filename:
        fig.savefig(filename)


def plot_combined_H(H, img_size, split=False, filename=None):
    H3d = H.reshape(-1, img_size[0], img_size[1], order='F')
    if split:
        half_rank = int(H.shape[0] / 2)
        large_mat = np.vstack([np.hstack(H3d[: half_rank, :, :]), np.hstack(H3d[half_rank: , :, :])])
    else:
        large_mat = np.hstack(H3d)
    plt.figure(figsize=(20, 10))
    plt.imshow(large_mat, cmap='gray')
    plt.colorbar()

    if filename:
        plt.savefig(filename)


def plot_W_mats(W, split=False, filename=None):
    rank = W.shape[1]
    wmax = np.max(W)

    if split:
        half_rank = int(rank / 2)
        fig, axs = plt.subplots(2, half_rank, figsize=(20, 10))
        i, j, cnt = 0, 0, 0
        while cnt < rank:
            axs[i, j].plot(W[:, cnt])
            axs[i, j].set_ylim([0, wmax])
            axs[i, j].set_title(f'$W_{cnt + 1}$')
            j += 1
            if cnt + 1 == half_rank:
                i = 1
                j = 0
            cnt += 1
    else:
        fig, axs = plt.subplots(1, rank, figsize=(20, 5))
        cnt = 0
        while cnt < rank:
            axs[cnt].plot(W[:, cnt])
            axs[cnt].set_ylim([0, wmax])
            axs[cnt].set_title(f'$W_{cnt + 1}$')
            cnt += 1
    if filename:
        fig.savefig(filename)

In [None]:
max_iter = 3000

## Urban Small

In [None]:
# X = np.load('../datasets/urban_small.npz')['X']
# rank = 6
# img_size = (20, 10)
# data = np.load(f'../saved_models/urban_small_r{rank}_ini.npz')
# ini_W = data['W']
# ini_H = data['H']
# save_filepath = '../saved_models/urban_small_tuning/r{}_l{}.npz'
#
# reg_vals = [1e-7, 1e-6, 1e-5, 0.0001, 0.001, 0.01, 0.1, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 3, 3.5, 4.5, 5, 8, 10, 20, 50, 100, 1000, 10000]

In [None]:
# for reg_val in reg_vals:
#     Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = nmf_son(X, ini_W.copy(), ini_H.copy(), _lambda=reg_val, itermax=max_iter)
#     with open(save_filepath.format(rank, reg_val), 'wb') as fout:
#         np.savez_compressed(fout, Wb=Wb, Hb=Hb, Wl=Wl, Hl=Hl, fscores=fscores, gscores=gscores, lambda_vals=lambda_vals)
#         print(reg_val, 'done')

In [None]:
# true_rank = 2
# m, n = X.shape
# model = NMF(n_components=true_rank, init='random', random_state=42, tol=1e-5, max_iter=3000)
# vanillaW = model.fit_transform(X=X)
# vanillaH = model.components_
# plot_W_mats(vanillaW)
# plot_combined_H(vanillaH, img_size)

In [None]:
# model = NMF(n_components=rank, init='custom', random_state=42, tol=1e-5, max_iter=3000)
# vanillaW = model.fit_transform(X=X, W=ini_W.copy(), H=ini_H.copy())
# vanillaH = model.components_
# plot_combined_H(vanillaH, img_size)
# plot_W_mats(vanillaW)

In [None]:
# for reg_val in reg_vals:
#     Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = load_results(save_filepath.format(rank, reg_val))
#     plot_scores(fscores, gscores, lambda_vals, f'$\lambda = {reg_val}$')
#     plot_W_mats(Wl)
#     plot_combined_H(Hl, img_size)

## Jasper

In [None]:
X = np.load('../datasets/jasper_full.npz')['X']
rank = 8
img_size = (100, 100)
data = np.load(f'../saved_models/jasper_full_r{rank}_ini.npz')
ini_W = data['W']
ini_H = data['H']
save_filepath = '../saved_models/jasper_tuning/r{}_l{}.npz'

reg_vals = [1e-8, 1e-7, 1e-6, 1e-5, 0.0001, 0.001, 0.01, 0.1, 0.5, 1, 2, 5, 10, 20, 50, 100, 1000]

In [None]:
# for reg_val in reg_vals:
#     Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = nmf_son(X, ini_W.copy(), ini_H.copy(), _lambda=reg_val, itermax=max_iter)
#     with open(save_filepath.format(rank, reg_val), 'wb') as fout:
#         np.savez_compressed(fout, Wb=Wb, Hb=Hb, Wl=Wl, Hl=Hl, fscores=fscores, gscores=gscores, lambda_vals=lambda_vals)
#         print(reg_val, 'done')

In [None]:
true_rank = 4
m, n = X.shape
model = NMF(n_components=true_rank, init='random', random_state=42, tol=1e-5, max_iter=3000)
vanillaW = model.fit_transform(X=X)
vanillaH = model.components_
plot_W_mats(vanillaW)
plot_combined_H(vanillaH, img_size)

In [None]:
model = NMF(n_components=rank, init='custom', random_state=42, tol=1e-5, max_iter=3000)
vanillaW = model.fit_transform(X=X, W=ini_W.copy(), H=ini_H.copy())
vanillaH = model.components_
plot_W_mats(vanillaW, split=True)
plot_combined_H(vanillaH, img_size, split=True)

In [None]:
for reg_val in reg_vals:
    Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = load_results(save_filepath.format(rank, reg_val))
    plot_scores(fscores, gscores, lambda_vals, f'$\lambda = {reg_val}$')
    plot_W_mats(Wl, split=True)
    plot_combined_H(Hl, img_size, split=True)

## Urban

In [None]:
# X = np.load('../datasets/urban_full.npz')['X']
# rank = 10
# img_size = (307, 307)
# data = np.load(f'../saved_models/urban/urban_full_r{rank}_ini.npz')
# ini_W = data['W']
# ini_H = data['H']
# save_filepath = '../saved_models/urban_tuning/r{}_l{}.npz'
#
# reg_vals = [0.01, 0.1, 1, 10, 100]

In [None]:
# for reg_val in reg_vals:
#     Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = nmf_son(X, ini_W.copy(), ini_H.copy(), _lambda=reg_val, itermax=max_iter)
#     with open(save_filepath.format(rank, reg_val), 'wb') as fout:
#         np.savez_compressed(fout, Wb=Wb, Hb=Hb, Wl=Wl, Hl=Hl, fscores=fscores, gscores=gscores, lambda_vals=lambda_vals)
#         print(reg_val, 'done')

In [None]:
# true_rank = 6
# m, n = X.shape
# model = NMF(n_components=true_rank, init='random', random_state=42, tol=1e-5, max_iter=3000)
# vanillaW = model.fit_transform(X=X)
# vanillaH = model.components_
# plot_W_mats(vanillaW)
# plot_combined_H(vanillaH, img_size)

In [None]:
# model = NMF(n_components=rank, init='custom', random_state=42, tol=1e-5, max_iter=3000)
# vanillaW = model.fit_transform(X=X, W=ini_W.copy(), H=ini_H.copy())
# vanillaH = model.components_
# plot_W_mats(vanillaW)
# plot_combined_H(vanillaH, img_size)

In [None]:
# for reg_val in reg_vals:
#     Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = load_results(save_filepath.format(rank, reg_val))
#     plot_scores(fscores, gscores, lambda_vals, f'$\lambda = {reg_val}$')
#     plot_W_mats(Wl, split=True)
#     plot_combined_H(Hl, img_size, split=True)