In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from espm.datasets.base import generate_spim
from espm.estimators import SmoothNMF
from espm.measures import find_min_angle, find_min_MSE, ordered_mse, ordered_mae, ordered_r2, KLdiv


In [None]:
C = 15
L = 200
P = 100**2
nx = 50
seed = 1

n_poisson = 2 # Average poisson number per pixel (this number will be splitted on the L dimension)


In [None]:
def syntheticG(L=200, C=15, seed=None):

    np.random.seed(seed=seed)
    n_el = 45
    n_gauss = np.random.randint(2, 5,[C])
    l = np.arange(0, 1, 1/L)
    mu_gauss = np.random.rand(n_el)
    sigma_gauss = 1/n_el + np.abs(np.random.randn(n_el))/n_el/5

    G = np.zeros([L,C])

    def gauss(x, mu, sigma):
        # return np.exp(-(x-mu)**2/(2*sigma**2)) / (sigma * np.sqrt(2*np.pi))
        return np.exp(-(x-mu)**2/(2*sigma**2))

    for i, c in enumerate(n_gauss):
        inds = np.random.choice(n_el, size=[c] , replace=False)
        for ind in inds:
            w = 0.1+0.9*np.random.rand()
            G[:,i] += w * gauss(l, mu_gauss[ind], sigma_gauss[ind])
    return G

