In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import util as f
from matplotlib.pyplot import cm

In [None]:
#NEEDED FOR SOME GRAPHS
import plotly.graph_objects as go
import pandas as pd

In [None]:
def mse1(u_e, v_e, u_t, v_t):
    a=torch.pow((u_e-u_t),2).mean(0)
    b=torch.pow((v_e-v_t),2).mean(0)
    return a+b
def mse2(u_e, v_e, u_t, v_t):
    ## u*v*-uv
    tmp = torch.tensordot(u_e, v_e, 0) - torch.tensordot(u_t, v_t, 0)
    return torch.mean(torch.pow((tmp), 2))
def mse3(u_e, v_e, u_t, v_t, Y, N, M, lambda_):
    ## Y-uv
    tmp = Y - (torch.sqrt(lambda_/N)*torch.tensordot(u_t, v_t, 0))
    return torch.mean(torch.pow((tmp), 2))

In [None]:
def main_gradient_2_sans_proj_normalisation(iteration, u_p, v_p, Y, u_, v_, N, M, lambda_, beta_u, beta_v, lambda_1, lambda_2, dt):

    res = []
    res_mse= []

    for i in range(iteration):

        # Computation
        sqrt_dt = torch.sqrt(dt)

        u_1 = (1/lambda_1) * f.gradient_u_2(N, M, u_p,v_p,Y,lambda_) * dt
        u_2 = torch.sqrt(2/(lambda_1*beta_u)) * torch.empty(N).normal_(mean=0,std=sqrt_dt)
        u_3 = ((N-1)/(N*lambda_1*beta_u))*u_p*dt
        u_n = u_p - u_1 + u_2 - u_3

        v_1 = (1/lambda_2) * f.gradient_v_2(N,M,u_p,v_p,Y,lambda_) * dt
        v_2 = torch.sqrt(2/(lambda_2*beta_v)) * torch.empty(M).normal_(mean=0,std=sqrt_dt)
        v_3 = ((M-1)/(M*lambda_2*beta_v))*v_p*dt
        v_n = v_p - v_1 + v_2 - v_3

        # Normalisation
        u_n = u_n / torch.linalg.norm(u_n)
        v_n = v_n / torch.linalg.norm(v_n)
        u_n = u_n * torch.sqrt(N)
        v_n = v_n * torch.sqrt(M)

        # Re-asign for the loop
        u_p = u_n
        v_p = v_n
    
        res_u = torch.abs(f.overlap(u_,u_n,N))
        res_v = torch.abs(f.overlap(v_,v_n,M))
        res.append((res_u,res_v))
        
        res_mse.append((mse1(torch.abs(u_),torch.abs(v_),torch.abs(u_n),torch.abs(v_n)),
                        mse2(u_,v_,u_n,v_n),
                       mse3(u_,v_,u_n,v_n,Y,N,M,lambda_)))

    return np.array(res), np.array(res_mse)

In [None]:
def one_run(lambda_=2, N=500, M=500, beta_u=float("inf"), beta_v=float("inf"), lambda_1=1,
        lambda_2=1, dt=1/100, iteration=5000):
      #Size of the matrix Y
    N = torch.tensor(N)
    M = torch.tensor(M)

    # lambda
    lambda_ = torch.tensor(lambda_)

    #temperatures
    beta_u = torch.tensor(beta_u)
    beta_v = torch.tensor(beta_v)

    #learning rates
    lambda_1 = torch.tensor(lambda_1)
    lambda_2 = torch.tensor(lambda_2)

    # Pas de temps
    dt = torch.tensor(dt)

    u_ = f.generate_vector(N)
    v_ = f.generate_vector(M)

    Y = f.generate_Y(N, M, u_,v_, lambda_)
    #print(Y)

    # Conditions initiales
    u_p = f.generate_vector(N)
    v_p = f.generate_vector(M)

    # Perform the gradient descent
    (res, res_mse) = main_gradient_2_sans_proj_normalisation(iteration, u_p, v_p, Y, u_, v_, N, M, lambda_, beta_u, beta_v, lambda_1, lambda_2, dt)
    
    return res, res_mse


In [None]:
# MAIN METHOD:

