In [None]:
import math
import math
import numpy as np
import scipy.io
import matplotlib.pyplot as plt
from simple import nmf_son
from sklearn.decomposition import NMF
from mpl_toolkits.axes_grid1 import make_axes_locatable


np.random.seed(42)
np.set_printoptions(precision=3)

In [None]:
def normalized_similarity(W_ins):
    r = W_ins.shape[1]
    res = np.ones(shape=(r, r)) * -1
    for i in range(r):
        for j in range(r):
            res[i, j] = np.linalg.norm(W_ins[:, i] - W_ins[:, j])
        res[i, :] = res[i, :] / sum(res[i, :])
    return res

def plot_scores(fscores, gscores, lambda_vals):
    def plot_ax(ax, f, g, total):
        ax.plot(total, color='black', linewidth=3, label='total')
        ax.plot(f, label='f')
        ax.plot(g, label='g')
        ax.legend()

    total_score = fscores + lambda_vals * gscores
    fig, axs = plt.subplots(1, 3, figsize=(20, 6))

    plot_ax(axs[0], fscores, gscores, total_score)

    axs[1].set_yscale('log')
    plot_ax(axs[1], fscores, gscores, total_score)

    axs[2].set_yscale('log')
    plot_ax(axs[2], fscores - fscores[-1], gscores - gscores[-1], total_score - total_score[-1])

In [None]:
mat = scipy.io.loadmat('urban/Urban.mat')
X = mat['X']

m, n = X.shape # (162, 94249)
num_col = int(math.sqrt(n)) # 307

X3d = X.reshape(m, num_col, num_col, order='F')

In [None]:
wavelength = 80
plt.imshow(X3d[wavelength, :, :], cmap='gray')
plt.colorbar()

In [None]:
# img = X3d[wavelength, :, :].copy()
# img[80: 120, 190: 230] = 1000 # roof
# plt.imshow(img, cmap='gray')
# plt.colorbar()

In [None]:
# medX3d = X3d[:, 80: 120, 190: 230] # (162, 40, 40)
# medX = medX3d.reshape(m, -1, order='F') # (162, 1600)
# with open(f'urban/2022_11_19/med/X.npz', 'wb') as fout:
#     np.savez_compressed(fout, X=medX)

In [None]:
m, n = X.shape
rank = 10

W_ini = np.random.rand(m, rank)
H_ini = np.random.rand(rank, n)
with open(f'urban/2022_11_19/full/r{rank}_ini.npz', 'wb') as fout:
    np.savez_compressed(fout, W=W_ini, H=H_ini)

In [None]:
it_ckpts = [100, 200, 500, 1000, 2000]
save_filepath = 'urban/2022_11_19/full/r{}-l{}-it{}.npz'
reg_val = 10


Wb, Hb, Wl, Hl, fscores, gscores, lambda_vals = nmf_son(X, W_ini.copy(), H_ini.copy(), _lambda=reg_val, itermax=it_ckpts[0], scale_lambda=True)
with open(save_filepath.format(rank, reg_val, it_ckpts[0]), 'wb') as fout:
    np.savez_compressed(fout, Wb=Wb, Hb=Hb, Wl=Wl, Hl=Hl, fscores=fscores, gscores=gscores, lambda_vals=lambda_vals)

for k in range(1, len(it_ckpts)):
    Wb, Hb, Wl, Hl, new_fscores, new_gscores, new_lambda_vals = nmf_son(X, Wl, Hl, _lambda=reg_val, itermax=it_ckpts[k] - it_ckpts[k-1], scale_lambda=True)

    fscores = np.append(fscores[:-1], new_fscores)
    gscores = np.append(gscores[:-1], new_gscores)
    lambda_vals = np.append(lambda_vals[:-1], new_lambda_vals)

    with open(save_filepath.format(rank, reg_val, it_ckpts[k]), 'wb') as fout:
        np.savez_compressed(fout, Wb=Wb, Hb=Hb, Wl=Wl, Hl=Hl, fscores=fscores, gscores=gscores, lambda_vals=lambda_vals)
    print(rank, reg_val, it_ckpts[k], 'complete')

In [None]:
plot_scores(fscores, gscores, lambda_vals)

In [None]:
normalized_similarity(Wl)