### Notes
- base implementation changed from alternating columns to matrix


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from nmf_methods.nmf_son.base import nmf_son as nmf_son_base
from nmf_methods.nmf_son.new import nmf_son_constrained_H
from nmf_methods.nmf_son.new import nmf_son_new
from nmf_methods.nmf_son.utils import save_results, load_results

np.random.seed(42)
np.set_printoptions(precision=3)

## Toy Example

In [None]:
def create_toy_ex(n):
    W = np.random.rand(2, 3)
    H = np.ones((3, n))
    thres = 0.88
    id = np.argwhere(np.sum(H >= thres, axis=0))
    while id.any():
        id = np.argwhere(np.sum(H >= thres, axis=0))
        H[:, id.flatten()] = np.random.dirichlet((0.33, 0.33, 0.33), len(id)).T

    M = W @ H
    return M, W, H

def find_min(arrs):
    min_val = np.Inf
    for arr in arrs:
        if min_val > min(arr):
            min_val = min(arr)
    return min_val

def plot_mats(ax, M, W, W_true):
    symbols = ['o', 'x', 'v', 's', '.']
    for i in range(len(symbols)):
        ax.plot(W[0, i], W[1, i],f'r{symbols[i]}', markersize=5, linewidth=2)
    ax.plot(M[0, :], M[1, :],'k.')

    for j in range(W_true.shape[1]):
        ax.plot(W_true[0, j], W_true[1, j],'bx', markersize=5, linewidth=2)

def plot_scores_comp(f_arr, g_arr, t_arr, vers):
    fig, axs = plt.subplots(1, 3, figsize=(20, 6))
    axs[0].set_yscale('log')
    axs[0].set_xlabel('Iterations')
    axs[0].set_ylabel('$F(W, H)$')
    axs[0].plot(t_arr[0], color='black', linewidth=3, label=vers[0])
    axs[0].plot(t_arr[1], color='cyan', linewidth=1.5, label=vers[1])
    axs[0].plot(t_arr[2], color='yellow', linewidth=1.5, label=vers[2])
    axs[0].legend()

    axs[1].set_yscale('log')
    axs[1].set_xlabel('Iterations')
    axs[1].set_ylabel('$f(W, H)$')
    axs[1].plot(f_arr[0], color='black', linewidth=3, label=vers[0])
    axs[1].plot(f_arr[1], color='cyan', linewidth=1.5, label=vers[1])
    axs[1].plot(f_arr[2], color='yellow', linewidth=1.5, label=vers[2])
    axs[1].legend()

    axs[2].set_yscale('log')
    axs[2].set_xlabel('Iterations')
    axs[2].set_ylabel('$g(W)$')
    axs[2].plot(g_arr[0], color='black', linewidth=3, label=vers[0])
    axs[2].plot(g_arr[1], color='cyan', linewidth=1.5, label=vers[1])
    axs[2].plot(g_arr[2], color='yellow', linewidth=1.5, label=vers[2])
    axs[2].legend()

In [None]:
M, W, H = create_toy_ex(30)
m, n = M.shape

r = 5
lam = 0.05
itermax = 1000

EARLY_STOP = True
VERBOSE = False
SCALE_REG = True

In [None]:
ini_W = np.random.rand(m, r)
ini_H = np.random.rand(r, n)

In [None]:
%%time
W_base, H_base, fscores_base, gscores_base, lvals_base = nmf_son_base(M, ini_W.copy(), ini_H.copy(), _lambda=lam, itermax=itermax, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)

In [None]:
%%time
W_conH, H_conH, fscores_conH, gscores_conH, lvals_conH = nmf_son_constrained_H(M, ini_W.copy(), ini_H.copy(), _lambda=lam, itermax=itermax, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)

In [None]:
%%time
W_new, H_new, fscores_new, gscores_new, lvals_new = nmf_son_new(M, ini_W.copy(), ini_H.copy(), _lambda=lam, itermax=itermax, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)

In [None]:
save_results('../experimental/saved_models/toy_base.npz', W_base, H_base, fscores_base, gscores_base, lvals_base)
save_results('../experimental/saved_models/toy_conh.npz', W_conH, H_conH, fscores_conH, gscores_conH, lvals_conH)
save_results('../experimental/saved_models/toy_new.npz', W_new, H_new, fscores_new, gscores_new, lvals_new)

In [None]:
vers = ['base', 'constrainedH', 'new']
f_arr = np.array(np.array([fscores_base[1:], fscores_conH[1:], fscores_new[1:]]))
g_arr = np.array([gscores_base[1:], gscores_conH[1:], gscores_new[1:]])
lvals_arr = np.array([lvals_base[1:], lvals_conH[1:], lvals_new[1:]])
t_arr = f_arr + lvals_arr * g_arr

plot_scores_comp(f_arr, g_arr, t_arr, vers)

In [None]:
f_arr -= find_min(f_arr)
g_arr -= find_min(g_arr)
t_arr -= find_min(t_arr)

plot_scores_comp(f_arr, g_arr, t_arr, vers)

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(20, 6))

