In [None]:
import os
import numpy as np
from nmf_methods.nmf_son.new import new as nmf_son_new
from nmf_methods.nmf_son.utils import save_results, load_results, plot_scores, plot_W_mats, plot_separate_H, plot_combined_H, merge_images, plot_and_merge
from sklearn.decomposition import NMF

np.random.seed(42)
np.set_printoptions(precision=3)

In [None]:
RUN = True

EARLY_STOP = True
VERBOSE = False
SCALE_REG = True

In [None]:
max_iters = 10000

In [None]:
# jasper_full = np.load('../../experimental/datasets/jasper_full.npz')['X']
# jasper_3d = jasper_full.reshape(-1, 100, 100, order='F')
# # jasper_3d[:, 10: 20, 30: 40] # jasper small 2
# # jasper_3d[:, 10: 20, 19: 29] # jasper small 3
# jasper_small_3d = jasper_3d[:, 10: 20, 19: 29] # jasper small 3
# jasper_small = jasper_small_3d.reshape(-1, 10, 10, order='F')
# with open('../../experimental/datasets/jasper_small_3.npz', 'wb') as fp:
#     np.savez_compressed(fp, X=jasper_small)

In [None]:
# import matplotlib.pyplot as plt
# import matplotlib.patches as patches
#
# # Plot the matrix with imshow
# img = plt.imshow(jasper_3d[80, :, :], cmap='gray')
#
# # Add a rectangular box around the region of interest
# # rect = patches.Rectangle((60, 0), 40, 50, linewidth=2, edgecolor='r', facecolor='none') # jasper_small
# # rect = patches.Rectangle((30, 10), 10, 10, linewidth=2, edgecolor='r', facecolor='none') # jasper_small_2
# rect = patches.Rectangle((19, 10), 10, 10, linewidth=2, edgecolor='r', facecolor='none') # jasper_small_3
# plt.gca().add_patch(rect)
# plt.tight_layout()
#
# # Show the plot
# plt.savefig('../../experimental/images/jasper_small_3/outlined_region.png')

In [None]:
M = np.load('../../experimental/datasets/jasper_small_3.npz')['X']
m, n = M.shape

### vanilla NMF (r = 2)

In [None]:
r_true = 2
imgsize = (10, 10)
figsize = (16, 8)
fontsize = 10
num_rows = 1

ini_filepath = f'../../experimental/saved_models/jasper_small_3/r{r_true}_ini.npz'
save_filepath = f'../../experimental/saved_models/jasper_small_3/vanilla_r{r_true}_mit{max_iters}.npz'

In [None]:
if os.path.exists(ini_filepath):
    data = np.load(ini_filepath)
    ini_W = data['ini_W']
    ini_H = data['ini_H']
else:
    ini_W = np.random.rand(m, r_true)
    ini_H = np.random.rand(r_true, n)
    with open(ini_filepath, 'wb') as fout:
        np.savez_compressed(fout, ini_W=ini_W, ini_H=ini_H)

In [None]:
if RUN:
    model = NMF(n_components=r_true, init='custom', random_state=42, max_iter=max_iters)
    W = model.fit_transform(X=M, W=ini_W.copy(), H=ini_H.copy())
    H = model.components_
    with open(save_filepath, 'wb') as fout:
        np.savez_compressed(fout, W=W, H=H)
else:
    data2 = np.load(save_filepath)
    W = data2['W']
    H = data2['H']
    img_filenames = [f'../../experimental/images/jasper_small_3/w_vanilla_r{r_true}_mit{max_iters}.png', f'../../experimental/images/jasper_small_3/seph_vanilla_r{r_true}_mit{max_iters}.png', f'../../experimental/images/jasper_small_3/combh_vanilla_r{r_true}_mit{max_iters}.png', f'../../experimental/images/jasper_small_3/r{r_true}_vanilla.png']
    plot_and_merge(W, H, imgsize=imgsize, figsize=figsize, fontsize=fontsize, filenames=img_filenames, num_rows=num_rows, delete=True)

### vanilla NMF (r = n)

In [None]:
r = n
figsize = (32, 32)
num_rows = 2

ini_filepath = f'../../experimental/saved_models/jasper_small_3/r{r}_ini.npz'
save_filepath = f'../../experimental/saved_models/jasper_small_3/vanilla_r{r}_mit{max_iters}.npz'

In [None]:
if os.path.exists(ini_filepath):
    data = np.load(ini_filepath)
    ini_W = data['ini_W']
    ini_H = data['ini_H']
else:
    ini_W = np.random.rand(m, r)
    ini_H = np.random.rand(r, n)
    with open(ini_filepath, 'wb') as fout:
        np.savez_compressed(fout, ini_W=ini_W, ini_H=ini_H)

In [None]:
if RUN:
    model = NMF(n_components=r, init='custom', random_state=42, max_iter=max_iters)
    W = model.fit_transform(X=M, W=ini_W.copy(), H=ini_H.copy())
    H = model.components_
    with open(save_filepath, 'wb') as fout:
        np.savez_compressed(fout, W=W, H=H)
else:
    data2 = np.load(save_filepath)
    W = data2['W']
    H = data2['H']
    img_filenames = [f'../../experimental/images/jasper_small_3/w_vanilla_r{r}_mit{max_iters}.png', f'../../experimental/images/jasper_small_3/seph_vanilla_r{r}_mit{max_iters}.png', f'../../experimental/images/jasper_small_3/combh_vanilla_r{r}_mit{max_iters}.png', f'../../experimental/images/jasper_small_3/r{r}_vanilla.png']
    plot_and_merge(W, H, imgsize=imgsize, figsize=figsize, fontsize=fontsize, filenames=img_filenames, num_rows=num_rows, delete=True)

### nmf-son with random initialization

In [None]:
lambda_vals = [0.0001, 0.01, 0.1, 1, 10, 100, 1000]
save_filepath = '../../experimental/saved_models/jasper_small_3/r{}_l{}_mit{}.npz'

for _lam in lambda_vals:
    if RUN:
        W, H, fscores, gscores, lvals = nmf_son_new(M, ini_W.copy(), ini_H.copy(), lam=_lam, itermax=max_iters, early_stop=EARLY_STOP, verbose=VERBOSE, scale_reg=SCALE_REG)
        save_results(save_filepath.format(r, _lam, max_iters), W, H, fscores, gscores, lvals)
    else:
        W, H, fscores, gscores, lvals = load_results(save_filepath.format(r, _lam, max_iters))
        plot_scores(fscores, gscores, lvals, plot_title=_lam)

        img_filenames = [f'../../experimental/images/jasper_small_3/w_r{r}_l{_lam}_mit{max_iters}.png', f'../../experimental/images/jasper_small_3/seph_r{r}_l{_lam}_mit{max_iters}.png', f'../../experimental/images/jasper_small_3/combh_r{r}_l{_lam}_mit{max_iters}.png', f'../../experimental/images/jasper_small_3/r{r}_l{_lam}_mit{max_iters}.png']
        plot_and_merge(W, H, imgsize=imgsize, figsize=figsize, fontsize=fontsize, filenames=img_filenames, num_rows=num_rows, delete=True)
    print(_lam)