In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from nmf_methods.nmf_son.new import new as nmf_son_new
from nmf_methods.nmf_son.utils import save_results, load_results
from sklearn.decomposition import NMF
from PIL import Image

np.random.seed(42)
np.set_printoptions(precision=3)

In [None]:
def plot_scores(fscores, gscores, lambda_vals, plot_title=None, filename=None):
    fscores = fscores[1:]
    gscores = gscores[1:]
    lambda_vals = lambda_vals[1:]
    total_score = fscores + lambda_vals * gscores
    fig, axs = plt.subplots(1, 2, figsize=(20, 5))
    if plot_title:
        fig.suptitle(plot_title, fontsize=25)

    axs[0].set_yscale('log')
    axs[0].plot(total_score, color='black', linewidth=3, label='$F(W, H)$')
    axs[0].plot(fscores, color='cyan', linewidth=1.5, label='$f(W, H)$')
    axs[0].plot(gscores, color='yellow', linewidth=1.5, label='$g(W)$')
    axs[0].set_xlabel('Iterations')
    axs[0].legend()

    fscores -= min(fscores)
    gscores -= min(gscores)
    total_score -= min(total_score)

    axs[1].set_yscale('log')
    axs[1].plot(total_score, color='black', linewidth=3, label='$F(W, H) - min(F(W, H))$')
    axs[1].plot(fscores, color='cyan', linewidth=1.5, label='$f(W, H) - min(f(W, H))$')
    axs[1].plot(gscores, color='yellow', linewidth=1.5, label='$g(W) - min(g(W))$')
    axs[1].set_xlabel('Iterations')
    axs[1].legend()

    if filename:
        fig.savefig(filename)
        plt.close()


def plot_separate_H(H, img_size, figsize, fontsize, normalize_row=False, split=False, filename=None):
    rank = H.shape[0]
    if normalize_row:
        H /= np.linalg.norm(H, axis=1, keepdims=True)
    H3d = H.reshape(-1, img_size[0], img_size[1], order='F')
    if split:
        half_rank = int(rank / 2)
        fig, axs = plt.subplots(2, half_rank, figsize=figsize)
        i, j, cnt = 0, 0, 0
        while cnt < rank:
            img = axs[i, j].imshow(H3d[cnt, :, :], cmap='gray')
            axs[i, j].set_title(f'$h^{{{cnt + 1}}}$', fontsize=fontsize)
            axs[i, j].axis('off')
            divider = make_axes_locatable(axs[i, j])
            cax = divider.append_axes('right', size='5%', pad=0.1)
            fig.colorbar(img, cax=cax, orientation='vertical')

            j += 1
            if cnt + 1 == half_rank:
                i = 1
                j = 0
            cnt += 1
    else:
        fig, axs = plt.subplots(1, rank, figsize=figsize)
        cnt = 0
        while cnt < rank:
            img = axs[cnt].imshow(H3d[cnt, :, :], cmap='gray')
            axs[cnt].set_title(f'$h^{{{cnt + 1}}}$', fontsize=fontsize)
            axs[cnt].axis('off')
            divider = make_axes_locatable(axs[cnt])
            cax = divider.append_axes('right', size='5%', pad=0.1)
            fig.colorbar(img, cax=cax, orientation='vertical')

            cnt += 1
    plt.tight_layout()
    if filename:
        plt.savefig(filename)
        plt.close()


def plot_combined_H(H, img_size, figsize, normalize_row=False, split=False, filename=None):
    if normalize_row:
        H /= np.linalg.norm(H, axis=1, keepdims=True)

    H3d = H.reshape(-1, img_size[0], img_size[1], order='F')

    if split:
        half_rank = H.shape[0] // 2
        large_mat = np.vstack([np.hstack(H3d[:half_rank]), np.hstack(H3d[half_rank:])])
    else:
        large_mat = np.hstack(H3d)

    fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(large_mat, cmap='gray')
    cax = ax.inset_axes([1.05, 0, 0.05, 1])
    plt.colorbar(im, cax=cax)

    plt.tight_layout()

    if filename:
        plt.savefig(filename)
        plt.close()