def main(size_of_one_sample, list_value_1, list_value_2):
    
    all_overlap_iterations=[]
    all_final_overlaps=[]
    
    all_mse_iterations=[]
    all_final_mse=[]
    
    num_iter=1
    tt_iter=len(list_value_1)*len(list_value_2)*size_of_one_sample
    
    for value_to_test_1 in list_value_1:
        
        for value_to_test_2 in list_value_2:
        
            one_run_overlap_iterations=[]
            one_run_mse_iterations=[]

            # Calculation
            for _ in range(size_of_one_sample):
                (current_uv_overlap,current_uv_mse)=one_run(beta_u=value_to_test_1, beta_v=value_to_test_2)
                # Store the overlap at the end of each iterations:
                one_run_overlap_iterations.append(current_uv_overlap)
                # Store the mse at the end of each iterations:
                one_run_mse_iterations.append(current_uv_mse)
                
                print(f"progress {100*num_iter/tt_iter:.2f}%", end="\r")
                num_iter+=1

            # Store the overlap/mse at the final iteration:
            one_run_final_overlap = np.array([x[-1] for x in one_run_overlap_iterations])
            one_run_final_mse = np.array([x[-1] for x in one_run_mse_iterations])

            # Add the results of this specific run to the global array:
            # overlap at all iterations:
            all_overlap_iterations.append(one_run_overlap_iterations)
            # overlap at the final iteration:
            all_final_overlaps.append(one_run_final_overlap)
            # mse at all iterations:
            all_mse_iterations.append(one_run_mse_iterations)
            # mse at the final iteration:
            all_final_mse.append(one_run_final_mse)
            
    return np.array(all_overlap_iterations), np.array(all_final_overlaps), np.array(all_mse_iterations), np.array(all_final_mse)

In [None]:
# MAIN CELL: DO THE COMPUTATION

list_value_1 = np.logspace(-1,0.5, 30)
#list_value_1 = np.arange(0.1, 3, 0.1)
#list_value_1 = np.array([0.5,1,2])

list_value_2 = np.logspace(-1,0.5, 30)
#list_value_2 = np.arange(0.1, 3, 0.1)
#list_value_2 = np.array([0.5,1,2])

size_of_one_sample=10

(overlap_at_each_iterations, overlap_at_final_iteration,
 mse_at_each_iterations, mse_at_final_iteration) = main(size_of_one_sample, list_value_1, list_value_2)

In [None]:
#########################################
# The first sequence of graphs are classics, just pyplot graphs.
# But we used also plotly and pandas so we had dynamic graphs and
# we can save them in .html so whenever you want a new screenshot,
# you can re-open the .html and take a new picture of the graph.
#
# With plotly you can re-size a graph, zoom-in/out, and
# whenever there is a legend you can click on it to display or not a specific plot.
#
# If you want to save the ploty graph,
# enter the desired file path in the following variable:
# example: FILE_PATH="C:/Users/Admin/Desktop/"
FILE_PATH=""

# And you need to uncomment the last line of the cell that are containing plotly graph:
# This line: fig.write_html("... .html")

In [None]:
# ENTER THE NAME OF THE TESTED PARAMETER:
name_value_1="beta_u"
name_value_2="beta_v"

In [None]:
# GRAPH STABILISATION DES OVERLAPS

fig, axs = plt.subplots(len(list_value_1)*len(list_value_2), figsize=(20,200), facecolor='white')
fig.tight_layout()

for i in range(len(list_value_1)*len(list_value_2)):
    axs[i].grid()
    axs[i].set_ylabel("overlap")
    axs[i].set_title(f"{name_value_1}={list_value_1[i//len(list_value_2)]}"
                    + f" & {name_value_2}={list_value_2[i%len(list_value_2)]}")
    for run_in_on_sample in range(size_of_one_sample):
        axs[i].plot(overlap_at_each_iterations[i][run_in_on_sample])

In [None]:
# GRAPH STABILISATION DES MEAN SQUARRE ERRORS
fig, axs = plt.subplots(len(list_value_1)*len(list_value_2), figsize=(20,200), facecolor='white')
fig.tight_layout()

for i in range(len(list_value_1)*len(list_value_2)):
    axs[i].grid()
    axs[i].set_ylabel("mse")
    axs[i].set_title(f"{name_value_1}={list_value_1[i//len(list_value_2)]}"
                    + f" & {name_value_2}={list_value_2[i%len(list_value_2)]}")
    for run_in_on_sample in range(size_of_one_sample):
        axs[i].plot(mse_at_each_iterations[i][run_in_on_sample])