def build_simple_images(nx = 40):
    

    im1 = np.zeros(shape=[nx,nx])
    # 1st quarter
    im1[:nx//2,:nx//2] = 1
    # 2st quarter
    im1[nx//6:nx//3,-nx//3:-nx//6] = 0.5
    # 3rd quarter
    im1[-nx//3:,:nx//2] = np.expand_dims(np.arange(0,1, 3/nx), axis=1) * np.ones(shape=[1, nx//2])
    # im1[-nx//3:,:nx//2] = np.expand_dims(np.arange(3/nx,1+3/nx, 3/nx), axis=1) * np.ones(shape=[1, nx//2])
    # 4th quarter
    x,y = np.meshgrid(np.arange(nx), np.arange(nx))
    mask = (x-nx+nx//4)**2 + (y-nx+nx//4)**2<(nx//6)**2
    im1[mask] = 1

    im0 = 1 - im1

    Hdot = np.array([im0, im1])

    return Hdot


def create_simple_problem(L, n_poisson,  C, nx=40, seed=None):
    np.random.seed(seed)
    G = syntheticG(L,C, seed=seed)

    Hdot = build_simple_images(nx)
    K = len(Hdot)
    Hdotflat = Hdot.reshape(K, -1)
    Wdot = np.abs(np.random.laplace(size=[C, K]))
    mask = Wdot>np.mean(Wdot)
    Wdot[np.logical_not(mask)] = 0
    m = np.mean(Wdot[mask])
    Wdot = Wdot / m *K/np.sum(mask)
    Ddot = G @ Wdot
    Ydot = Ddot @ Hdotflat

    Y = 1/n_poisson * np.random.poisson(n_poisson * Ydot)
    shape_2d = Hdot.shape[1:]
    return G, Wdot, Ddot, Hdot, Hdotflat, Ydot, Y, shape_2d, K

def plot_results(Ddot, D, Hdotflat, Hflat):
    fontsize = 30
    scale = 15
    aspect_ratio = 1.4
    marker_list = ["-o","-s","->","-<","-^","-v","-d"]
    mark_space = 20
    # cmap = plt.cm.hot_r    
    cmap = plt.cm.gray_r
    vmax = 1
    vmin = 0
    K = len(H)
    L = len(D)
    
    angles, true_inds = find_min_angle(Ddot.T, D.T, unique=True, get_ind=True)
    mse = ordered_mse(Hdotflat, Hflat, true_inds)
    mae = ordered_mae(Hdotflat, Hflat, true_inds)
    r2 = ordered_r2(Hdotflat, Hflat, true_inds)


    fig, axes = plt.subplots(K, 3,figsize = (scale/K * 3 * aspect_ratio, scale))
    x = np.linspace(0,1, num = L)
    for i in range(K): 
        axes[i, 0].plot(x,Ddot.T[i,:],'bo',label='truth',linewidth=4)
        axes[i, 0].plot(x,D[:,true_inds[i]],'r-',label='reconstructed',markersize=3.5)
        axes[i, 0].set_title("{:.2f} deg".format(angles[i]),fontsize = fontsize-2)
        axes[i, 0].set_xlim(0,1)
        axes[i, 0].legend()

        axes[i, 1].imshow((Hflat[true_inds[i],:]).reshape(shape_2d),vmin = vmin, vmax = vmax , cmap=cmap)
        axes[i, 1].set_title("Reconstruction - R2: {:.2f}".format(r2[true_inds[i]]),fontsize = fontsize-2)
        # axes[i,1].set_ylim(0.0,1.0)
        axes[i, 1].tick_params(axis = "both",labelleft = False, labelbottom = False,left = False, bottom = False)

        im = axes[i, 2].imshow(Hdotflat[i].reshape(shape_2d),vmin = vmin, vmax = vmax, cmap=cmap)
        axes[i,2].set_title("GT - Phase {}".format(i),fontsize = fontsize)
        axes[i,2].tick_params(axis = "both",labelleft = False, labelbottom = False,left = False, bottom = False)
        

    rows = [f"Phase {i}" for i in range(K)]

    for ax, row in zip(axes[:,0], rows):
        ax.set_ylabel(row, rotation=90, fontsize=fontsize)


    fig.subplots_adjust(right=0.84)
    # put colorbar at desire position
    cbar_ax = fig.add_axes([0.85, 0.5, 0.01, 0.3])
    fig.colorbar(im,cax=cbar_ax)

    # fig.tight_layout()

    plt.show()
    print("angles : ", angles)
    print("mse : ", mse)
    print("mae : ", mae)
    print("r2 : ", r2)

    
        

# Create a synthetic problem

In [None]:
G = syntheticG(L,C, seed)
l = np.arange(0, 1, 1/L)
plt.plot(l, G[:,:2])
plt.title("Spectral response of each elements")


In [None]:

Hdot = build_simple_images(nx=nx)
vmin, vmax = 0,1
cmap = plt.cm.gray_r
plt.figure(figsize=(10, 3))
for i, hdot in enumerate(Hdot):
    plt.subplot(1,3,i+1)
    plt.imshow(hdot, cmap=cmap, vmin=vmin, vmax=vmax)
    plt.axis("off")
    plt.title(f"Map {i+1}")


In [None]:
G, Wdot, Ddot, Hdot, Hdotflat, Ydot, Y, shape_2d, K = create_simple_problem(L=L, n_poisson=n_poisson,  C=C, nx=nx)
plt.plot(Ddot)

Wdot.shape, G.shape, L, C, nx

# Solve the problem

In [None]:

mu = 0
lambda_L = 0
force_simplex = True
Gused = None

params = {}
params["tol"]=1e-6
params["max_iter"] = 500
params["hspy_comp"] = False
params["verbose"] = 0
params["epsilon_reg"] = 1
params["linesearch"] = False
params["shape_2d"] = shape_2d
params["n_components"] = K

estimator = SmoothNMF(mu=mu, lambda_L=lambda_L, G = Gused, force_simplex=force_simplex, **params)
D = estimator.fit_transform(Y)
Hflat = estimator.H_
H = Hflat.reshape([Hflat.shape[0], *shape_2d])
W = estimator.W_
plt.plot(estimator.losses_)

In [None]:
        
angles, true_inds = find_min_angle(Ddot.T, D.T, unique=True, get_ind=True)
mse = ordered_mse(Hdotflat, Hflat, true_inds)
KL = KLdiv(Y, D, Hflat)
r2 = ordered_r2(Hdotflat, Hflat, true_inds)

print("angles: ", angles)
print("mse: ", mse)
print("KL divergence: ", KL)
print("R2: ", r2)

In [None]:
plot_results(Ddot, D, Hdotflat, Hflat)

In [None]:
# plt.figure(figsize=(10, 4))
# plt.subplot(1,2,1)
# plt.hist(Ydot.flatten()*n_poisson, 100);
# plt.subplot(1,2,1)
# plt.hist(Y.flatten()*n_poisson, 100, alpha=0.5);

In [None]:
def one_experiment(mu, lambda_L, force_simplex, Gused, params, Y):
    estimator = SmoothNMF(mu=mu, lambda_L=lambda_L, G = Gused, force_simplex=force_simplex, **params)
    D = estimator.fit_transform(Y)
    Hflat = estimator.H_
    H = Hflat.reshape([Hflat.shape[0], *shape_2d])
    W = estimator.W_
    angles, true_inds = find_min_angle(Ddot.T, D.T, unique=True, get_ind=True)
    mse = ordered_mse(Hdotflat, Hflat, true_inds)
    KL = KLdiv(Y, D, Hflat, average=True)
    r2 = ordered_r2(Hdotflat, Hflat, true_inds)

    print("angles: ", angles)
    print("mse: ", mse)
    print("KL divergence: ", KL)
    print("R2: ", r2)
    return angles, mse, KL, r2

In [None]:

lambda_L = 1
force_simplex = True
Gused = None

mus = np.array([0, 0.5, 1, 2, 3, 5, 10])

metrics_mus = []
for i, mu in enumerate(mus):
    print(f"Experiment -- {i+1}/{len(mus)} --")
    print(f"  lambda={lambda_L}, mu={mu}, Gused: {Gused is not None}, simplex: {force_simplex}")
    metrics_mus.append(one_experiment(mu, lambda_L, force_simplex, Gused, params, Y))


In [None]:

r2_mus = np.mean(np.array([e[3] for e in metrics_mus]), axis=1)
angles_mus = np.mean(np.array([e[0] for e in metrics_mus]), axis=1)

KL_mus = np.array([e[2] for e in metrics_mus])
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.plot(mus+1, KL_mus, "-x")
plt.xscale("log")
plt.subplot(122)
plt.scatter(r2_mus, angles_mus)
plt.plot(r2_mus, angles_mus, linestyle = 'dashed')
for lam, r2, angle in zip(mus, r2_mus, angles_mus):
    plt.annotate(str(lam), xy=(r2+0.0003, angle+0.1), size=10, va='center',ha='left')
plt.xlim([0.885, 0.905])
plt.ylim([4, 6.5])
plt.title("Avg. R2 and angle error for different mu")
plt.xlabel(r'$R^2(H, \tilde{H})$ [-]')
plt.ylabel(r'$\theta(W, \tilde{W})$ [deg]')
plt.grid()
plt.show()

In [None]:
force_simplex = True
Gused = None
mu = 0


# lambdas = np.array([0, 10, 20, 50, 100, 200, 500, 1000])
lambdas = np.array([0, 0.1, 0.2, 0.5, 1, 2, 5,])
metrics_lambda = []
for i, lambda_L in enumerate(lambdas):
    print(f"Experiment -- {i+1}/{len(lambdas)} --")
    print(f"  lambda={lambda_L}, mu={mu}, Gused={Gused is not None}, simplex={force_simplex}")
    metrics_lambda.append(one_experiment(mu, lambda_L, force_simplex, Gused, params, Y))


In [None]:

r2_lambdas = np.mean(np.array([e[3] for e in metrics_lambda]), axis=1)
angles_lambdas = np.mean(np.array([e[0] for e in metrics_lambda]), axis=1)

KL_lambdas = np.array([e[2] for e in metrics_lambda])
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.plot(lambdas+1, KL_lambdas, "-x")
plt.xscale("log")
plt.subplot(122)
# plt.scatter(r2_lambdas, angles_lambdas)


# plt.annotate(str(1), xy=(angles-0.01, angles_lambdas-0.5), size=20, va='center',ha='left')
plt.scatter(r2_lambdas, angles_lambdas)
plt.plot(r2_lambdas, angles_lambdas, linestyle = 'dashed')
for lam, r2, angle in zip(lambdas, r2_lambdas, angles_lambdas):
    # plt.annotate(str(lam), xy=(r2+0.0003, angle+0.1), size=10, va='center',ha='left')
    plt.annotate(str(lam), xy=(r2+0.000, angle+0.), size=10, va='center',ha='left')
# plt.xlim([0.885, 0.905])
# plt.ylim([4, 6.5])
plt.title("Avg. R2 and angle error for different lambda")
plt.xlabel(r'$R^2(H, \tilde{H})$ [-]')
plt.ylabel(r'$\theta(W, \tilde{W})$ [deg]')
plt.grid()
plt.show()




In [None]:
KL_lambdas