def plot_W_mats(W, figsize, fontsize, split=False, filename=None, scale_y=False, log_scale=False, plot_title=None):
    rank = W.shape[1]
    wmin, wmax = np.min(W), np.max(W)

    if split:
        half_rank = int(rank / 2)
        fig, axs = plt.subplots(2, half_rank, figsize=figsize)
        axs = axs.ravel()
    else:
        fig, axs = plt.subplots(1, rank, figsize=figsize)

    for cnt, ax in enumerate(axs):
        ax.plot(W[:, cnt], linewidth=3)
        if scale_y:
            ax.set_ylim([min(0, wmin), wmax])
        if log_scale:
            ax.set_yscale('log')
        ax.set_title(f'$w_{{{cnt + 1}}}$', fontsize=fontsize)
        ax.set_xlabel('Bands')
        ax.set_ylabel('Reflectance')

    plt.tight_layout()
    if plot_title:
        fig.suptitle(plot_title, fontsize=25)
    if filename:
        fig.savefig(filename)
        plt.close()


def merge_images(images_list, filename, delete_images=False):
    imgs = [Image.open(i) for i in images_list]
    min_img_width = min(i.width for i in imgs)

    total_height = 0
    for i, img in enumerate(imgs):
        if img.width > min_img_width:
            imgs[i] = img.resize((min_img_width, int(img.height / img.width * min_img_width)), Image.ANTIALIAS)
        total_height += imgs[i].height

    img_merge = Image.new(imgs[0].mode, (min_img_width, total_height))
    y = 0
    for img in imgs:
        img_merge.paste(img, (0, y))
        y += img.height

    img_merge.save(filename)

    if delete_images:
        for fp in images_list:
            os.remove(fp)

In [None]:
RUN = True

EARLY_STOP = True
VERBOSE = False
SCALE_REG = True

In [None]:
max_iters = 10000

In [None]:
M = np.load('../../experimental/datasets/jasper_small_2.npz')['X']
m, n = M.shape

# img_size = (50, 40)
img_size = (10, 10)
w_plot_size = (32, 8)
h_plot_size = (32, 8)

In [None]:
# # dataset creation
# import matplotlib.patches as patches
#
# M = np.load('../../experimental/datasets/jasper_full.npz')['X']
# M3d = M.reshape(-1, 100, 100, order='F')
# img = M3d[80, :, :].copy()
#
# fig, ax = plt.subplots()
# ax.imshow(img, cmap='gray')
# # rect = patches.Rectangle((60, 0), 40, 50, linewidth=2, edgecolor='r', facecolor='none')
# rect = patches.Rectangle((30, 10), 10, 10, linewidth=2, edgecolor='r', facecolor='none')
# ax.add_patch(rect)
#
# # jasper_small_3d = M3d[:, :50, 60:]
# jasper_small_3d = M3d[:, 10: 20, 30: 40]
# fig, ax = plt.subplots()
# ax.imshow(jasper_small_3d[80], cmap='gray')
#
# jasper_small = jasper_small_3d.reshape(m, -1, order='F')
# with open('../../experimental/datasets/jasper_small_2.npz', 'wb') as fout:
#     np.savez_compressed(fout, X=jasper_small)

### vanilla NMF (r = 4)

In [None]:
r_true = 4
ini_filepath = f'../../experimental/saved_models/jasper_small/r{r_true}_ini.npz'
save_filepath = f'../../experimental/saved_models/jasper_small/vanilla_r{r_true}_mit{max_iters}.npz'

if RUN:
    data = np.load(ini_filepath)
    ini_W = data['ini_W']
    ini_H = data['ini_H']

    model = NMF(n_components=r_true, init='custom', random_state=42, max_iter=max_iters)
    W = model.fit_transform(X=M, W=ini_W.copy(), H=ini_H.copy())
    H = model.components_
    with open(save_filepath, 'wb') as fout:
        np.savez_compressed(fout, W=W, H=H)
else:
    data2 = np.load(save_filepath)
    W = data2['W']
    H = data2['H']
    plot_W_mats(W, figsize=(16, 4), fontsize=15, scale_y=False, plot_title='vanilla nmf (r = 4)', filename=f'../../experimental/images/jasper_small/w_vanilla_r{r_true}_mit{max_iters}.png')
    plot_separate_H(H, img_size, figsize=(16, 4), fontsize=15, normalize_row=False, split=False, filename=f'../../experimental/images/jasper_small/seph_vanilla_r{r_true}_mit{max_iters}.png')
    plot_combined_H(H, img_size, figsize=(16, 4), normalize_row=False, split=False, filename=f'../../experimental/images/jasper_small/combh_vanilla_r{r_true}_mit{max_iters}.png')
    merge_images([f'../../experimental/images/jasper_small/w_vanilla_r{r_true}_mit{max_iters}.png', f'../../experimental/images/jasper_small/seph_vanilla_r{r_true}_mit{max_iters}.png', f'../../experimental/images/jasper_small/combh_vanilla_r{r_true}_mit{max_iters}.png'], f'../../experimental/images/jasper_small/r{r_true}_vanilla.png', delete_images=True)