In [None]:
# PRINT THE OVERLAP IN FUNCTION OF THE STEP

print(overlap_at_each_iterations.shape)
overlap_at_each_iterations_per_sample=overlap_at_each_iterations.mean(axis=1).mean(axis=2)

plt.figure(facecolor='white')
plt.grid()
plt.ylabel("overlap")
plt.xlabel("iterations")
color=iter(cm.rainbow(np.linspace(0,1,len(list_value_1)*len(list_value_2))))

for i in range(len(list_value_1)*len(list_value_2)):
    c=next(color)
    plt.plot(overlap_at_each_iterations_per_sample[i], 
             label=f"{name_value_1}={list_value_1[i//len(list_value_2)]}"
            + f" & {name_value_2}={list_value_2[i%len(list_value_2)]}", color=c)
plt.legend(loc='upper right', bbox_to_anchor=(2, 1))

In [None]:
# PRINT THE MSE IN FUNCTION OF THE STEP
# same as the cell above but with the men squarre error:

mse1,mse2,mse3=mse_at_each_iterations.mean(1).T

print(mse1.shape)

plt.figure(facecolor='white')
plt.grid()
plt.ylabel("mse")
plt.xlabel("iterations")
color=iter(cm.rainbow(np.linspace(0,1,len(list_value_1)*len(list_value_2))))

for i in range(len(list_value_1)*len(list_value_2)):
    c=next(color)
    plt.plot(mse2.T[i], 
             label=f"{name_value_1}={list_value_1[i//len(list_value_2)]}"
            + f" & {name_value_2}={list_value_2[i%len(list_value_2)]}", color=c)
plt.legend(loc='upper right', bbox_to_anchor=(2, 1))

In [None]:
# CONTOUR PLOT OVERLAP

overlap_mean=overlap_at_final_iteration.mean(axis=1).mean(axis=1)

x, y = np.meshgrid(list_value_2,list_value_1)
overlap_mean=overlap_mean.reshape(len(list_value_1),len(list_value_2))

fig,ax=plt.subplots(1,1, facecolor='white')
c = ax.contourf(x,y,overlap_mean)
fig.colorbar(c)
plt.show()

In [None]:
## HEAT MAP VARIANCE
variance_map = np.var(overlap_at_final_iteration.mean(axis=2), axis=1)
variance_map = variance_map.reshape(len(list_value_1),len(list_value_2))

mse1_mean, mse2_mean, mse3_mean = mse_at_final_iteration.T
var_mse = np.var(mse2_mean, axis=0)
var_mse=var_mse.reshape(len(list_value_1),len(list_value_2))

x, y = np.meshgrid(list_value_2,list_value_1)

fig,ax=plt.subplots(1,2, figsize=(12,5), facecolor='white')

bar0=ax[0].contourf(x,y, variance_map)
ax[0].set_title("VARIANCE OVERLAP")
ax[0].set_xlabel(f"{name_value_2}")
ax[0].set_ylabel(f"{name_value_1}")

bar1 = ax[1].contourf(x,y, var_mse)
ax[1].set_title("VARIANCE MSE 2")
ax[1].set_xlabel(f"{name_value_2}")
ax[1].set_ylabel(f"{name_value_1}")

plt.colorbar(bar0, ax=ax[0])
plt.colorbar(bar1, ax=ax[1])

plt.show()

In [None]:
#########################################
# The following cells are Plotly/Pandas plots

In [None]:
overlap_mean=overlap_at_final_iteration.mean(axis=1).mean(axis=1)
overlap_mean=overlap_mean.reshape(len(list_value_1),len(list_value_2))

mse1_mean, mse2_mean, mse3_mean = mse_at_final_iteration.mean(1).T
mse1_mean=mse1_mean.reshape(len(list_value_1),len(list_value_2))
mse2_mean=mse2_mean.reshape(len(list_value_1),len(list_value_2))
mse3_mean=mse3_mean.reshape(len(list_value_1),len(list_value_2))

In [None]:
fig=go.Figure(go.Contour(z=overlap_mean, x=list_value_2, y=list_value_1,
                         name="overlap", colorbar=dict(title='Overlap')))

fig.update_layout(
    xaxis_title="beta v",
    yaxis_title="beta u",
    legend_title="Legend Title",
    width=500, height=500
)

