# NMF on a toy dataset

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

## Import and function definition

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from espm.datasets.base import generate_spim
from espm.estimators import SmoothNMF
from espm.measures import find_min_angle, find_min_MSE, ordered_mse, ordered_mae, ordered_r2
from espm.datasets.toy import syntheticG, load_toy_images, create_toy_problem


In [None]:
C = 15
L = 200
P = 100**2
seed = 0

n_poisson = 600 # Average poisson number per pixel (this number will be splitted on the L dimension)


In [None]:


def plot_results(Ddot, D, Hdotflat, Hflat):
    fontsize = 30
    scale = 15
    aspect_ratio = 1.4
    marker_list = ["-o","-s","->","-<","-^","-v","-d"]
    mark_space = 20
    # cmap = plt.cm.hot_r    
    cmap = plt.cm.gray_r
    vmax = 1
    vmin = 0
    K = len(H)
    L = len(D)
    
    angles, true_inds = find_min_angle(Ddot.T, D.T, unique=True, get_ind=True)
    mse = ordered_mse(Hdotflat, Hflat, true_inds)
    mae = ordered_mae(Hdotflat, Hflat, true_inds)
    r2 = ordered_r2(Hdotflat, Hflat, true_inds)


    fig, axes = plt.subplots(K,3,figsize = (scale/K * 3 * aspect_ratio,scale))
    x = np.linspace(0,1, num = L)
    for i in range(K): 
        axes[2,i].plot(x,Ddot.T[i,:],'bo',label='truth',linewidth=4)
        axes[2,i].plot(x,D[:,true_inds[i]],'r-',label='reconstructed',markersize=3.5)
        axes[2,i].set_title("{:.2f} deg".format(angles[i]),fontsize = fontsize-2)
        axes[2,i].set_xlim(0,1)

        axes[1,i].imshow((Hflat[true_inds[i],:]).reshape(shape_2d),vmin = vmin, vmax = vmax , cmap=cmap)
        axes[1,i].set_title("R2: {:.2f}".format(r2[true_inds[i]]),fontsize = fontsize-2)
        # axes[i,1].set_ylim(0.0,1.0)
        axes[1,i].tick_params(axis = "both",labelleft = False, labelbottom = False,left = False, bottom = False)

        im = axes[0,i].imshow(Hdotflat[i].reshape(shape_2d),vmin = vmin, vmax = vmax, cmap=cmap)
        axes[0,i].set_title("Phase {}".format(i),fontsize = fontsize)
        axes[0,i].tick_params(axis = "both",labelleft = False, labelbottom = False,left = False, bottom = False)
        axes[2,0].legend()

    rows = ["True maps","Reconstructed maps","Spectra"]

    for ax, row in zip(axes[:,0], rows):
        ax.set_ylabel(row, rotation=90, fontsize=fontsize)


    fig.subplots_adjust(right=0.84)
    # put colorbar at desire position
    cbar_ax = fig.add_axes([0.85, 0.5, 0.01, 0.3])
    fig.colorbar(im,cax=cbar_ax)

    # fig.tight_layout()

    print("angles : ", angles)
    print("mse : ", mse)
    print("mae : ", mae)
    print("r2 : ", r2)

    return fig

    
        

## Create a synthetic problem

In [None]:
G = syntheticG(L,C, seed)
l = np.arange(0, 1, 1/L)
plt.plot(l, G[:,:3])
plt.title("Spectral response of each elements")


In [None]:

Hdot = load_toy_images()
vmin, vmax = 0,1
cmap = plt.cm.gray_r
plt.figure(figsize=(10, 3))
for i, hdot in enumerate(Hdot):
    plt.subplot(1,3,i+1)
    plt.imshow(hdot, cmap=cmap, vmin=vmin, vmax=vmax)
    plt.axis("off")
    plt.title(f"Map {i+1}")


In [None]:
G, Wdot, Ddot, Hdot, Hdotflat, Ydot, Y, shape_2d, K = create_toy_problem(L, C, n_poisson)

## Solve the problem

In [None]:

mu = 0
lambda_L = 0
force_simplex = False
Gused = None

params = {}
params["tol"]=1e-6
params["max_iter"] = 50
params["hspy_comp"] = False
params["verbose"] = 1
params["epsilon_reg"] = 1
params["linesearch"] = False
params["shape_2d"] = shape_2d
params["n_components"] = K

estimator = SmoothNMF(mu=mu, lambda_L=lambda_L, G = Gused, force_simplex=force_simplex, **params)
D = estimator.fit_transform(Y)
Hflat = estimator.H_
H = Hflat.reshape([Hflat.shape[0], *shape_2d])
W = estimator.W_

In [None]:
        
angles, true_inds = find_min_angle(Ddot.T, D.T, unique=True, get_ind=True)
mse = ordered_mse(Hdotflat, Hflat, true_inds)
print("angles : ", angles)
print("mse : ", mse)

In [None]:
fig = plot_results(Ddot, D, Hdotflat, Hflat)
plt.show()


In [None]:
plt.figure(figsize=(10, 4))
plt.subplot(1,2,1)
plt.hist(Ydot.flatten()*n_poisson, 100);
plt.subplot(1,2,1)
plt.hist(Y.flatten()*n_poisson, 100, alpha=0.5);