### vanilla NMF (r = 20)

In [None]:
r = 20

ini_filepath = f'../../experimental/saved_models/jasper_small/r{r}_ini.npz'
save_filepath = f'../../experimental/saved_models/jasper_small/vanilla_r{r}_mit{max_iters}.npz'

In [None]:
data = np.load(ini_filepath)
ini_W = data['ini_W']
ini_H = data['ini_H']

In [None]:
if RUN:
    model = NMF(n_components=r, init='custom', random_state=42, max_iter=max_iters)
    W = model.fit_transform(X=M, W=ini_W.copy(), H=ini_H.copy())
    H = model.components_
    with open(save_filepath, 'wb') as fout:
        np.savez_compressed(fout, W=W, H=H)
else:
    data2 = np.load(save_filepath)
    W = data2['W']
    H = data2['H']
    plot_W_mats(W, figsize=w_plot_size, fontsize=15, split=True, scale_y=False, plot_title='vanilla nmf (r = 20)', filename=f'../../experimental/images/jasper_small/w_vanilla_r{r}_mit{max_iters}.png')
    plot_separate_H(H, img_size, figsize=h_plot_size, fontsize=15, normalize_row=False, split=True, filename=f'../../experimental/images/jasper_small/seph_vanilla_r{r}_mit{max_iters}.png')
    plot_combined_H(H, img_size, figsize=h_plot_size, normalize_row=False, split=True, filename=f'../../experimental/images/jasper_small/combh_vanilla_r{r}_mit{max_iters}.png')
    merge_images([f'../../experimental/images/jasper_small/w_vanilla_r{r}_mit{max_iters}.png', f'../../experimental/images/jasper_small/seph_vanilla_r{r}_mit{max_iters}.png', f'../../experimental/images/jasper_small/combh_vanilla_r{r}_mit{max_iters}.png'], f'../../experimental/images/jasper_small/r{r}_vanilla.png', delete_images=True)

### nmf-son with random initialization

In [None]:
# lambda_vals = [1e-7, 1e-6, 1e-5, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10, 50, 100, 500, 1000]
lambda_vals = [1500, 2000, 5000, 10000]

save_filepath = '../../experimental/saved_models/jasper_small/r{}_l{}_mit{}.npz'

for _lam in lambda_vals:
    if RUN:
        W, H, fscores, gscores, lvals = nmf_son_new(M, ini_W.copy(), ini_H.copy(), lam=_lam, itermax=max_iters, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)
        save_results(save_filepath.format(r, _lam, max_iters), W, H, fscores, gscores, lvals)
    else:
        W, H, fscores, gscores, lvals = load_results(save_filepath.format(r, _lam, max_iters))
        plot_scores(fscores, gscores, lvals, plot_title=_lam)
        plot_W_mats(W, figsize=w_plot_size, fontsize=15, split=True, scale_y=False, filename=f'../../experimental/images/jasper_small/w_r{r}_l{_lam}_mit{max_iters}.png')
        plot_separate_H(H, img_size, figsize=h_plot_size, fontsize=15, normalize_row=False, split=True, filename=f'../../experimental/images/jasper_small/seph_r{r}_l{_lam}_mit{max_iters}.png')
        plot_combined_H(H, img_size, figsize=h_plot_size, normalize_row=False, split=True, filename=f'../../experimental/images/jasper_small/combh_r{r}_l{_lam}_mit{max_iters}.png')
        merge_images([f'../../experimental/images/jasper_small/w_r{r}_l{_lam}_mit{max_iters}.png', f'../../experimental/images/jasper_small/seph_r{r}_l{_lam}_mit{max_iters}.png', f'../../experimental/images/jasper_small/combh_r{r}_l{_lam}_mit{max_iters}.png'], f'../../experimental/images/jasper_small/random/r{r}_l{_lam}_mit{max_iters}_thres.png', delete_images=True)
    print(_lam)

### nmf-son with vanilla nmf initialization