axs[0].set_title('Initial')
plot_mats(axs[0], M, ini_W, W)
axs[1].set_title('Base')
plot_mats(axs[1], M, W_base, W)
axs[2].set_title('Base w. constrained H')
plot_mats(axs[2], M, W_conH, W)
axs[3].set_title('New')
plot_mats(axs[3], M, W_new, W)

## Real data (small Jasper)

In [None]:
def plot_seperate_H(H, img_size, figsize, fontsize, normalize_row=False, split=False, filename=None):
    rank = H.shape[0]
    if normalize_row:
        H /= np.linalg.norm(H, axis=1, keepdims=True)
    H3d = H.reshape(-1, img_size[0], img_size[1], order='F')
    if split:
        half_rank = int(rank / 2)
        fig, axs = plt.subplots(2, half_rank, figsize=figsize)
        i, j, cnt = 0, 0, 0
        while cnt < rank:
            img = axs[i, j].imshow(H3d[cnt, :, :], cmap='gray')
            axs[i, j].set_title(f'$h^{cnt + 1}$', fontsize=fontsize)
            axs[i, j].axis('off')
            divider = make_axes_locatable(axs[i, j])
            cax = divider.append_axes('right', size='5%', pad=0.1)
            fig.colorbar(img, cax=cax, orientation='vertical')



            j += 1
            if cnt + 1 == half_rank:
                i = 1
                j = 0
            cnt += 1
    else:
        fig, axs = plt.subplots(1, rank, figsize=figsize)
        cnt = 0
        while cnt < rank:
            img = axs[cnt].imshow(H3d[cnt, :, :], cmap='gray')
            axs[cnt].set_title(f'$h^{cnt + 1}$', fontsize=fontsize)
            axs[cnt].axis('off')
            divider = make_axes_locatable(axs[cnt])
            cax = divider.append_axes('right', size='5%', pad=0.1)
            fig.colorbar(img, cax=cax, orientation='vertical')

            cnt += 1
    # plt.tight_layout()
    if filename:
        plt.savefig(filename)


def plot_combined_H(H, img_size, figsize, normalize_row=False, split=False, filename=None):
    if normalize_row:
        H /= np.linalg.norm(H, axis=1, keepdims=True)
    H3d = H.reshape(-1, img_size[0], img_size[1], order='F')
    if split:
        half_rank = int(H.shape[0] / 2)
        large_mat = np.vstack([np.hstack(H3d[: half_rank, :, :]), np.hstack(H3d[half_rank: , :, :])])
    else:
        large_mat = np.hstack(H3d)
    plt.figure(figsize=figsize)
    ax = plt.axes()
    im = plt.imshow(large_mat, cmap='gray')

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)

    plt.colorbar(im, cax=cax)

    if filename:
        plt.savefig(filename)


def plot_W_mats(W, figsize, fontsize, split=False, filename=None, scale_y=False, log_scale=False, plot_title=None):
    rank = W.shape[1]
    wmin = np.min(W)
    wmax = np.max(W)

    if split:
        half_rank = int(rank / 2)
        fig, axs = plt.subplots(2, half_rank, figsize=figsize)
        i, j, cnt = 0, 0, 0
        while cnt < rank:
            axs[i, j].plot(W[:, cnt], linewidth=3)
            if scale_y:
                axs[i, j].set_ylim([min(0, wmin), wmax])
            if log_scale:
                axs[i, j].set_yscale('log')
            axs[i, j].set_title(f'$w_{cnt + 1}$', fontsize=fontsize)
            axs[i, j].set_xlabel(f'Bands')
            axs[i, j].set_ylabel(f'Reflectance')
            j += 1
            if cnt + 1 == half_rank:
                i = 1
                j = 0
            cnt += 1
    else:
        fig, axs = plt.subplots(1, rank, figsize=figsize)
        cnt = 0
        while cnt < rank:
            axs[cnt].plot(W[:, cnt], linewidth=3)
            if scale_y:
                axs[cnt].set_ylim([min(0, wmin), wmax])
            if log_scale:
                axs[cnt].set_yscale('log')
            axs[cnt].set_title(f'$w_{cnt + 1}$', fontsize=fontsize)
            axs[cnt].set_xlabel(f'Bands')
            axs[cnt].set_ylabel(f'Reflectance')

            cnt += 1
    plt.tight_layout()

    if plot_title:
        fig.suptitle(plot_title, fontsize=25)
    if filename:
        fig.savefig(filename)

In [None]:
M = np.load('../experimental/datasets/urban_small.npz')['X']
m, n = M.shape

r = 6
lam = 1
itermax = 1000

EARLY_STOP = True
VERBOSE = False
SCALE_REG = True

In [None]:
ini_W = np.random.rand(m, r)
ini_H = np.random.rand(r, n)

In [None]:
%%time
W_base, H_base, fscores_base, gscores_base, lvals_base = nmf_son_base(M, ini_W.copy(), ini_H.copy(), _lambda=lam, itermax=itermax, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)

In [None]:
%%time
W_conH, H_conH, fscores_conH, gscores_conH, lvals_conH = nmf_son_constrained_H(M, ini_W.copy(), ini_H.copy(), _lambda=lam, itermax=itermax, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)