fig.update_layout(
    xaxis = dict(
        tickmode = 'linear',
        tick0 = 0,
        dtick = 0.5
    ),
    yaxis = dict(
        tickmode = 'linear',
        tick0 = 0,
        dtick = 0.5
    )
)

fig.show()
#fig.write_html(f"{FILE_PATH}contour_overlap.html")

In [None]:
# REMINDER:
# This graph displayed three mse,
# You can double-click on a mse in the legend,
# to display only the specific mse,
# Performing one click will just hide the mse

fig=go.Figure()

fig.add_trace(go.Contour(z=mse1_mean, x=list_value_2, y=list_value_1,
                         name="mse1", colorscale ='viridis'))
fig.add_trace(go.Contour(z=mse2_mean, x=list_value_2, y=list_value_1,
                         name="mse2", colorscale ='viridis'))
fig.add_trace(go.Contour(z=mse3_mean, x=list_value_2, y=list_value_1,
                         name="mse3", colorscale ='viridis'))

fig.update_layout(
    xaxis_title="beta v",
    yaxis_title="beta u",
    legend_title="MSE",
    width=500, height=500
)

fig.update_layout(
    xaxis = dict(
        tickmode = 'linear',
        tick0 = 0,
        dtick = 0.5
    ),
    yaxis = dict(
        tickmode = 'linear',
        tick0 = 0,
        dtick = 0.5
    )
)

fig.update_layout(legend=dict(
    orientation="h",
    yanchor="bottom",
    y=1,
    xanchor="right",
    x=1
))


fig.update_traces(showlegend=True)

fig.show()
#fig.write_html(f"{FILE_PATH}contour_mse.html")

In [None]:
######## GRAPH WITH PLOTY !

# STABILISATION OVERLAP AVEC PLOTLY
# CHANGER SELECTED_RUN !
selected_run=0

var_u,var_v=np.var(overlap_at_each_iterations[selected_run], axis=0).T
overlap_u,overlap_v=overlap_at_each_iterations.mean(1)[selected_run].T

#var_m1,var_m2, var_m3=np.var(mse_at_each_iterations[selected_run], axis=0).T
#mse1,mse2,mse3=mse_at_each_iterations.mean(1)[selected_run].T

x=[i for i in range(len(overlap_u))]

fig=go.Figure(data=[go.Scatter(
        x=x,
        y=overlap_u,
        error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=var_u,
            visible=True),
        name="u overlap"
    ),
                    go.Scatter(
        x=x,
        y=overlap_v,
        error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=var_v,
            visible=True),
        name="v overlap"
    )])

fig.update_layout(
    title=f"Stabilisation overlap {name_value_1}={list_value_1[selected_run//len(list_value_2)]}"
            + f" & {name_value_2}={list_value_2[selected_run%len(list_value_2)]}",
    xaxis_title="Iterations",
    yaxis_title="Overlap",
)

fig.show()
#fig.write_html(f"{FILE_PATH}stabilisation_u_v.html")

In [None]:
######## GRAPH WITH PLOTY !

# STABILISATION MSE AVEC PLOTLY

selected_run=0

var_overlap = np.var(overlap_at_each_iterations[selected_run].mean(axis=2), axis=0)

var_m1,var_m2, var_m3=np.var(mse_at_each_iterations[selected_run], axis=0).T
mse1,mse2,mse3=mse_at_each_iterations.mean(1)[selected_run].T

x=[i for i in range(len(mse1))]

fig=go.Figure(data=[go.Scatter(
        x=x,
        y=mse1,
        name="mse1",
        error_y=dict(
            type='data',
            array=var_m1,
            visible=True)
    ),
      go.Scatter(
        x=x,
        y=mse2,
        name="mse2",    
        error_y=dict(
            type='data',
            array=var_m2,
            visible=True)
      ),
      go.Scatter(
        x=x,
        y=mse3,
        name="mse3",
        error_y=dict(
            type='data',
            array=var_m3,
            visible=True)
      ),
      go.Scatter(
        x=x,
        y=overlap_at_each_iterations.mean(1).mean(2)[selected_run],
        name="overlap",
        error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=var_overlap,
            visible=True)
    )])

fig.update_layout(
    title=f"{name_value_1}={list_value_1[selected_run//len(list_value_2)]}"
            + f" & {name_value_2}={list_value_2[selected_run%len(list_value_2)]}",
    xaxis_title="Iterations",
)

fig.show()
#fig.write_html(f"{FILE_PATH}stabilisation_mse_overlap.html")