In [None]:
# # lambda_vals = [1e-7, 1e-6, 1e-5, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10, 50, 100, 500, 1000]
#
# ini_filepath = f'../../experimental/saved_models/jasper_small/vanilla_r{r}_mit{max_iters}.npz'
# save_filepath = '../../experimental/saved_models/jasper_small/r{}_vl{}_mit{}.npz'
#
# data = np.load(ini_filepath)
# ini_W = data['W']
# ini_H = data['H']
#
# for _lam in lambda_vals:
#     if RUN:
#         W, H, fscores, gscores, lvals = nmf_son_new(M, ini_W.copy(), ini_H.copy(), lam=_lam, itermax=max_iters, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)
#         save_results(save_filepath.format(r, _lam, max_iters), W, H, fscores, gscores, lvals)
#     else:
#         W, H, fscores, gscores, lvals = load_results(save_filepath.format(r, _lam, max_iters))
#         plot_scores(fscores, gscores, lvals, plot_title=_lam)
#         plot_W_mats(W, figsize=w_plot_size, fontsize=15, split=True, scale_y=False, filename=f'../../experimental/images/jasper_small/w_r{r}_vl{_lam}_mit{max_iters}.png')
#         plot_separate_H(H, img_size, figsize=h_plot_size, fontsize=15, normalize_row=False, split=True, filename=f'../../experimental/images/jasper_small/seph_r{r}_vl{_lam}_mit{max_iters}.png')
#         plot_combined_H(H, img_size, figsize=h_plot_size, normalize_row=False, split=True, filename=f'../../experimental/images/jasper_small/combh_r{r}_vl{_lam}_mit{max_iters}.png')
#         merge_images([f'../../experimental/images/jasper_small/w_r{r}_vl{_lam}_mit{max_iters}.png', f'../../experimental/images/jasper_small/seph_r{r}_vl{_lam}_mit{max_iters}.png', f'../../experimental/images/jasper_small/combh_r{r}_vl{_lam}_mit{max_iters}.png'], f'../../experimental/images/jasper_small/vanilla/r{r}_vl{_lam}_mit{max_iters}.png', delete_images=True)
#     print(_lam)

### r = n

In [None]:
r = n
lambda_vals = [0.001, 0.01, 0.1, 1, 10, 100, 1000, 10000]

ini_filepath = f'../../experimental/saved_models/jasper_small_2/r{r}_ini.npz'
save_filepath = '../../experimental/saved_models/jasper_small_2/r{}_l{}_mit{}.npz'

In [None]:
if os.path.exists(ini_filepath):
    data = np.load(ini_filepath)
    ini_W = data['ini_W']
    ini_H = data['ini_H']
else:
    ini_W = np.random.rand(m, r)
    ini_H = np.random.rand(r, n)
    with open(ini_filepath, 'wb') as fout:
        np.savez_compressed(fout, ini_W=ini_W, ini_H=ini_H)

In [None]:
for _lam in lambda_vals:
    if RUN:
        W, H, fscores, gscores, lvals = nmf_son_new(M, ini_W.copy(), ini_H.copy(), lam=_lam, itermax=max_iters, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)
        save_results(save_filepath.format(r, _lam, max_iters), W, H, fscores, gscores, lvals)
    else:
        W, H, fscores, gscores, lvals = load_results(save_filepath.format(r, _lam, max_iters))
        plot_scores(fscores, gscores, lvals, plot_title=_lam)
        plot_W_mats(W, figsize=w_plot_size, fontsize=15, split=True, scale_y=False, filename=f'../../experimental/images/jasper_small_2/w_r{r}_l{_lam}_mit{max_iters}.png')
        plot_separate_H(H, img_size, figsize=h_plot_size, fontsize=15, normalize_row=False, split=True, filename=f'../../experimental/images/jasper_small_2/seph_r{r}_l{_lam}_mit{max_iters}.png')
        plot_combined_H(H, img_size, figsize=h_plot_size, normalize_row=False, split=True, filename=f'../../experimental/images/jasper_small_2/combh_r{r}_l{_lam}_mit{max_iters}.png')
        merge_images([f'../../experimental/images/jasper_small_2/w_r{r}_l{_lam}_mit{max_iters}.png', f'../../experimental/images/jasper_small_2/seph_r{r}_l{_lam}_mit{max_iters}.png', f'../../experimental/images/jasper_small_2/combh_r{r}_l{_lam}_mit{max_iters}.png'], f'../../experimental/images/jasper_small_2/random/r{r}_l{_lam}_mit{max_iters}_thres.png', delete_images=True)
    print(_lam)