In [None]:
%%time
W_new, H_new, fscores_new, gscores_new, lvals_new = nmf_son_new(M, ini_W.copy(), ini_H.copy(), _lambda=lam, itermax=itermax, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)

In [None]:
save_results('../experimental/saved_models/urban_sm_base.npz', W_base, H_base, fscores_base, gscores_base, lvals_base)
save_results('../experimental/saved_models/urban_sm_conh.npz', W_conH, H_conH, fscores_conH, gscores_conH, lvals_conH)
save_results('../experimental/saved_models/urban_sm_new.npz', W_new, H_new, fscores_new, gscores_new, lvals_new)

In [None]:
vers = ['base', 'constrainedH', 'new']
f_arr = np.array(np.array([fscores_base[1:], fscores_conH[1:], fscores_new[1:]]))
g_arr = np.array([gscores_base[1:], gscores_conH[1:], gscores_new[1:]])
lvals_arr = np.array([lvals_base[1:], lvals_conH[1:], lvals_new[1:]])
t_arr = f_arr + lam * g_arr

plot_scores_comp(f_arr, g_arr, t_arr, vers)

In [None]:
f_arr -= find_min(f_arr)
g_arr -= find_min(g_arr)
t_arr -= find_min(t_arr)

plot_scores_comp(f_arr, g_arr, t_arr, vers)

In [None]:
plot_W_mats(W_base, figsize=(28, 6), fontsize=25, scale_y=True)
plot_seperate_H(H_base, (20, 10), figsize=(20, 5), fontsize=15, normalize_row=True)

In [None]:
plot_W_mats(W_conH, figsize=(28, 6), fontsize=25, scale_y=True)
plot_seperate_H(H_conH, (20, 10), figsize=(20, 5), fontsize=15, normalize_row=True)

In [None]:
plot_W_mats(W_new, figsize=(28, 6), fontsize=25, scale_y=True)
plot_seperate_H(H_new, (20, 10), figsize=(20, 5), fontsize=15, normalize_row=True)

In [None]:
lambda_vals = [1e-7, 1e-6, 1e-5, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10, 50, 100, 500, 1000]
max_iters = 10000

In [None]:
EARLY_STOP = True
VERBOSE = False
SCALE_REG = True

In [None]:
# toy
M, W, H = create_toy_ex(30)
m, n = M.shape
r = 5

ini_filepath = f'../experimental/saved_models/toy/r{r}_ini.npz'
save_filepath = '../experimental/saved_models/toy/r{}_l{}_mit{}.npz'


ini_W = np.random.rand(m, r)
ini_H = np.random.rand(r, n)

with open(ini_filepath, 'wb') as fout:
    np.savez_compressed(fout, ini_W=ini_W, ini_H=ini_H)

for _lam in lambda_vals:
    # W, H, fscores, gscores, lvals = nmf_son_new(M, ini_W.copy(), ini_H.copy(), _lambda=_lam, itermax=max_iters, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)
    # save_results(save_filepath.format(r, _lam, max_iters), W, H, fscores, gscores, lvals)
    fig, axs = plt.subplots(1, 1, figsize=(20, 6))

    axs[0].set_title('Initial')
    plot_mats(axs[0], M, ini_W, W)

In [None]:
# # jasper
# M = np.load('../experimental/datasets/jasper_full.npz')['X']
# m, n = M.shape
# r = 8
#
# ini_filepath = f'../experimental/saved_models/jasper/r{r}_ini.npz'
# save_filepath = '../experimental/saved_models/jasper/r{}_l{}_mit{}.npz'
#
#
# ini_W = np.random.rand(m, r)
# ini_H = np.random.rand(r, n)
#
# with open(ini_filepath, 'wb') as fout:
#     np.savez_compressed(fout, ini_W=ini_W, ini_H=ini_H)
#
# for _lam in lambda_vals:
#     W, H, fscores, gscores, lvals = nmf_son_new(M, ini_W.copy(), ini_H.copy(), _lambda=_lam, itermax=max_iters, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)
#     save_results(save_filepath.format(r, _lam, max_iters), W, H, fscores, gscores, lvals)
#     print(_lam)

In [None]:
# # urban
# M = np.load('../experimental/datasets/urban_full.npz')['X']
# m, n = M.shape
# r = 10
#
# ini_filepath = f'../experimental/saved_models/urban/r{r}_ini.npz'
# save_filepath = '../experimental/saved_models/urban/r{}_l{}_it{}.npz'
#
#
# ini_W = np.random.rand(m, r)
# ini_H = np.random.rand(r, n)
#
# with open(ini_filepath, 'wb') as fout:
#     np.savez_compressed(fout, ini_W=ini_W, ini_H=ini_H)
#
# for _lam in lambda_vals:
#     W, H, fscores, gscores, lvals = nmf_son_new(M, ini_W.copy(), ini_H.copy(), _lambda=_lam, itermax=max_iters, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)
#     save_results(save_filepath.format(r, _lam, max_iters), W, H, fscores, gscores, lvals)
#     print(_lam)